Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py +8 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py +160 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +135 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +253 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +231 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +369 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +210 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +291 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py +9 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py +483 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py +348 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py +89 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py +20 -0
- .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py +161 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py +459 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py +133 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py +500 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py +447 -0
- .venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py +631 -0
- .venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py +94 -0
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
| 4 |
+
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
|
| 8 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (474 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc
ADDED
|
Binary file (8.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc
ADDED
|
Binary file (6.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from functools import cached_property
|
| 5 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
| 6 |
+
|
| 7 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 8 |
+
DeltaMessage)
|
| 9 |
+
from vllm.logger import init_logger
|
| 10 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 11 |
+
from vllm.utils import import_from_path, is_list_of
|
| 12 |
+
|
| 13 |
+
logger = init_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ReasoningParser:
|
| 17 |
+
"""
|
| 18 |
+
Abstract reasoning parser class that should not be used directly.
|
| 19 |
+
Provided and methods should be used in derived classes.
|
| 20 |
+
|
| 21 |
+
It is used to extract reasoning content from the model output.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 25 |
+
self.model_tokenizer = tokenizer
|
| 26 |
+
|
| 27 |
+
@cached_property
|
| 28 |
+
def vocab(self) -> Dict[str, int]:
|
| 29 |
+
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
| 30 |
+
# whereas all tokenizers have .get_vocab()
|
| 31 |
+
return self.model_tokenizer.get_vocab()
|
| 32 |
+
|
| 33 |
+
def extract_reasoning_content(
|
| 34 |
+
self, model_output: str, request: ChatCompletionRequest
|
| 35 |
+
) -> Tuple[Optional[str], Optional[str]]:
|
| 36 |
+
"""
|
| 37 |
+
Extract reasoning content from a complete model-generated string.
|
| 38 |
+
|
| 39 |
+
Used for non-streaming responses where we have the entire model response
|
| 40 |
+
available before sending to the client.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
model_output: str
|
| 44 |
+
The model-generated string to extract reasoning content from.
|
| 45 |
+
|
| 46 |
+
request: ChatCompletionRequest
|
| 47 |
+
The request object that was used to generate the model_output.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Tuple[Optional[str], Optional[str]]
|
| 51 |
+
A tuple containing the reasoning content and the content.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
raise NotImplementedError(
|
| 55 |
+
"AbstractReasoningParser.extract_reasoning_calls "
|
| 56 |
+
"has not been implemented!")
|
| 57 |
+
|
| 58 |
+
def extract_reasoning_content_streaming(
|
| 59 |
+
self,
|
| 60 |
+
previous_text: str,
|
| 61 |
+
current_text: str,
|
| 62 |
+
delta_text: str,
|
| 63 |
+
previous_token_ids: Sequence[int],
|
| 64 |
+
current_token_ids: Sequence[int],
|
| 65 |
+
delta_token_ids: Sequence[int],
|
| 66 |
+
) -> Union[DeltaMessage, None]:
|
| 67 |
+
"""
|
| 68 |
+
Instance method that should be implemented for extracting reasoning
|
| 69 |
+
from an incomplete response; for use when handling reasoning calls and
|
| 70 |
+
streaming. Has to be an instance method because it requires state -
|
| 71 |
+
the current tokens/diffs, but also the information about what has
|
| 72 |
+
previously been parsed and extracted (see constructor)
|
| 73 |
+
"""
|
| 74 |
+
raise NotImplementedError(
|
| 75 |
+
"AbstractReasoningParser.extract_reasoning_content_streaming "
|
| 76 |
+
"has not been implemented!")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ReasoningParserManager:
|
| 80 |
+
reasoning_parsers: Dict[str, Type] = {}
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def get_reasoning_parser(cls, name) -> Type:
|
| 84 |
+
"""
|
| 85 |
+
Get reasoning parser by name which is registered by `register_module`.
|
| 86 |
+
|
| 87 |
+
Raise a KeyError exception if the name is not registered.
|
| 88 |
+
"""
|
| 89 |
+
if name in cls.reasoning_parsers:
|
| 90 |
+
return cls.reasoning_parsers[name]
|
| 91 |
+
|
| 92 |
+
raise KeyError(f"reasoning helper: '{name}' not found in "
|
| 93 |
+
"reasoning_parsers")
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def _register_module(cls,
|
| 97 |
+
module: Type,
|
| 98 |
+
module_name: Optional[Union[str, List[str]]] = None,
|
| 99 |
+
force: bool = True) -> None:
|
| 100 |
+
if not issubclass(module, ReasoningParser):
|
| 101 |
+
raise TypeError("module must be subclass of ReasoningParser, "
|
| 102 |
+
f"but got {type(module)}")
|
| 103 |
+
if module_name is None:
|
| 104 |
+
module_name = module.__name__
|
| 105 |
+
if isinstance(module_name, str):
|
| 106 |
+
module_name = [module_name]
|
| 107 |
+
for name in module_name:
|
| 108 |
+
if not force and name in cls.reasoning_parsers:
|
| 109 |
+
existed_module = cls.reasoning_parsers[name]
|
| 110 |
+
raise KeyError(f"{name} is already registered "
|
| 111 |
+
f"at {existed_module.__module__}")
|
| 112 |
+
cls.reasoning_parsers[name] = module
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def register_module(
|
| 116 |
+
cls,
|
| 117 |
+
name: Optional[Union[str, List[str]]] = None,
|
| 118 |
+
force: bool = True,
|
| 119 |
+
module: Union[Type, None] = None) -> Union[type, Callable]:
|
| 120 |
+
"""
|
| 121 |
+
Register module with the given name or name list. it can be used as a
|
| 122 |
+
decoder(with module as None) or normal function(with module as not
|
| 123 |
+
None).
|
| 124 |
+
"""
|
| 125 |
+
if not isinstance(force, bool):
|
| 126 |
+
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
| 127 |
+
|
| 128 |
+
# raise the error ahead of time
|
| 129 |
+
if not (name is None or isinstance(name, str)
|
| 130 |
+
or is_list_of(name, str)):
|
| 131 |
+
raise TypeError(
|
| 132 |
+
"name must be None, an instance of str, or a sequence of str, "
|
| 133 |
+
f"but got {type(name)}")
|
| 134 |
+
|
| 135 |
+
# use it as a normal method: x.register_module(module=SomeClass)
|
| 136 |
+
if module is not None:
|
| 137 |
+
cls._register_module(module=module, module_name=name, force=force)
|
| 138 |
+
return module
|
| 139 |
+
|
| 140 |
+
# use it as a decorator: @x.register_module()
|
| 141 |
+
def _register(module):
|
| 142 |
+
cls._register_module(module=module, module_name=name, force=force)
|
| 143 |
+
return module
|
| 144 |
+
|
| 145 |
+
return _register
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def import_reasoning_parser(cls, plugin_path: str) -> None:
|
| 149 |
+
"""
|
| 150 |
+
Import a user-defined reasoning parser by the path
|
| 151 |
+
of the reasoning parser define file.
|
| 152 |
+
"""
|
| 153 |
+
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
import_from_path(module_name, plugin_path)
|
| 157 |
+
except Exception:
|
| 158 |
+
logger.exception("Failed to load module '%s' from %s.",
|
| 159 |
+
module_name, plugin_path)
|
| 160 |
+
return
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from transformers import PreTrainedTokenizerBase
|
| 7 |
+
|
| 8 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 9 |
+
DeltaMessage)
|
| 10 |
+
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
|
| 11 |
+
ReasoningParser, ReasoningParserManager)
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
|
| 14 |
+
logger = init_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@ReasoningParserManager.register_module("deepseek_r1")
|
| 18 |
+
class DeepSeekR1ReasoningParser(ReasoningParser):
|
| 19 |
+
"""
|
| 20 |
+
Reasoning parser for DeepSeek R1 model.
|
| 21 |
+
|
| 22 |
+
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
|
| 23 |
+
text. This parser extracts the reasoning content from the model output.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 27 |
+
super().__init__(tokenizer)
|
| 28 |
+
self.think_start_token = "<think>"
|
| 29 |
+
self.think_end_token = "</think>"
|
| 30 |
+
|
| 31 |
+
self.reasoning_regex = re.compile(
|
| 32 |
+
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
|
| 33 |
+
|
| 34 |
+
if not self.model_tokenizer:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
"The model tokenizer must be passed to the ReasoningParser "
|
| 37 |
+
"constructor during construction.")
|
| 38 |
+
|
| 39 |
+
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
| 40 |
+
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
| 41 |
+
if (self.think_start_token_id is None
|
| 42 |
+
or self.think_end_token_id is None):
|
| 43 |
+
raise RuntimeError(
|
| 44 |
+
"DeepSeek R1 reasoning parser could not locate think start/end "
|
| 45 |
+
"tokens in the tokenizer!")
|
| 46 |
+
|
| 47 |
+
def extract_reasoning_content_streaming(
|
| 48 |
+
self,
|
| 49 |
+
previous_text: str,
|
| 50 |
+
current_text: str,
|
| 51 |
+
delta_text: str,
|
| 52 |
+
previous_token_ids: Sequence[int],
|
| 53 |
+
current_token_ids: Sequence[int],
|
| 54 |
+
delta_token_ids: Sequence[int],
|
| 55 |
+
) -> Union[DeltaMessage, None]:
|
| 56 |
+
"""
|
| 57 |
+
Extract reasoning content from a delta message.
|
| 58 |
+
Handles streaming output where previous + delta = current.
|
| 59 |
+
Uses token IDs for faster processing.
|
| 60 |
+
For text <think>abc</think>xyz:
|
| 61 |
+
- 'abc' goes to reasoning_content
|
| 62 |
+
- 'xyz' goes to content
|
| 63 |
+
"""
|
| 64 |
+
# Skip single special tokens
|
| 65 |
+
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
| 66 |
+
self.think_start_token_id, self.think_end_token_id
|
| 67 |
+
]):
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
if self.think_start_token_id in previous_token_ids:
|
| 71 |
+
if self.think_end_token_id in delta_token_ids:
|
| 72 |
+
# <think> in previous, </think> in delta,
|
| 73 |
+
# extract reasoning content
|
| 74 |
+
end_index = delta_text.find(self.think_end_token)
|
| 75 |
+
reasoning_content = delta_text[:end_index]
|
| 76 |
+
content = delta_text[end_index + len(self.think_end_token):]
|
| 77 |
+
return DeltaMessage(reasoning_content=reasoning_content,
|
| 78 |
+
content=content if content else None)
|
| 79 |
+
elif self.think_end_token_id in previous_token_ids:
|
| 80 |
+
# <think> in previous, </think> in previous,
|
| 81 |
+
# reasoning content continues
|
| 82 |
+
return DeltaMessage(content=delta_text)
|
| 83 |
+
else:
|
| 84 |
+
# <think> in previous, no </think> in previous or delta,
|
| 85 |
+
# reasoning content continues
|
| 86 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 87 |
+
elif self.think_start_token_id in delta_token_ids:
|
| 88 |
+
logger.info(delta_text)
|
| 89 |
+
if self.think_end_token_id in delta_token_ids:
|
| 90 |
+
# <think> in delta, </think> in delta, extract reasoning content
|
| 91 |
+
start_index = delta_text.find(self.think_start_token)
|
| 92 |
+
end_index = delta_text.find(self.think_end_token)
|
| 93 |
+
reasoning_content = delta_text[start_index +
|
| 94 |
+
len(self.think_start_token
|
| 95 |
+
):end_index]
|
| 96 |
+
content = delta_text[end_index + len(self.think_end_token):]
|
| 97 |
+
return DeltaMessage(reasoning_content=reasoning_content,
|
| 98 |
+
content=content if content else None)
|
| 99 |
+
else:
|
| 100 |
+
# <think> in delta, no </think> in delta,
|
| 101 |
+
# reasoning content continues
|
| 102 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 103 |
+
else:
|
| 104 |
+
# No <think> in previous or delta, reasoning content continues.
|
| 105 |
+
return DeltaMessage(content=delta_text)
|
| 106 |
+
|
| 107 |
+
def extract_reasoning_content(
|
| 108 |
+
self, model_output: str, request: ChatCompletionRequest
|
| 109 |
+
) -> Tuple[Optional[str], Optional[str]]:
|
| 110 |
+
|
| 111 |
+
# Check if the model output contains the <think> tokens.
|
| 112 |
+
if (self.think_start_token not in model_output
|
| 113 |
+
or self.think_end_token not in model_output):
|
| 114 |
+
return None, model_output
|
| 115 |
+
else:
|
| 116 |
+
# Use a regex to find the reasoning content
|
| 117 |
+
reasoning_content = self.reasoning_regex.findall(model_output)[0]
|
| 118 |
+
|
| 119 |
+
# Remove the reasoning content from the model output
|
| 120 |
+
# Although deepseek's <think> token is always at the
|
| 121 |
+
# beginning of the line, we cannot guarantee that the
|
| 122 |
+
# other models will follow this convention.
|
| 123 |
+
# Therefore, we need to add :start_index.
|
| 124 |
+
start_index = model_output.find(self.think_start_token)
|
| 125 |
+
if start_index != -1:
|
| 126 |
+
end_index = start_index + len(
|
| 127 |
+
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
|
| 128 |
+
)
|
| 129 |
+
model_output = model_output[:start_index] + \
|
| 130 |
+
model_output[end_index:]
|
| 131 |
+
|
| 132 |
+
if len(model_output) == 0:
|
| 133 |
+
return reasoning_content, None
|
| 134 |
+
|
| 135 |
+
return reasoning_content, model_output
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from json import JSONDecoder
|
| 6 |
+
from typing import Dict, Sequence, Union
|
| 7 |
+
|
| 8 |
+
import partial_json_parser
|
| 9 |
+
from partial_json_parser.core.options import Allow
|
| 10 |
+
|
| 11 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 12 |
+
DeltaFunctionCall, DeltaMessage,
|
| 13 |
+
DeltaToolCall,
|
| 14 |
+
ExtractedToolCallInformation,
|
| 15 |
+
FunctionCall, ToolCall)
|
| 16 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 17 |
+
ToolParser, ToolParserManager)
|
| 18 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
| 19 |
+
find_common_prefix,
|
| 20 |
+
is_complete_json,
|
| 21 |
+
partial_json_loads)
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 24 |
+
from vllm.utils import random_uuid
|
| 25 |
+
|
| 26 |
+
logger = init_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@ToolParserManager.register_module("granite-20b-fc")
|
| 30 |
+
class Granite20bFCToolParser(ToolParser):
|
| 31 |
+
"""
|
| 32 |
+
Tool call parser for the granite-20b-functioncalling model intended
|
| 33 |
+
for use with the examples/tool_chat_template_granite20b_fc.jinja
|
| 34 |
+
template.
|
| 35 |
+
|
| 36 |
+
Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
|
| 37 |
+
are all set
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 41 |
+
super().__init__(tokenizer)
|
| 42 |
+
|
| 43 |
+
self.bot_token = "<function_call>"
|
| 44 |
+
self.tool_start_token = self.bot_token
|
| 45 |
+
self.tool_call_regex = re.compile(r"<function_call>\s*")
|
| 46 |
+
|
| 47 |
+
def extract_tool_calls(
|
| 48 |
+
self, model_output: str,
|
| 49 |
+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
| 50 |
+
if self.tool_start_token not in model_output:
|
| 51 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 52 |
+
tool_calls=[],
|
| 53 |
+
content=model_output)
|
| 54 |
+
|
| 55 |
+
dec = JSONDecoder()
|
| 56 |
+
try:
|
| 57 |
+
matches = list(self.tool_call_regex.finditer(model_output))
|
| 58 |
+
logger.debug("Found %d tool call matches", len(matches))
|
| 59 |
+
|
| 60 |
+
raw_function_calls = []
|
| 61 |
+
|
| 62 |
+
for i, match in enumerate(matches):
|
| 63 |
+
# position after the <function_call> tag
|
| 64 |
+
start_of_json = match.end()
|
| 65 |
+
# end_index == the start of the next function call
|
| 66 |
+
# (if exists)
|
| 67 |
+
next_function_call_start = (matches[i + 1].start() if i +
|
| 68 |
+
1 < len(matches) else None)
|
| 69 |
+
|
| 70 |
+
raw_function_calls.append(
|
| 71 |
+
dec.raw_decode(
|
| 72 |
+
model_output[start_of_json:next_function_call_start])
|
| 73 |
+
[0])
|
| 74 |
+
|
| 75 |
+
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
| 76 |
+
tool_calls = [
|
| 77 |
+
ToolCall(
|
| 78 |
+
type="function",
|
| 79 |
+
function=FunctionCall(
|
| 80 |
+
name=function_call["name"],
|
| 81 |
+
# function call args are JSON but as a string
|
| 82 |
+
arguments=json.dumps(function_call["arguments"]),
|
| 83 |
+
),
|
| 84 |
+
) for function_call in raw_function_calls
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
content = model_output[:model_output.find(self.bot_token)]
|
| 88 |
+
return ExtractedToolCallInformation(
|
| 89 |
+
tools_called=True,
|
| 90 |
+
tool_calls=tool_calls,
|
| 91 |
+
content=content if content else None,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error("Error in extracting tool call from response %s", e)
|
| 96 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 97 |
+
tool_calls=[],
|
| 98 |
+
content=model_output)
|
| 99 |
+
|
| 100 |
+
def extract_tool_calls_streaming(
|
| 101 |
+
self,
|
| 102 |
+
previous_text: str,
|
| 103 |
+
current_text: str,
|
| 104 |
+
delta_text: str,
|
| 105 |
+
previous_token_ids: Sequence[int],
|
| 106 |
+
current_token_ids: Sequence[int],
|
| 107 |
+
delta_token_ids: Sequence[int],
|
| 108 |
+
request: ChatCompletionRequest,
|
| 109 |
+
) -> Union[DeltaMessage, None]:
|
| 110 |
+
|
| 111 |
+
if len(current_text) < len(
|
| 112 |
+
self.bot_token) and self.bot_token.startswith(current_text):
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
if not current_text.startswith(self.bot_token):
|
| 116 |
+
return DeltaMessage(content=delta_text)
|
| 117 |
+
|
| 118 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 119 |
+
# sent yet, don't allow sending
|
| 120 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 121 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 122 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 123 |
+
else Allow.ALL & ~Allow.STR
|
| 124 |
+
try:
|
| 125 |
+
tool_call_arr = []
|
| 126 |
+
is_complete = []
|
| 127 |
+
try:
|
| 128 |
+
start_idx = len(self.bot_token)
|
| 129 |
+
start_idx = consume_space(start_idx, current_text)
|
| 130 |
+
|
| 131 |
+
while start_idx < len(current_text):
|
| 132 |
+
(obj,
|
| 133 |
+
end_idx) = partial_json_loads(current_text[start_idx:],
|
| 134 |
+
flags)
|
| 135 |
+
is_complete.append(
|
| 136 |
+
is_complete_json(current_text[start_idx:start_idx +
|
| 137 |
+
end_idx]))
|
| 138 |
+
start_idx += end_idx
|
| 139 |
+
start_idx = consume_space(start_idx, current_text)
|
| 140 |
+
start_idx += len(self.bot_token)
|
| 141 |
+
start_idx = consume_space(start_idx, current_text)
|
| 142 |
+
tool_call_arr.append(obj)
|
| 143 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 144 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
# select as the current tool call the one we're on the state at
|
| 148 |
+
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
| 149 |
+
if len(tool_call_arr) > 0 else {}
|
| 150 |
+
|
| 151 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 152 |
+
# only the array brackets, stream nothing
|
| 153 |
+
if len(tool_call_arr) == 0:
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
# case: we are starting a new tool in the array
|
| 157 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 158 |
+
elif (len(tool_call_arr) > 0
|
| 159 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 160 |
+
|
| 161 |
+
# if we're moving on to a new call, first make sure we
|
| 162 |
+
# haven't missed anything in the previous one that was
|
| 163 |
+
# auto-generated due to JSON completions, but wasn't
|
| 164 |
+
# streamed to the client yet.
|
| 165 |
+
if self.current_tool_id >= 0:
|
| 166 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 167 |
+
if cur_arguments:
|
| 168 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 169 |
+
sent = len(
|
| 170 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 171 |
+
argument_diff = cur_args_json[sent:]
|
| 172 |
+
|
| 173 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 174 |
+
delta = DeltaMessage(tool_calls=[
|
| 175 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 176 |
+
function=DeltaFunctionCall(
|
| 177 |
+
arguments=argument_diff).
|
| 178 |
+
model_dump(exclude_none=True))
|
| 179 |
+
])
|
| 180 |
+
self.streamed_args_for_tool[
|
| 181 |
+
self.current_tool_id] += argument_diff
|
| 182 |
+
else:
|
| 183 |
+
delta = None
|
| 184 |
+
else:
|
| 185 |
+
delta = None
|
| 186 |
+
# re-set stuff pertaining to progress in the current tool
|
| 187 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 188 |
+
self.current_tool_name_sent = False
|
| 189 |
+
self.streamed_args_for_tool.append("")
|
| 190 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 191 |
+
return delta
|
| 192 |
+
|
| 193 |
+
# if the current tool name hasn't been sent, send if available
|
| 194 |
+
# - otherwise send nothing
|
| 195 |
+
elif not self.current_tool_name_sent:
|
| 196 |
+
function_name = current_tool_call.get("name")
|
| 197 |
+
if function_name:
|
| 198 |
+
|
| 199 |
+
delta = DeltaMessage(tool_calls=[
|
| 200 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 201 |
+
type="function",
|
| 202 |
+
id=f"chatcmpl-tool-{random_uuid()}",
|
| 203 |
+
function=DeltaFunctionCall(
|
| 204 |
+
name=function_name).model_dump(
|
| 205 |
+
exclude_none=True))
|
| 206 |
+
])
|
| 207 |
+
self.current_tool_name_sent = True
|
| 208 |
+
else:
|
| 209 |
+
delta = None
|
| 210 |
+
|
| 211 |
+
# now we know we're on the same tool call and we're streaming
|
| 212 |
+
# arguments
|
| 213 |
+
else:
|
| 214 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 215 |
+
delta = None
|
| 216 |
+
|
| 217 |
+
if cur_arguments:
|
| 218 |
+
sent = len(
|
| 219 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 220 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 221 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 222 |
+
self.current_tool_id].get("arguments")
|
| 223 |
+
|
| 224 |
+
argument_diff = None
|
| 225 |
+
if is_complete[self.current_tool_id]:
|
| 226 |
+
argument_diff = cur_args_json[sent:]
|
| 227 |
+
elif prev_arguments:
|
| 228 |
+
prev_args_json = json.dumps(prev_arguments)
|
| 229 |
+
if cur_args_json != prev_args_json:
|
| 230 |
+
|
| 231 |
+
prefix = find_common_prefix(
|
| 232 |
+
prev_args_json, cur_args_json)
|
| 233 |
+
argument_diff = prefix[sent:]
|
| 234 |
+
|
| 235 |
+
if argument_diff is not None:
|
| 236 |
+
delta = DeltaMessage(tool_calls=[
|
| 237 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 238 |
+
function=DeltaFunctionCall(
|
| 239 |
+
arguments=argument_diff).
|
| 240 |
+
model_dump(exclude_none=True))
|
| 241 |
+
])
|
| 242 |
+
self.streamed_args_for_tool[
|
| 243 |
+
self.current_tool_id] += argument_diff
|
| 244 |
+
|
| 245 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 246 |
+
return delta
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.error("Error trying to handle streaming tool call: %s", e)
|
| 250 |
+
logger.debug(
|
| 251 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 252 |
+
"error")
|
| 253 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import Dict, Sequence, Union
|
| 5 |
+
|
| 6 |
+
import partial_json_parser
|
| 7 |
+
from partial_json_parser.core.options import Allow
|
| 8 |
+
|
| 9 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 10 |
+
DeltaFunctionCall, DeltaMessage,
|
| 11 |
+
DeltaToolCall,
|
| 12 |
+
ExtractedToolCallInformation,
|
| 13 |
+
FunctionCall, ToolCall)
|
| 14 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 15 |
+
ToolParser, ToolParserManager)
|
| 16 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
| 17 |
+
find_common_prefix,
|
| 18 |
+
is_complete_json,
|
| 19 |
+
partial_json_loads)
|
| 20 |
+
from vllm.logger import init_logger
|
| 21 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 22 |
+
from vllm.utils import random_uuid
|
| 23 |
+
|
| 24 |
+
logger = init_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@ToolParserManager.register_module("granite")
|
| 28 |
+
class GraniteToolParser(ToolParser):
|
| 29 |
+
"""
|
| 30 |
+
Tool call parser for the granite 3.0 models. Intended
|
| 31 |
+
for use with the examples/tool_chat_template_granite.jinja
|
| 32 |
+
template.
|
| 33 |
+
|
| 34 |
+
Used when --enable-auto-tool-choice --tool-call-parser granite
|
| 35 |
+
are all set
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 39 |
+
super().__init__(tokenizer)
|
| 40 |
+
# for granite 3.0, the token `<|tool_call|>`
|
| 41 |
+
self.bot_token = "<|tool_call|>"
|
| 42 |
+
# for granite 3.1, the string `<tool_call>`
|
| 43 |
+
self.bot_string = "<tool_call>"
|
| 44 |
+
|
| 45 |
+
def extract_tool_calls(
|
| 46 |
+
self, model_output: str,
|
| 47 |
+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
| 48 |
+
stripped = model_output.strip()\
|
| 49 |
+
.removeprefix(self.bot_token)\
|
| 50 |
+
.removeprefix(self.bot_string)\
|
| 51 |
+
.lstrip()
|
| 52 |
+
if not stripped or stripped[0] != '[':
|
| 53 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 54 |
+
tool_calls=[],
|
| 55 |
+
content=model_output)
|
| 56 |
+
try:
|
| 57 |
+
raw_function_calls = json.loads(stripped)
|
| 58 |
+
if not isinstance(raw_function_calls, list):
|
| 59 |
+
raise Exception(
|
| 60 |
+
f"Expected dict or list, got {type(raw_function_calls)}")
|
| 61 |
+
|
| 62 |
+
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
| 63 |
+
tool_calls = [
|
| 64 |
+
ToolCall(
|
| 65 |
+
type="function",
|
| 66 |
+
function=FunctionCall(
|
| 67 |
+
name=function_call["name"],
|
| 68 |
+
# function call args are JSON but as a string
|
| 69 |
+
arguments=json.dumps(function_call["arguments"]),
|
| 70 |
+
),
|
| 71 |
+
) for function_call in raw_function_calls
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
return ExtractedToolCallInformation(
|
| 75 |
+
tools_called=True,
|
| 76 |
+
tool_calls=tool_calls,
|
| 77 |
+
content=None,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error("Error in extracting tool call from response %s", e)
|
| 82 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 83 |
+
tool_calls=[],
|
| 84 |
+
content=model_output)
|
| 85 |
+
|
| 86 |
+
def extract_tool_calls_streaming(
|
| 87 |
+
self,
|
| 88 |
+
previous_text: str,
|
| 89 |
+
current_text: str,
|
| 90 |
+
delta_text: str,
|
| 91 |
+
previous_token_ids: Sequence[int],
|
| 92 |
+
current_token_ids: Sequence[int],
|
| 93 |
+
delta_token_ids: Sequence[int],
|
| 94 |
+
request: ChatCompletionRequest,
|
| 95 |
+
) -> Union[DeltaMessage, None]:
|
| 96 |
+
|
| 97 |
+
start_idx = consume_space(0, current_text)
|
| 98 |
+
if current_text[start_idx:].startswith(self.bot_token):
|
| 99 |
+
start_idx = consume_space(start_idx + len(self.bot_token),
|
| 100 |
+
current_text)
|
| 101 |
+
if current_text[start_idx:].startswith(self.bot_string):
|
| 102 |
+
start_idx = consume_space(start_idx + len(self.bot_string),
|
| 103 |
+
current_text)
|
| 104 |
+
if not current_text or start_idx >= len(current_text)\
|
| 105 |
+
or current_text[start_idx] != '[':
|
| 106 |
+
return DeltaMessage(content=delta_text)
|
| 107 |
+
|
| 108 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 109 |
+
# sent yet, don't allow sending
|
| 110 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 111 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 112 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 113 |
+
else Allow.ALL & ~Allow.STR
|
| 114 |
+
try:
|
| 115 |
+
tool_call_arr = None
|
| 116 |
+
is_complete = None
|
| 117 |
+
try:
|
| 118 |
+
tool_calls, end_idx = partial_json_loads(
|
| 119 |
+
current_text[start_idx:], flags)
|
| 120 |
+
if type(tool_calls) is list:
|
| 121 |
+
tool_call_arr = tool_calls
|
| 122 |
+
else:
|
| 123 |
+
return DeltaMessage(content=delta_text)
|
| 124 |
+
|
| 125 |
+
is_complete = [True] * len(tool_calls)
|
| 126 |
+
if not is_complete_json(
|
| 127 |
+
current_text[start_idx:start_idx + end_idx]):
|
| 128 |
+
is_complete[-1] = False
|
| 129 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 130 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 134 |
+
# only the array brackets, stream nothing
|
| 135 |
+
if not tool_call_arr:
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
# select as the current tool call the one we're on the state at
|
| 139 |
+
current_tool_call: Dict = tool_call_arr[self.current_tool_id]
|
| 140 |
+
|
| 141 |
+
delta = None
|
| 142 |
+
# case: we are starting a new tool in the array
|
| 143 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 144 |
+
if len(tool_call_arr) > self.current_tool_id + 1:
|
| 145 |
+
|
| 146 |
+
# if we're moving on to a new call, first make sure we
|
| 147 |
+
# haven't missed anything in the previous one that was
|
| 148 |
+
# auto-generated due to JSON completions, but wasn't
|
| 149 |
+
# streamed to the client yet.
|
| 150 |
+
if self.current_tool_id >= 0:
|
| 151 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 152 |
+
if cur_arguments:
|
| 153 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 154 |
+
sent = len(
|
| 155 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 156 |
+
argument_diff = cur_args_json[sent:]
|
| 157 |
+
|
| 158 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 159 |
+
delta = DeltaMessage(tool_calls=[
|
| 160 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 161 |
+
function=DeltaFunctionCall(
|
| 162 |
+
arguments=argument_diff).
|
| 163 |
+
model_dump(exclude_none=True))
|
| 164 |
+
])
|
| 165 |
+
self.streamed_args_for_tool[
|
| 166 |
+
self.current_tool_id] += argument_diff
|
| 167 |
+
|
| 168 |
+
# re-set stuff pertaining to progress in the current tool
|
| 169 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 170 |
+
self.current_tool_name_sent = False
|
| 171 |
+
self.streamed_args_for_tool.append("")
|
| 172 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 173 |
+
return delta
|
| 174 |
+
|
| 175 |
+
# if the current tool name hasn't been sent, send if available
|
| 176 |
+
# - otherwise send nothing
|
| 177 |
+
elif not self.current_tool_name_sent:
|
| 178 |
+
function_name = current_tool_call.get("name")
|
| 179 |
+
if function_name:
|
| 180 |
+
|
| 181 |
+
delta = DeltaMessage(tool_calls=[
|
| 182 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 183 |
+
type="function",
|
| 184 |
+
id=f"chatcmpl-tool-{random_uuid()}",
|
| 185 |
+
function=DeltaFunctionCall(
|
| 186 |
+
name=function_name).model_dump(
|
| 187 |
+
exclude_none=True))
|
| 188 |
+
])
|
| 189 |
+
self.current_tool_name_sent = True
|
| 190 |
+
|
| 191 |
+
# now we know we're on the same tool call and we're streaming
|
| 192 |
+
# arguments
|
| 193 |
+
else:
|
| 194 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 195 |
+
|
| 196 |
+
if cur_arguments:
|
| 197 |
+
sent = len(
|
| 198 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 199 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 200 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 201 |
+
self.current_tool_id].get("arguments")
|
| 202 |
+
|
| 203 |
+
argument_diff = None
|
| 204 |
+
if is_complete[self.current_tool_id]:
|
| 205 |
+
argument_diff = cur_args_json[sent:]
|
| 206 |
+
elif prev_arguments:
|
| 207 |
+
prev_args_json = json.dumps(prev_arguments)
|
| 208 |
+
if cur_args_json != prev_args_json:
|
| 209 |
+
prefix = find_common_prefix(
|
| 210 |
+
prev_args_json, cur_args_json)
|
| 211 |
+
argument_diff = prefix[sent:]
|
| 212 |
+
|
| 213 |
+
if argument_diff is not None:
|
| 214 |
+
delta = DeltaMessage(tool_calls=[
|
| 215 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 216 |
+
function=DeltaFunctionCall(
|
| 217 |
+
arguments=argument_diff).
|
| 218 |
+
model_dump(exclude_none=True))
|
| 219 |
+
])
|
| 220 |
+
self.streamed_args_for_tool[
|
| 221 |
+
self.current_tool_id] += argument_diff
|
| 222 |
+
|
| 223 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 224 |
+
return delta
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logger.error("Error trying to handle streaming tool call: %s", e)
|
| 228 |
+
logger.debug(
|
| 229 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 230 |
+
"error")
|
| 231 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, List, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import partial_json_parser
|
| 8 |
+
from partial_json_parser.core.options import Allow
|
| 9 |
+
|
| 10 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 11 |
+
DeltaFunctionCall, DeltaMessage,
|
| 12 |
+
DeltaToolCall,
|
| 13 |
+
ExtractedToolCallInformation,
|
| 14 |
+
FunctionCall, ToolCall)
|
| 15 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 16 |
+
ToolParser, ToolParserManager)
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
| 19 |
+
from vllm.utils import random_uuid
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@ToolParserManager.register_module("hermes")
|
| 25 |
+
class Hermes2ProToolParser(ToolParser):
|
| 26 |
+
|
| 27 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 28 |
+
super().__init__(tokenizer)
|
| 29 |
+
|
| 30 |
+
if isinstance(self.model_tokenizer, MistralTokenizer):
|
| 31 |
+
logger.error(
|
| 32 |
+
"Detected Mistral tokenizer when using a Hermes model")
|
| 33 |
+
self.model_tokenizer = self.model_tokenizer.tokenizer
|
| 34 |
+
|
| 35 |
+
self.current_tool_name_sent: bool = False
|
| 36 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 37 |
+
self.current_tool_id: int = -1
|
| 38 |
+
self.streamed_args_for_tool: List[str] = [
|
| 39 |
+
] # map what has been streamed for each tool so far to a list
|
| 40 |
+
|
| 41 |
+
self.tool_call_start_token: str = "<tool_call>"
|
| 42 |
+
self.tool_call_end_token: str = "</tool_call>"
|
| 43 |
+
|
| 44 |
+
self.tool_call_regex = re.compile(
|
| 45 |
+
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
| 46 |
+
self.scratch_pad_regex = re.compile(
|
| 47 |
+
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
| 48 |
+
|
| 49 |
+
if not self.model_tokenizer:
|
| 50 |
+
raise ValueError(
|
| 51 |
+
"The model tokenizer must be passed to the ToolParser "
|
| 52 |
+
"constructor during construction.")
|
| 53 |
+
self.tool_call_start_token_id = self.vocab.get(
|
| 54 |
+
self.tool_call_start_token)
|
| 55 |
+
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
| 56 |
+
if (self.tool_call_start_token_id is None
|
| 57 |
+
or self.tool_call_end_token_id is None):
|
| 58 |
+
raise RuntimeError(
|
| 59 |
+
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
| 60 |
+
"tokens in the tokenizer!")
|
| 61 |
+
|
| 62 |
+
def extract_tool_calls(
|
| 63 |
+
self,
|
| 64 |
+
model_output: str,
|
| 65 |
+
request: ChatCompletionRequest,
|
| 66 |
+
) -> ExtractedToolCallInformation:
|
| 67 |
+
|
| 68 |
+
# sanity check; avoid unnecessary processing
|
| 69 |
+
if self.tool_call_start_token not in model_output:
|
| 70 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 71 |
+
tool_calls=[],
|
| 72 |
+
content=model_output)
|
| 73 |
+
|
| 74 |
+
else:
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
# there are two possible captures - between tags, or between a
|
| 78 |
+
# tag and end-of-string so the result of
|
| 79 |
+
# findall is an array of tuples where one is a function call and
|
| 80 |
+
# the other is None
|
| 81 |
+
function_call_tuples = (
|
| 82 |
+
self.tool_call_regex.findall(model_output))
|
| 83 |
+
|
| 84 |
+
# load the JSON, and then use it to build the Function and
|
| 85 |
+
# Tool Call
|
| 86 |
+
raw_function_calls = [
|
| 87 |
+
json.loads(match[0] if match[0] else match[1])
|
| 88 |
+
for match in function_call_tuples
|
| 89 |
+
]
|
| 90 |
+
tool_calls = [
|
| 91 |
+
ToolCall(
|
| 92 |
+
type="function",
|
| 93 |
+
function=FunctionCall(
|
| 94 |
+
name=function_call["name"],
|
| 95 |
+
# function call args are JSON but as a string
|
| 96 |
+
arguments=json.dumps(function_call["arguments"],
|
| 97 |
+
ensure_ascii=False)))
|
| 98 |
+
for function_call in raw_function_calls
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
content = model_output[:model_output.
|
| 102 |
+
find(self.tool_call_start_token)]
|
| 103 |
+
return ExtractedToolCallInformation(
|
| 104 |
+
tools_called=True,
|
| 105 |
+
tool_calls=tool_calls,
|
| 106 |
+
content=content if content else None)
|
| 107 |
+
|
| 108 |
+
except Exception:
|
| 109 |
+
logger.exception(
|
| 110 |
+
"Error in extracting tool call from response.")
|
| 111 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 112 |
+
tool_calls=[],
|
| 113 |
+
content=model_output)
|
| 114 |
+
|
| 115 |
+
def extract_tool_calls_streaming(
|
| 116 |
+
self,
|
| 117 |
+
previous_text: str,
|
| 118 |
+
current_text: str,
|
| 119 |
+
delta_text: str,
|
| 120 |
+
previous_token_ids: Sequence[int],
|
| 121 |
+
current_token_ids: Sequence[int],
|
| 122 |
+
delta_token_ids: Sequence[int],
|
| 123 |
+
request: ChatCompletionRequest,
|
| 124 |
+
) -> Union[DeltaMessage, None]:
|
| 125 |
+
|
| 126 |
+
logger.debug("delta_text: %s", delta_text)
|
| 127 |
+
logger.debug("delta_token_ids: %s", delta_token_ids)
|
| 128 |
+
# check to see if we should be streaming a tool call - is there a
|
| 129 |
+
if self.tool_call_start_token_id not in current_token_ids:
|
| 130 |
+
logger.debug("No tool call tokens found!")
|
| 131 |
+
return DeltaMessage(content=delta_text)
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
|
| 135 |
+
# figure out where we are in the parsing by counting tool call
|
| 136 |
+
# start & end tags
|
| 137 |
+
prev_tool_start_count = previous_token_ids.count(
|
| 138 |
+
self.tool_call_start_token_id)
|
| 139 |
+
prev_tool_end_count = previous_token_ids.count(
|
| 140 |
+
self.tool_call_end_token_id)
|
| 141 |
+
cur_tool_start_count = current_token_ids.count(
|
| 142 |
+
self.tool_call_start_token_id)
|
| 143 |
+
cur_tool_end_count = current_token_ids.count(
|
| 144 |
+
self.tool_call_end_token_id)
|
| 145 |
+
tool_call_portion = None
|
| 146 |
+
text_portion = None
|
| 147 |
+
|
| 148 |
+
# case: if we're generating text, OR rounding out a tool call
|
| 149 |
+
if (cur_tool_start_count == cur_tool_end_count
|
| 150 |
+
and prev_tool_end_count == cur_tool_end_count
|
| 151 |
+
and self.tool_call_end_token not in delta_text):
|
| 152 |
+
logger.debug("Generating text content! skipping tool parsing.")
|
| 153 |
+
return DeltaMessage(content=delta_text)
|
| 154 |
+
|
| 155 |
+
if self.tool_call_end_token in delta_text:
|
| 156 |
+
logger.debug("tool_call_end_token in delta_text")
|
| 157 |
+
full_text = current_text + delta_text
|
| 158 |
+
tool_call_portion = full_text.split(
|
| 159 |
+
self.tool_call_start_token)[-1].split(
|
| 160 |
+
self.tool_call_end_token)[0].rstrip()
|
| 161 |
+
delta_text = delta_text.split(
|
| 162 |
+
self.tool_call_end_token)[0].rstrip()
|
| 163 |
+
text_portion = delta_text.split(
|
| 164 |
+
self.tool_call_end_token)[-1].lstrip()
|
| 165 |
+
|
| 166 |
+
# case: if tool open & close tag counts don't match, we're doing
|
| 167 |
+
# imaginary "else" block here
|
| 168 |
+
# something with tools with this diff.
|
| 169 |
+
# flags for partial JSON parting. exported constants from
|
| 170 |
+
# "Allow" are handled via BIT MASK
|
| 171 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 172 |
+
else Allow.ALL & ~Allow.STR
|
| 173 |
+
|
| 174 |
+
# case -- we're starting a new tool call
|
| 175 |
+
if (cur_tool_start_count > cur_tool_end_count
|
| 176 |
+
and cur_tool_start_count > prev_tool_start_count):
|
| 177 |
+
if len(delta_token_ids) > 1:
|
| 178 |
+
tool_call_portion = current_text.split(
|
| 179 |
+
self.tool_call_start_token)[-1]
|
| 180 |
+
else:
|
| 181 |
+
tool_call_portion = None
|
| 182 |
+
delta = None
|
| 183 |
+
|
| 184 |
+
text_portion = None
|
| 185 |
+
|
| 186 |
+
# set cursors and state appropriately
|
| 187 |
+
self.current_tool_id += 1
|
| 188 |
+
self.current_tool_name_sent = False
|
| 189 |
+
self.streamed_args_for_tool.append("")
|
| 190 |
+
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
| 191 |
+
|
| 192 |
+
# case -- we're updating an existing tool call
|
| 193 |
+
elif (cur_tool_start_count > cur_tool_end_count
|
| 194 |
+
and cur_tool_start_count == prev_tool_start_count):
|
| 195 |
+
|
| 196 |
+
# get the portion of the text that's the tool call
|
| 197 |
+
tool_call_portion = current_text.split(
|
| 198 |
+
self.tool_call_start_token)[-1]
|
| 199 |
+
text_portion = None
|
| 200 |
+
|
| 201 |
+
# case -- the current tool call is being closed.
|
| 202 |
+
elif (cur_tool_start_count == cur_tool_end_count
|
| 203 |
+
and cur_tool_end_count >= prev_tool_end_count):
|
| 204 |
+
if (self.prev_tool_call_arr is None
|
| 205 |
+
or len(self.prev_tool_call_arr) == 0):
|
| 206 |
+
logger.debug(
|
| 207 |
+
"attempting to close tool call, but no tool call")
|
| 208 |
+
return None
|
| 209 |
+
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
| 210 |
+
"arguments")
|
| 211 |
+
if diff:
|
| 212 |
+
diff = diff.encode('utf-8').decode(
|
| 213 |
+
'unicode_escape') if diff is str else diff
|
| 214 |
+
if ('"}' not in delta_text):
|
| 215 |
+
return None
|
| 216 |
+
end_loc = delta_text.rindex('"}')
|
| 217 |
+
diff = delta_text[:end_loc] + '"}'
|
| 218 |
+
logger.debug(
|
| 219 |
+
"Finishing tool and found diff that had not "
|
| 220 |
+
"been streamed yet: %s", diff)
|
| 221 |
+
self.streamed_args_for_tool[self.current_tool_id] \
|
| 222 |
+
+= diff
|
| 223 |
+
return DeltaMessage(tool_calls=[
|
| 224 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 225 |
+
function=DeltaFunctionCall(
|
| 226 |
+
arguments=diff).model_dump(
|
| 227 |
+
exclude_none=True))
|
| 228 |
+
])
|
| 229 |
+
|
| 230 |
+
# case -- otherwise we're just generating text
|
| 231 |
+
else:
|
| 232 |
+
text = delta_text.replace(self.tool_call_start_token, "")
|
| 233 |
+
text = text.replace(self.tool_call_end_token, "")
|
| 234 |
+
delta = DeltaMessage(tool_calls=[], content=text)
|
| 235 |
+
return delta
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
|
| 239 |
+
current_tool_call = partial_json_parser.loads(
|
| 240 |
+
tool_call_portion or "{}",
|
| 241 |
+
flags) if tool_call_portion else None
|
| 242 |
+
logger.debug("Parsed tool call %s", current_tool_call)
|
| 243 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 244 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 245 |
+
return None
|
| 246 |
+
except json.decoder.JSONDecodeError:
|
| 247 |
+
logger.debug("unable to parse JSON")
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
# case - we haven't sent the tool name yet. If it's available, send
|
| 251 |
+
# it. otherwise, wait until it's available.
|
| 252 |
+
if not self.current_tool_name_sent:
|
| 253 |
+
if (current_tool_call is None):
|
| 254 |
+
return None
|
| 255 |
+
function_name: Union[str, None] = current_tool_call.get("name")
|
| 256 |
+
if function_name:
|
| 257 |
+
self.current_tool_name_sent = True
|
| 258 |
+
return DeltaMessage(tool_calls=[
|
| 259 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 260 |
+
type="function",
|
| 261 |
+
id=f"chatcmpl-tool-{random_uuid()}",
|
| 262 |
+
function=DeltaFunctionCall(
|
| 263 |
+
name=function_name).model_dump(
|
| 264 |
+
exclude_none=True))
|
| 265 |
+
])
|
| 266 |
+
else:
|
| 267 |
+
return None
|
| 268 |
+
# case -- otherwise, send the tool call delta
|
| 269 |
+
|
| 270 |
+
# if the tool call portion is None, send the delta as text
|
| 271 |
+
if tool_call_portion is None:
|
| 272 |
+
# if there's text but not tool calls, send that -
|
| 273 |
+
# otherwise None to skip chunk
|
| 274 |
+
delta = DeltaMessage(content=delta_text) \
|
| 275 |
+
if text_portion is not None else None
|
| 276 |
+
return delta
|
| 277 |
+
|
| 278 |
+
# now, the nitty-gritty of tool calls
|
| 279 |
+
# now we have the portion to parse as tool call.
|
| 280 |
+
|
| 281 |
+
logger.debug("Trying to parse current tool call with ID %s",
|
| 282 |
+
self.current_tool_id)
|
| 283 |
+
|
| 284 |
+
# if we're starting a new tool call, push an empty object in as
|
| 285 |
+
# a placeholder for the arguments
|
| 286 |
+
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
| 287 |
+
self.prev_tool_call_arr.append({})
|
| 288 |
+
|
| 289 |
+
# main logic for tool parsing here - compare prev. partially-parsed
|
| 290 |
+
# JSON to the current partially-parsed JSON
|
| 291 |
+
prev_arguments = (
|
| 292 |
+
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
| 293 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 294 |
+
|
| 295 |
+
logger.debug("diffing old arguments: %s", prev_arguments)
|
| 296 |
+
logger.debug("against new ones: %s", cur_arguments)
|
| 297 |
+
|
| 298 |
+
# case -- no arguments have been created yet. skip sending a delta.
|
| 299 |
+
if not cur_arguments and not prev_arguments:
|
| 300 |
+
logger.debug("Skipping text %s - no arguments", delta_text)
|
| 301 |
+
delta = None
|
| 302 |
+
|
| 303 |
+
# case -- prev arguments are defined, but non are now.
|
| 304 |
+
# probably impossible, but not a fatal error - just keep going
|
| 305 |
+
elif not cur_arguments and prev_arguments:
|
| 306 |
+
logger.error("should be impossible to have arguments reset "
|
| 307 |
+
"mid-call. skipping streaming anything.")
|
| 308 |
+
delta = None
|
| 309 |
+
|
| 310 |
+
# case -- we now have the first info about arguments available from
|
| 311 |
+
# autocompleting the JSON
|
| 312 |
+
elif cur_arguments and not prev_arguments:
|
| 313 |
+
|
| 314 |
+
cur_arguments_json = json.dumps(cur_arguments,
|
| 315 |
+
ensure_ascii=False)
|
| 316 |
+
logger.debug("finding %s in %s", delta_text,
|
| 317 |
+
cur_arguments_json)
|
| 318 |
+
|
| 319 |
+
# get the location where previous args differ from current
|
| 320 |
+
if (delta_text not in cur_arguments_json[:-2]):
|
| 321 |
+
return None
|
| 322 |
+
args_delta_start_loc = cur_arguments_json[:-2]. \
|
| 323 |
+
rindex(delta_text) + \
|
| 324 |
+
len(delta_text)
|
| 325 |
+
|
| 326 |
+
# use that to find the actual delta
|
| 327 |
+
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
| 328 |
+
logger.debug("First tokens in arguments received: %s",
|
| 329 |
+
arguments_delta)
|
| 330 |
+
|
| 331 |
+
delta = DeltaMessage(tool_calls=[
|
| 332 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 333 |
+
function=DeltaFunctionCall(
|
| 334 |
+
arguments=arguments_delta).model_dump(
|
| 335 |
+
exclude_none=True))
|
| 336 |
+
])
|
| 337 |
+
self.streamed_args_for_tool[self.current_tool_id] \
|
| 338 |
+
+= arguments_delta
|
| 339 |
+
|
| 340 |
+
# last case -- we have an update to existing arguments.
|
| 341 |
+
elif cur_arguments and prev_arguments:
|
| 342 |
+
if isinstance(delta_text, str) and len(delta_text.rstrip(
|
| 343 |
+
)) >= 1 and delta_text.rstrip()[-1] == '}':
|
| 344 |
+
delta_text = delta_text.rstrip()[:-1]
|
| 345 |
+
|
| 346 |
+
logger.debug("got diff %s", delta_text)
|
| 347 |
+
|
| 348 |
+
delta = DeltaMessage(tool_calls=[
|
| 349 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 350 |
+
function=DeltaFunctionCall(
|
| 351 |
+
arguments=delta_text).model_dump(
|
| 352 |
+
exclude_none=True))
|
| 353 |
+
])
|
| 354 |
+
self.streamed_args_for_tool[self.current_tool_id] \
|
| 355 |
+
+= delta_text
|
| 356 |
+
|
| 357 |
+
# handle saving the state for the current tool into
|
| 358 |
+
# the "prev" list for use in diffing for the next iteration
|
| 359 |
+
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
| 360 |
+
self.prev_tool_call_arr[self.current_tool_id] = \
|
| 361 |
+
current_tool_call
|
| 362 |
+
else:
|
| 363 |
+
self.prev_tool_call_arr.append(current_tool_call)
|
| 364 |
+
|
| 365 |
+
return delta
|
| 366 |
+
|
| 367 |
+
except Exception:
|
| 368 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 369 |
+
return None # do not stream a delta. skip this token ID.
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import Dict, Sequence, Union
|
| 5 |
+
|
| 6 |
+
import partial_json_parser
|
| 7 |
+
from partial_json_parser.core.options import Allow
|
| 8 |
+
|
| 9 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 10 |
+
DeltaFunctionCall, DeltaMessage,
|
| 11 |
+
DeltaToolCall,
|
| 12 |
+
ExtractedToolCallInformation,
|
| 13 |
+
FunctionCall, ToolCall)
|
| 14 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 15 |
+
ToolParser, ToolParserManager)
|
| 16 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (
|
| 17 |
+
extract_intermediate_diff)
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 20 |
+
from vllm.utils import random_uuid
|
| 21 |
+
|
| 22 |
+
logger = init_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@ToolParserManager.register_module(["internlm"])
|
| 26 |
+
class Internlm2ToolParser(ToolParser):
|
| 27 |
+
|
| 28 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 29 |
+
super().__init__(tokenizer)
|
| 30 |
+
self.position = 0
|
| 31 |
+
|
| 32 |
+
def adjust_request(
|
| 33 |
+
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
| 34 |
+
if request.tools and request.tool_choice != 'none':
|
| 35 |
+
# do not skip special tokens because internlm use the special
|
| 36 |
+
# tokens to indicated the start and end of the tool calls
|
| 37 |
+
# information.
|
| 38 |
+
request.skip_special_tokens = False
|
| 39 |
+
return request
|
| 40 |
+
|
| 41 |
+
def get_argments(self, obj):
|
| 42 |
+
if "parameters" in obj:
|
| 43 |
+
return obj.get("parameters")
|
| 44 |
+
elif "arguments" in obj:
|
| 45 |
+
return obj.get("arguments")
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
def extract_tool_calls_streaming(
|
| 49 |
+
self,
|
| 50 |
+
previous_text: str,
|
| 51 |
+
current_text: str,
|
| 52 |
+
delta_text: str,
|
| 53 |
+
previous_token_ids: Sequence[int],
|
| 54 |
+
current_token_ids: Sequence[int],
|
| 55 |
+
delta_token_ids: Sequence[int],
|
| 56 |
+
request: ChatCompletionRequest,
|
| 57 |
+
) -> Union[DeltaMessage, None]:
|
| 58 |
+
if '<|action_start|>' not in current_text:
|
| 59 |
+
self.position = len(current_text)
|
| 60 |
+
return DeltaMessage(content=delta_text)
|
| 61 |
+
# if the tool call is sended, return a empty delta message
|
| 62 |
+
# to make sure the finish_reason will be send correctly.
|
| 63 |
+
if self.current_tool_id > 0:
|
| 64 |
+
return DeltaMessage(content='')
|
| 65 |
+
|
| 66 |
+
last_pos = self.position
|
| 67 |
+
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
new_delta = current_text[last_pos:]
|
| 71 |
+
text, action = new_delta.split('<|action_start|><|plugin|>')
|
| 72 |
+
|
| 73 |
+
if len(text) > 0:
|
| 74 |
+
self.position = self.position + len(text)
|
| 75 |
+
return DeltaMessage(content=text)
|
| 76 |
+
|
| 77 |
+
action = action.strip()
|
| 78 |
+
action = action.split('<|action_end|>'.strip())[0]
|
| 79 |
+
|
| 80 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 81 |
+
# sent yet, don't allow sending
|
| 82 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 83 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 84 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 85 |
+
else Allow.ALL & ~Allow.STR
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
parsable_arr = action
|
| 89 |
+
|
| 90 |
+
# tool calls are generated in an object in inernlm2
|
| 91 |
+
# it's not support parallel tool calls
|
| 92 |
+
try:
|
| 93 |
+
tool_call_arr: Dict = partial_json_parser.loads(
|
| 94 |
+
parsable_arr, flags)
|
| 95 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 96 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
# if the current tool name hasn't been sent, send if available
|
| 100 |
+
# - otherwise send nothing
|
| 101 |
+
if not self.current_tool_name_sent:
|
| 102 |
+
function_name = tool_call_arr.get("name")
|
| 103 |
+
if function_name:
|
| 104 |
+
self.current_tool_id = self.current_tool_id + 1
|
| 105 |
+
delta = DeltaMessage(tool_calls=[
|
| 106 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 107 |
+
type="function",
|
| 108 |
+
id=f"chatcmpl-tool-{random_uuid()}",
|
| 109 |
+
function=DeltaFunctionCall(
|
| 110 |
+
name=function_name).model_dump(
|
| 111 |
+
exclude_none=True))
|
| 112 |
+
])
|
| 113 |
+
self.current_tool_name_sent = True
|
| 114 |
+
self.streamed_args_for_tool.append("")
|
| 115 |
+
else:
|
| 116 |
+
delta = None
|
| 117 |
+
# now we know we're on the same tool call and we're streaming
|
| 118 |
+
# arguments
|
| 119 |
+
else:
|
| 120 |
+
prev_arguments = self.get_argments(
|
| 121 |
+
self.prev_tool_call_arr[self.current_tool_id])
|
| 122 |
+
cur_arguments = self.get_argments(tool_call_arr)
|
| 123 |
+
|
| 124 |
+
# not arguments generated
|
| 125 |
+
if not cur_arguments and not prev_arguments:
|
| 126 |
+
delta = None
|
| 127 |
+
# will never happen
|
| 128 |
+
elif not cur_arguments and prev_arguments:
|
| 129 |
+
logger.error(
|
| 130 |
+
"INVARIANT - impossible to have arguments reset "
|
| 131 |
+
"mid-arguments")
|
| 132 |
+
delta = None
|
| 133 |
+
# first time to get parameters
|
| 134 |
+
elif cur_arguments and not prev_arguments:
|
| 135 |
+
cur_arguments_json = json.dumps(cur_arguments)
|
| 136 |
+
|
| 137 |
+
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
| 138 |
+
index(delta_text) +
|
| 139 |
+
len(delta_text)]
|
| 140 |
+
delta = DeltaMessage(tool_calls=[
|
| 141 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 142 |
+
function=DeltaFunctionCall(
|
| 143 |
+
arguments=arguments_delta).
|
| 144 |
+
model_dump(exclude_none=True))
|
| 145 |
+
])
|
| 146 |
+
self.streamed_args_for_tool[
|
| 147 |
+
self.current_tool_id] += arguments_delta
|
| 148 |
+
# both prev and cur parameters, send the increase parameters
|
| 149 |
+
elif cur_arguments and prev_arguments:
|
| 150 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 151 |
+
prev_args_json = json.dumps(prev_arguments)
|
| 152 |
+
|
| 153 |
+
argument_diff = extract_intermediate_diff(
|
| 154 |
+
cur_args_json, prev_args_json)
|
| 155 |
+
|
| 156 |
+
delta = DeltaMessage(tool_calls=[
|
| 157 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 158 |
+
function=DeltaFunctionCall(
|
| 159 |
+
arguments=argument_diff).model_dump(
|
| 160 |
+
exclude_none=True))
|
| 161 |
+
])
|
| 162 |
+
self.streamed_args_for_tool[
|
| 163 |
+
self.current_tool_id] += argument_diff
|
| 164 |
+
|
| 165 |
+
# check to see if the name is defined and has been sent. if so,
|
| 166 |
+
# stream the name - otherwise keep waiting
|
| 167 |
+
# finish by setting old and returning None as base case
|
| 168 |
+
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
| 169 |
+
self.prev_tool_call_arr = [tool_call_arr]
|
| 170 |
+
return delta
|
| 171 |
+
except Exception:
|
| 172 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 173 |
+
logger.debug(
|
| 174 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 175 |
+
"error")
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
def extract_tool_calls(
|
| 179 |
+
self,
|
| 180 |
+
model_output: str,
|
| 181 |
+
request: ChatCompletionRequest,
|
| 182 |
+
) -> ExtractedToolCallInformation:
|
| 183 |
+
text = model_output
|
| 184 |
+
tools = request.tools
|
| 185 |
+
if '<|action_start|><|plugin|>' in text:
|
| 186 |
+
text, action = text.split('<|action_start|><|plugin|>')
|
| 187 |
+
action = action.split('<|action_end|>'.strip())[0]
|
| 188 |
+
action = action[action.find('{'):]
|
| 189 |
+
action_dict = json.loads(action)
|
| 190 |
+
name, parameters = action_dict['name'], json.dumps(
|
| 191 |
+
action_dict.get('parameters', action_dict.get('arguments',
|
| 192 |
+
{})))
|
| 193 |
+
|
| 194 |
+
if not tools or name not in [t.function.name for t in tools]:
|
| 195 |
+
ExtractedToolCallInformation(tools_called=False,
|
| 196 |
+
tool_calls=[],
|
| 197 |
+
content=text)
|
| 198 |
+
|
| 199 |
+
tool_calls = [
|
| 200 |
+
ToolCall(
|
| 201 |
+
function=FunctionCall(name=name, arguments=parameters))
|
| 202 |
+
]
|
| 203 |
+
return ExtractedToolCallInformation(
|
| 204 |
+
tools_called=True,
|
| 205 |
+
tool_calls=tool_calls,
|
| 206 |
+
content=text if len(text) > 0 else None)
|
| 207 |
+
|
| 208 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 209 |
+
tool_calls=[],
|
| 210 |
+
content=text)
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Sequence, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from transformers import PreTrainedTokenizerBase
|
| 9 |
+
|
| 10 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 11 |
+
DeltaFunctionCall, DeltaMessage,
|
| 12 |
+
DeltaToolCall,
|
| 13 |
+
ExtractedToolCallInformation,
|
| 14 |
+
FunctionCall, ToolCall)
|
| 15 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 16 |
+
ToolParser, ToolParserManager)
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
|
| 19 |
+
logger = init_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class _UnexpectedAstError(Exception):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@ToolParserManager.register_module("pythonic")
|
| 27 |
+
class PythonicToolParser(ToolParser):
|
| 28 |
+
"""
|
| 29 |
+
Tool call parser for models that produce tool calls in a pythonic style,
|
| 30 |
+
such as Llama 3.2 models.
|
| 31 |
+
|
| 32 |
+
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
| 33 |
+
"""
|
| 34 |
+
# TODO(mdepinet): Possible future improvements:
|
| 35 |
+
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
| 36 |
+
# 2. Support tools outside of a list (or separated by a semicolon).
|
| 37 |
+
# This depends on item 1 for consistent streaming.
|
| 38 |
+
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
| 39 |
+
# Llama3.2 models more reliable.
|
| 40 |
+
|
| 41 |
+
TOOL_CALL_REGEX = re.compile(
|
| 42 |
+
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
| 43 |
+
re.DOTALL)
|
| 44 |
+
|
| 45 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 46 |
+
super().__init__(tokenizer)
|
| 47 |
+
|
| 48 |
+
# Rename for readability. This is NOT a tool id.
|
| 49 |
+
@property
|
| 50 |
+
def current_tool_index(self) -> int:
|
| 51 |
+
return self.current_tool_id
|
| 52 |
+
|
| 53 |
+
@current_tool_index.setter
|
| 54 |
+
def current_tool_index(self, value: int) -> None:
|
| 55 |
+
self.current_tool_id = value
|
| 56 |
+
|
| 57 |
+
def extract_tool_calls(
|
| 58 |
+
self, model_output: str,
|
| 59 |
+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
| 60 |
+
"""
|
| 61 |
+
Extract the tool calls from a complete model response.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
if not (self.TOOL_CALL_REGEX.match(model_output)):
|
| 65 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 66 |
+
tool_calls=[],
|
| 67 |
+
content=model_output)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
module = ast.parse(model_output)
|
| 71 |
+
parsed = getattr(module.body[0], "value", None)
|
| 72 |
+
if isinstance(parsed, ast.List) and all(
|
| 73 |
+
isinstance(e, ast.Call) for e in parsed.elts):
|
| 74 |
+
return ExtractedToolCallInformation(
|
| 75 |
+
tools_called=True,
|
| 76 |
+
tool_calls=[
|
| 77 |
+
_handle_single_tool(e) # type: ignore
|
| 78 |
+
for e in parsed.elts
|
| 79 |
+
],
|
| 80 |
+
content=None)
|
| 81 |
+
else:
|
| 82 |
+
raise _UnexpectedAstError(
|
| 83 |
+
"Tool output must be a list of function calls")
|
| 84 |
+
except Exception:
|
| 85 |
+
logger.exception("Error in extracting tool call from response.")
|
| 86 |
+
# Treat as regular text
|
| 87 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 88 |
+
tool_calls=[],
|
| 89 |
+
content=model_output)
|
| 90 |
+
|
| 91 |
+
def extract_tool_calls_streaming(
|
| 92 |
+
self,
|
| 93 |
+
previous_text: str,
|
| 94 |
+
current_text: str,
|
| 95 |
+
delta_text: str,
|
| 96 |
+
previous_token_ids: Sequence[int],
|
| 97 |
+
current_token_ids: Sequence[int],
|
| 98 |
+
delta_token_ids: Sequence[int],
|
| 99 |
+
request: ChatCompletionRequest,
|
| 100 |
+
) -> Union[DeltaMessage, None]:
|
| 101 |
+
|
| 102 |
+
if not current_text.startswith("["):
|
| 103 |
+
return DeltaMessage(content=delta_text)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
valid_and_added_text = _make_valid_python(current_text)
|
| 107 |
+
if valid_and_added_text is None:
|
| 108 |
+
return None
|
| 109 |
+
valid_text, added_text = valid_and_added_text
|
| 110 |
+
|
| 111 |
+
module = ast.parse(valid_text)
|
| 112 |
+
parsed = getattr(module.body[0], "value", None)
|
| 113 |
+
if not isinstance(parsed, ast.List) or not all(
|
| 114 |
+
isinstance(e, ast.Call) for e in parsed.elts):
|
| 115 |
+
raise _UnexpectedAstError(
|
| 116 |
+
"Tool output must be a list of function calls")
|
| 117 |
+
tool_calls = [
|
| 118 |
+
_handle_single_tool(e) # type: ignore
|
| 119 |
+
for e in parsed.elts
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
tool_deltas = []
|
| 123 |
+
for index, new_call in enumerate(tool_calls):
|
| 124 |
+
if index < self.current_tool_index:
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
self.current_tool_index = index
|
| 128 |
+
if len(self.streamed_args_for_tool) == index:
|
| 129 |
+
self.streamed_args_for_tool.append("")
|
| 130 |
+
|
| 131 |
+
new_call_complete = index < len(
|
| 132 |
+
tool_calls) - 1 or ")]" not in added_text
|
| 133 |
+
if new_call_complete:
|
| 134 |
+
self.current_tool_index += 1
|
| 135 |
+
|
| 136 |
+
withheld_suffix = (added_text[:-2]
|
| 137 |
+
if not new_call_complete else "")
|
| 138 |
+
if not new_call_complete and added_text[-2] == ")":
|
| 139 |
+
# Function call is incomplete. Withhold the closing bracket.
|
| 140 |
+
withheld_suffix = withheld_suffix + "}"
|
| 141 |
+
# Strings get single quotes in the model-produced string.
|
| 142 |
+
# JSON requires double quotes.
|
| 143 |
+
withheld_suffix = withheld_suffix.replace("'", '"')
|
| 144 |
+
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
|
| 145 |
+
new_call, index, withheld_suffix)
|
| 146 |
+
|
| 147 |
+
if delta is not None:
|
| 148 |
+
tool_deltas.append(delta)
|
| 149 |
+
if (delta.function is not None
|
| 150 |
+
and delta.function.arguments is not None):
|
| 151 |
+
self.streamed_args_for_tool[
|
| 152 |
+
index] += delta.function.arguments
|
| 153 |
+
|
| 154 |
+
# HACK: serving_chat.py inspects the internal state of tool parsers
|
| 155 |
+
# when determining it's final streaming delta, automatically
|
| 156 |
+
# adding autocompleted JSON.
|
| 157 |
+
# These two lines avoid that nonsense while ensuring finish_reason
|
| 158 |
+
# is set to tool_calls when at least one tool is called.
|
| 159 |
+
if tool_deltas and not self.prev_tool_call_arr:
|
| 160 |
+
self.prev_tool_call_arr = [{"arguments": {}}]
|
| 161 |
+
|
| 162 |
+
if tool_deltas:
|
| 163 |
+
return DeltaMessage(tool_calls=tool_deltas)
|
| 164 |
+
elif not added_text and self.current_tool_id > 0:
|
| 165 |
+
# Return an empty DeltaMessage once the tool calls are all done
|
| 166 |
+
# so that finish_reason gets set.
|
| 167 |
+
return DeltaMessage(content='')
|
| 168 |
+
else:
|
| 169 |
+
return None
|
| 170 |
+
except Exception:
|
| 171 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 172 |
+
logger.debug(
|
| 173 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 174 |
+
"error")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _get_parameter_value(val: ast.expr) -> Any:
|
| 179 |
+
if isinstance(val, ast.Constant):
|
| 180 |
+
return val.value
|
| 181 |
+
elif isinstance(val, ast.Dict):
|
| 182 |
+
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
| 183 |
+
raise _UnexpectedAstError(
|
| 184 |
+
"Dict tool call arguments must have literal keys")
|
| 185 |
+
return {
|
| 186 |
+
k.value: _get_parameter_value(v) # type: ignore
|
| 187 |
+
for k, v in zip(val.keys, val.values)
|
| 188 |
+
}
|
| 189 |
+
elif isinstance(val, ast.List):
|
| 190 |
+
return [_get_parameter_value(v) for v in val.elts]
|
| 191 |
+
else:
|
| 192 |
+
raise _UnexpectedAstError("Tool call arguments must be literals")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
| 196 |
+
if not isinstance(call.func, ast.Name):
|
| 197 |
+
raise _UnexpectedAstError("Invalid tool call name")
|
| 198 |
+
function_name = call.func.id
|
| 199 |
+
arguments = {}
|
| 200 |
+
for keyword in call.keywords:
|
| 201 |
+
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
| 202 |
+
return ToolCall(type="function",
|
| 203 |
+
function=FunctionCall(name=function_name,
|
| 204 |
+
arguments=json.dumps(arguments)))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _make_valid_python(text: str) -> Union[Tuple[str, str], None]:
|
| 208 |
+
bracket_stack = []
|
| 209 |
+
for index, char in enumerate(text):
|
| 210 |
+
if char in {"[", "(", "{"}:
|
| 211 |
+
bracket_stack.append(char)
|
| 212 |
+
elif char == "]":
|
| 213 |
+
if not bracket_stack or bracket_stack.pop() != "[":
|
| 214 |
+
raise _UnexpectedAstError("Mismatched square brackets")
|
| 215 |
+
elif char == ")":
|
| 216 |
+
if not bracket_stack or bracket_stack.pop() != "(":
|
| 217 |
+
raise _UnexpectedAstError("Mismatched parentheses")
|
| 218 |
+
elif char == "}":
|
| 219 |
+
if not bracket_stack or bracket_stack.pop() != "{":
|
| 220 |
+
raise _UnexpectedAstError("Mismatched curly braces")
|
| 221 |
+
elif char in {"'", '"'}:
|
| 222 |
+
if bracket_stack and bracket_stack[-1] == char:
|
| 223 |
+
if index > 0 and text[index - 1] == "\\":
|
| 224 |
+
# Treat an escaped quote as a regular character
|
| 225 |
+
pass
|
| 226 |
+
else:
|
| 227 |
+
bracket_stack.pop()
|
| 228 |
+
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
| 229 |
+
# Double quote within a single quote string or vice versa.
|
| 230 |
+
pass
|
| 231 |
+
else:
|
| 232 |
+
bracket_stack.append(char)
|
| 233 |
+
|
| 234 |
+
text = text.rstrip()
|
| 235 |
+
if text.endswith("=") or text.endswith(":"):
|
| 236 |
+
# Since we have no type information for this property/parameter value,
|
| 237 |
+
# we can't fill in a valid value.
|
| 238 |
+
return None
|
| 239 |
+
if bracket_stack and bracket_stack[-1] == "{":
|
| 240 |
+
trailing_dict_text = text[:text.rfind("{")]
|
| 241 |
+
num_keys = trailing_dict_text.count(":")
|
| 242 |
+
num_values = trailing_dict_text.count(",")
|
| 243 |
+
if num_keys <= num_values:
|
| 244 |
+
return None # Incomplete property name within parameter value
|
| 245 |
+
if bracket_stack and bracket_stack[-1] == "(":
|
| 246 |
+
trailing_params_text = text[:text.rfind("(")]
|
| 247 |
+
num_full_param_names = trailing_params_text.count("=")
|
| 248 |
+
num_full_param_values = trailing_params_text.count(",")
|
| 249 |
+
if num_full_param_names <= num_full_param_values:
|
| 250 |
+
return None # Incomplete parameter name
|
| 251 |
+
if text.endswith(","):
|
| 252 |
+
text = text[:-1]
|
| 253 |
+
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
|
| 254 |
+
"[") and not text.endswith(")"):
|
| 255 |
+
return None # Incomplete function name
|
| 256 |
+
|
| 257 |
+
added_text = ""
|
| 258 |
+
for char in reversed(bracket_stack):
|
| 259 |
+
if char == "[":
|
| 260 |
+
added_text += "]"
|
| 261 |
+
elif char == "(":
|
| 262 |
+
added_text += ")"
|
| 263 |
+
elif char == "{":
|
| 264 |
+
added_text += "}"
|
| 265 |
+
elif char == "'":
|
| 266 |
+
added_text += "'"
|
| 267 |
+
elif char == '"':
|
| 268 |
+
added_text += '"'
|
| 269 |
+
|
| 270 |
+
return text + added_text, added_text
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
| 274 |
+
index: int,
|
| 275 |
+
withheld_suffix: str) -> Union[DeltaToolCall, None]:
|
| 276 |
+
new_call_args = new_call.function.arguments
|
| 277 |
+
if withheld_suffix:
|
| 278 |
+
assert new_call_args.endswith(withheld_suffix)
|
| 279 |
+
new_call_args = new_call_args[:-len(withheld_suffix)]
|
| 280 |
+
if not previously_sent_args:
|
| 281 |
+
return DeltaToolCall(id=new_call.id,
|
| 282 |
+
index=index,
|
| 283 |
+
function=DeltaFunctionCall(
|
| 284 |
+
name=new_call.function.name,
|
| 285 |
+
arguments=new_call_args,
|
| 286 |
+
))
|
| 287 |
+
|
| 288 |
+
arg_diff = new_call_args[len(previously_sent_args):]
|
| 289 |
+
return DeltaToolCall(
|
| 290 |
+
id="", index=index, function=DeltaFunctionCall(
|
| 291 |
+
arguments=arg_diff)) if arg_diff else None
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
| 4 |
+
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"PunicaWrapperBase",
|
| 8 |
+
"get_punica_wrapper",
|
| 9 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (431 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc
ADDED
|
Binary file (4.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc
ADDED
|
Binary file (1.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Based on:
|
| 4 |
+
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
| 5 |
+
Punica: Multi-Tenant LoRA Serving.
|
| 6 |
+
https://arxiv.org/abs/2310.18547
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from .utils import compute_meta, convert_mapping
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
# avoid circuit import
|
| 18 |
+
from vllm.lora.layers import LoRAMapping
|
| 19 |
+
from vllm.lora.models import LongContextLoRAContext
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PunicaWrapperABC(ABC):
|
| 23 |
+
"""
|
| 24 |
+
PunicaWrapper ABC.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def update_metadata(
|
| 29 |
+
self,
|
| 30 |
+
mapping: "LoRAMapping",
|
| 31 |
+
lora_index_to_id: List[Optional[int]],
|
| 32 |
+
max_loras: int,
|
| 33 |
+
vocab_size: int,
|
| 34 |
+
extra_vocab_size: int,
|
| 35 |
+
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
| 36 |
+
**kwargs,
|
| 37 |
+
) -> None:
|
| 38 |
+
"""
|
| 39 |
+
Update the lora-related metadata
|
| 40 |
+
"""
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def add_shrink(
|
| 45 |
+
self,
|
| 46 |
+
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 47 |
+
x: torch.Tensor,
|
| 48 |
+
lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 49 |
+
scale: float,
|
| 50 |
+
**kwargs,
|
| 51 |
+
) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Performs GEMM for multiple slices of lora_a.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def add_expand(
|
| 60 |
+
self,
|
| 61 |
+
y: torch.Tensor,
|
| 62 |
+
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 63 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 64 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 65 |
+
output_slices: Tuple[int, ...],
|
| 66 |
+
offset_start: int = 0,
|
| 67 |
+
add_inputs=True,
|
| 68 |
+
**kwargs,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
Performs GEMM and bias addition for multiple slices of lora_b.
|
| 72 |
+
"""
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def add_lora_embedding(
|
| 77 |
+
self,
|
| 78 |
+
y: torch.Tensor,
|
| 79 |
+
x: torch.Tensor,
|
| 80 |
+
lora_b_stacked: torch.Tensor,
|
| 81 |
+
add_inputs: bool = True,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> None:
|
| 84 |
+
"""
|
| 85 |
+
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
|
| 86 |
+
and this layer only requires the expand operation.
|
| 87 |
+
"""
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def add_lora_linear(self,
|
| 92 |
+
y: torch.Tensor,
|
| 93 |
+
x: torch.Tensor,
|
| 94 |
+
lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 95 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 96 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 97 |
+
scale: float,
|
| 98 |
+
output_slices: Tuple[int, ...],
|
| 99 |
+
*,
|
| 100 |
+
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
| 101 |
+
**kwargs) -> None:
|
| 102 |
+
"""
|
| 103 |
+
Applicable to linear-related lora.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
raise NotImplementedError
|
| 107 |
+
|
| 108 |
+
@abstractmethod
|
| 109 |
+
def add_lora_logits(self,
|
| 110 |
+
y: torch.Tensor,
|
| 111 |
+
x: torch.Tensor,
|
| 112 |
+
lora_a_stacked: torch.Tensor,
|
| 113 |
+
lora_b_stacked: torch.Tensor,
|
| 114 |
+
scale,
|
| 115 |
+
*,
|
| 116 |
+
buffer: Optional[torch.Tensor] = None,
|
| 117 |
+
**kwargs) -> None:
|
| 118 |
+
"""
|
| 119 |
+
Applies lora specifically for LogitsProcessorWithLoRA.
|
| 120 |
+
"""
|
| 121 |
+
raise NotImplementedError
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class PunicaWrapperBase(PunicaWrapperABC):
|
| 125 |
+
"""
|
| 126 |
+
PunicaWrapperBase is designed to manage and provide metadata for the punica
|
| 127 |
+
kernel. The main function is to maintain the state information for
|
| 128 |
+
Multi-LoRA, and to provide the interface for the punica.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
| 132 |
+
device: Union[torch.device, str], **kwargs):
|
| 133 |
+
self._token_lora_indices = torch.empty(max_num_batched_tokens,
|
| 134 |
+
dtype=torch.long,
|
| 135 |
+
device=device)
|
| 136 |
+
self._sampler_indices = torch.empty(max_num_batched_tokens,
|
| 137 |
+
dtype=torch.long,
|
| 138 |
+
device=device)
|
| 139 |
+
self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
|
| 140 |
+
dtype=torch.long,
|
| 141 |
+
device=device)
|
| 142 |
+
self._embeddings_indices = torch.empty(2,
|
| 143 |
+
max_num_batched_tokens,
|
| 144 |
+
dtype=torch.long,
|
| 145 |
+
device=device)
|
| 146 |
+
self._long_lora_indices = torch.empty(max_num_batched_tokens,
|
| 147 |
+
dtype=torch.long,
|
| 148 |
+
device=device)
|
| 149 |
+
|
| 150 |
+
# 5 is the number of indicies tensors.
|
| 151 |
+
# base_indices, sampler_indices, sampler_indices_padded,
|
| 152 |
+
# embeddings_indices,long_lora_indices
|
| 153 |
+
self.indices_len: List[Optional[int]] = [None] * 5
|
| 154 |
+
# these attributes are the information required for sgmv kernel
|
| 155 |
+
self._seq_start_locs = torch.empty(max_batches,
|
| 156 |
+
dtype=torch.long,
|
| 157 |
+
device=device)
|
| 158 |
+
self._seq_lengths = torch.empty(max_batches,
|
| 159 |
+
dtype=torch.long,
|
| 160 |
+
device=device)
|
| 161 |
+
self._lora_indices_per_batch = torch.empty(max_batches,
|
| 162 |
+
dtype=torch.long,
|
| 163 |
+
device=device)
|
| 164 |
+
self.device: torch.device = device
|
| 165 |
+
self.max_length: int = 0
|
| 166 |
+
self.token_nums: int = 0
|
| 167 |
+
self.batch_size: int = -1
|
| 168 |
+
self.is_prefill = False
|
| 169 |
+
self.no_lora = False
|
| 170 |
+
|
| 171 |
+
def _update_base_metadata(
|
| 172 |
+
self,
|
| 173 |
+
mapping: "LoRAMapping",
|
| 174 |
+
lora_index_to_id: List[Optional[int]],
|
| 175 |
+
max_loras: int,
|
| 176 |
+
vocab_size: int,
|
| 177 |
+
extra_vocab_size: int,
|
| 178 |
+
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
| 179 |
+
):
|
| 180 |
+
(
|
| 181 |
+
base_indices,
|
| 182 |
+
sampler_indices,
|
| 183 |
+
sampler_indices_padded,
|
| 184 |
+
embeddings_indices,
|
| 185 |
+
long_lora_offsets_tensor,
|
| 186 |
+
indices_len,
|
| 187 |
+
) = convert_mapping(
|
| 188 |
+
mapping,
|
| 189 |
+
lora_index_to_id,
|
| 190 |
+
max_loras,
|
| 191 |
+
vocab_size,
|
| 192 |
+
extra_vocab_size,
|
| 193 |
+
self.device,
|
| 194 |
+
long_lora_context,
|
| 195 |
+
)
|
| 196 |
+
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
| 197 |
+
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
| 198 |
+
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
| 199 |
+
sampler_indices_padded)
|
| 200 |
+
self._embeddings_indices[:embeddings_indices.
|
| 201 |
+
shape[0], :embeddings_indices.shape[1]].copy_(
|
| 202 |
+
embeddings_indices)
|
| 203 |
+
if long_lora_offsets_tensor is not None:
|
| 204 |
+
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
| 205 |
+
long_lora_offsets_tensor)
|
| 206 |
+
else:
|
| 207 |
+
self._long_lora_indices.zero_()
|
| 208 |
+
self.indices_len[:] = indices_len
|
| 209 |
+
|
| 210 |
+
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
| 211 |
+
|
| 212 |
+
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
| 213 |
+
batch_size, max_length, token_nums,
|
| 214 |
+
no_lora) = compute_meta(token_lora_tensor)
|
| 215 |
+
|
| 216 |
+
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
|
| 217 |
+
b_seq_start_tensor)
|
| 218 |
+
self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
|
| 219 |
+
self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
|
| 220 |
+
lora_indices_tensor)
|
| 221 |
+
self.batch_size = batch_size
|
| 222 |
+
self.max_length = max_length
|
| 223 |
+
self.token_nums = token_nums
|
| 224 |
+
self.no_lora = no_lora
|
| 225 |
+
|
| 226 |
+
def _apply_bias(
|
| 227 |
+
self,
|
| 228 |
+
indices: torch.Tensor,
|
| 229 |
+
output: torch.Tensor,
|
| 230 |
+
output_slices: Tuple[int, ...],
|
| 231 |
+
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
|
| 232 |
+
):
|
| 233 |
+
"""Applies bias to output
|
| 234 |
+
|
| 235 |
+
Input shapes:
|
| 236 |
+
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
|
| 237 |
+
indices: (batch_size)
|
| 238 |
+
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
| 239 |
+
output_slices: n-1 element tuple of (slice_size...),
|
| 240 |
+
where n is number of slices
|
| 241 |
+
"""
|
| 242 |
+
org_output = output
|
| 243 |
+
output = output.view(-1, output.shape[-1])
|
| 244 |
+
indices = indices.view(-1)
|
| 245 |
+
|
| 246 |
+
offset_left = 0
|
| 247 |
+
for slice_idx, slice in enumerate(output_slices):
|
| 248 |
+
bias = lora_bias_stacked[slice_idx]
|
| 249 |
+
if bias is not None:
|
| 250 |
+
bias = bias.view(-1, bias.shape[-1])
|
| 251 |
+
bias = bias[indices]
|
| 252 |
+
bias[indices == -1] = 0
|
| 253 |
+
output[:, offset_left:offset_left + slice] += bias
|
| 254 |
+
offset_left += slice
|
| 255 |
+
|
| 256 |
+
return output.view_as(org_output)
|
| 257 |
+
|
| 258 |
+
@property
|
| 259 |
+
def prefill_metadata(
|
| 260 |
+
self
|
| 261 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
| 262 |
+
"""
|
| 263 |
+
This property provides a convenient way to access the necessary
|
| 264 |
+
metadata for prefill-related kernel computations.
|
| 265 |
+
1. seq_start_locs: Tensor of sequence start positions.
|
| 266 |
+
2. seq_lengths: Tensor of sequence lengths.
|
| 267 |
+
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
| 268 |
+
-1 means no lora should be applied.
|
| 269 |
+
4. batch_size: Batch size after clustering identical lora indices.
|
| 270 |
+
5. max_length: The maximum sequence length in the batch.
|
| 271 |
+
6. token_nums: The token numbers in the batch.
|
| 272 |
+
"""
|
| 273 |
+
return (self._seq_start_locs[:self.batch_size],
|
| 274 |
+
self._seq_lengths[:self.batch_size],
|
| 275 |
+
self._lora_indices_per_batch[:self.batch_size],
|
| 276 |
+
self.batch_size, self.max_length, self.token_nums)
|
| 277 |
+
|
| 278 |
+
@property
|
| 279 |
+
def token_lora_indices(self) -> torch.Tensor:
|
| 280 |
+
"""
|
| 281 |
+
This property provides the lora indices corresponding to each token
|
| 282 |
+
in the batch. An index of -1 means no lora should be applied.
|
| 283 |
+
"""
|
| 284 |
+
token_lora_len = self.indices_len[0]
|
| 285 |
+
return self._token_lora_indices[:token_lora_len]
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def sampler_indices(self) -> torch.Tensor:
|
| 289 |
+
"""
|
| 290 |
+
This property is used to access the lora indices specifically for
|
| 291 |
+
LogitsProcessorWithLoRA.
|
| 292 |
+
"""
|
| 293 |
+
sampler_indices_len = self.indices_len[1]
|
| 294 |
+
return self._sampler_indices[:sampler_indices_len]
|
| 295 |
+
|
| 296 |
+
@property
|
| 297 |
+
def sampler_indices_padded(self) -> torch.Tensor:
|
| 298 |
+
"""
|
| 299 |
+
This property provides access to padded sampler indices.
|
| 300 |
+
"""
|
| 301 |
+
indices_padded_len = self.indices_len[2]
|
| 302 |
+
return self._sampler_indices_padded[:indices_padded_len]
|
| 303 |
+
|
| 304 |
+
@property
|
| 305 |
+
def embeddings_indices(self) -> torch.Tensor:
|
| 306 |
+
"""
|
| 307 |
+
This property provides access to the indices used for lora embeddings,
|
| 308 |
+
specifically for VocabParallelEmbeddingWithLoRA.
|
| 309 |
+
"""
|
| 310 |
+
embeddings_indices_len = self.indices_len[3]
|
| 311 |
+
return self._embeddings_indices[:, :embeddings_indices_len]
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def long_lora_indices(self) -> torch.Tensor:
|
| 315 |
+
"""
|
| 316 |
+
This property provides access to the indices used for long context
|
| 317 |
+
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
|
| 318 |
+
"""
|
| 319 |
+
long_lora_len = self.indices_len[4]
|
| 320 |
+
return self._long_lora_indices[:long_lora_len]
|
| 321 |
+
|
| 322 |
+
def update_metadata(
|
| 323 |
+
self,
|
| 324 |
+
mapping: "LoRAMapping",
|
| 325 |
+
lora_index_to_id: List[Optional[int]],
|
| 326 |
+
max_loras: int,
|
| 327 |
+
vocab_size: int,
|
| 328 |
+
extra_vocab_size: int,
|
| 329 |
+
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
| 330 |
+
**kwargs):
|
| 331 |
+
|
| 332 |
+
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
| 333 |
+
vocab_size, extra_vocab_size,
|
| 334 |
+
long_lora_context)
|
| 335 |
+
if mapping.is_prefill:
|
| 336 |
+
# Update metadata required for prefill-related operators.
|
| 337 |
+
self._update_prefill_metada(self.token_lora_indices)
|
| 338 |
+
self.is_prefill = True
|
| 339 |
+
else:
|
| 340 |
+
self.is_prefill = False
|
| 341 |
+
|
| 342 |
+
@abstractmethod
|
| 343 |
+
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 344 |
+
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 345 |
+
scale: float, **kwargs) -> None:
|
| 346 |
+
"""
|
| 347 |
+
Performs GEMM for multiple slices of lora_a.
|
| 348 |
+
|
| 349 |
+
Semantics:
|
| 350 |
+
for i in range(len(lora_a_stacked)):
|
| 351 |
+
y[i] += (x @ lora_a_stacked[i]) * scale
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
| 355 |
+
x (torch.Tensor): Input tensor
|
| 356 |
+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
| 357 |
+
scale (float): Scaling factor for the operation
|
| 358 |
+
|
| 359 |
+
"""
|
| 360 |
+
# TODO: implement it based on torch ops
|
| 361 |
+
raise NotImplementedError
|
| 362 |
+
|
| 363 |
+
@abstractmethod
|
| 364 |
+
def add_expand(self,
|
| 365 |
+
y: torch.Tensor,
|
| 366 |
+
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 367 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 368 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 369 |
+
output_slices: Tuple[int, ...],
|
| 370 |
+
offset_start: int = 0,
|
| 371 |
+
add_inputs=True,
|
| 372 |
+
**kwargs) -> None:
|
| 373 |
+
"""
|
| 374 |
+
Performs GEMM and bias addition for multiple slices of lora_b.
|
| 375 |
+
|
| 376 |
+
Semantics:
|
| 377 |
+
offset = offset_start
|
| 378 |
+
for i in range(len(lora_b_stacked)):
|
| 379 |
+
slice = output_slices[i]
|
| 380 |
+
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
| 381 |
+
lora_bias_stacked[i]
|
| 382 |
+
offset += slice
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
y (torch.Tensor): Output tensor.
|
| 386 |
+
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
| 387 |
+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
| 388 |
+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
| 389 |
+
bias's weight
|
| 390 |
+
output_slices (Tuple[int, ...]): Every slice's size
|
| 391 |
+
offset_start (int): The starting position of y, defaults to 0
|
| 392 |
+
add_inputs (bool): Defaults to True.
|
| 393 |
+
|
| 394 |
+
"""
|
| 395 |
+
# TODO: implement it based on torch ops
|
| 396 |
+
raise NotImplementedError
|
| 397 |
+
|
| 398 |
+
@abstractmethod
|
| 399 |
+
def add_lora_embedding(self,
|
| 400 |
+
y: torch.Tensor,
|
| 401 |
+
x: torch.Tensor,
|
| 402 |
+
lora_b_stacked: torch.Tensor,
|
| 403 |
+
add_inputs: bool = True,
|
| 404 |
+
**kwargs) -> None:
|
| 405 |
+
"""
|
| 406 |
+
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
| 407 |
+
and this layer only requires the expand operation.
|
| 408 |
+
Semantics:
|
| 409 |
+
y += x @ lora_b_stacked
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
y (torch.Tensor): Output tensor.
|
| 413 |
+
x (torch.Tensor): Input tensor.
|
| 414 |
+
lora_b_stacked (torch.Tensor): lora_b's weights.
|
| 415 |
+
add_inputs (bool): Default to True.
|
| 416 |
+
"""
|
| 417 |
+
# TODO: implement it based on torch ops
|
| 418 |
+
raise NotImplementedError
|
| 419 |
+
|
| 420 |
+
@abstractmethod
|
| 421 |
+
def add_lora_linear(self,
|
| 422 |
+
y: torch.Tensor,
|
| 423 |
+
x: torch.Tensor,
|
| 424 |
+
lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 425 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 426 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 427 |
+
scale: float,
|
| 428 |
+
output_slices: Tuple[int, ...],
|
| 429 |
+
*,
|
| 430 |
+
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
| 431 |
+
**kwargs) -> None:
|
| 432 |
+
"""
|
| 433 |
+
Applicable to linear-related lora.
|
| 434 |
+
|
| 435 |
+
Semantics:
|
| 436 |
+
for i in range(len(lora_a_stacked)):
|
| 437 |
+
y[i] += (
|
| 438 |
+
x[i].unsqueeze(0)
|
| 439 |
+
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
| 440 |
+
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
| 441 |
+
* scale
|
| 442 |
+
).squeeze(0)+lora_bias_stacked[i]
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
y (torch.Tensor): Output tensor. Will be changed in-place.
|
| 446 |
+
x (torch.Tensor): Input tensor
|
| 447 |
+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
| 448 |
+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
| 449 |
+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
| 450 |
+
scale (float): Scaling factor.
|
| 451 |
+
output_slices (Tuple[int, ...]): Every slice's size.
|
| 452 |
+
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
| 453 |
+
"""
|
| 454 |
+
# TODO: implement it based on torch ops
|
| 455 |
+
raise NotImplementedError
|
| 456 |
+
|
| 457 |
+
@abstractmethod
|
| 458 |
+
def add_lora_logits(self,
|
| 459 |
+
y: torch.Tensor,
|
| 460 |
+
x: torch.Tensor,
|
| 461 |
+
lora_a_stacked: torch.Tensor,
|
| 462 |
+
lora_b_stacked: torch.Tensor,
|
| 463 |
+
scale,
|
| 464 |
+
*,
|
| 465 |
+
buffer: Optional[torch.Tensor] = None,
|
| 466 |
+
**kwargs) -> None:
|
| 467 |
+
"""
|
| 468 |
+
Applies lora specifically for LogitsProcessorWithLoRA.
|
| 469 |
+
|
| 470 |
+
Semantics:
|
| 471 |
+
buffer = (x @ lora_a_stacked) * scale
|
| 472 |
+
y += buffer @ lora_b_stacked
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
y (torch.Tensor): Output tensor.
|
| 476 |
+
x (torch.Tensor): Input tensor.
|
| 477 |
+
lora_a_stacked (torch.Tensor): lora_a's weights.
|
| 478 |
+
lora_b_stacked (torch.Tensor):lora_b's weights.
|
| 479 |
+
scale (float): Scaling factor.
|
| 480 |
+
buffer (Optional[torch.Tensor]):Default to None.
|
| 481 |
+
"""
|
| 482 |
+
# TODO: implement it based on torch ops
|
| 483 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
| 8 |
+
bgmv_shrink, sgmv_expand,
|
| 9 |
+
sgmv_expand_slice, sgmv_shrink)
|
| 10 |
+
|
| 11 |
+
from .punica_base import PunicaWrapperBase
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# The platforms that are compatible with the PyTorch-native implementation can
|
| 15 |
+
# inherit this class
|
| 16 |
+
class PunicaWrapperCPU(PunicaWrapperBase):
|
| 17 |
+
"""
|
| 18 |
+
PunicaWrapperCPU is designed to manage and provide metadata for the punica
|
| 19 |
+
kernel. The main function is to maintain the state information for
|
| 20 |
+
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
| 24 |
+
device: Union[torch.device, str], **kwargs):
|
| 25 |
+
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
| 26 |
+
device)
|
| 27 |
+
|
| 28 |
+
def _shrink_prefill(
|
| 29 |
+
self,
|
| 30 |
+
y: torch.Tensor,
|
| 31 |
+
x: torch.Tensor,
|
| 32 |
+
w_t_all: torch.Tensor,
|
| 33 |
+
scale: float,
|
| 34 |
+
):
|
| 35 |
+
#No LoRA request, so return directly
|
| 36 |
+
if self.no_lora:
|
| 37 |
+
return
|
| 38 |
+
sgmv_shrink(
|
| 39 |
+
x,
|
| 40 |
+
w_t_all,
|
| 41 |
+
y,
|
| 42 |
+
*self.prefill_metadata,
|
| 43 |
+
scale,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def _shrink_decode(
|
| 47 |
+
self,
|
| 48 |
+
y: torch.Tensor,
|
| 49 |
+
x: torch.Tensor,
|
| 50 |
+
w_t_all: torch.Tensor,
|
| 51 |
+
scale: float,
|
| 52 |
+
):
|
| 53 |
+
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
| 54 |
+
|
| 55 |
+
def _expand_prefill(
|
| 56 |
+
self,
|
| 57 |
+
y: torch.Tensor,
|
| 58 |
+
x: torch.Tensor,
|
| 59 |
+
w_t_all: torch.Tensor,
|
| 60 |
+
add_inputs: bool,
|
| 61 |
+
):
|
| 62 |
+
#No LoRA request, so return directly
|
| 63 |
+
if self.no_lora:
|
| 64 |
+
return
|
| 65 |
+
sgmv_expand(
|
| 66 |
+
x,
|
| 67 |
+
w_t_all,
|
| 68 |
+
y,
|
| 69 |
+
*self.prefill_metadata,
|
| 70 |
+
add_inputs,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def _expand_decode(
|
| 74 |
+
self,
|
| 75 |
+
y: torch.Tensor,
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
+
w_t_all: torch.Tensor,
|
| 78 |
+
add_inputs: bool,
|
| 79 |
+
):
|
| 80 |
+
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
| 81 |
+
|
| 82 |
+
def _expand_slice_prefill(
|
| 83 |
+
self,
|
| 84 |
+
y: torch.Tensor,
|
| 85 |
+
x: torch.Tensor,
|
| 86 |
+
w_t_all: torch.Tensor,
|
| 87 |
+
y_offset: int,
|
| 88 |
+
y_slice_size: int,
|
| 89 |
+
add_inputs: bool,
|
| 90 |
+
):
|
| 91 |
+
#No LoRA request, so return directly
|
| 92 |
+
if self.no_lora:
|
| 93 |
+
return
|
| 94 |
+
sgmv_expand_slice(
|
| 95 |
+
x,
|
| 96 |
+
w_t_all,
|
| 97 |
+
y,
|
| 98 |
+
*self.prefill_metadata,
|
| 99 |
+
y_offset,
|
| 100 |
+
y_slice_size,
|
| 101 |
+
add_inputs,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _expand_slice_decode(
|
| 105 |
+
self,
|
| 106 |
+
y: torch.Tensor,
|
| 107 |
+
x: torch.Tensor,
|
| 108 |
+
w_t_all: torch.Tensor,
|
| 109 |
+
y_offset: int,
|
| 110 |
+
y_slice_size: int,
|
| 111 |
+
add_inputs: bool,
|
| 112 |
+
):
|
| 113 |
+
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
| 114 |
+
y_slice_size, add_inputs)
|
| 115 |
+
|
| 116 |
+
def _apply_expand(
|
| 117 |
+
self,
|
| 118 |
+
y: torch.Tensor,
|
| 119 |
+
x: torch.Tensor,
|
| 120 |
+
w_t_all: torch.Tensor,
|
| 121 |
+
y_offset: int,
|
| 122 |
+
y_slice_size: int,
|
| 123 |
+
add_inputs: bool = True,
|
| 124 |
+
):
|
| 125 |
+
"""
|
| 126 |
+
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
| 127 |
+
computation, which is suitable for the
|
| 128 |
+
GEMM of lora'b.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
expand_slice_fun: Callable = (self._expand_slice_prefill
|
| 132 |
+
if self.is_prefill else
|
| 133 |
+
self._expand_slice_decode)
|
| 134 |
+
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
| 135 |
+
|
| 136 |
+
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
| 137 |
+
w_t_all: torch.Tensor, scale: float):
|
| 138 |
+
"""
|
| 139 |
+
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
| 140 |
+
GEMM of lora'a.
|
| 141 |
+
When `is_prefill is` true, it indicates that it is currently the
|
| 142 |
+
prefill stage, and the `_shrink_prefill` function should be called.
|
| 143 |
+
Otherwise, it is the decode stage, and the _shrink_decode function
|
| 144 |
+
should be called.
|
| 145 |
+
"""
|
| 146 |
+
y_org = y
|
| 147 |
+
y = y.view(-1, y.shape[-1])
|
| 148 |
+
shrink_fun: Callable = (self._shrink_prefill
|
| 149 |
+
if self.is_prefill else self._shrink_decode)
|
| 150 |
+
shrink_fun(y, x, w_t_all, scale)
|
| 151 |
+
y = y.view_as(y_org)
|
| 152 |
+
|
| 153 |
+
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 154 |
+
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 155 |
+
scale: float, **kwargs):
|
| 156 |
+
"""
|
| 157 |
+
Performs GEMM for multiple slices of lora_a.
|
| 158 |
+
When `is_prefill is` true, it indicates that it is currently the
|
| 159 |
+
prefill stage, and the `_shrink_prefill` function should be called.
|
| 160 |
+
Otherwise, it is the decode stage, and the _shrink_decode function
|
| 161 |
+
should be called.
|
| 162 |
+
|
| 163 |
+
Semantics:
|
| 164 |
+
for i in range(len(lora_a_stacked)):
|
| 165 |
+
y[i] += (x @ lora_a_stacked[i]) * scale
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
| 169 |
+
x (torch.Tensor): Input tensor
|
| 170 |
+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
| 171 |
+
scale (float): Scaling factor for the operation
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
x = x.view(-1, x.shape[-1])
|
| 175 |
+
# TODO fuse these kernels
|
| 176 |
+
for slice_idx in range(len(lora_a_stacked)):
|
| 177 |
+
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
| 178 |
+
scale)
|
| 179 |
+
|
| 180 |
+
def add_expand(self,
|
| 181 |
+
y: torch.Tensor,
|
| 182 |
+
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 183 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 184 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 185 |
+
output_slices: Tuple[int, ...],
|
| 186 |
+
offset_start: int = 0,
|
| 187 |
+
add_inputs=True,
|
| 188 |
+
**kwargs) -> None:
|
| 189 |
+
"""
|
| 190 |
+
Performs GEMM and bias addition for multiple slices of lora_b.
|
| 191 |
+
|
| 192 |
+
Semantics:
|
| 193 |
+
for i in range(len(lora_b_stacked)):
|
| 194 |
+
slice = output_slices[i]
|
| 195 |
+
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
| 196 |
+
lora_bias_stacked[i]
|
| 197 |
+
offset += slice
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
y (torch.Tensor): Output tensor.
|
| 201 |
+
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
| 202 |
+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
| 203 |
+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
| 204 |
+
bias's weight
|
| 205 |
+
output_slices (Tuple[int, ...]): Every slice's size
|
| 206 |
+
add_inputs (bool): Defaults to True.
|
| 207 |
+
"""
|
| 208 |
+
y_org = y
|
| 209 |
+
y = y.view(-1, y.shape[-1])
|
| 210 |
+
offset_left = offset_start
|
| 211 |
+
if lora_bias_stacked is not None:
|
| 212 |
+
self._apply_bias(self.token_lora_indices, y, output_slices,
|
| 213 |
+
lora_bias_stacked)
|
| 214 |
+
for slice_idx in range(len(lora_b_stacked)):
|
| 215 |
+
self._apply_expand(
|
| 216 |
+
y,
|
| 217 |
+
x[slice_idx],
|
| 218 |
+
lora_b_stacked[slice_idx],
|
| 219 |
+
offset_left,
|
| 220 |
+
output_slices[slice_idx],
|
| 221 |
+
add_inputs=add_inputs,
|
| 222 |
+
)
|
| 223 |
+
offset_left += output_slices[slice_idx]
|
| 224 |
+
y = y.view_as(y_org)
|
| 225 |
+
|
| 226 |
+
def add_lora_embedding(self,
|
| 227 |
+
y: torch.Tensor,
|
| 228 |
+
x: torch.Tensor,
|
| 229 |
+
lora_b_stacked: torch.Tensor,
|
| 230 |
+
add_inputs: bool = True,
|
| 231 |
+
**kwargs) -> None:
|
| 232 |
+
"""
|
| 233 |
+
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
| 234 |
+
|
| 235 |
+
Semantics:
|
| 236 |
+
y += x @ lora_b_stacked
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
y (torch.Tensor): Output tensor.
|
| 240 |
+
x (torch.Tensor): Input tensor.
|
| 241 |
+
lora_b_stacked (torch.Tensor): lora_b's weights.
|
| 242 |
+
add_inputs (bool): Default to True.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
# Embedding layer only need expand op
|
| 246 |
+
expand_fun: Callable = (self._expand_prefill
|
| 247 |
+
if self.is_prefill else self._expand_decode)
|
| 248 |
+
expand_fun(y, x, lora_b_stacked, add_inputs)
|
| 249 |
+
|
| 250 |
+
def add_lora_linear(self,
|
| 251 |
+
y: torch.Tensor,
|
| 252 |
+
x: torch.Tensor,
|
| 253 |
+
lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 254 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 255 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 256 |
+
scale: float,
|
| 257 |
+
output_slices: Tuple[int, ...],
|
| 258 |
+
*,
|
| 259 |
+
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
| 260 |
+
**kwargs) -> None:
|
| 261 |
+
"""
|
| 262 |
+
Applicable to linear-related lora.
|
| 263 |
+
|
| 264 |
+
Semantics:
|
| 265 |
+
for i in range(len(lora_a_stacked)):
|
| 266 |
+
y[i] += (
|
| 267 |
+
x[i].unsqueeze(0)
|
| 268 |
+
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
| 269 |
+
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
| 270 |
+
* scale
|
| 271 |
+
).squeeze(0)+lora_bias_stacked[i]
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
y (torch.Tensor): Output tensor. Will be changed in-place.
|
| 275 |
+
x (torch.Tensor): Input tensor
|
| 276 |
+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
| 277 |
+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
| 278 |
+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
| 279 |
+
scale (float): Scaling factor.
|
| 280 |
+
output_slices (Tuple[int, ...]): Every slice's size.
|
| 281 |
+
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
| 285 |
+
if lora_bias_stacked is not None:
|
| 286 |
+
assert len(lora_bias_stacked) == len(output_slices)
|
| 287 |
+
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
| 288 |
+
lora_bias_stacked)
|
| 289 |
+
|
| 290 |
+
if buffer is None:
|
| 291 |
+
r = lora_b_stacked[0].size(-1)
|
| 292 |
+
# We set the buffer to be float32 by default, consistent with the
|
| 293 |
+
# triton op
|
| 294 |
+
buffer = tuple(
|
| 295 |
+
torch.zeros(
|
| 296 |
+
(x.size(0), r), dtype=torch.float32, device=x.device)
|
| 297 |
+
for _ in range(len(output_slices)))
|
| 298 |
+
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
| 299 |
+
self.add_expand(y,
|
| 300 |
+
buffer,
|
| 301 |
+
lora_b_stacked,
|
| 302 |
+
None,
|
| 303 |
+
output_slices,
|
| 304 |
+
add_inputs=True,
|
| 305 |
+
**kwargs)
|
| 306 |
+
|
| 307 |
+
def add_lora_logits(self,
|
| 308 |
+
y: torch.Tensor,
|
| 309 |
+
x: torch.Tensor,
|
| 310 |
+
lora_a_stacked: torch.Tensor,
|
| 311 |
+
lora_b_stacked: torch.Tensor,
|
| 312 |
+
scale,
|
| 313 |
+
*,
|
| 314 |
+
buffer: Optional[torch.Tensor] = None,
|
| 315 |
+
**kwargs) -> None:
|
| 316 |
+
"""
|
| 317 |
+
Applies lora specifically for LogitsProcessorWithLoRA.
|
| 318 |
+
|
| 319 |
+
Semantics:
|
| 320 |
+
buffer = (x @ lora_a_stacked) * scale
|
| 321 |
+
y += buffer @ lora_b_stacked
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
y (torch.Tensor): Output tensor.
|
| 325 |
+
x (torch.Tensor): Input tensor.
|
| 326 |
+
lora_a_stacked (torch.Tensor): lora_a's weights.
|
| 327 |
+
lora_b_stacked (torch.Tensor):lora_b's weights.
|
| 328 |
+
scale (float): Scaling factor.
|
| 329 |
+
buffer (Optional[torch.Tensor]):Default to None.
|
| 330 |
+
"""
|
| 331 |
+
y_org = y
|
| 332 |
+
y = y.view(-1, y.shape[-1])
|
| 333 |
+
x = x.view(-1, x.shape[-1])
|
| 334 |
+
r = lora_b_stacked.size(-1)
|
| 335 |
+
if buffer is None:
|
| 336 |
+
# We set the buffer to be float32 by default, consistent with the
|
| 337 |
+
# triton op
|
| 338 |
+
buffer = torch.zeros((x.size(0), r),
|
| 339 |
+
dtype=torch.float32,
|
| 340 |
+
device=x.device)
|
| 341 |
+
# LogitsProcessorWithLoRA always using bgmv.
|
| 342 |
+
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
| 343 |
+
bgmv_expand(buffer,
|
| 344 |
+
lora_b_stacked,
|
| 345 |
+
y,
|
| 346 |
+
self.sampler_indices,
|
| 347 |
+
add_inputs=True)
|
| 348 |
+
y = y.view_as(y_org)
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple, Union, final
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
|
| 7 |
+
dispatch_bgmv_linear)
|
| 8 |
+
|
| 9 |
+
from .punica_base import PunicaWrapperBase
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@final
|
| 13 |
+
class PunicaWrapperHPU(PunicaWrapperBase):
|
| 14 |
+
|
| 15 |
+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
| 16 |
+
device: Union[torch.device, str], **kwargs):
|
| 17 |
+
# Increasing max_num_batched_tokens by 3x to handle increase in
|
| 18 |
+
# tensor size due to padding.
|
| 19 |
+
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
|
| 20 |
+
max_batches, device)
|
| 21 |
+
|
| 22 |
+
def add_lora_embedding(self,
|
| 23 |
+
y: torch.Tensor,
|
| 24 |
+
x: torch.Tensor,
|
| 25 |
+
lora_b_stacked: torch.Tensor,
|
| 26 |
+
add_inputs: bool = True,
|
| 27 |
+
**kwargs) -> None:
|
| 28 |
+
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
|
| 29 |
+
|
| 30 |
+
def add_lora_linear(self,
|
| 31 |
+
y: torch.Tensor,
|
| 32 |
+
x: torch.Tensor,
|
| 33 |
+
lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 34 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 35 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 36 |
+
scale: float,
|
| 37 |
+
output_slices: Tuple[int, ...],
|
| 38 |
+
*,
|
| 39 |
+
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
| 40 |
+
**kwargs) -> None:
|
| 41 |
+
y_org = y
|
| 42 |
+
x = x.view(-1, x.shape[-1])
|
| 43 |
+
y = y.view(-1, y.shape[-1])
|
| 44 |
+
offset_left = 0
|
| 45 |
+
|
| 46 |
+
for slice_idx in range(len(output_slices)):
|
| 47 |
+
dispatch_bgmv_linear(
|
| 48 |
+
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
|
| 49 |
+
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
|
| 50 |
+
offset_left += output_slices[slice_idx]
|
| 51 |
+
y = y.view_as(y_org)
|
| 52 |
+
|
| 53 |
+
def add_lora_logits(self,
|
| 54 |
+
y: torch.Tensor,
|
| 55 |
+
x: torch.Tensor,
|
| 56 |
+
lora_a_stacked: torch.Tensor,
|
| 57 |
+
lora_b_stacked: torch.Tensor,
|
| 58 |
+
scale,
|
| 59 |
+
*,
|
| 60 |
+
buffer: Optional[torch.Tensor] = None,
|
| 61 |
+
**kwargs) -> None:
|
| 62 |
+
y_org = y
|
| 63 |
+
y = y.view(-1, y.shape[-1])
|
| 64 |
+
x = x.view(-1, x.shape[-1])
|
| 65 |
+
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
|
| 66 |
+
y = y.view_as(y_org)
|
| 67 |
+
|
| 68 |
+
def add_shrink(
|
| 69 |
+
self,
|
| 70 |
+
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 71 |
+
x: torch.Tensor,
|
| 72 |
+
lora_a_stacked: Tuple[torch.Tensor, ...],
|
| 73 |
+
scale: float,
|
| 74 |
+
**kwargs,
|
| 75 |
+
) -> None:
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
def add_expand(
|
| 79 |
+
self,
|
| 80 |
+
y: torch.Tensor,
|
| 81 |
+
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
| 82 |
+
lora_b_stacked: Tuple[torch.Tensor, ...],
|
| 83 |
+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
| 84 |
+
output_slices: Tuple[int, ...],
|
| 85 |
+
offset_start: int = 0,
|
| 86 |
+
add_inputs=True,
|
| 87 |
+
**kwargs,
|
| 88 |
+
) -> None:
|
| 89 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.logger import init_logger
|
| 4 |
+
from vllm.platforms import current_platform
|
| 5 |
+
from vllm.utils import resolve_obj_by_qualname
|
| 6 |
+
|
| 7 |
+
from .punica_base import PunicaWrapperBase
|
| 8 |
+
|
| 9 |
+
logger = init_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
| 13 |
+
punica_wrapper_qualname = current_platform.get_punica_wrapper()
|
| 14 |
+
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
|
| 15 |
+
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
|
| 16 |
+
assert punica_wrapper is not None, \
|
| 17 |
+
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
|
| 18 |
+
logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
|
| 19 |
+
".")
|
| 20 |
+
return punica_wrapper
|
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
# avoid circuit import
|
| 9 |
+
from vllm.lora.layers import LoRAMapping
|
| 10 |
+
from vllm.lora.models import LongContextLoRAContext
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_meta(
|
| 14 |
+
token_lora_tensor: torch.Tensor
|
| 15 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
|
| 16 |
+
"""
|
| 17 |
+
Get the information required for the sgmv kernel. With the features:
|
| 18 |
+
1. If consecutive requests in the batch use the same LoRA, this function
|
| 19 |
+
will combine them into a single request, improving sgmv kernel inference
|
| 20 |
+
performance.
|
| 21 |
+
2. At the beginning of each prefill stage inference, recalculations are
|
| 22 |
+
needed based on the input, but only once.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
|
| 26 |
+
token_lora_tensor, return_counts=True)
|
| 27 |
+
cum_result = torch.cumsum(seq_length_tensor, dim=0)
|
| 28 |
+
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
| 29 |
+
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
| 30 |
+
max_length = seq_length_tensor.max().item()
|
| 31 |
+
token_nums = seq_length_tensor.sum().item()
|
| 32 |
+
batch_size = lora_indices_tensor.size(0)
|
| 33 |
+
no_lora = False
|
| 34 |
+
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
| 35 |
+
# the current step requires LoRA. If LoRA is not needed, the prefill stage
|
| 36 |
+
# does not need to launch the triton kernel, which can improve performance
|
| 37 |
+
if batch_size == 1 and lora_indices_tensor == -1:
|
| 38 |
+
no_lora = True
|
| 39 |
+
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
| 40 |
+
batch_size, max_length, token_nums, no_lora)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# TODO see if this can be vectorized
|
| 44 |
+
def convert_mapping(
|
| 45 |
+
mapping: "LoRAMapping",
|
| 46 |
+
lora_index_to_id: List[Optional[int]],
|
| 47 |
+
max_loras: int,
|
| 48 |
+
vocab_size: int,
|
| 49 |
+
extra_vocab_size: int,
|
| 50 |
+
device: torch.device,
|
| 51 |
+
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
| 52 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
| 53 |
+
Optional[torch.Tensor], List[int]]:
|
| 54 |
+
"""Converts LoRAMapping to index tensors.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
| 58 |
+
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
| 59 |
+
max_loras: Maximum number of LoRAs.
|
| 60 |
+
vocab_size: Model vocab size.
|
| 61 |
+
extra_vocab_size: Extra vocab size each LoRA can have.
|
| 62 |
+
long_lora_context: Passed if there are long context lora in a batch.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
A tuple of tensors:
|
| 66 |
+
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
| 67 |
+
LoRA indices.
|
| 68 |
+
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
| 69 |
+
LoRA indices for sampler. For generation, this will be the
|
| 70 |
+
same as base_indicies. For prefill, this will map requests
|
| 71 |
+
to LoRA indices.
|
| 72 |
+
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
| 73 |
+
requests to LoRA indices for sampler with padding.
|
| 74 |
+
Same as sampler_indicies, but -1 is replaced with
|
| 75 |
+
max_loras.
|
| 76 |
+
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
| 77 |
+
requests to embedding indices. First row is for embeddings
|
| 78 |
+
added by the LoRAs, second row is for the LoRA.lora_a
|
| 79 |
+
embeddings.
|
| 80 |
+
long_lora_indices: Tensor of shape [batch_size] mapping
|
| 81 |
+
requests to RoPE offsets and rot dims for long LoRAs.
|
| 82 |
+
None if long context lora doesn't exist.
|
| 83 |
+
indices_len: List of lengths of the above tensors. It contains
|
| 84 |
+
(base_indices, sampler_indices, sampler_indices_padded,
|
| 85 |
+
embeddings_indices, long_lora_indices).
|
| 86 |
+
"""
|
| 87 |
+
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
| 88 |
+
embedding_indices = index_mapping_indices.copy()
|
| 89 |
+
lora_indices = index_mapping_indices.copy()
|
| 90 |
+
long_lora_offsets: Optional[torch.Tensor] = None
|
| 91 |
+
if long_lora_context:
|
| 92 |
+
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
| 93 |
+
device=device,
|
| 94 |
+
dtype=torch.long)
|
| 95 |
+
prompt_mapping: List[int] = [
|
| 96 |
+
lora_index_to_id.index(x) if x > 0 else -1
|
| 97 |
+
for x in mapping.prompt_mapping
|
| 98 |
+
]
|
| 99 |
+
lora_idx = None
|
| 100 |
+
for i in range(len(index_mapping_indices)):
|
| 101 |
+
# TODO index can be slow. optimize
|
| 102 |
+
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
| 103 |
+
if index_mapping_indices[i] > 0 else -1)
|
| 104 |
+
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
| 105 |
+
lora_indices[i] = lora_idx
|
| 106 |
+
if long_lora_context:
|
| 107 |
+
assert long_lora_offsets is not None
|
| 108 |
+
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
| 109 |
+
index_mapping_indices[i], 0)
|
| 110 |
+
long_lora_offsets[i] = lora_offset
|
| 111 |
+
|
| 112 |
+
indices_list: List[Union[List[int], torch.Tensor]] = [
|
| 113 |
+
index_mapping_indices,
|
| 114 |
+
lora_indices,
|
| 115 |
+
embedding_indices,
|
| 116 |
+
]
|
| 117 |
+
if long_lora_context:
|
| 118 |
+
assert long_lora_offsets is not None
|
| 119 |
+
indices_list.append(long_lora_offsets)
|
| 120 |
+
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
| 121 |
+
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
| 122 |
+
dtype=torch.long,
|
| 123 |
+
device=device)
|
| 124 |
+
embeddings_indices = torch.stack([
|
| 125 |
+
indices[2] * extra_vocab_size,
|
| 126 |
+
indices[2] * (vocab_size + extra_vocab_size),
|
| 127 |
+
])
|
| 128 |
+
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
| 129 |
+
base_indices = indices[1]
|
| 130 |
+
sampler_indices = prompt_mapping_tensor
|
| 131 |
+
sampler_indices_padded = sampler_indices.clone()
|
| 132 |
+
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
| 133 |
+
sampler_indices_padded = torch.arange(
|
| 134 |
+
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
| 135 |
+
sampler_indices_padded * len(sampler_indices_padded))
|
| 136 |
+
long_lora_indices = None
|
| 137 |
+
long_lora_indices_len: Optional[int] = None
|
| 138 |
+
if long_lora_context:
|
| 139 |
+
long_lora_indices = indices[3]
|
| 140 |
+
long_lora_indices_len = long_lora_indices.shape[-1]
|
| 141 |
+
# Contain length of indices tensors. Used to index into each tensor.
|
| 142 |
+
indices_len = [
|
| 143 |
+
base_indices.shape[-1],
|
| 144 |
+
sampler_indices.shape[-1],
|
| 145 |
+
sampler_indices_padded.shape[-1],
|
| 146 |
+
embeddings_indices.shape[-1],
|
| 147 |
+
]
|
| 148 |
+
if long_lora_indices_len is not None:
|
| 149 |
+
indices_len.append(long_lora_indices_len)
|
| 150 |
+
else:
|
| 151 |
+
# If long_lora doesn't exist,append None
|
| 152 |
+
indices_len.append(None)
|
| 153 |
+
|
| 154 |
+
return (
|
| 155 |
+
base_indices,
|
| 156 |
+
sampler_indices,
|
| 157 |
+
sampler_indices_padded,
|
| 158 |
+
embeddings_indices,
|
| 159 |
+
long_lora_indices,
|
| 160 |
+
indices_len,
|
| 161 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/v1/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (180 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc
ADDED
|
Binary file (8.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc
ADDED
|
Binary file (842 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (190 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Attention layer with FlashAttention."""
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 12 |
+
AttentionMetadata, AttentionType)
|
| 13 |
+
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.platforms import current_platform
|
| 16 |
+
from vllm.utils import cdiv
|
| 17 |
+
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
| 18 |
+
flash_attn_varlen_func,
|
| 19 |
+
is_fa_version_supported)
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FlashAttentionBackend(AttentionBackend):
|
| 25 |
+
|
| 26 |
+
accept_output_buffer: bool = True
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def get_supported_head_sizes() -> List[int]:
|
| 30 |
+
return [32, 64, 96, 128, 160, 192, 224, 256]
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def get_name() -> str:
|
| 34 |
+
return "FLASH_ATTN_VLLM_V1"
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
| 38 |
+
return FlashAttentionImpl
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 42 |
+
return FlashAttentionMetadata
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def get_kv_cache_shape(
|
| 46 |
+
num_blocks: int,
|
| 47 |
+
block_size: int,
|
| 48 |
+
num_kv_heads: int,
|
| 49 |
+
head_size: int,
|
| 50 |
+
) -> Tuple[int, ...]:
|
| 51 |
+
if block_size % 16 != 0:
|
| 52 |
+
raise ValueError("Block size must be a multiple of 16.")
|
| 53 |
+
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def use_cascade_attention(*args, **kwargs) -> bool:
|
| 57 |
+
return use_cascade_attention(*args, **kwargs)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class FlashAttentionMetadata:
|
| 62 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 63 |
+
# |---------- N-1 iteration --------|
|
| 64 |
+
# |---------------- N iteration ---------------------|
|
| 65 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 66 |
+
# |---------- context_len ----------|
|
| 67 |
+
# |-------------------- seq_len ---------------------|
|
| 68 |
+
# |-- query_len ---|
|
| 69 |
+
|
| 70 |
+
num_actual_tokens: int # Number of tokens excluding padding.
|
| 71 |
+
max_query_len: int
|
| 72 |
+
query_start_loc: torch.Tensor
|
| 73 |
+
max_seq_len: int
|
| 74 |
+
seq_lens: torch.Tensor
|
| 75 |
+
block_table: torch.Tensor
|
| 76 |
+
slot_mapping: torch.Tensor
|
| 77 |
+
|
| 78 |
+
# For cascade attention.
|
| 79 |
+
use_cascade: bool
|
| 80 |
+
common_prefix_len: int
|
| 81 |
+
cu_prefix_query_lens: Optional[torch.Tensor]
|
| 82 |
+
prefix_kv_lens: Optional[torch.Tensor]
|
| 83 |
+
suffix_kv_lens: Optional[torch.Tensor]
|
| 84 |
+
|
| 85 |
+
# For logging.
|
| 86 |
+
num_input_tokens: int = 0 # Number of tokens including padding.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class FlashAttentionImpl(AttentionImpl):
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
num_heads: int,
|
| 94 |
+
head_size: int,
|
| 95 |
+
scale: float,
|
| 96 |
+
num_kv_heads: int,
|
| 97 |
+
alibi_slopes: Optional[List[float]],
|
| 98 |
+
sliding_window: Optional[int],
|
| 99 |
+
kv_cache_dtype: str,
|
| 100 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 101 |
+
logits_soft_cap: Optional[float] = None,
|
| 102 |
+
attn_type: AttentionType = AttentionType.DECODER,
|
| 103 |
+
) -> None:
|
| 104 |
+
if blocksparse_params is not None:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"FlashAttention does not support block-sparse attention.")
|
| 107 |
+
self.num_heads = num_heads
|
| 108 |
+
self.head_size = head_size
|
| 109 |
+
self.scale = float(scale)
|
| 110 |
+
self.num_kv_heads = num_kv_heads
|
| 111 |
+
if alibi_slopes is not None:
|
| 112 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 113 |
+
self.alibi_slopes = alibi_slopes
|
| 114 |
+
if sliding_window is None:
|
| 115 |
+
self.sliding_window = (-1, -1)
|
| 116 |
+
else:
|
| 117 |
+
self.sliding_window = (sliding_window - 1, 0)
|
| 118 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 119 |
+
if logits_soft_cap is None:
|
| 120 |
+
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
| 121 |
+
logits_soft_cap = 0
|
| 122 |
+
self.logits_soft_cap = logits_soft_cap
|
| 123 |
+
|
| 124 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 125 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 126 |
+
|
| 127 |
+
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
| 128 |
+
if head_size not in support_head_sizes:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
f"Head size {head_size} is not supported by FlashAttention. "
|
| 131 |
+
f"Supported head sizes are: {support_head_sizes}.")
|
| 132 |
+
|
| 133 |
+
if attn_type != AttentionType.DECODER:
|
| 134 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 135 |
+
"encoder/decoder cross-attention "
|
| 136 |
+
"are not implemented for "
|
| 137 |
+
"FlashAttentionImpl")
|
| 138 |
+
|
| 139 |
+
# if hopper default to FA3, otherwise stick to FA2 for now
|
| 140 |
+
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
| 141 |
+
# use FA3 as default for both
|
| 142 |
+
if current_platform.get_device_capability()[0] >= 9:
|
| 143 |
+
self.fa_version = 3 if is_fa_version_supported(3) else 2
|
| 144 |
+
else:
|
| 145 |
+
self.fa_version = 2
|
| 146 |
+
|
| 147 |
+
if VLLM_FLASH_ATTN_VERSION is not None:
|
| 148 |
+
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
| 149 |
+
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
| 150 |
+
|
| 151 |
+
if not is_fa_version_supported(self.fa_version):
|
| 152 |
+
logger.error("Cannot use FA version %d is not supported due to %s",
|
| 153 |
+
self.fa_version,
|
| 154 |
+
fa_version_unsupported_reason(self.fa_version))
|
| 155 |
+
|
| 156 |
+
assert is_fa_version_supported(self.fa_version)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
layer: torch.nn.Module,
|
| 161 |
+
query: torch.Tensor,
|
| 162 |
+
key: torch.Tensor,
|
| 163 |
+
value: torch.Tensor,
|
| 164 |
+
kv_cache: torch.Tensor,
|
| 165 |
+
attn_metadata: FlashAttentionMetadata,
|
| 166 |
+
output: Optional[torch.Tensor] = None,
|
| 167 |
+
) -> torch.Tensor:
|
| 168 |
+
"""Forward pass with FlashAttention.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
query: shape = [num_tokens, num_heads, head_size]
|
| 172 |
+
key: shape = [num_tokens, num_kv_heads, head_size]
|
| 173 |
+
value: shape = [num_tokens, num_kv_heads, head_size]
|
| 174 |
+
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
| 175 |
+
attn_metadata: Metadata for attention.
|
| 176 |
+
Returns:
|
| 177 |
+
shape = [num_tokens, num_heads * head_size]
|
| 178 |
+
"""
|
| 179 |
+
assert output is not None, "Output tensor must be provided."
|
| 180 |
+
|
| 181 |
+
if attn_metadata is None:
|
| 182 |
+
# Profiling run.
|
| 183 |
+
return output
|
| 184 |
+
|
| 185 |
+
# IMPORTANT!
|
| 186 |
+
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
| 187 |
+
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
| 188 |
+
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
| 189 |
+
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
| 190 |
+
# Minimize the PyTorch ops in this method as much as possible.
|
| 191 |
+
# Whenever making a change in this method, please benchmark the
|
| 192 |
+
# performance to make sure it does not introduce any overhead.
|
| 193 |
+
|
| 194 |
+
num_actual_tokens = attn_metadata.num_actual_tokens
|
| 195 |
+
# Reshape the input keys and values and store them in the cache.
|
| 196 |
+
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
| 197 |
+
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
| 198 |
+
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
| 199 |
+
# the slot_mapping's shape to determine the number of actual tokens.
|
| 200 |
+
key_cache, value_cache = kv_cache.unbind(0)
|
| 201 |
+
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
| 202 |
+
key,
|
| 203 |
+
value,
|
| 204 |
+
key_cache,
|
| 205 |
+
value_cache,
|
| 206 |
+
attn_metadata.slot_mapping,
|
| 207 |
+
self.kv_cache_dtype,
|
| 208 |
+
layer._k_scale,
|
| 209 |
+
layer._v_scale,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Compute attention and update output up to `num_actual_tokens`.
|
| 213 |
+
if not attn_metadata.use_cascade:
|
| 214 |
+
# Regular attention (common case).
|
| 215 |
+
flash_attn_varlen_func(
|
| 216 |
+
q=query[:num_actual_tokens],
|
| 217 |
+
k=key_cache,
|
| 218 |
+
v=value_cache,
|
| 219 |
+
out=output[:num_actual_tokens],
|
| 220 |
+
cu_seqlens_q=attn_metadata.query_start_loc,
|
| 221 |
+
max_seqlen_q=attn_metadata.max_query_len,
|
| 222 |
+
seqused_k=attn_metadata.seq_lens,
|
| 223 |
+
max_seqlen_k=attn_metadata.max_seq_len,
|
| 224 |
+
softmax_scale=self.scale,
|
| 225 |
+
causal=True,
|
| 226 |
+
alibi_slopes=self.alibi_slopes,
|
| 227 |
+
window_size=self.sliding_window,
|
| 228 |
+
block_table=attn_metadata.block_table,
|
| 229 |
+
softcap=self.logits_soft_cap,
|
| 230 |
+
fa_version=self.fa_version,
|
| 231 |
+
)
|
| 232 |
+
return output
|
| 233 |
+
|
| 234 |
+
# Cascade attention (rare case).
|
| 235 |
+
cascade_attention(
|
| 236 |
+
output[:num_actual_tokens],
|
| 237 |
+
query[:num_actual_tokens],
|
| 238 |
+
key_cache,
|
| 239 |
+
value_cache,
|
| 240 |
+
cu_query_lens=attn_metadata.query_start_loc,
|
| 241 |
+
max_query_len=attn_metadata.max_query_len,
|
| 242 |
+
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
| 243 |
+
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
| 244 |
+
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
| 245 |
+
max_kv_len=attn_metadata.max_seq_len,
|
| 246 |
+
softmax_scale=self.scale,
|
| 247 |
+
alibi_slopes=self.alibi_slopes,
|
| 248 |
+
sliding_window=self.sliding_window,
|
| 249 |
+
logits_soft_cap=self.logits_soft_cap,
|
| 250 |
+
block_table=attn_metadata.block_table,
|
| 251 |
+
common_prefix_len=attn_metadata.common_prefix_len,
|
| 252 |
+
fa_version=self.fa_version,
|
| 253 |
+
)
|
| 254 |
+
return output
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def use_cascade_attention(
|
| 258 |
+
common_prefix_len: int,
|
| 259 |
+
query_lens: np.ndarray,
|
| 260 |
+
num_query_heads: int,
|
| 261 |
+
num_kv_heads: int,
|
| 262 |
+
use_alibi: bool,
|
| 263 |
+
use_sliding_window: bool,
|
| 264 |
+
num_sms: int,
|
| 265 |
+
) -> bool:
|
| 266 |
+
"""Decide whether to use cascade attention.
|
| 267 |
+
|
| 268 |
+
This function 1) checks whether cascade attention is supported with the
|
| 269 |
+
given configuration, and 2) heuristically decides whether using cascade
|
| 270 |
+
attention can improve performance.
|
| 271 |
+
"""
|
| 272 |
+
# Too short common prefix. Probably not worth using cascade attention.
|
| 273 |
+
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
| 274 |
+
# NOTE(woosuk): This is the common case. We should return False as soon as
|
| 275 |
+
# possible to avoid any unnecessary computation.
|
| 276 |
+
if common_prefix_len < 256:
|
| 277 |
+
return False
|
| 278 |
+
# Cascade attention is currently not supported with these variants.
|
| 279 |
+
if use_alibi or use_sliding_window:
|
| 280 |
+
return False
|
| 281 |
+
# Too few queries. Probably not worth using cascade attention.
|
| 282 |
+
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
|
| 283 |
+
num_reqs = len(query_lens)
|
| 284 |
+
if num_reqs < 8:
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
# Heuristics to decide whether using cascade attention is beneficial.
|
| 288 |
+
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
| 289 |
+
# is likely to be faster since it saves memory bandwidth.
|
| 290 |
+
num_queries_per_kv = num_query_heads // num_kv_heads
|
| 291 |
+
# The criteria for using FlashDecoding can be found in the following link:
|
| 292 |
+
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
| 293 |
+
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
| 294 |
+
and not use_alibi and np.all(query_lens == 1))
|
| 295 |
+
if not use_flash_decoding:
|
| 296 |
+
# Use cascade attention.
|
| 297 |
+
return True
|
| 298 |
+
|
| 299 |
+
# 2. When FlashDecoding is used for normal attention, it is not clear
|
| 300 |
+
# whether cascade attention is beneficial, because FlashDecoding can
|
| 301 |
+
# launch more CTAs than cascade attention.
|
| 302 |
+
# We use a simple performance model to compare the two methods.
|
| 303 |
+
# NOTE(woosuk): The performance model is very rough and may not be
|
| 304 |
+
# accurate.
|
| 305 |
+
num_tokens = num_reqs
|
| 306 |
+
# NOTE(woosuk): These are default tile sizes. flash-attn might use
|
| 307 |
+
# different tile sizes (e.g., 64 or 256) depending on the configuration.
|
| 308 |
+
q_tile_size = 128
|
| 309 |
+
kv_tile_size = 128
|
| 310 |
+
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
|
| 311 |
+
|
| 312 |
+
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
|
| 313 |
+
cascade_waves = cdiv(cascade_ctas, num_sms)
|
| 314 |
+
cascade_time = cascade_waves * num_prefix_tiles
|
| 315 |
+
|
| 316 |
+
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
| 317 |
+
cdiv(num_queries_per_kv, q_tile_size))
|
| 318 |
+
flash_decoding_ctas *= num_prefix_tiles
|
| 319 |
+
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
| 320 |
+
|
| 321 |
+
# Use cascade attention if it is faster than FlashDecoding.
|
| 322 |
+
return cascade_time < flash_decoding_time
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def cascade_attention(
|
| 326 |
+
output: torch.Tensor,
|
| 327 |
+
query: torch.Tensor,
|
| 328 |
+
key_cache: torch.Tensor,
|
| 329 |
+
value_cache: torch.Tensor,
|
| 330 |
+
cu_query_lens: torch.Tensor,
|
| 331 |
+
max_query_len: int,
|
| 332 |
+
cu_prefix_query_lens: torch.Tensor,
|
| 333 |
+
prefix_kv_lens: torch.Tensor,
|
| 334 |
+
suffix_kv_lens: torch.Tensor,
|
| 335 |
+
max_kv_len: int,
|
| 336 |
+
softmax_scale: float,
|
| 337 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 338 |
+
sliding_window: Tuple[int, int],
|
| 339 |
+
logits_soft_cap: float,
|
| 340 |
+
block_table: torch.Tensor,
|
| 341 |
+
common_prefix_len: int,
|
| 342 |
+
fa_version: int,
|
| 343 |
+
) -> torch.Tensor:
|
| 344 |
+
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
| 345 |
+
# TODO: Support sliding window.
|
| 346 |
+
assert sliding_window == (-1, -1), (
|
| 347 |
+
"Cascade attention does not support sliding window.")
|
| 348 |
+
|
| 349 |
+
num_tokens = query.shape[0]
|
| 350 |
+
block_size = key_cache.shape[-3]
|
| 351 |
+
assert common_prefix_len % block_size == 0
|
| 352 |
+
num_common_kv_blocks = common_prefix_len // block_size
|
| 353 |
+
assert num_common_kv_blocks > 0
|
| 354 |
+
|
| 355 |
+
# Process shared prefix.
|
| 356 |
+
prefix_output, prefix_lse = flash_attn_varlen_func(
|
| 357 |
+
q=query,
|
| 358 |
+
k=key_cache,
|
| 359 |
+
v=value_cache,
|
| 360 |
+
cu_seqlens_q=cu_prefix_query_lens,
|
| 361 |
+
seqused_k=prefix_kv_lens,
|
| 362 |
+
max_seqlen_q=num_tokens,
|
| 363 |
+
max_seqlen_k=common_prefix_len,
|
| 364 |
+
softmax_scale=softmax_scale,
|
| 365 |
+
causal=False,
|
| 366 |
+
window_size=sliding_window,
|
| 367 |
+
block_table=block_table[:1],
|
| 368 |
+
softcap=logits_soft_cap,
|
| 369 |
+
return_softmax_lse=True,
|
| 370 |
+
fa_version=fa_version,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# Process suffix per query.
|
| 374 |
+
suffix_output, suffix_lse = flash_attn_varlen_func(
|
| 375 |
+
q=query,
|
| 376 |
+
k=key_cache,
|
| 377 |
+
v=value_cache,
|
| 378 |
+
cu_seqlens_q=cu_query_lens,
|
| 379 |
+
seqused_k=suffix_kv_lens,
|
| 380 |
+
max_seqlen_q=max_query_len,
|
| 381 |
+
max_seqlen_k=max_kv_len - common_prefix_len,
|
| 382 |
+
softmax_scale=softmax_scale,
|
| 383 |
+
causal=True,
|
| 384 |
+
window_size=sliding_window,
|
| 385 |
+
block_table=block_table[:, num_common_kv_blocks:],
|
| 386 |
+
softcap=logits_soft_cap,
|
| 387 |
+
return_softmax_lse=True,
|
| 388 |
+
fa_version=fa_version,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Merge prefix and suffix outputs, and store the result in output.
|
| 392 |
+
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
| 393 |
+
suffix_lse)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def merge_attn_states(
|
| 397 |
+
output: torch.Tensor,
|
| 398 |
+
prefix_output: torch.Tensor,
|
| 399 |
+
prefix_lse: torch.Tensor,
|
| 400 |
+
suffix_output: torch.Tensor,
|
| 401 |
+
suffix_lse: torch.Tensor,
|
| 402 |
+
) -> None:
|
| 403 |
+
num_tokens = output.shape[0]
|
| 404 |
+
num_query_heads = output.shape[1]
|
| 405 |
+
head_size = output.shape[2]
|
| 406 |
+
padded_head_size = triton.next_power_of_2(head_size)
|
| 407 |
+
|
| 408 |
+
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
| 409 |
+
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
| 410 |
+
output,
|
| 411 |
+
prefix_output,
|
| 412 |
+
prefix_lse,
|
| 413 |
+
suffix_output,
|
| 414 |
+
suffix_lse,
|
| 415 |
+
head_size,
|
| 416 |
+
padded_head_size,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
@triton.jit
|
| 421 |
+
def merge_attn_states_kernel(
|
| 422 |
+
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
| 423 |
+
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
| 424 |
+
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
| 425 |
+
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
| 426 |
+
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
| 427 |
+
HEAD_SIZE: tl.constexpr,
|
| 428 |
+
PADDED_HEAD_SIZE: tl.constexpr,
|
| 429 |
+
):
|
| 430 |
+
token_idx = tl.program_id(0)
|
| 431 |
+
num_tokens = tl.num_programs(0)
|
| 432 |
+
head_idx = tl.program_id(1)
|
| 433 |
+
num_heads = tl.num_programs(1)
|
| 434 |
+
|
| 435 |
+
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
| 436 |
+
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
| 437 |
+
max_lse = tl.maximum(p_lse, s_lse)
|
| 438 |
+
p_lse = p_lse - max_lse
|
| 439 |
+
s_lse = s_lse - max_lse
|
| 440 |
+
|
| 441 |
+
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
| 442 |
+
head_mask = head_arange < HEAD_SIZE
|
| 443 |
+
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
| 444 |
+
head_idx * HEAD_SIZE + head_arange,
|
| 445 |
+
mask=head_mask)
|
| 446 |
+
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
| 447 |
+
head_idx * HEAD_SIZE + head_arange,
|
| 448 |
+
mask=head_mask)
|
| 449 |
+
|
| 450 |
+
# NOTE(woosuk): Be careful with the numerical stability.
|
| 451 |
+
# We should compute the scale first, and then multiply it with the output.
|
| 452 |
+
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
| 453 |
+
p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
|
| 454 |
+
s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
|
| 455 |
+
out = p_out * p_scale + s_out * s_scale
|
| 456 |
+
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
| 457 |
+
head_idx * HEAD_SIZE + head_arange,
|
| 458 |
+
out,
|
| 459 |
+
mask=head_mask)
|
.venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc
ADDED
|
Binary file (6.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING, Dict, List, Set, Tuple
|
| 4 |
+
|
| 5 |
+
from vllm.logger import init_logger
|
| 6 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 7 |
+
from vllm.v1.request import Request
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from vllm.config import ModelConfig, SchedulerConfig
|
| 11 |
+
|
| 12 |
+
logger = init_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EncoderCacheManager:
|
| 16 |
+
|
| 17 |
+
def __init__(self, cache_size: int):
|
| 18 |
+
self.cache_size = cache_size
|
| 19 |
+
self.num_free_slots = cache_size
|
| 20 |
+
# req_id -> cached input ids
|
| 21 |
+
self.cached: Dict[str, Set[int]] = {}
|
| 22 |
+
# List of [req_id, input_id]
|
| 23 |
+
self.freed: List[Tuple[str, int]] = []
|
| 24 |
+
|
| 25 |
+
def has_cache(self, request: Request, input_id: int) -> bool:
|
| 26 |
+
req_id = request.request_id
|
| 27 |
+
return req_id in self.cached and input_id in self.cached[req_id]
|
| 28 |
+
|
| 29 |
+
def can_allocate(self, request: Request, input_id: int) -> bool:
|
| 30 |
+
num_tokens = request.get_num_encoder_tokens(input_id)
|
| 31 |
+
return num_tokens <= self.num_free_slots
|
| 32 |
+
|
| 33 |
+
def allocate(self, request: Request, input_id: int) -> None:
|
| 34 |
+
req_id = request.request_id
|
| 35 |
+
if req_id not in self.cached:
|
| 36 |
+
self.cached[req_id] = set()
|
| 37 |
+
self.cached[req_id].add(input_id)
|
| 38 |
+
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
|
| 39 |
+
|
| 40 |
+
def get_cached_input_ids(self, request: Request) -> Set[int]:
|
| 41 |
+
return self.cached.get(request.request_id, set())
|
| 42 |
+
|
| 43 |
+
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
| 44 |
+
"""Free a single encoder input id for the request."""
|
| 45 |
+
req_id = request.request_id
|
| 46 |
+
if req_id not in self.cached:
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
self.cached[req_id].discard(input_id)
|
| 50 |
+
if len(self.cached[req_id]) == 0:
|
| 51 |
+
del self.cached[req_id]
|
| 52 |
+
self.num_free_slots += request.get_num_encoder_tokens(input_id)
|
| 53 |
+
self.freed.append((req_id, input_id))
|
| 54 |
+
|
| 55 |
+
def free(self, request: Request) -> None:
|
| 56 |
+
"""Free all cached input ids for the request."""
|
| 57 |
+
input_ids = self.get_cached_input_ids(request)
|
| 58 |
+
for input_id in input_ids:
|
| 59 |
+
self.free_encoder_input(request, input_id)
|
| 60 |
+
|
| 61 |
+
def get_freed_ids(self) -> List[Tuple[str, int]]:
|
| 62 |
+
freed = self.freed
|
| 63 |
+
self.freed = []
|
| 64 |
+
return freed
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def compute_encoder_budget(
|
| 68 |
+
model_config: "ModelConfig",
|
| 69 |
+
scheduler_config: "SchedulerConfig",
|
| 70 |
+
) -> Tuple[int, int]:
|
| 71 |
+
"""Compute the encoder cache budget based on the model and scheduler
|
| 72 |
+
configurations.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
model_config: Model configuration.
|
| 76 |
+
scheduler_config: Scheduler configuration.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
- Compute budget for encoder execution, in unit of number of tokens
|
| 80 |
+
in the input sequence.
|
| 81 |
+
- Space budget for encoder cache size, in unit of number of tokens
|
| 82 |
+
in the input sequence.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
if not model_config.is_multimodal_model:
|
| 86 |
+
return 0, 0
|
| 87 |
+
|
| 88 |
+
# TODO: handle encoder-decoder models once we support them.
|
| 89 |
+
(
|
| 90 |
+
encoder_compute_budget,
|
| 91 |
+
encoder_cache_size,
|
| 92 |
+
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
|
| 93 |
+
|
| 94 |
+
return encoder_compute_budget, encoder_cache_size
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _compute_encoder_budget_multimodal(
|
| 98 |
+
model_config: "ModelConfig",
|
| 99 |
+
scheduler_config: "SchedulerConfig",
|
| 100 |
+
) -> Tuple[int, int]:
|
| 101 |
+
"""Compute the encoder cache budget based on the model and scheduler
|
| 102 |
+
configurations for a multimodal model.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model_config: Model configuration.
|
| 106 |
+
scheduler_config: Scheduler configuration.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
- Compute budget for encoder execution, in unit of number of tokens
|
| 110 |
+
in the input sequence.
|
| 111 |
+
- Space budget for encoder cache size, in unit of number of tokens
|
| 112 |
+
in the input sequence.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
|
| 116 |
+
model_config)
|
| 117 |
+
|
| 118 |
+
if not max_tokens_by_modality_dict:
|
| 119 |
+
logger.warning(
|
| 120 |
+
"All non-text modalities supported by the model have been "
|
| 121 |
+
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
|
| 122 |
+
"not be initialized.")
|
| 123 |
+
return 0, 0
|
| 124 |
+
|
| 125 |
+
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
|
| 126 |
+
key=lambda item: item[1])
|
| 127 |
+
|
| 128 |
+
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
|
| 129 |
+
max_tokens_per_mm_item)
|
| 130 |
+
encoder_cache_size = max(scheduler_config.encoder_cache_size,
|
| 131 |
+
max_tokens_per_mm_item)
|
| 132 |
+
|
| 133 |
+
return encoder_compute_budget, encoder_cache_size
|
.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
from vllm.logger import init_logger
|
| 7 |
+
from vllm.utils import cdiv
|
| 8 |
+
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
| 9 |
+
KVCacheBlock,
|
| 10 |
+
generate_block_hash_extra_keys,
|
| 11 |
+
hash_block_tokens,
|
| 12 |
+
hash_request_tokens)
|
| 13 |
+
from vllm.v1.request import Request, RequestStatus
|
| 14 |
+
|
| 15 |
+
logger = init_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class KVCacheManager:
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
block_size: int,
|
| 23 |
+
num_gpu_blocks: int,
|
| 24 |
+
max_model_len: int,
|
| 25 |
+
sliding_window: Optional[int] = None,
|
| 26 |
+
enable_caching: bool = True,
|
| 27 |
+
num_preallocate_tokens: int = 64,
|
| 28 |
+
) -> None:
|
| 29 |
+
self.block_size = block_size
|
| 30 |
+
self.num_gpu_blocks = num_gpu_blocks
|
| 31 |
+
self.max_model_len = max_model_len
|
| 32 |
+
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
|
| 33 |
+
self.sliding_window = sliding_window
|
| 34 |
+
self.enable_caching = enable_caching
|
| 35 |
+
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
|
| 36 |
+
# blocks for each request. For example, when a request reaches the end
|
| 37 |
+
# of its block table, we preallocate N blocks in advance. This way, we
|
| 38 |
+
# reduce the overhead of updating free_block_ids and ref_cnts for each
|
| 39 |
+
# request every step (at the cost of some memory waste).
|
| 40 |
+
# NOTE(woosuk): This is different from the "lookahead" slots since this
|
| 41 |
+
# does not guarantee that the request always has N empty blocks. After
|
| 42 |
+
# the request gets N empty blocks, it starts to use the blocks without
|
| 43 |
+
# further allocation. When it uses up all the N empty blocks, it gets
|
| 44 |
+
# N new empty blocks.
|
| 45 |
+
self.num_preallocate_tokens = num_preallocate_tokens
|
| 46 |
+
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
|
| 47 |
+
|
| 48 |
+
# A Block pool of all kv-cache blocks.
|
| 49 |
+
self.block_pool: List[KVCacheBlock] = [
|
| 50 |
+
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
| 51 |
+
]
|
| 52 |
+
# Free block queue that constructs and manipulates a doubly linked
|
| 53 |
+
# list of free blocks (including eviction candidates when caching is
|
| 54 |
+
# enabled).
|
| 55 |
+
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
|
| 56 |
+
|
| 57 |
+
# {block_hash: {block ID: block}}. A cached block is
|
| 58 |
+
# a full block with a block hash that can be used for prefix caching.
|
| 59 |
+
# The cached block may be used by running requests or in the
|
| 60 |
+
# free_block_queue that could potentially be evicted.
|
| 61 |
+
# NOTE: We currently don't de-duplicate the blocks in the cache,
|
| 62 |
+
# meaning that if a block becomes full and is cached, we don't check
|
| 63 |
+
# if there is already an identical block in the cache. This is because
|
| 64 |
+
# we want to make sure the allocated block IDs won't change so that
|
| 65 |
+
# block tables are append-only.
|
| 66 |
+
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
|
| 67 |
+
int, KVCacheBlock]] = defaultdict(dict)
|
| 68 |
+
|
| 69 |
+
# Mapping from request ID to blocks to track the blocks allocated
|
| 70 |
+
# for each request, so that we can free the blocks when the request
|
| 71 |
+
# is finished.
|
| 72 |
+
self.req_to_blocks: DefaultDict[str,
|
| 73 |
+
List[KVCacheBlock]] = defaultdict(list)
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def usage(self) -> float:
|
| 77 |
+
return 1.0 - (self.free_block_queue.num_free_blocks /
|
| 78 |
+
self.num_gpu_blocks)
|
| 79 |
+
|
| 80 |
+
def get_computed_blocks(
|
| 81 |
+
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
|
| 82 |
+
"""Get the computed (cached) blocks for the request.
|
| 83 |
+
Note that the computed blocks must be full.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
request: The request to get the computed blocks.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
A tuple containing:
|
| 90 |
+
- A list of blocks that are computed for the request.
|
| 91 |
+
- The number of computed tokens.
|
| 92 |
+
"""
|
| 93 |
+
if not self.enable_caching:
|
| 94 |
+
# Prefix caching is disabled.
|
| 95 |
+
return [], 0
|
| 96 |
+
|
| 97 |
+
computed_blocks = []
|
| 98 |
+
|
| 99 |
+
# The block hashes for the request may already be computed
|
| 100 |
+
# if the request was preempted and resumed.
|
| 101 |
+
if not request.kv_block_hashes:
|
| 102 |
+
request.set_kv_block_hashes(
|
| 103 |
+
hash_request_tokens(self.block_size, request))
|
| 104 |
+
block_hashes = request.kv_block_hashes
|
| 105 |
+
|
| 106 |
+
for block_hash in block_hashes:
|
| 107 |
+
# block_hashes is a chain of block hashes. If a block hash is not
|
| 108 |
+
# in the cached_block_hash_to_id, the following block hashes are
|
| 109 |
+
# not computed yet for sure.
|
| 110 |
+
if cached_block := self._get_cached_block(block_hash):
|
| 111 |
+
computed_blocks.append(cached_block)
|
| 112 |
+
else:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
| 116 |
+
# sharing, `num_computed_tokens` is always a multiple of
|
| 117 |
+
# `block_size`.
|
| 118 |
+
num_computed_tokens = len(computed_blocks) * self.block_size
|
| 119 |
+
return computed_blocks, num_computed_tokens
|
| 120 |
+
|
| 121 |
+
def allocate_slots(
|
| 122 |
+
self,
|
| 123 |
+
request: Request,
|
| 124 |
+
num_tokens: int,
|
| 125 |
+
new_computed_blocks: Optional[List[KVCacheBlock]] = None
|
| 126 |
+
) -> Optional[List[KVCacheBlock]]:
|
| 127 |
+
"""Add slots for a request with new tokens to append.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
request: The request to allocate slots.
|
| 131 |
+
num_tokens: The number of tokens to allocate. Note that this does
|
| 132 |
+
not include the tokens that have already been computed.
|
| 133 |
+
new_computed_blocks: A list of new computed blocks just hitting the
|
| 134 |
+
prefix caching.
|
| 135 |
+
|
| 136 |
+
Blocks layout:
|
| 137 |
+
-----------------------------------------------------------------------
|
| 138 |
+
| < computed > | < new computed > | < new > | < pre-allocated > |
|
| 139 |
+
-----------------------------------------------------------------------
|
| 140 |
+
| < required > |
|
| 141 |
+
--------------------------------------------------
|
| 142 |
+
| < full > |
|
| 143 |
+
------------------------------------------------
|
| 144 |
+
| <new full> |
|
| 145 |
+
--------------
|
| 146 |
+
The following *_blocks are illustrated in this layout.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
A list of new allocated blocks.
|
| 150 |
+
"""
|
| 151 |
+
if num_tokens == 0:
|
| 152 |
+
raise ValueError("num_tokens must be greater than 0")
|
| 153 |
+
|
| 154 |
+
new_computed_blocks = new_computed_blocks or []
|
| 155 |
+
|
| 156 |
+
# The number of computed tokens is the number of computed tokens plus
|
| 157 |
+
# the new prefix caching hits
|
| 158 |
+
num_computed_tokens = (request.num_computed_tokens +
|
| 159 |
+
len(new_computed_blocks) * self.block_size)
|
| 160 |
+
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
|
| 161 |
+
self.block_size)
|
| 162 |
+
req_blocks = self.req_to_blocks[request.request_id]
|
| 163 |
+
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
| 164 |
+
len(new_computed_blocks))
|
| 165 |
+
|
| 166 |
+
# If a computed block of a request is an eviction candidate (in the
|
| 167 |
+
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
| 168 |
+
# when allocating this request.
|
| 169 |
+
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
|
| 170 |
+
if blk.ref_cnt == 0)
|
| 171 |
+
if (num_new_blocks > self.free_block_queue.num_free_blocks -
|
| 172 |
+
num_evictable_computed_blocks):
|
| 173 |
+
# Cannot allocate new blocks
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
# Touch the computed blocks to make sure they won't be evicted.
|
| 177 |
+
if self.enable_caching:
|
| 178 |
+
self._touch(new_computed_blocks)
|
| 179 |
+
else:
|
| 180 |
+
assert not new_computed_blocks, (
|
| 181 |
+
"Computed blocks should be empty when "
|
| 182 |
+
"prefix caching is disabled")
|
| 183 |
+
|
| 184 |
+
# Append the new computed blocks to the request blocks until now to
|
| 185 |
+
# avoid the case where the new blocks cannot be allocated.
|
| 186 |
+
req_blocks.extend(new_computed_blocks)
|
| 187 |
+
|
| 188 |
+
# Start to handle new blocks
|
| 189 |
+
|
| 190 |
+
if num_new_blocks <= 0:
|
| 191 |
+
# No new block is needed.
|
| 192 |
+
new_blocks = []
|
| 193 |
+
else:
|
| 194 |
+
# Get new blocks from the free block pool considering
|
| 195 |
+
# preallocated blocks.
|
| 196 |
+
num_new_blocks = min(
|
| 197 |
+
num_new_blocks + self.num_preallocate_blocks,
|
| 198 |
+
self.free_block_queue.num_free_blocks,
|
| 199 |
+
# Should not exceed the maximum number of blocks per request.
|
| 200 |
+
# This is especially because the block table has the shape
|
| 201 |
+
# [..., max_num_blocks_per_req].
|
| 202 |
+
# TODO(woosuk): Check and reject requests if
|
| 203 |
+
# num_prompt_tokens + max_tokens > max_model_len.
|
| 204 |
+
self.max_num_blocks_per_req - len(req_blocks),
|
| 205 |
+
)
|
| 206 |
+
assert num_new_blocks > 0
|
| 207 |
+
|
| 208 |
+
# Concatenate the computed block IDs and the new block IDs.
|
| 209 |
+
new_blocks = self._get_new_blocks(num_new_blocks)
|
| 210 |
+
req_blocks.extend(new_blocks)
|
| 211 |
+
|
| 212 |
+
if not self.enable_caching:
|
| 213 |
+
return new_blocks
|
| 214 |
+
|
| 215 |
+
# NOTE(rickyx): We are assuming the `num_tokens` are actual
|
| 216 |
+
# tokens rather than lookahead slots (e.g. for speculative decoding).
|
| 217 |
+
# TODO(rickyx): When supporting speculative decoding, we will need to
|
| 218 |
+
# differentiate between them so that we can know how many blocks are
|
| 219 |
+
# full after appending the actual tokens.
|
| 220 |
+
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
|
| 221 |
+
num_computed_full_blocks = num_computed_tokens // self.block_size
|
| 222 |
+
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
|
| 223 |
+
if new_full_blocks:
|
| 224 |
+
self._cache_full_blocks(
|
| 225 |
+
request=request,
|
| 226 |
+
blk_start_idx=num_computed_full_blocks,
|
| 227 |
+
# The new full blocks are the full blocks that are not computed.
|
| 228 |
+
full_blocks=new_full_blocks,
|
| 229 |
+
prev_block=(req_blocks[num_computed_full_blocks - 1]
|
| 230 |
+
if num_computed_full_blocks > 0 else None))
|
| 231 |
+
|
| 232 |
+
return new_blocks
|
| 233 |
+
|
| 234 |
+
def free(self, request: Request) -> None:
|
| 235 |
+
"""Free the blocks allocated for the request.
|
| 236 |
+
When caching is enabled, we free the blocks in reverse order so that
|
| 237 |
+
the tail blocks are evicted first.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
request: The request to free the blocks.
|
| 241 |
+
"""
|
| 242 |
+
# Default to [] in case a request is freed (aborted) before alloc.
|
| 243 |
+
blocks = self.req_to_blocks.pop(request.request_id, [])
|
| 244 |
+
ordered_blocks: Iterable[KVCacheBlock] = blocks
|
| 245 |
+
if self.enable_caching:
|
| 246 |
+
# Free blocks in reverse order so that the tail blocks are
|
| 247 |
+
# freed first.
|
| 248 |
+
ordered_blocks = reversed(blocks)
|
| 249 |
+
|
| 250 |
+
for block in ordered_blocks:
|
| 251 |
+
block.decr_ref()
|
| 252 |
+
if block.ref_cnt == 0:
|
| 253 |
+
self.free_block_queue.append(block)
|
| 254 |
+
|
| 255 |
+
def reset_prefix_cache(self) -> bool:
|
| 256 |
+
"""Reset prefix cache. This function may be used in RLHF
|
| 257 |
+
flows to invalid prefix caching after the weights are updated,
|
| 258 |
+
or used for resetting prefix caching status for benchmarking.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
bool: True if the prefix cache is successfully reset,
|
| 262 |
+
False otherwise.
|
| 263 |
+
"""
|
| 264 |
+
num_used_blocks = (self.num_gpu_blocks -
|
| 265 |
+
self.free_block_queue.num_free_blocks)
|
| 266 |
+
if num_used_blocks > 0:
|
| 267 |
+
logger.warning(
|
| 268 |
+
"Failed to reset prefix cache because some "
|
| 269 |
+
"blocks (%d) are not freed yet", num_used_blocks)
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
# Remove all hashes so that no new blocks will hit.
|
| 273 |
+
self.cached_block_hash_to_block = defaultdict(dict)
|
| 274 |
+
|
| 275 |
+
# Remove all hashes from all blocks.
|
| 276 |
+
for block in self.block_pool:
|
| 277 |
+
block.reset_hash()
|
| 278 |
+
|
| 279 |
+
logger.info("Successfully reset prefix cache")
|
| 280 |
+
return True
|
| 281 |
+
|
| 282 |
+
def get_num_common_prefix_blocks(
|
| 283 |
+
self,
|
| 284 |
+
request: Request,
|
| 285 |
+
num_running_requests: int,
|
| 286 |
+
) -> int:
|
| 287 |
+
"""Calculate the number of common prefix blocks shared by all requests
|
| 288 |
+
in the RUNNING state.
|
| 289 |
+
|
| 290 |
+
The function determines this by selecting any request and iterating
|
| 291 |
+
through its blocks. A block is considered a common prefix block if its
|
| 292 |
+
`ref_cnt` equals the total number of requests in the RUNNING state.
|
| 293 |
+
|
| 294 |
+
NOTE(woosuk): The number of requests in the RUNNING state is **greater
|
| 295 |
+
than or equal to** the number of requests scheduled in the current step.
|
| 296 |
+
This is because the RUNNING state only indicates that:
|
| 297 |
+
1. The request has not yet finished, and
|
| 298 |
+
2. The request holds its blocks unfreed.
|
| 299 |
+
|
| 300 |
+
While all scheduled requests must be in the RUNNING state, the inverse
|
| 301 |
+
is not necessarily true. There may be RUNNING requests that are not
|
| 302 |
+
scheduled in the current step. As of 1/1/2025, the scheduler does not
|
| 303 |
+
allow this case, but it is possible in the future, as we allow more
|
| 304 |
+
flexible scheduling.
|
| 305 |
+
|
| 306 |
+
This can result in an edge case where the number of common prefix blocks
|
| 307 |
+
is 0, even though all scheduled requests share a common prefix. This
|
| 308 |
+
occurs because there may be unscheduled RUNNING requests that do not
|
| 309 |
+
share the common prefix. Currently, this case cannot be easily detected,
|
| 310 |
+
so the function returns 0 in such cases.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
request: Any request in the RUNNING state, used to identify the
|
| 314 |
+
common prefix blocks.
|
| 315 |
+
num_running_requests: The total number of requests in the RUNNING
|
| 316 |
+
state. This can be different from the number of scheduled
|
| 317 |
+
requests in the current step.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
int: The number of common prefix blocks.
|
| 321 |
+
"""
|
| 322 |
+
assert request.status == RequestStatus.RUNNING
|
| 323 |
+
blocks = self.req_to_blocks[request.request_id]
|
| 324 |
+
num_common_blocks = 0
|
| 325 |
+
for block in blocks:
|
| 326 |
+
if block.ref_cnt == num_running_requests:
|
| 327 |
+
num_common_blocks += 1
|
| 328 |
+
else:
|
| 329 |
+
break
|
| 330 |
+
return num_common_blocks
|
| 331 |
+
|
| 332 |
+
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
|
| 333 |
+
"""Get new blocks from the free block pool.
|
| 334 |
+
|
| 335 |
+
Note that we do not check block cache in this function.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
num_blocks: The number of blocks to allocate.
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
A list of new block.
|
| 342 |
+
"""
|
| 343 |
+
if num_blocks > self.free_block_queue.num_free_blocks:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"Cannot get {num_blocks} free blocks from the pool")
|
| 346 |
+
|
| 347 |
+
ret: List[KVCacheBlock] = []
|
| 348 |
+
idx = 0
|
| 349 |
+
while idx < num_blocks:
|
| 350 |
+
# First allocate blocks.
|
| 351 |
+
curr_block = self.free_block_queue.popleft()
|
| 352 |
+
assert curr_block.ref_cnt == 0
|
| 353 |
+
|
| 354 |
+
# If the block is cached, evict it.
|
| 355 |
+
if self.enable_caching:
|
| 356 |
+
self._maybe_evict_cached_block(curr_block)
|
| 357 |
+
|
| 358 |
+
curr_block.incr_ref()
|
| 359 |
+
ret.append(curr_block)
|
| 360 |
+
idx += 1
|
| 361 |
+
|
| 362 |
+
return ret
|
| 363 |
+
|
| 364 |
+
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
| 365 |
+
"""
|
| 366 |
+
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
| 367 |
+
metadata and evict it from the cache.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
block: The block to evict.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
True if the block is evicted, False otherwise.
|
| 374 |
+
"""
|
| 375 |
+
block_hash = block.block_hash
|
| 376 |
+
if block_hash and block_hash in self.cached_block_hash_to_block:
|
| 377 |
+
block.reset_hash()
|
| 378 |
+
del self.cached_block_hash_to_block[block_hash][block.block_id]
|
| 379 |
+
|
| 380 |
+
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
| 381 |
+
del self.cached_block_hash_to_block[block_hash]
|
| 382 |
+
|
| 383 |
+
return True
|
| 384 |
+
return False
|
| 385 |
+
|
| 386 |
+
def _get_cached_block(self,
|
| 387 |
+
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
| 388 |
+
"""Get a cached block by the block hash, or None if cache miss.
|
| 389 |
+
If there are duplicated blocks, we return the first block in the cache.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
block_hash: The hash value of the block.
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
The cached block if it exists, or None.
|
| 396 |
+
"""
|
| 397 |
+
if block_hash in self.cached_block_hash_to_block:
|
| 398 |
+
first_block_id = list(
|
| 399 |
+
self.cached_block_hash_to_block[block_hash].keys())[0]
|
| 400 |
+
return self.cached_block_hash_to_block[block_hash][first_block_id]
|
| 401 |
+
return None
|
| 402 |
+
|
| 403 |
+
def _touch(self, blocks: List[KVCacheBlock]) -> None:
|
| 404 |
+
"""Touch a block increases its reference count by 1, and may remove
|
| 405 |
+
the block from the free queue. This is used when a block is hit by
|
| 406 |
+
another request with the same prefix.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
blocks: A list of blocks to touch.
|
| 410 |
+
"""
|
| 411 |
+
for block in blocks:
|
| 412 |
+
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
| 413 |
+
# candidate), so remove it.
|
| 414 |
+
if block.ref_cnt == 0:
|
| 415 |
+
self.free_block_queue.remove(block)
|
| 416 |
+
block.incr_ref()
|
| 417 |
+
|
| 418 |
+
def _cache_full_blocks(
|
| 419 |
+
self,
|
| 420 |
+
request: Request,
|
| 421 |
+
blk_start_idx: int,
|
| 422 |
+
full_blocks: List[KVCacheBlock],
|
| 423 |
+
prev_block: Optional[KVCacheBlock],
|
| 424 |
+
) -> None:
|
| 425 |
+
"""Cache a list of full blocks for prefix caching.
|
| 426 |
+
|
| 427 |
+
This function takes a list of blocks that will have their block hash
|
| 428 |
+
metadata to be updated and cached. Given a request, it computes the
|
| 429 |
+
block hashes for the blocks starting from `blk_start_idx` to the end
|
| 430 |
+
of the request's full blocks, updating the metadata for each block
|
| 431 |
+
and caching them in the `cached_block_hash_to_block`.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
request: The request to cache the blocks.
|
| 435 |
+
blk_start_idx: The index of the first block in the request's blocks
|
| 436 |
+
to cache.
|
| 437 |
+
full_blocks: The list of blocks to update hash metadata.
|
| 438 |
+
prev_block: The previous block in the chain.
|
| 439 |
+
"""
|
| 440 |
+
num_cached_block_hashes = len(request.kv_block_hashes)
|
| 441 |
+
|
| 442 |
+
# Update the new blocks with the block hashes through the chain.
|
| 443 |
+
prev_block_hash_value = None
|
| 444 |
+
if prev_block is not None:
|
| 445 |
+
# Previous block must have a block hash because it must be
|
| 446 |
+
# a full, cached block.
|
| 447 |
+
assert prev_block.block_hash is not None
|
| 448 |
+
prev_block_hash_value = prev_block.block_hash.hash_value
|
| 449 |
+
|
| 450 |
+
# Find the first uncached block. This case should only happen when
|
| 451 |
+
# speculative decoding is used.
|
| 452 |
+
offset = 0
|
| 453 |
+
for blk in full_blocks:
|
| 454 |
+
if blk.block_hash is None:
|
| 455 |
+
break
|
| 456 |
+
else:
|
| 457 |
+
prev_block_hash_value = blk.block_hash.hash_value
|
| 458 |
+
offset += 1
|
| 459 |
+
else:
|
| 460 |
+
# All blocks are cached.
|
| 461 |
+
return
|
| 462 |
+
|
| 463 |
+
for i, blk in enumerate(full_blocks[offset:]):
|
| 464 |
+
blk_idx = blk_start_idx + offset + i
|
| 465 |
+
assert blk.block_hash is None
|
| 466 |
+
|
| 467 |
+
if blk_idx < num_cached_block_hashes:
|
| 468 |
+
# The block hash may already be computed in
|
| 469 |
+
# "get_computed_blocks" if the tokens are not generated by
|
| 470 |
+
# this request (either the prompt tokens or the previously
|
| 471 |
+
# generated tokens with preemption). In this case we simply
|
| 472 |
+
# reuse the block hash.
|
| 473 |
+
block_hash = request.kv_block_hashes[blk_idx]
|
| 474 |
+
else:
|
| 475 |
+
# Otherwise compute the block hash and cache it in the request
|
| 476 |
+
# in case it will be preempted in the future.
|
| 477 |
+
start_token_idx = blk_idx * self.block_size
|
| 478 |
+
end_token_idx = (blk_idx + 1) * self.block_size
|
| 479 |
+
block_tokens = request.all_token_ids[
|
| 480 |
+
start_token_idx:end_token_idx]
|
| 481 |
+
assert len(block_tokens) == self.block_size, (
|
| 482 |
+
f"Expected {self.block_size} tokens, got "
|
| 483 |
+
f"{len(block_tokens)} at {blk_idx}th block for request "
|
| 484 |
+
f"{request.request_id}({request})")
|
| 485 |
+
|
| 486 |
+
# Generate extra keys for multi-modal inputs. Note that since
|
| 487 |
+
# we reach to this branch only when the block is completed with
|
| 488 |
+
# generated tokens, we only need to consider the last mm input.
|
| 489 |
+
extra_keys, _ = generate_block_hash_extra_keys(
|
| 490 |
+
request, start_token_idx, end_token_idx, -1)
|
| 491 |
+
|
| 492 |
+
# Compute the hash of the current block.
|
| 493 |
+
block_hash = hash_block_tokens(prev_block_hash_value,
|
| 494 |
+
block_tokens, extra_keys)
|
| 495 |
+
request.append_kv_block_hashes(block_hash)
|
| 496 |
+
|
| 497 |
+
# Update and added the full block to the cache.
|
| 498 |
+
blk.block_hash = block_hash
|
| 499 |
+
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
| 500 |
+
prev_block_hash_value = block_hash.hash_value
|
.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""KV-Cache Utilities."""
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, List, NamedTuple, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from vllm.config import VllmConfig
|
| 8 |
+
from vllm.logger import init_logger
|
| 9 |
+
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
|
| 10 |
+
KVCacheTensor)
|
| 11 |
+
from vllm.v1.request import Request
|
| 12 |
+
|
| 13 |
+
logger = init_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BlockHashType(NamedTuple):
|
| 17 |
+
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
| 18 |
+
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
| 19 |
+
hash collisions when the hash value is the same. But please note that
|
| 20 |
+
hash collisions can still theoretically occur, albeit with an extremely
|
| 21 |
+
low probability.
|
| 22 |
+
"""
|
| 23 |
+
# Hash value of the block in an integer.
|
| 24 |
+
hash_value: int
|
| 25 |
+
# Token IDs in the block.
|
| 26 |
+
token_ids: Tuple[int, ...]
|
| 27 |
+
# Extra keys for the block.
|
| 28 |
+
extra_keys: Optional[Any] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class KVCacheBlock:
|
| 33 |
+
"""KV-cache block metadata."""
|
| 34 |
+
# Block ID, ranging from 0 to num_gpu_blocks - 1.
|
| 35 |
+
block_id: int
|
| 36 |
+
# Reference count.
|
| 37 |
+
ref_cnt: int = 0
|
| 38 |
+
# The hash of the block composed of (block hash, tuple of token IDs).
|
| 39 |
+
# It is only available when the block is full.
|
| 40 |
+
_block_hash: Optional[BlockHashType] = None
|
| 41 |
+
|
| 42 |
+
# Used to construct a doubly linked list for free blocks.
|
| 43 |
+
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
| 44 |
+
prev_free_block: Optional["KVCacheBlock"] = None
|
| 45 |
+
next_free_block: Optional["KVCacheBlock"] = None
|
| 46 |
+
|
| 47 |
+
def incr_ref(self):
|
| 48 |
+
self.ref_cnt += 1
|
| 49 |
+
|
| 50 |
+
def decr_ref(self):
|
| 51 |
+
self.ref_cnt -= 1
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def block_hash(self) -> Optional[BlockHashType]:
|
| 55 |
+
return self._block_hash
|
| 56 |
+
|
| 57 |
+
@block_hash.setter
|
| 58 |
+
def block_hash(self, block_hash: BlockHashType):
|
| 59 |
+
assert self.block_hash is None, (
|
| 60 |
+
"The block already has a hash. This should not happen.")
|
| 61 |
+
self._block_hash = block_hash
|
| 62 |
+
|
| 63 |
+
def reset_hash(self):
|
| 64 |
+
"""Reset the block hash when the block is evicted."""
|
| 65 |
+
self._block_hash = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class FreeKVCacheBlockQueue:
|
| 69 |
+
"""This class organizes a list of KVCacheBlock objects to a doubly linked
|
| 70 |
+
list of free blocks. We implement this class instead of using Python
|
| 71 |
+
builtin deque to support removing a block in the middle of the queue
|
| 72 |
+
in O(1) time. To close the performance gap to the builtin deque which is
|
| 73 |
+
implemented in C++, this class does not allocate any Python objects when
|
| 74 |
+
manipulating the linked list. Instead, this class manipulates the
|
| 75 |
+
prev_free_block and next_free_block attributes of the given blocks.
|
| 76 |
+
|
| 77 |
+
The queue is ordered by block ID in the beginning. When a block is allocated
|
| 78 |
+
and then freed, it will be appended back with the eviction order:
|
| 79 |
+
1. The least recent used block is at the front (LRU).
|
| 80 |
+
2. If two blocks have the same last accessed time (allocated by the
|
| 81 |
+
same sequence), the one with more hash tokens (the tail of a block
|
| 82 |
+
chain) is at the front.
|
| 83 |
+
Note that we maintain this order by reversing the block order when free
|
| 84 |
+
blocks of a request. This operation is outside of this class.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
blocks: A list of KVCacheBlock objects.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, blocks: List[KVCacheBlock]) -> None:
|
| 91 |
+
self.num_free_blocks = len(blocks)
|
| 92 |
+
|
| 93 |
+
# Initialize the doubly linked list of free blocks.
|
| 94 |
+
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
|
| 95 |
+
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
|
| 96 |
+
for i in range(self.num_free_blocks):
|
| 97 |
+
if i > 0:
|
| 98 |
+
blocks[i].prev_free_block = blocks[i - 1]
|
| 99 |
+
if i < self.num_free_blocks - 1:
|
| 100 |
+
blocks[i].next_free_block = blocks[i + 1]
|
| 101 |
+
|
| 102 |
+
def popleft(self) -> KVCacheBlock:
|
| 103 |
+
"""Pop the first free block and reduce num_free_blocks by 1.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
The first free block.
|
| 107 |
+
"""
|
| 108 |
+
if not self.free_list_head:
|
| 109 |
+
raise ValueError("No free blocks available")
|
| 110 |
+
|
| 111 |
+
block = self.free_list_head
|
| 112 |
+
self.remove(block)
|
| 113 |
+
return block
|
| 114 |
+
|
| 115 |
+
def remove(self, block: KVCacheBlock) -> None:
|
| 116 |
+
"""Remove a block in the free list and reduce num_free_blocks by 1.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
block: The block to remove.
|
| 120 |
+
"""
|
| 121 |
+
if block.prev_free_block is not None:
|
| 122 |
+
# Link the previous block to the next block.
|
| 123 |
+
block.prev_free_block.next_free_block = block.next_free_block
|
| 124 |
+
if block.next_free_block is not None:
|
| 125 |
+
# Link the next block to the previous block.
|
| 126 |
+
block.next_free_block.prev_free_block = block.prev_free_block
|
| 127 |
+
|
| 128 |
+
if block == self.free_list_head:
|
| 129 |
+
# Update the head if the block is the head.
|
| 130 |
+
self.free_list_head = block.next_free_block
|
| 131 |
+
if block == self.free_list_tail:
|
| 132 |
+
# Update the tail if the block is the tail.
|
| 133 |
+
self.free_list_tail = block.prev_free_block
|
| 134 |
+
|
| 135 |
+
# Remove the block from the linked list.
|
| 136 |
+
block.prev_free_block = block.next_free_block = None
|
| 137 |
+
self.num_free_blocks -= 1
|
| 138 |
+
|
| 139 |
+
def append(self, block: KVCacheBlock) -> None:
|
| 140 |
+
"""Put a block back into the free list and increase
|
| 141 |
+
num_free_blocks by 1.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
block: The block to append.
|
| 145 |
+
"""
|
| 146 |
+
if self.free_list_tail is not None:
|
| 147 |
+
# Link the last block to the new block.
|
| 148 |
+
self.free_list_tail.next_free_block = block
|
| 149 |
+
block.prev_free_block = self.free_list_tail
|
| 150 |
+
self.free_list_tail = block
|
| 151 |
+
else:
|
| 152 |
+
# The free list is empty.
|
| 153 |
+
assert self.free_list_head is None
|
| 154 |
+
self.free_list_head = self.free_list_tail = block
|
| 155 |
+
|
| 156 |
+
block.next_free_block = None
|
| 157 |
+
self.num_free_blocks += 1
|
| 158 |
+
|
| 159 |
+
def get_all_free_blocks(self) -> List[KVCacheBlock]:
|
| 160 |
+
"""Get all free blocks in the free list. Mainly used for testing.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A list of free blocks.
|
| 164 |
+
"""
|
| 165 |
+
ret = []
|
| 166 |
+
curr_block = self.free_list_head
|
| 167 |
+
while curr_block is not None:
|
| 168 |
+
ret.append(curr_block)
|
| 169 |
+
curr_block = curr_block.next_free_block
|
| 170 |
+
return ret
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def generate_block_hash_extra_keys(
|
| 174 |
+
request: Request, start_token_idx: int, end_token_idx: int,
|
| 175 |
+
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
|
| 176 |
+
"""Generate extra keys for the block hash. The extra keys can come from
|
| 177 |
+
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
|
| 178 |
+
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
|
| 179 |
+
indicate a mm input contained in the block and its starting offset in
|
| 180 |
+
the block tokens.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
request: The request object.
|
| 184 |
+
start_token_idx: The start token index of the block.
|
| 185 |
+
end_token_idx: The end token index of the block.
|
| 186 |
+
start_mm_idx: The start multi-modal index of the block.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
A tuple of extra keys and the next multi-modal index.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
| 193 |
+
if not mm_positions:
|
| 194 |
+
return None, start_mm_idx
|
| 195 |
+
|
| 196 |
+
if mm_positions and len(mm_positions) != len(mm_hashes):
|
| 197 |
+
raise ValueError(
|
| 198 |
+
"The number of multi-modal positions and hashes must match. This "
|
| 199 |
+
"is likely because you do not enable MM preprocessor hashing. "
|
| 200 |
+
"Please set disable_mm_preprocessor_cache=False.")
|
| 201 |
+
|
| 202 |
+
# Note that we assume mm_positions is sorted by offset.
|
| 203 |
+
# We do not need to check all mm inputs if the start token index is out of
|
| 204 |
+
# range. This usually happens in the late prefill phase and decoding phase.
|
| 205 |
+
if mm_positions[-1]["offset"] + mm_positions[-1][
|
| 206 |
+
"length"] < start_token_idx:
|
| 207 |
+
return None, start_mm_idx
|
| 208 |
+
|
| 209 |
+
# Support start_mm_idx == -1 to indicate the last mm input.
|
| 210 |
+
if start_mm_idx < 0:
|
| 211 |
+
assert -start_mm_idx <= len(mm_positions)
|
| 212 |
+
start_mm_idx = len(mm_positions) + start_mm_idx
|
| 213 |
+
|
| 214 |
+
extra_keys = []
|
| 215 |
+
curr_mm_idx = start_mm_idx
|
| 216 |
+
while mm_positions and curr_mm_idx < len(mm_positions):
|
| 217 |
+
assert mm_hashes[curr_mm_idx] is not None
|
| 218 |
+
offset = mm_positions[curr_mm_idx]["offset"]
|
| 219 |
+
length = mm_positions[curr_mm_idx]["length"]
|
| 220 |
+
if end_token_idx > offset:
|
| 221 |
+
if start_token_idx > offset + length:
|
| 222 |
+
# This block has passed the current mm input.
|
| 223 |
+
curr_mm_idx += 1
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
# The block contains the current mm input.
|
| 227 |
+
extra_keys.append(mm_hashes[curr_mm_idx])
|
| 228 |
+
|
| 229 |
+
if end_token_idx >= offset + length:
|
| 230 |
+
# If this block contains the end of the current mm input,
|
| 231 |
+
# move to the next mm input as this block may also contain
|
| 232 |
+
# the next mm input.
|
| 233 |
+
curr_mm_idx += 1
|
| 234 |
+
else:
|
| 235 |
+
# Otherwise this block is done with mm inputs.
|
| 236 |
+
break
|
| 237 |
+
else:
|
| 238 |
+
# This block has not reached the current mm input.
|
| 239 |
+
break
|
| 240 |
+
return tuple(extra_keys), curr_mm_idx
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def hash_block_tokens(
|
| 244 |
+
parent_block_hash: Optional[int],
|
| 245 |
+
curr_block_token_ids: Sequence[int],
|
| 246 |
+
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
|
| 247 |
+
"""Computes a hash value corresponding to the contents of a block and
|
| 248 |
+
the contents of the preceding block(s). The hash value is used for
|
| 249 |
+
prefix caching. We use LRU cache for this function to avoid recomputing
|
| 250 |
+
hash values for the same block contents.
|
| 251 |
+
|
| 252 |
+
TODO: Support arbitrary metadata so that we could support more
|
| 253 |
+
features such as LoRA adapter.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
parent_block_hash: The hash of the parent block. None
|
| 257 |
+
if this is the first block.
|
| 258 |
+
curr_block_token_ids: A list of token ids in the current
|
| 259 |
+
block. The current block is assumed to be full.
|
| 260 |
+
extra_keys: Extra keys for the block.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
The hash value of the block and the token ids in the block.
|
| 264 |
+
The entire tuple is used as the hash key of the block.
|
| 265 |
+
"""
|
| 266 |
+
if not parent_block_hash:
|
| 267 |
+
# Note that we use 'None' as a string here instead of None because
|
| 268 |
+
# as of Python 3.12, hash(None) returns a constant predictable value.
|
| 269 |
+
# This could possibly make it easier to find and exploit hash
|
| 270 |
+
# collisions. 'None' as a string will be hashed differently per process,
|
| 271 |
+
# but consistently within the same process. This is the same as the
|
| 272 |
+
# behavior of None prior to Python 3.12.
|
| 273 |
+
parent_block_hash = hash('None')
|
| 274 |
+
|
| 275 |
+
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
| 276 |
+
return BlockHashType(
|
| 277 |
+
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
| 278 |
+
curr_block_token_ids_tuple, extra_keys)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def hash_request_tokens(block_size: int,
|
| 282 |
+
request: Request) -> List[BlockHashType]:
|
| 283 |
+
"""Computes hash values of a chain of blocks given a sequence of
|
| 284 |
+
token IDs. The hash value is used for prefix caching.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
block_size: The size of each block.
|
| 288 |
+
request: The request object.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
The list of computed hash values.
|
| 292 |
+
"""
|
| 293 |
+
token_ids = request.all_token_ids
|
| 294 |
+
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
| 295 |
+
if mm_positions and len(mm_positions) != len(mm_hashes):
|
| 296 |
+
raise ValueError(
|
| 297 |
+
"The number of multi-modal positions and hashes must match.")
|
| 298 |
+
|
| 299 |
+
# TODO: Extend this to support other features such as LoRA.
|
| 300 |
+
need_extra_keys = bool(mm_positions)
|
| 301 |
+
extra_keys = None
|
| 302 |
+
curr_mm_idx = 0
|
| 303 |
+
|
| 304 |
+
ret = []
|
| 305 |
+
parent_block_hash_value = None
|
| 306 |
+
for start in range(0, len(token_ids), block_size):
|
| 307 |
+
end = start + block_size
|
| 308 |
+
block_token_ids = token_ids[start:end]
|
| 309 |
+
# Do not hash the block if it is not full.
|
| 310 |
+
if len(block_token_ids) < block_size:
|
| 311 |
+
break
|
| 312 |
+
|
| 313 |
+
# Add extra keys if the block is a multi-modal block.
|
| 314 |
+
if need_extra_keys:
|
| 315 |
+
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
| 316 |
+
request, start, end, curr_mm_idx)
|
| 317 |
+
|
| 318 |
+
block_hash = hash_block_tokens(parent_block_hash_value,
|
| 319 |
+
block_token_ids, extra_keys)
|
| 320 |
+
ret.append(block_hash)
|
| 321 |
+
parent_block_hash_value = block_hash.hash_value
|
| 322 |
+
return ret
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
| 326 |
+
kv_cache_spec: KVCacheSpec,
|
| 327 |
+
available_memory: int):
|
| 328 |
+
"""
|
| 329 |
+
Checks whether `available_memory` is enough for the KV cache to hold at
|
| 330 |
+
least one request with the model's max_model_len.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
vllm_config: The global VllmConfig
|
| 334 |
+
kv_cache_spec: The kv cache spec of the model
|
| 335 |
+
available_memory: Memory available for KV cache in bytes.
|
| 336 |
+
|
| 337 |
+
Raises:
|
| 338 |
+
ValueError: If there is not enough memory available for the KV cache.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
if available_memory <= 0:
|
| 342 |
+
raise ValueError("No available memory for the cache blocks. "
|
| 343 |
+
"Try increasing `gpu_memory_utilization` when "
|
| 344 |
+
"initializing the engine.")
|
| 345 |
+
|
| 346 |
+
max_model_len = vllm_config.model_config.max_model_len
|
| 347 |
+
needed_memory = 0
|
| 348 |
+
for layer_spec in kv_cache_spec.values():
|
| 349 |
+
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
|
| 350 |
+
|
| 351 |
+
if needed_memory > available_memory:
|
| 352 |
+
raise ValueError(
|
| 353 |
+
f"To serve at least one request with the models's max seq len "
|
| 354 |
+
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
|
| 355 |
+
f"cache is needed, which is larger than the available KV cache "
|
| 356 |
+
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
|
| 357 |
+
f"increasing `gpu_memory_utilization` or decreasing "
|
| 358 |
+
f"`max_model_len` when initializing the engine.")
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
|
| 362 |
+
"""
|
| 363 |
+
Whether all layers in the given KVCacheSpec have the same type of KV cache.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
kv_cache_spec: The KVCacheSpec of the model
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
True if all layers have the same type, False otherwise.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
|
| 373 |
+
return len(layer_keys) == 1
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
| 377 |
+
kv_cache_spec: KVCacheSpec,
|
| 378 |
+
available_memory: int) -> KVCacheConfig:
|
| 379 |
+
"""
|
| 380 |
+
Generates the KV cache configuration for a model with one type of KV cache.
|
| 381 |
+
Divide the available memory equally among all layers.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
vllm_config: The global VllmConfig
|
| 385 |
+
kv_cache_spec: The kv cache spec of the model
|
| 386 |
+
available_memory: Memory available for KV cache in bytes.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
The generated KVCacheConfig
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
| 393 |
+
assert len(page_sizes) == 1
|
| 394 |
+
page_size = page_sizes.pop()
|
| 395 |
+
|
| 396 |
+
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
|
| 397 |
+
num_blocks = max(num_blocks, 0)
|
| 398 |
+
|
| 399 |
+
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
| 400 |
+
num_gpu_blocks_override = \
|
| 401 |
+
vllm_config.cache_config.num_gpu_blocks_override
|
| 402 |
+
logger.info(
|
| 403 |
+
"Overriding num_gpu_blocks=%d with "
|
| 404 |
+
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
| 405 |
+
num_blocks = num_gpu_blocks_override
|
| 406 |
+
|
| 407 |
+
logger.info("# GPU blocks: %d", num_blocks)
|
| 408 |
+
max_concurrency = (num_blocks * vllm_config.cache_config.block_size /
|
| 409 |
+
vllm_config.model_config.max_model_len)
|
| 410 |
+
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
| 411 |
+
vllm_config.model_config.max_model_len, max_concurrency)
|
| 412 |
+
|
| 413 |
+
per_layer_size = page_size * num_blocks
|
| 414 |
+
|
| 415 |
+
kv_cache_config = KVCacheConfig(
|
| 416 |
+
num_blocks=num_blocks,
|
| 417 |
+
tensors={
|
| 418 |
+
layer_name: KVCacheTensor(size=per_layer_size)
|
| 419 |
+
for layer_name in kv_cache_spec
|
| 420 |
+
},
|
| 421 |
+
groups=[[layer_name for layer_name in kv_cache_spec]],
|
| 422 |
+
kv_cache_spec=kv_cache_spec)
|
| 423 |
+
return kv_cache_config
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
|
| 427 |
+
available_memory: int) -> KVCacheConfig:
|
| 428 |
+
"""
|
| 429 |
+
Generates the KV cache configuration for a model
|
| 430 |
+
TODO: support hybrid models with more than one type of KV cache.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
vllm_config: The global VllmConfig
|
| 434 |
+
kv_cache_spec: The kv cache spec of the model
|
| 435 |
+
available_memory: Memory available for KV cache in bytes.
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
The generated KVCacheConfig
|
| 439 |
+
"""
|
| 440 |
+
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
| 441 |
+
if is_kv_cache_type_uniform(kv_cache_spec):
|
| 442 |
+
# KV cache of all layers are the same, which is true for most models.
|
| 443 |
+
# Allocate the same amount of memory for each layer.
|
| 444 |
+
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
| 445 |
+
available_memory)
|
| 446 |
+
else:
|
| 447 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from collections import deque
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
|
| 6 |
+
Tuple, Union)
|
| 7 |
+
|
| 8 |
+
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
|
| 9 |
+
from vllm.logger import init_logger
|
| 10 |
+
from vllm.sampling_params import SamplingParams
|
| 11 |
+
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
| 12 |
+
compute_encoder_budget)
|
| 13 |
+
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
| 14 |
+
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
| 15 |
+
from vllm.v1.metrics.stats import SchedulerStats
|
| 16 |
+
from vllm.v1.outputs import ModelRunnerOutput
|
| 17 |
+
from vllm.v1.request import Request, RequestStatus
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from vllm.multimodal import MultiModalKwargs
|
| 21 |
+
from vllm.multimodal.base import PlaceholderRange
|
| 22 |
+
|
| 23 |
+
logger = init_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Scheduler:
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
scheduler_config: SchedulerConfig,
|
| 31 |
+
model_config: ModelConfig,
|
| 32 |
+
cache_config: CacheConfig,
|
| 33 |
+
lora_config: Optional[LoRAConfig],
|
| 34 |
+
) -> None:
|
| 35 |
+
self.scheduler_config = scheduler_config
|
| 36 |
+
self.cache_config = cache_config
|
| 37 |
+
self.lora_config = lora_config
|
| 38 |
+
# TODO: Support LoRA.
|
| 39 |
+
assert lora_config is None, "V1 does not support LoRA yet."
|
| 40 |
+
|
| 41 |
+
# Scheduling constraints.
|
| 42 |
+
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
| 43 |
+
self.max_num_scheduled_tokens = \
|
| 44 |
+
self.scheduler_config.max_num_batched_tokens
|
| 45 |
+
self.max_model_len = self.scheduler_config.max_model_len
|
| 46 |
+
|
| 47 |
+
num_gpu_blocks = cache_config.num_gpu_blocks
|
| 48 |
+
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
| 49 |
+
# Create the KV cache manager.
|
| 50 |
+
self.kv_cache_manager = KVCacheManager(
|
| 51 |
+
block_size=self.cache_config.block_size,
|
| 52 |
+
num_gpu_blocks=num_gpu_blocks,
|
| 53 |
+
max_model_len=self.max_model_len,
|
| 54 |
+
sliding_window=self.cache_config.sliding_window,
|
| 55 |
+
enable_caching=self.cache_config.enable_prefix_caching)
|
| 56 |
+
self.block_size = self.cache_config.block_size
|
| 57 |
+
|
| 58 |
+
# req_id -> Request
|
| 59 |
+
self.requests: Dict[str, Request] = {}
|
| 60 |
+
# Priority queues for requests.
|
| 61 |
+
self.waiting: Deque[Request] = deque()
|
| 62 |
+
self.running: List[Request] = []
|
| 63 |
+
|
| 64 |
+
# The request IDs that are finished in between the previous and the
|
| 65 |
+
# current steps. This is used to notify the workers about the finished
|
| 66 |
+
# requests so that they can free the cached states for those requests.
|
| 67 |
+
# This is flushed at the end of each scheduling step.
|
| 68 |
+
self.finished_req_ids: Set[str] = set()
|
| 69 |
+
|
| 70 |
+
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
| 71 |
+
# them at each scheduling step.
|
| 72 |
+
# Request id -> CachedRequestData
|
| 73 |
+
self._cached_reqs_data: Dict[str, CachedRequestData] = {}
|
| 74 |
+
|
| 75 |
+
# Encoder-related.
|
| 76 |
+
# Calculate encoder cache size if applicable
|
| 77 |
+
# NOTE: For now we use the same budget for both compute and space.
|
| 78 |
+
# This can be changed when we make encoder cache for embedding caching
|
| 79 |
+
# across requests.
|
| 80 |
+
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
| 81 |
+
model_config=model_config,
|
| 82 |
+
scheduler_config=scheduler_config,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
| 86 |
+
# projector if needed). Currently, we assume that the encoder also
|
| 87 |
+
# has the Transformer architecture (e.g., ViT).
|
| 88 |
+
self.max_num_encoder_input_tokens = encoder_compute_budget
|
| 89 |
+
# NOTE: For the models without encoder (e.g., text-only models),
|
| 90 |
+
# the encoder cache will not be initialized because cache size is 0
|
| 91 |
+
# for these models.
|
| 92 |
+
self.encoder_cache_manager = EncoderCacheManager(
|
| 93 |
+
cache_size=encoder_cache_size)
|
| 94 |
+
|
| 95 |
+
def schedule(self) -> "SchedulerOutput":
|
| 96 |
+
# NOTE(woosuk) on the scheduling algorithm:
|
| 97 |
+
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
| 98 |
+
# Each request just has the num_computed_tokens and num_tokens,
|
| 99 |
+
# which is equal to len(prompt_token_ids) + len(output_token_ids).
|
| 100 |
+
# At each step, the scheduler tries to assign tokens to the requests
|
| 101 |
+
# so that each request's num_computed_tokens can catch up its
|
| 102 |
+
# num_tokens. This is general enough to cover chunked prefills,
|
| 103 |
+
# prefix caching, and the "jump decoding" optimization in the future.
|
| 104 |
+
|
| 105 |
+
scheduled_new_reqs: List[Request] = []
|
| 106 |
+
scheduled_resumed_reqs: List[Request] = []
|
| 107 |
+
scheduled_running_reqs: List[Request] = []
|
| 108 |
+
preempted_reqs: List[Request] = []
|
| 109 |
+
|
| 110 |
+
req_to_new_block_ids: Dict[str, List[int]] = {}
|
| 111 |
+
num_scheduled_tokens: Dict[str, int] = {}
|
| 112 |
+
token_budget = self.max_num_scheduled_tokens
|
| 113 |
+
# Encoder-related.
|
| 114 |
+
scheduled_encoder_inputs: Dict[str, List[int]] = {}
|
| 115 |
+
encoder_budget = self.max_num_encoder_input_tokens
|
| 116 |
+
|
| 117 |
+
# First, schedule the RUNNING requests.
|
| 118 |
+
req_index = 0
|
| 119 |
+
while req_index < len(self.running) and token_budget > 0:
|
| 120 |
+
request = self.running[req_index]
|
| 121 |
+
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
| 122 |
+
num_new_tokens = min(num_new_tokens, token_budget)
|
| 123 |
+
assert num_new_tokens > 0
|
| 124 |
+
|
| 125 |
+
# Schedule encoder inputs.
|
| 126 |
+
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
|
| 127 |
+
self._try_schedule_encoder_inputs(request,
|
| 128 |
+
request.num_computed_tokens,
|
| 129 |
+
num_new_tokens,
|
| 130 |
+
encoder_budget))
|
| 131 |
+
if num_new_tokens == 0:
|
| 132 |
+
# The request cannot be scheduled because the encoder budget
|
| 133 |
+
# or the encoder cache is exhausted.
|
| 134 |
+
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
| 135 |
+
# we do not strictly follow the FCFS scheduling policy and
|
| 136 |
+
# allow the lower-priority requests to be scheduled.
|
| 137 |
+
req_index += 1
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
while True:
|
| 141 |
+
new_blocks = self.kv_cache_manager.allocate_slots(
|
| 142 |
+
request, num_new_tokens)
|
| 143 |
+
if new_blocks is None:
|
| 144 |
+
# The request cannot be scheduled.
|
| 145 |
+
# Preempt the lowest-priority request.
|
| 146 |
+
preempted_req = self.running.pop()
|
| 147 |
+
self.kv_cache_manager.free(preempted_req)
|
| 148 |
+
preempted_req.status = RequestStatus.PREEMPTED
|
| 149 |
+
preempted_req.num_computed_tokens = 0
|
| 150 |
+
|
| 151 |
+
self.waiting.appendleft(preempted_req)
|
| 152 |
+
preempted_reqs.append(preempted_req)
|
| 153 |
+
if preempted_req == request:
|
| 154 |
+
# No more request to preempt.
|
| 155 |
+
can_schedule = False
|
| 156 |
+
break
|
| 157 |
+
else:
|
| 158 |
+
# The request can be scheduled.
|
| 159 |
+
can_schedule = True
|
| 160 |
+
break
|
| 161 |
+
if not can_schedule:
|
| 162 |
+
break
|
| 163 |
+
assert new_blocks is not None
|
| 164 |
+
|
| 165 |
+
# Schedule the request.
|
| 166 |
+
scheduled_running_reqs.append(request)
|
| 167 |
+
req_to_new_block_ids[request.request_id] = [
|
| 168 |
+
b.block_id for b in new_blocks
|
| 169 |
+
]
|
| 170 |
+
num_scheduled_tokens[request.request_id] = num_new_tokens
|
| 171 |
+
token_budget -= num_new_tokens
|
| 172 |
+
req_index += 1
|
| 173 |
+
|
| 174 |
+
# Encoder-related.
|
| 175 |
+
if encoder_inputs_to_schedule:
|
| 176 |
+
scheduled_encoder_inputs[request.request_id] = (
|
| 177 |
+
encoder_inputs_to_schedule)
|
| 178 |
+
# Allocate the encoder cache.
|
| 179 |
+
for i in encoder_inputs_to_schedule:
|
| 180 |
+
self.encoder_cache_manager.allocate(request, i)
|
| 181 |
+
encoder_budget = new_encoder_budget
|
| 182 |
+
|
| 183 |
+
# Next, schedule the WAITING requests.
|
| 184 |
+
if not preempted_reqs:
|
| 185 |
+
while self.waiting and token_budget > 0:
|
| 186 |
+
if len(self.running) == self.max_num_running_reqs:
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
request = self.waiting[0]
|
| 190 |
+
# Get already-cached tokens.
|
| 191 |
+
computed_blocks, num_computed_tokens = \
|
| 192 |
+
self.kv_cache_manager.get_computed_blocks(request)
|
| 193 |
+
# Number of tokens to be scheduled.
|
| 194 |
+
# We use `request.num_tokens` instead of
|
| 195 |
+
# `request.num_prompt_tokens` to consider the resumed requests,
|
| 196 |
+
# which have output tokens.
|
| 197 |
+
num_new_tokens = request.num_tokens - num_computed_tokens
|
| 198 |
+
if num_new_tokens == 0:
|
| 199 |
+
# This happens when prompt length is divisible by the block
|
| 200 |
+
# size and all blocks are cached. Now we force to recompute
|
| 201 |
+
# the last block. Note that we have to re-compute an entire
|
| 202 |
+
# block because allocate_slots() assumes num_computed_tokens
|
| 203 |
+
# is always a multiple of the block size. This limitation
|
| 204 |
+
# can potentially be removed in the future to slightly
|
| 205 |
+
# improve the performance.
|
| 206 |
+
num_computed_tokens -= self.block_size
|
| 207 |
+
num_new_tokens = self.block_size
|
| 208 |
+
computed_blocks.pop()
|
| 209 |
+
num_new_tokens = min(num_new_tokens, token_budget)
|
| 210 |
+
assert num_new_tokens > 0
|
| 211 |
+
|
| 212 |
+
# Schedule encoder inputs.
|
| 213 |
+
(encoder_inputs_to_schedule, num_new_tokens,
|
| 214 |
+
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
| 215 |
+
request, num_computed_tokens, num_new_tokens,
|
| 216 |
+
encoder_budget)
|
| 217 |
+
if num_new_tokens == 0:
|
| 218 |
+
# The request cannot be scheduled.
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
new_blocks = self.kv_cache_manager.allocate_slots(
|
| 222 |
+
request, num_new_tokens, computed_blocks)
|
| 223 |
+
if new_blocks is None:
|
| 224 |
+
# The request cannot be scheduled.
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
self.waiting.popleft()
|
| 228 |
+
self.running.append(request)
|
| 229 |
+
if request.status == RequestStatus.WAITING:
|
| 230 |
+
scheduled_new_reqs.append(request)
|
| 231 |
+
elif request.status == RequestStatus.PREEMPTED:
|
| 232 |
+
scheduled_resumed_reqs.append(request)
|
| 233 |
+
else:
|
| 234 |
+
raise RuntimeError(
|
| 235 |
+
f"Invalid request status: {request.status}")
|
| 236 |
+
|
| 237 |
+
req_to_new_block_ids[request.request_id] = [
|
| 238 |
+
b.block_id for b in computed_blocks + new_blocks
|
| 239 |
+
]
|
| 240 |
+
num_scheduled_tokens[request.request_id] = num_new_tokens
|
| 241 |
+
token_budget -= num_new_tokens
|
| 242 |
+
request.status = RequestStatus.RUNNING
|
| 243 |
+
request.num_computed_tokens = num_computed_tokens
|
| 244 |
+
|
| 245 |
+
# Encoder-related.
|
| 246 |
+
if encoder_inputs_to_schedule:
|
| 247 |
+
scheduled_encoder_inputs[request.request_id] = (
|
| 248 |
+
encoder_inputs_to_schedule)
|
| 249 |
+
# Allocate the encoder cache.
|
| 250 |
+
for i in encoder_inputs_to_schedule:
|
| 251 |
+
self.encoder_cache_manager.allocate(request, i)
|
| 252 |
+
encoder_budget = new_encoder_budget
|
| 253 |
+
|
| 254 |
+
# Check if the scheduling constraints are satisfied.
|
| 255 |
+
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
| 256 |
+
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
| 257 |
+
assert token_budget >= 0
|
| 258 |
+
assert len(self.running) <= self.max_num_running_reqs
|
| 259 |
+
# Since some requests in the RUNNING queue may not be scheduled in
|
| 260 |
+
# this step, the total number of scheduled requests can be smaller than
|
| 261 |
+
# len(self.running).
|
| 262 |
+
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
| 263 |
+
len(scheduled_running_reqs) <= len(self.running))
|
| 264 |
+
|
| 265 |
+
# Get the longest common prefix among all requests in the running queue.
|
| 266 |
+
# This can be potentially used for cascade attention.
|
| 267 |
+
num_common_prefix_blocks = 0
|
| 268 |
+
if self.running:
|
| 269 |
+
any_request = self.running[0]
|
| 270 |
+
num_common_prefix_blocks = (
|
| 271 |
+
self.kv_cache_manager.get_num_common_prefix_blocks(
|
| 272 |
+
any_request, len(self.running)))
|
| 273 |
+
|
| 274 |
+
# Construct the scheduler output.
|
| 275 |
+
new_reqs_data = [
|
| 276 |
+
NewRequestData.from_request(req,
|
| 277 |
+
req_to_new_block_ids[req.request_id],
|
| 278 |
+
req.num_computed_tokens)
|
| 279 |
+
for req in scheduled_new_reqs
|
| 280 |
+
]
|
| 281 |
+
resumed_reqs_data = [
|
| 282 |
+
self._make_cached_request_data(
|
| 283 |
+
req,
|
| 284 |
+
req_to_new_block_ids[req.request_id],
|
| 285 |
+
req.num_computed_tokens,
|
| 286 |
+
resumed_from_preemption=True,
|
| 287 |
+
) for req in scheduled_resumed_reqs
|
| 288 |
+
]
|
| 289 |
+
running_reqs_data = [
|
| 290 |
+
self._make_cached_request_data(
|
| 291 |
+
req,
|
| 292 |
+
req_to_new_block_ids[req.request_id],
|
| 293 |
+
req.num_computed_tokens,
|
| 294 |
+
resumed_from_preemption=False,
|
| 295 |
+
) for req in scheduled_running_reqs
|
| 296 |
+
]
|
| 297 |
+
scheduler_output = SchedulerOutput(
|
| 298 |
+
scheduled_new_reqs=new_reqs_data,
|
| 299 |
+
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
| 300 |
+
num_scheduled_tokens=num_scheduled_tokens,
|
| 301 |
+
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
| 302 |
+
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
| 303 |
+
num_common_prefix_blocks=num_common_prefix_blocks,
|
| 304 |
+
# finished_req_ids is an existing state in the scheduler,
|
| 305 |
+
# instead of being newly scheduled in this step.
|
| 306 |
+
# It contains the request IDs that are finished in between
|
| 307 |
+
# the previous and the current steps.
|
| 308 |
+
finished_req_ids=self.finished_req_ids,
|
| 309 |
+
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
self.finished_req_ids = set()
|
| 313 |
+
return scheduler_output
|
| 314 |
+
|
| 315 |
+
def _make_cached_request_data(
|
| 316 |
+
self,
|
| 317 |
+
request: Request,
|
| 318 |
+
new_block_ids: List[int],
|
| 319 |
+
num_computed_tokens: int,
|
| 320 |
+
resumed_from_preemption: bool,
|
| 321 |
+
) -> "CachedRequestData":
|
| 322 |
+
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
| 323 |
+
# them at each scheduling step.
|
| 324 |
+
if request.request_id in self._cached_reqs_data:
|
| 325 |
+
req_data = self._cached_reqs_data[request.request_id]
|
| 326 |
+
req_data.resumed_from_preemption = resumed_from_preemption
|
| 327 |
+
req_data.new_block_ids = new_block_ids
|
| 328 |
+
req_data.num_computed_tokens = num_computed_tokens
|
| 329 |
+
else:
|
| 330 |
+
req_data = CachedRequestData.from_request(request,
|
| 331 |
+
resumed_from_preemption,
|
| 332 |
+
new_block_ids,
|
| 333 |
+
num_computed_tokens)
|
| 334 |
+
self._cached_reqs_data[request.request_id] = req_data
|
| 335 |
+
return req_data
|
| 336 |
+
|
| 337 |
+
def _try_schedule_encoder_inputs(
|
| 338 |
+
self,
|
| 339 |
+
request: Request,
|
| 340 |
+
num_computed_tokens: int,
|
| 341 |
+
num_new_tokens: int,
|
| 342 |
+
encoder_budget: int,
|
| 343 |
+
) -> Tuple[List[int], int, int]:
|
| 344 |
+
"""
|
| 345 |
+
Determine which encoder inputs need to be scheduled in the current step,
|
| 346 |
+
and update `num_new_tokens` and encoder token budget accordingly.
|
| 347 |
+
|
| 348 |
+
An encoder input will be scheduled if:
|
| 349 |
+
- Its output tokens overlap with the range of tokens being computed
|
| 350 |
+
in this step, i.e.,
|
| 351 |
+
[num_computed_tokens, num_computed_tokens + num_new_tokens).
|
| 352 |
+
- It is not already computed and stored in the encoder cache.
|
| 353 |
+
- There is sufficient encoder token budget to process it.
|
| 354 |
+
- The encoder cache has space to store it.
|
| 355 |
+
|
| 356 |
+
If an encoder input cannot be scheduled due to cache or budget
|
| 357 |
+
limitations, the method adjusts `num_new_tokens` to schedule only the
|
| 358 |
+
decoder tokens up to just before the unschedulable encoder input.
|
| 359 |
+
"""
|
| 360 |
+
if not request.has_encoder_inputs():
|
| 361 |
+
return [], num_new_tokens, encoder_budget
|
| 362 |
+
|
| 363 |
+
encoder_inputs_to_schedule: List[int] = []
|
| 364 |
+
mm_positions = request.mm_positions
|
| 365 |
+
assert mm_positions is not None
|
| 366 |
+
assert len(mm_positions) > 0
|
| 367 |
+
for i, pos_info in enumerate(mm_positions):
|
| 368 |
+
start_pos = pos_info["offset"]
|
| 369 |
+
num_encoder_tokens = pos_info["length"]
|
| 370 |
+
|
| 371 |
+
# The encoder output is needed if the two ranges overlap:
|
| 372 |
+
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
| 373 |
+
# [start_pos, start_pos + num_encoder_tokens)
|
| 374 |
+
if start_pos >= num_computed_tokens + num_new_tokens:
|
| 375 |
+
# The encoder input is not needed in this step.
|
| 376 |
+
break
|
| 377 |
+
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
| 378 |
+
# The encoder input is already computed and stored
|
| 379 |
+
# in the decoder's KV cache.
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
if self.encoder_cache_manager.has_cache(request, i):
|
| 383 |
+
# The encoder input is already computed and cached.
|
| 384 |
+
continue
|
| 385 |
+
if (not self.encoder_cache_manager.can_allocate(request, i)
|
| 386 |
+
or num_encoder_tokens > encoder_budget):
|
| 387 |
+
# The encoder cache is full or the encoder budget is exhausted.
|
| 388 |
+
# NOTE(woosuk): We assume that the encoder input tokens should
|
| 389 |
+
# be processed altogether, as the encoder usually uses
|
| 390 |
+
# bidirectional attention.
|
| 391 |
+
if num_computed_tokens < start_pos:
|
| 392 |
+
# We only schedule the decoder tokens just before the
|
| 393 |
+
# encoder input.
|
| 394 |
+
num_new_tokens = start_pos - num_computed_tokens
|
| 395 |
+
else:
|
| 396 |
+
# Because of prefix caching, num_computed_tokens is greater
|
| 397 |
+
# than start_pos even though its encoder input is not
|
| 398 |
+
# available. In this case, we can't schedule any token for
|
| 399 |
+
# the request in this step.
|
| 400 |
+
num_new_tokens = 0
|
| 401 |
+
break
|
| 402 |
+
|
| 403 |
+
encoder_budget -= num_encoder_tokens
|
| 404 |
+
encoder_inputs_to_schedule.append(i)
|
| 405 |
+
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
|
| 406 |
+
|
| 407 |
+
def update_from_output(
|
| 408 |
+
self,
|
| 409 |
+
scheduler_output: "SchedulerOutput",
|
| 410 |
+
model_runner_output: "ModelRunnerOutput",
|
| 411 |
+
) -> EngineCoreOutputs:
|
| 412 |
+
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
| 413 |
+
sampled_token_ids = model_runner_output.sampled_token_ids
|
| 414 |
+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
| 415 |
+
new_running: List[Request] = []
|
| 416 |
+
outputs: List[EngineCoreOutput] = []
|
| 417 |
+
|
| 418 |
+
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
| 419 |
+
# loop can be a performance bottleneck. We should do our best to avoid
|
| 420 |
+
# expensive operations inside the loop.
|
| 421 |
+
for request in self.running:
|
| 422 |
+
req_id = request.request_id
|
| 423 |
+
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
| 424 |
+
if num_tokens_scheduled == 0:
|
| 425 |
+
# The request was not scheduled in this step.
|
| 426 |
+
new_running.append(request)
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
request.num_computed_tokens += num_tokens_scheduled
|
| 430 |
+
# When the request's num_computed_tokens catches up its num_tokens,
|
| 431 |
+
# the request generates output tokens. Otherwise, we ignore the
|
| 432 |
+
# sampler output for the request.
|
| 433 |
+
assert request.num_computed_tokens <= request.num_tokens
|
| 434 |
+
|
| 435 |
+
cached_encoder_input_ids = (
|
| 436 |
+
self.encoder_cache_manager.get_cached_input_ids(request))
|
| 437 |
+
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
| 438 |
+
if cached_encoder_input_ids:
|
| 439 |
+
for input_id in list(cached_encoder_input_ids):
|
| 440 |
+
start_pos = request.mm_positions[input_id]["offset"]
|
| 441 |
+
num_tokens = request.mm_positions[input_id]["length"]
|
| 442 |
+
if start_pos + num_tokens <= request.num_computed_tokens:
|
| 443 |
+
# The encoder output is already processed and stored
|
| 444 |
+
# in the decoder's KV cache.
|
| 445 |
+
self.encoder_cache_manager.free_encoder_input(
|
| 446 |
+
request, input_id)
|
| 447 |
+
|
| 448 |
+
if request.num_computed_tokens == request.num_tokens:
|
| 449 |
+
req_index = model_runner_output.req_id_to_index[req_id]
|
| 450 |
+
# NOTE(woosuk): Currently, we assume that each request
|
| 451 |
+
# generates at most one token at each step.
|
| 452 |
+
token_id = sampled_token_ids[req_index]
|
| 453 |
+
request.append_output_token_ids(token_id)
|
| 454 |
+
num_new_tokens = 1
|
| 455 |
+
# TODO: Update the KV cache manager for prefix caching.
|
| 456 |
+
|
| 457 |
+
# Check for stop and update request state.
|
| 458 |
+
# This must be called before we make the EngineCoreOutput.
|
| 459 |
+
stopped = self._check_stop(request)
|
| 460 |
+
if stopped:
|
| 461 |
+
self._free_request(request)
|
| 462 |
+
|
| 463 |
+
# Add EngineCoreOutput for this Request.
|
| 464 |
+
output = EngineCoreOutput(
|
| 465 |
+
request_id=req_id,
|
| 466 |
+
new_token_ids=request.output_token_ids[-num_new_tokens:],
|
| 467 |
+
finished=request.is_finished(),
|
| 468 |
+
finish_reason=request.get_finished_reason(),
|
| 469 |
+
stop_reason=request.stop_reason)
|
| 470 |
+
outputs.append(output)
|
| 471 |
+
|
| 472 |
+
# Breakout of the loop.
|
| 473 |
+
if stopped:
|
| 474 |
+
continue
|
| 475 |
+
|
| 476 |
+
new_running.append(request)
|
| 477 |
+
self.running = new_running
|
| 478 |
+
return EngineCoreOutputs(
|
| 479 |
+
outputs=outputs,
|
| 480 |
+
scheduler_stats=self.make_stats(),
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
def _check_stop(self, request: Request) -> bool:
|
| 484 |
+
if (request.num_tokens >= self.max_model_len
|
| 485 |
+
or request.num_output_tokens >= request.max_tokens):
|
| 486 |
+
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
| 487 |
+
return True
|
| 488 |
+
|
| 489 |
+
sampling_params = request.sampling_params
|
| 490 |
+
last_token_id = request.output_token_ids[-1]
|
| 491 |
+
if (not sampling_params.ignore_eos
|
| 492 |
+
and last_token_id == request.eos_token_id):
|
| 493 |
+
request.status = RequestStatus.FINISHED_STOPPED
|
| 494 |
+
return True
|
| 495 |
+
|
| 496 |
+
if last_token_id in (sampling_params.stop_token_ids or ()):
|
| 497 |
+
request.status = RequestStatus.FINISHED_STOPPED
|
| 498 |
+
request.stop_reason = last_token_id
|
| 499 |
+
return True
|
| 500 |
+
return False
|
| 501 |
+
|
| 502 |
+
def add_request(self, request: Request) -> None:
|
| 503 |
+
self.waiting.append(request)
|
| 504 |
+
self.requests[request.request_id] = request
|
| 505 |
+
|
| 506 |
+
def finish_requests(
|
| 507 |
+
self,
|
| 508 |
+
request_ids: Union[str, Iterable[str]],
|
| 509 |
+
finished_status: RequestStatus,
|
| 510 |
+
) -> None:
|
| 511 |
+
"""Handles the finish signal from outside the scheduler.
|
| 512 |
+
|
| 513 |
+
For example, the API server can abort a request when the client
|
| 514 |
+
disconnects.
|
| 515 |
+
"""
|
| 516 |
+
assert RequestStatus.is_finished(finished_status)
|
| 517 |
+
if isinstance(request_ids, str):
|
| 518 |
+
request_ids = (request_ids, )
|
| 519 |
+
request_ids = set(request_ids)
|
| 520 |
+
|
| 521 |
+
for req_id in request_ids:
|
| 522 |
+
request = self.requests.get(req_id)
|
| 523 |
+
if request is None:
|
| 524 |
+
# Invalid request ID.
|
| 525 |
+
continue
|
| 526 |
+
|
| 527 |
+
if request.status == RequestStatus.RUNNING:
|
| 528 |
+
self.running.remove(request)
|
| 529 |
+
else:
|
| 530 |
+
self.waiting.remove(request)
|
| 531 |
+
request.status = finished_status
|
| 532 |
+
self._free_request(request)
|
| 533 |
+
|
| 534 |
+
def _free_request(self, request: Request) -> None:
|
| 535 |
+
assert request.is_finished()
|
| 536 |
+
self.kv_cache_manager.free(request)
|
| 537 |
+
self.encoder_cache_manager.free(request)
|
| 538 |
+
self._cached_reqs_data.pop(request.request_id, None)
|
| 539 |
+
del self.requests[request.request_id]
|
| 540 |
+
self.finished_req_ids.add(request.request_id)
|
| 541 |
+
|
| 542 |
+
def get_num_unfinished_requests(self) -> int:
|
| 543 |
+
return len(self.waiting) + len(self.running)
|
| 544 |
+
|
| 545 |
+
def has_unfinished_requests(self) -> bool:
|
| 546 |
+
return self.get_num_unfinished_requests() > 0
|
| 547 |
+
|
| 548 |
+
def reset_prefix_cache(self) -> bool:
|
| 549 |
+
return self.kv_cache_manager.reset_prefix_cache()
|
| 550 |
+
|
| 551 |
+
def make_stats(self) -> SchedulerStats:
|
| 552 |
+
return SchedulerStats(
|
| 553 |
+
num_running_reqs=len(self.running),
|
| 554 |
+
num_waiting_reqs=len(self.waiting),
|
| 555 |
+
gpu_cache_usage=self.kv_cache_manager.usage,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@dataclass
|
| 560 |
+
class NewRequestData:
|
| 561 |
+
|
| 562 |
+
req_id: str
|
| 563 |
+
prompt_token_ids: List[int]
|
| 564 |
+
prompt: Optional[str]
|
| 565 |
+
mm_inputs: List["MultiModalKwargs"]
|
| 566 |
+
mm_hashes: List[str]
|
| 567 |
+
mm_positions: List["PlaceholderRange"]
|
| 568 |
+
sampling_params: SamplingParams
|
| 569 |
+
block_ids: List[int]
|
| 570 |
+
num_computed_tokens: int
|
| 571 |
+
|
| 572 |
+
@classmethod
|
| 573 |
+
def from_request(
|
| 574 |
+
cls,
|
| 575 |
+
request: Request,
|
| 576 |
+
block_ids: List[int],
|
| 577 |
+
num_computed_tokens: int,
|
| 578 |
+
) -> "NewRequestData":
|
| 579 |
+
return cls(
|
| 580 |
+
req_id=request.request_id,
|
| 581 |
+
prompt_token_ids=request.prompt_token_ids,
|
| 582 |
+
prompt=request.prompt,
|
| 583 |
+
mm_inputs=request.mm_inputs,
|
| 584 |
+
mm_hashes=request.mm_hashes,
|
| 585 |
+
mm_positions=request.mm_positions,
|
| 586 |
+
sampling_params=request.sampling_params,
|
| 587 |
+
block_ids=block_ids,
|
| 588 |
+
num_computed_tokens=num_computed_tokens,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@dataclass
|
| 593 |
+
class CachedRequestData:
|
| 594 |
+
|
| 595 |
+
req_id: str
|
| 596 |
+
# If resumed_from_preemption is False, new_block_ids will be appended to
|
| 597 |
+
# the request's block IDs. If True, new_block_ids will be used as the
|
| 598 |
+
# request's block IDs instead of appending to the existing block IDs.
|
| 599 |
+
resumed_from_preemption: bool
|
| 600 |
+
new_block_ids: List[int]
|
| 601 |
+
num_computed_tokens: int
|
| 602 |
+
|
| 603 |
+
@classmethod
|
| 604 |
+
def from_request(
|
| 605 |
+
cls,
|
| 606 |
+
request: Request,
|
| 607 |
+
resumed_from_preemption: bool,
|
| 608 |
+
new_block_ids: List[int],
|
| 609 |
+
num_computed_tokens: int,
|
| 610 |
+
) -> "CachedRequestData":
|
| 611 |
+
return cls(
|
| 612 |
+
req_id=request.request_id,
|
| 613 |
+
resumed_from_preemption=resumed_from_preemption,
|
| 614 |
+
new_block_ids=new_block_ids,
|
| 615 |
+
num_computed_tokens=num_computed_tokens,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
@dataclass
|
| 620 |
+
class SchedulerOutput:
|
| 621 |
+
|
| 622 |
+
scheduled_new_reqs: List[NewRequestData]
|
| 623 |
+
scheduled_cached_reqs: List[CachedRequestData]
|
| 624 |
+
|
| 625 |
+
num_scheduled_tokens: Dict[str, int]
|
| 626 |
+
total_num_scheduled_tokens: int
|
| 627 |
+
scheduled_encoder_inputs: Dict[str, List[int]]
|
| 628 |
+
num_common_prefix_blocks: int
|
| 629 |
+
|
| 630 |
+
finished_req_ids: Set[str]
|
| 631 |
+
free_encoder_input_ids: List[Tuple[str, int]]
|
.venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc
ADDED
|
Binary file (4.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Type
|
| 4 |
+
|
| 5 |
+
from vllm.config import VllmConfig
|
| 6 |
+
from vllm.executor.executor_base import ExecutorBase
|
| 7 |
+
from vllm.executor.ray_distributed_executor import ( # noqa
|
| 8 |
+
RayDistributedExecutor as RayDistributedExecutorV0)
|
| 9 |
+
from vllm.executor.uniproc_executor import ( # noqa
|
| 10 |
+
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
| 11 |
+
from vllm.executor.uniproc_executor import ( # noqa
|
| 12 |
+
UniProcExecutor as UniProcExecutorV0)
|
| 13 |
+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
| 14 |
+
from vllm.v1.outputs import ModelRunnerOutput
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Executor(ExecutorBase):
|
| 18 |
+
"""
|
| 19 |
+
Abstract class for v1 executors, mainly define some methods for v1.
|
| 20 |
+
For methods shared by v0 and v1, define them in ExecutorBase"""
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
|
| 24 |
+
executor_class: Type[Executor]
|
| 25 |
+
parallel_config = vllm_config.parallel_config
|
| 26 |
+
distributed_executor_backend = (
|
| 27 |
+
parallel_config.distributed_executor_backend)
|
| 28 |
+
if distributed_executor_backend is None:
|
| 29 |
+
# If the user does not specify the distributed executor backend,
|
| 30 |
+
# we will choose the backend based on the world size.
|
| 31 |
+
if parallel_config.world_size > 1:
|
| 32 |
+
distributed_executor_backend = "mp"
|
| 33 |
+
else:
|
| 34 |
+
distributed_executor_backend = "uni"
|
| 35 |
+
|
| 36 |
+
if distributed_executor_backend == "ray":
|
| 37 |
+
executor_class = RayDistributedExecutor
|
| 38 |
+
elif distributed_executor_backend == "mp":
|
| 39 |
+
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
| 40 |
+
executor_class = MultiprocExecutor
|
| 41 |
+
elif distributed_executor_backend == "uni":
|
| 42 |
+
executor_class = UniProcExecutor
|
| 43 |
+
elif distributed_executor_backend == "external_launcher":
|
| 44 |
+
# TODO: make v1 scheduling deterministic
|
| 45 |
+
# to support external launcher
|
| 46 |
+
executor_class = ExecutorWithExternalLauncher
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError("Unknown distributed executor backend: "
|
| 49 |
+
f"{distributed_executor_backend}")
|
| 50 |
+
return executor_class
|
| 51 |
+
|
| 52 |
+
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
| 53 |
+
"""
|
| 54 |
+
Initialize the KV caches and begin the model execution loop of the
|
| 55 |
+
underlying workers.
|
| 56 |
+
"""
|
| 57 |
+
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
|
| 58 |
+
self.collective_rpc("compile_or_warm_up_model")
|
| 59 |
+
|
| 60 |
+
def determine_available_memory(self) -> int: # in bytes
|
| 61 |
+
output = self.collective_rpc("determine_available_memory")
|
| 62 |
+
# Since we use a shared centralized controller, we take the minimum
|
| 63 |
+
# memory size across all workers to make sure all the memory
|
| 64 |
+
# operators can be applied to all workers.
|
| 65 |
+
return min(output)
|
| 66 |
+
|
| 67 |
+
def get_kv_cache_spec(self) -> KVCacheSpec:
|
| 68 |
+
output = self.collective_rpc("get_kv_cache_spec")
|
| 69 |
+
for x in output:
|
| 70 |
+
assert x == output[0]
|
| 71 |
+
return output[0]
|
| 72 |
+
|
| 73 |
+
def execute_model(
|
| 74 |
+
self,
|
| 75 |
+
scheduler_output,
|
| 76 |
+
) -> ModelRunnerOutput:
|
| 77 |
+
output = self.collective_rpc("execute_model",
|
| 78 |
+
args=(scheduler_output, ))
|
| 79 |
+
return output[0]
|
| 80 |
+
|
| 81 |
+
def profile(self, is_start: bool = True):
|
| 82 |
+
self.collective_rpc("profile", args=(is_start, ))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class UniProcExecutor(UniProcExecutorV0, Executor):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
| 94 |
+
pass
|