koichi12 commited on
Commit
b0fb87d
·
verified ·
1 Parent(s): 7bc7e14

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py +8 -0
  2. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py +160 -0
  6. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +135 -0
  7. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +253 -0
  8. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +231 -0
  9. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +369 -0
  10. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +210 -0
  11. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +291 -0
  12. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py +9 -0
  13. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py +483 -0
  19. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py +348 -0
  20. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py +89 -0
  21. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py +20 -0
  22. .venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py +161 -0
  23. .venv/lib/python3.11/site-packages/vllm/v1/__init__.py +0 -0
  24. .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py +0 -0
  31. .venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py +0 -0
  33. .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py +459 -0
  36. .venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py +0 -0
  37. .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py +133 -0
  43. .venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py +500 -0
  44. .venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py +447 -0
  45. .venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py +631 -0
  46. .venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py +0 -0
  47. .venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py +94 -0
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
4
+ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
5
+
6
+ __all__ = [
7
+ "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
8
+ ]
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (474 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc ADDED
Binary file (8.53 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc ADDED
Binary file (6.01 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from functools import cached_property
5
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
6
+
7
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
8
+ DeltaMessage)
9
+ from vllm.logger import init_logger
10
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
11
+ from vllm.utils import import_from_path, is_list_of
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class ReasoningParser:
17
+ """
18
+ Abstract reasoning parser class that should not be used directly.
19
+ Provided and methods should be used in derived classes.
20
+
21
+ It is used to extract reasoning content from the model output.
22
+ """
23
+
24
+ def __init__(self, tokenizer: AnyTokenizer):
25
+ self.model_tokenizer = tokenizer
26
+
27
+ @cached_property
28
+ def vocab(self) -> Dict[str, int]:
29
+ # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
30
+ # whereas all tokenizers have .get_vocab()
31
+ return self.model_tokenizer.get_vocab()
32
+
33
+ def extract_reasoning_content(
34
+ self, model_output: str, request: ChatCompletionRequest
35
+ ) -> Tuple[Optional[str], Optional[str]]:
36
+ """
37
+ Extract reasoning content from a complete model-generated string.
38
+
39
+ Used for non-streaming responses where we have the entire model response
40
+ available before sending to the client.
41
+
42
+ Parameters:
43
+ model_output: str
44
+ The model-generated string to extract reasoning content from.
45
+
46
+ request: ChatCompletionRequest
47
+ The request object that was used to generate the model_output.
48
+
49
+ Returns:
50
+ Tuple[Optional[str], Optional[str]]
51
+ A tuple containing the reasoning content and the content.
52
+ """
53
+
54
+ raise NotImplementedError(
55
+ "AbstractReasoningParser.extract_reasoning_calls "
56
+ "has not been implemented!")
57
+
58
+ def extract_reasoning_content_streaming(
59
+ self,
60
+ previous_text: str,
61
+ current_text: str,
62
+ delta_text: str,
63
+ previous_token_ids: Sequence[int],
64
+ current_token_ids: Sequence[int],
65
+ delta_token_ids: Sequence[int],
66
+ ) -> Union[DeltaMessage, None]:
67
+ """
68
+ Instance method that should be implemented for extracting reasoning
69
+ from an incomplete response; for use when handling reasoning calls and
70
+ streaming. Has to be an instance method because it requires state -
71
+ the current tokens/diffs, but also the information about what has
72
+ previously been parsed and extracted (see constructor)
73
+ """
74
+ raise NotImplementedError(
75
+ "AbstractReasoningParser.extract_reasoning_content_streaming "
76
+ "has not been implemented!")
77
+
78
+
79
+ class ReasoningParserManager:
80
+ reasoning_parsers: Dict[str, Type] = {}
81
+
82
+ @classmethod
83
+ def get_reasoning_parser(cls, name) -> Type:
84
+ """
85
+ Get reasoning parser by name which is registered by `register_module`.
86
+
87
+ Raise a KeyError exception if the name is not registered.
88
+ """
89
+ if name in cls.reasoning_parsers:
90
+ return cls.reasoning_parsers[name]
91
+
92
+ raise KeyError(f"reasoning helper: '{name}' not found in "
93
+ "reasoning_parsers")
94
+
95
+ @classmethod
96
+ def _register_module(cls,
97
+ module: Type,
98
+ module_name: Optional[Union[str, List[str]]] = None,
99
+ force: bool = True) -> None:
100
+ if not issubclass(module, ReasoningParser):
101
+ raise TypeError("module must be subclass of ReasoningParser, "
102
+ f"but got {type(module)}")
103
+ if module_name is None:
104
+ module_name = module.__name__
105
+ if isinstance(module_name, str):
106
+ module_name = [module_name]
107
+ for name in module_name:
108
+ if not force and name in cls.reasoning_parsers:
109
+ existed_module = cls.reasoning_parsers[name]
110
+ raise KeyError(f"{name} is already registered "
111
+ f"at {existed_module.__module__}")
112
+ cls.reasoning_parsers[name] = module
113
+
114
+ @classmethod
115
+ def register_module(
116
+ cls,
117
+ name: Optional[Union[str, List[str]]] = None,
118
+ force: bool = True,
119
+ module: Union[Type, None] = None) -> Union[type, Callable]:
120
+ """
121
+ Register module with the given name or name list. it can be used as a
122
+ decoder(with module as None) or normal function(with module as not
123
+ None).
124
+ """
125
+ if not isinstance(force, bool):
126
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
127
+
128
+ # raise the error ahead of time
129
+ if not (name is None or isinstance(name, str)
130
+ or is_list_of(name, str)):
131
+ raise TypeError(
132
+ "name must be None, an instance of str, or a sequence of str, "
133
+ f"but got {type(name)}")
134
+
135
+ # use it as a normal method: x.register_module(module=SomeClass)
136
+ if module is not None:
137
+ cls._register_module(module=module, module_name=name, force=force)
138
+ return module
139
+
140
+ # use it as a decorator: @x.register_module()
141
+ def _register(module):
142
+ cls._register_module(module=module, module_name=name, force=force)
143
+ return module
144
+
145
+ return _register
146
+
147
+ @classmethod
148
+ def import_reasoning_parser(cls, plugin_path: str) -> None:
149
+ """
150
+ Import a user-defined reasoning parser by the path
151
+ of the reasoning parser define file.
152
+ """
153
+ module_name = os.path.splitext(os.path.basename(plugin_path))[0]
154
+
155
+ try:
156
+ import_from_path(module_name, plugin_path)
157
+ except Exception:
158
+ logger.exception("Failed to load module '%s' from %s.",
159
+ module_name, plugin_path)
160
+ return
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import re
4
+ from typing import Optional, Sequence, Tuple, Union
5
+
6
+ from transformers import PreTrainedTokenizerBase
7
+
8
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
9
+ DeltaMessage)
10
+ from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
11
+ ReasoningParser, ReasoningParserManager)
12
+ from vllm.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ @ReasoningParserManager.register_module("deepseek_r1")
18
+ class DeepSeekR1ReasoningParser(ReasoningParser):
19
+ """
20
+ Reasoning parser for DeepSeek R1 model.
21
+
22
+ The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
23
+ text. This parser extracts the reasoning content from the model output.
24
+ """
25
+
26
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
27
+ super().__init__(tokenizer)
28
+ self.think_start_token = "<think>"
29
+ self.think_end_token = "</think>"
30
+
31
+ self.reasoning_regex = re.compile(
32
+ rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
33
+
34
+ if not self.model_tokenizer:
35
+ raise ValueError(
36
+ "The model tokenizer must be passed to the ReasoningParser "
37
+ "constructor during construction.")
38
+
39
+ self.think_start_token_id = self.vocab.get(self.think_start_token)
40
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
41
+ if (self.think_start_token_id is None
42
+ or self.think_end_token_id is None):
43
+ raise RuntimeError(
44
+ "DeepSeek R1 reasoning parser could not locate think start/end "
45
+ "tokens in the tokenizer!")
46
+
47
+ def extract_reasoning_content_streaming(
48
+ self,
49
+ previous_text: str,
50
+ current_text: str,
51
+ delta_text: str,
52
+ previous_token_ids: Sequence[int],
53
+ current_token_ids: Sequence[int],
54
+ delta_token_ids: Sequence[int],
55
+ ) -> Union[DeltaMessage, None]:
56
+ """
57
+ Extract reasoning content from a delta message.
58
+ Handles streaming output where previous + delta = current.
59
+ Uses token IDs for faster processing.
60
+ For text <think>abc</think>xyz:
61
+ - 'abc' goes to reasoning_content
62
+ - 'xyz' goes to content
63
+ """
64
+ # Skip single special tokens
65
+ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
66
+ self.think_start_token_id, self.think_end_token_id
67
+ ]):
68
+ return None
69
+
70
+ if self.think_start_token_id in previous_token_ids:
71
+ if self.think_end_token_id in delta_token_ids:
72
+ # <think> in previous, </think> in delta,
73
+ # extract reasoning content
74
+ end_index = delta_text.find(self.think_end_token)
75
+ reasoning_content = delta_text[:end_index]
76
+ content = delta_text[end_index + len(self.think_end_token):]
77
+ return DeltaMessage(reasoning_content=reasoning_content,
78
+ content=content if content else None)
79
+ elif self.think_end_token_id in previous_token_ids:
80
+ # <think> in previous, </think> in previous,
81
+ # reasoning content continues
82
+ return DeltaMessage(content=delta_text)
83
+ else:
84
+ # <think> in previous, no </think> in previous or delta,
85
+ # reasoning content continues
86
+ return DeltaMessage(reasoning_content=delta_text)
87
+ elif self.think_start_token_id in delta_token_ids:
88
+ logger.info(delta_text)
89
+ if self.think_end_token_id in delta_token_ids:
90
+ # <think> in delta, </think> in delta, extract reasoning content
91
+ start_index = delta_text.find(self.think_start_token)
92
+ end_index = delta_text.find(self.think_end_token)
93
+ reasoning_content = delta_text[start_index +
94
+ len(self.think_start_token
95
+ ):end_index]
96
+ content = delta_text[end_index + len(self.think_end_token):]
97
+ return DeltaMessage(reasoning_content=reasoning_content,
98
+ content=content if content else None)
99
+ else:
100
+ # <think> in delta, no </think> in delta,
101
+ # reasoning content continues
102
+ return DeltaMessage(reasoning_content=delta_text)
103
+ else:
104
+ # No <think> in previous or delta, reasoning content continues.
105
+ return DeltaMessage(content=delta_text)
106
+
107
+ def extract_reasoning_content(
108
+ self, model_output: str, request: ChatCompletionRequest
109
+ ) -> Tuple[Optional[str], Optional[str]]:
110
+
111
+ # Check if the model output contains the <think> tokens.
112
+ if (self.think_start_token not in model_output
113
+ or self.think_end_token not in model_output):
114
+ return None, model_output
115
+ else:
116
+ # Use a regex to find the reasoning content
117
+ reasoning_content = self.reasoning_regex.findall(model_output)[0]
118
+
119
+ # Remove the reasoning content from the model output
120
+ # Although deepseek's <think> token is always at the
121
+ # beginning of the line, we cannot guarantee that the
122
+ # other models will follow this convention.
123
+ # Therefore, we need to add :start_index.
124
+ start_index = model_output.find(self.think_start_token)
125
+ if start_index != -1:
126
+ end_index = start_index + len(
127
+ f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
128
+ )
129
+ model_output = model_output[:start_index] + \
130
+ model_output[end_index:]
131
+
132
+ if len(model_output) == 0:
133
+ return reasoning_content, None
134
+
135
+ return reasoning_content, model_output
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from json import JSONDecoder
6
+ from typing import Dict, Sequence, Union
7
+
8
+ import partial_json_parser
9
+ from partial_json_parser.core.options import Allow
10
+
11
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
+ DeltaFunctionCall, DeltaMessage,
13
+ DeltaToolCall,
14
+ ExtractedToolCallInformation,
15
+ FunctionCall, ToolCall)
16
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
17
+ ToolParser, ToolParserManager)
18
+ from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
19
+ find_common_prefix,
20
+ is_complete_json,
21
+ partial_json_loads)
22
+ from vllm.logger import init_logger
23
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
24
+ from vllm.utils import random_uuid
25
+
26
+ logger = init_logger(__name__)
27
+
28
+
29
+ @ToolParserManager.register_module("granite-20b-fc")
30
+ class Granite20bFCToolParser(ToolParser):
31
+ """
32
+ Tool call parser for the granite-20b-functioncalling model intended
33
+ for use with the examples/tool_chat_template_granite20b_fc.jinja
34
+ template.
35
+
36
+ Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
37
+ are all set
38
+ """
39
+
40
+ def __init__(self, tokenizer: AnyTokenizer):
41
+ super().__init__(tokenizer)
42
+
43
+ self.bot_token = "<function_call>"
44
+ self.tool_start_token = self.bot_token
45
+ self.tool_call_regex = re.compile(r"<function_call>\s*")
46
+
47
+ def extract_tool_calls(
48
+ self, model_output: str,
49
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
50
+ if self.tool_start_token not in model_output:
51
+ return ExtractedToolCallInformation(tools_called=False,
52
+ tool_calls=[],
53
+ content=model_output)
54
+
55
+ dec = JSONDecoder()
56
+ try:
57
+ matches = list(self.tool_call_regex.finditer(model_output))
58
+ logger.debug("Found %d tool call matches", len(matches))
59
+
60
+ raw_function_calls = []
61
+
62
+ for i, match in enumerate(matches):
63
+ # position after the <function_call> tag
64
+ start_of_json = match.end()
65
+ # end_index == the start of the next function call
66
+ # (if exists)
67
+ next_function_call_start = (matches[i + 1].start() if i +
68
+ 1 < len(matches) else None)
69
+
70
+ raw_function_calls.append(
71
+ dec.raw_decode(
72
+ model_output[start_of_json:next_function_call_start])
73
+ [0])
74
+
75
+ logger.debug("Extracted %d tool calls", len(raw_function_calls))
76
+ tool_calls = [
77
+ ToolCall(
78
+ type="function",
79
+ function=FunctionCall(
80
+ name=function_call["name"],
81
+ # function call args are JSON but as a string
82
+ arguments=json.dumps(function_call["arguments"]),
83
+ ),
84
+ ) for function_call in raw_function_calls
85
+ ]
86
+
87
+ content = model_output[:model_output.find(self.bot_token)]
88
+ return ExtractedToolCallInformation(
89
+ tools_called=True,
90
+ tool_calls=tool_calls,
91
+ content=content if content else None,
92
+ )
93
+
94
+ except Exception as e:
95
+ logger.error("Error in extracting tool call from response %s", e)
96
+ return ExtractedToolCallInformation(tools_called=False,
97
+ tool_calls=[],
98
+ content=model_output)
99
+
100
+ def extract_tool_calls_streaming(
101
+ self,
102
+ previous_text: str,
103
+ current_text: str,
104
+ delta_text: str,
105
+ previous_token_ids: Sequence[int],
106
+ current_token_ids: Sequence[int],
107
+ delta_token_ids: Sequence[int],
108
+ request: ChatCompletionRequest,
109
+ ) -> Union[DeltaMessage, None]:
110
+
111
+ if len(current_text) < len(
112
+ self.bot_token) and self.bot_token.startswith(current_text):
113
+ return None
114
+
115
+ if not current_text.startswith(self.bot_token):
116
+ return DeltaMessage(content=delta_text)
117
+
118
+ # bit mask flags for partial JSON parsing. If the name hasn't been
119
+ # sent yet, don't allow sending
120
+ # an incomplete string since OpenAI only ever (as far as I have
121
+ # seen) allows sending the entire tool/ function name at once.
122
+ flags = Allow.ALL if self.current_tool_name_sent \
123
+ else Allow.ALL & ~Allow.STR
124
+ try:
125
+ tool_call_arr = []
126
+ is_complete = []
127
+ try:
128
+ start_idx = len(self.bot_token)
129
+ start_idx = consume_space(start_idx, current_text)
130
+
131
+ while start_idx < len(current_text):
132
+ (obj,
133
+ end_idx) = partial_json_loads(current_text[start_idx:],
134
+ flags)
135
+ is_complete.append(
136
+ is_complete_json(current_text[start_idx:start_idx +
137
+ end_idx]))
138
+ start_idx += end_idx
139
+ start_idx = consume_space(start_idx, current_text)
140
+ start_idx += len(self.bot_token)
141
+ start_idx = consume_space(start_idx, current_text)
142
+ tool_call_arr.append(obj)
143
+ except partial_json_parser.core.exceptions.MalformedJSON:
144
+ logger.debug('not enough tokens to parse into JSON yet')
145
+ return None
146
+
147
+ # select as the current tool call the one we're on the state at
148
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
149
+ if len(tool_call_arr) > 0 else {}
150
+
151
+ # case -- if no tokens have been streamed for the tool, e.g.
152
+ # only the array brackets, stream nothing
153
+ if len(tool_call_arr) == 0:
154
+ return None
155
+
156
+ # case: we are starting a new tool in the array
157
+ # -> array has > 0 length AND length has moved past cursor
158
+ elif (len(tool_call_arr) > 0
159
+ and len(tool_call_arr) > self.current_tool_id + 1):
160
+
161
+ # if we're moving on to a new call, first make sure we
162
+ # haven't missed anything in the previous one that was
163
+ # auto-generated due to JSON completions, but wasn't
164
+ # streamed to the client yet.
165
+ if self.current_tool_id >= 0:
166
+ cur_arguments = current_tool_call.get("arguments")
167
+ if cur_arguments:
168
+ cur_args_json = json.dumps(cur_arguments)
169
+ sent = len(
170
+ self.streamed_args_for_tool[self.current_tool_id])
171
+ argument_diff = cur_args_json[sent:]
172
+
173
+ logger.debug("got arguments diff: %s", argument_diff)
174
+ delta = DeltaMessage(tool_calls=[
175
+ DeltaToolCall(index=self.current_tool_id,
176
+ function=DeltaFunctionCall(
177
+ arguments=argument_diff).
178
+ model_dump(exclude_none=True))
179
+ ])
180
+ self.streamed_args_for_tool[
181
+ self.current_tool_id] += argument_diff
182
+ else:
183
+ delta = None
184
+ else:
185
+ delta = None
186
+ # re-set stuff pertaining to progress in the current tool
187
+ self.current_tool_id = len(tool_call_arr) - 1
188
+ self.current_tool_name_sent = False
189
+ self.streamed_args_for_tool.append("")
190
+ logger.debug("starting on new tool %d", self.current_tool_id)
191
+ return delta
192
+
193
+ # if the current tool name hasn't been sent, send if available
194
+ # - otherwise send nothing
195
+ elif not self.current_tool_name_sent:
196
+ function_name = current_tool_call.get("name")
197
+ if function_name:
198
+
199
+ delta = DeltaMessage(tool_calls=[
200
+ DeltaToolCall(index=self.current_tool_id,
201
+ type="function",
202
+ id=f"chatcmpl-tool-{random_uuid()}",
203
+ function=DeltaFunctionCall(
204
+ name=function_name).model_dump(
205
+ exclude_none=True))
206
+ ])
207
+ self.current_tool_name_sent = True
208
+ else:
209
+ delta = None
210
+
211
+ # now we know we're on the same tool call and we're streaming
212
+ # arguments
213
+ else:
214
+ cur_arguments = current_tool_call.get("arguments")
215
+ delta = None
216
+
217
+ if cur_arguments:
218
+ sent = len(
219
+ self.streamed_args_for_tool[self.current_tool_id])
220
+ cur_args_json = json.dumps(cur_arguments)
221
+ prev_arguments = self.prev_tool_call_arr[
222
+ self.current_tool_id].get("arguments")
223
+
224
+ argument_diff = None
225
+ if is_complete[self.current_tool_id]:
226
+ argument_diff = cur_args_json[sent:]
227
+ elif prev_arguments:
228
+ prev_args_json = json.dumps(prev_arguments)
229
+ if cur_args_json != prev_args_json:
230
+
231
+ prefix = find_common_prefix(
232
+ prev_args_json, cur_args_json)
233
+ argument_diff = prefix[sent:]
234
+
235
+ if argument_diff is not None:
236
+ delta = DeltaMessage(tool_calls=[
237
+ DeltaToolCall(index=self.current_tool_id,
238
+ function=DeltaFunctionCall(
239
+ arguments=argument_diff).
240
+ model_dump(exclude_none=True))
241
+ ])
242
+ self.streamed_args_for_tool[
243
+ self.current_tool_id] += argument_diff
244
+
245
+ self.prev_tool_call_arr = tool_call_arr
246
+ return delta
247
+
248
+ except Exception as e:
249
+ logger.error("Error trying to handle streaming tool call: %s", e)
250
+ logger.debug(
251
+ "Skipping chunk as a result of tool streaming extraction "
252
+ "error")
253
+ return None
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ from typing import Dict, Sequence, Union
5
+
6
+ import partial_json_parser
7
+ from partial_json_parser.core.options import Allow
8
+
9
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
10
+ DeltaFunctionCall, DeltaMessage,
11
+ DeltaToolCall,
12
+ ExtractedToolCallInformation,
13
+ FunctionCall, ToolCall)
14
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
15
+ ToolParser, ToolParserManager)
16
+ from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
17
+ find_common_prefix,
18
+ is_complete_json,
19
+ partial_json_loads)
20
+ from vllm.logger import init_logger
21
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
22
+ from vllm.utils import random_uuid
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ @ToolParserManager.register_module("granite")
28
+ class GraniteToolParser(ToolParser):
29
+ """
30
+ Tool call parser for the granite 3.0 models. Intended
31
+ for use with the examples/tool_chat_template_granite.jinja
32
+ template.
33
+
34
+ Used when --enable-auto-tool-choice --tool-call-parser granite
35
+ are all set
36
+ """
37
+
38
+ def __init__(self, tokenizer: AnyTokenizer):
39
+ super().__init__(tokenizer)
40
+ # for granite 3.0, the token `<|tool_call|>`
41
+ self.bot_token = "<|tool_call|>"
42
+ # for granite 3.1, the string `<tool_call>`
43
+ self.bot_string = "<tool_call>"
44
+
45
+ def extract_tool_calls(
46
+ self, model_output: str,
47
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
48
+ stripped = model_output.strip()\
49
+ .removeprefix(self.bot_token)\
50
+ .removeprefix(self.bot_string)\
51
+ .lstrip()
52
+ if not stripped or stripped[0] != '[':
53
+ return ExtractedToolCallInformation(tools_called=False,
54
+ tool_calls=[],
55
+ content=model_output)
56
+ try:
57
+ raw_function_calls = json.loads(stripped)
58
+ if not isinstance(raw_function_calls, list):
59
+ raise Exception(
60
+ f"Expected dict or list, got {type(raw_function_calls)}")
61
+
62
+ logger.debug("Extracted %d tool calls", len(raw_function_calls))
63
+ tool_calls = [
64
+ ToolCall(
65
+ type="function",
66
+ function=FunctionCall(
67
+ name=function_call["name"],
68
+ # function call args are JSON but as a string
69
+ arguments=json.dumps(function_call["arguments"]),
70
+ ),
71
+ ) for function_call in raw_function_calls
72
+ ]
73
+
74
+ return ExtractedToolCallInformation(
75
+ tools_called=True,
76
+ tool_calls=tool_calls,
77
+ content=None,
78
+ )
79
+
80
+ except Exception as e:
81
+ logger.error("Error in extracting tool call from response %s", e)
82
+ return ExtractedToolCallInformation(tools_called=False,
83
+ tool_calls=[],
84
+ content=model_output)
85
+
86
+ def extract_tool_calls_streaming(
87
+ self,
88
+ previous_text: str,
89
+ current_text: str,
90
+ delta_text: str,
91
+ previous_token_ids: Sequence[int],
92
+ current_token_ids: Sequence[int],
93
+ delta_token_ids: Sequence[int],
94
+ request: ChatCompletionRequest,
95
+ ) -> Union[DeltaMessage, None]:
96
+
97
+ start_idx = consume_space(0, current_text)
98
+ if current_text[start_idx:].startswith(self.bot_token):
99
+ start_idx = consume_space(start_idx + len(self.bot_token),
100
+ current_text)
101
+ if current_text[start_idx:].startswith(self.bot_string):
102
+ start_idx = consume_space(start_idx + len(self.bot_string),
103
+ current_text)
104
+ if not current_text or start_idx >= len(current_text)\
105
+ or current_text[start_idx] != '[':
106
+ return DeltaMessage(content=delta_text)
107
+
108
+ # bit mask flags for partial JSON parsing. If the name hasn't been
109
+ # sent yet, don't allow sending
110
+ # an incomplete string since OpenAI only ever (as far as I have
111
+ # seen) allows sending the entire tool/ function name at once.
112
+ flags = Allow.ALL if self.current_tool_name_sent \
113
+ else Allow.ALL & ~Allow.STR
114
+ try:
115
+ tool_call_arr = None
116
+ is_complete = None
117
+ try:
118
+ tool_calls, end_idx = partial_json_loads(
119
+ current_text[start_idx:], flags)
120
+ if type(tool_calls) is list:
121
+ tool_call_arr = tool_calls
122
+ else:
123
+ return DeltaMessage(content=delta_text)
124
+
125
+ is_complete = [True] * len(tool_calls)
126
+ if not is_complete_json(
127
+ current_text[start_idx:start_idx + end_idx]):
128
+ is_complete[-1] = False
129
+ except partial_json_parser.core.exceptions.MalformedJSON:
130
+ logger.debug('not enough tokens to parse into JSON yet')
131
+ return None
132
+
133
+ # case -- if no tokens have been streamed for the tool, e.g.
134
+ # only the array brackets, stream nothing
135
+ if not tool_call_arr:
136
+ return None
137
+
138
+ # select as the current tool call the one we're on the state at
139
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id]
140
+
141
+ delta = None
142
+ # case: we are starting a new tool in the array
143
+ # -> array has > 0 length AND length has moved past cursor
144
+ if len(tool_call_arr) > self.current_tool_id + 1:
145
+
146
+ # if we're moving on to a new call, first make sure we
147
+ # haven't missed anything in the previous one that was
148
+ # auto-generated due to JSON completions, but wasn't
149
+ # streamed to the client yet.
150
+ if self.current_tool_id >= 0:
151
+ cur_arguments = current_tool_call.get("arguments")
152
+ if cur_arguments:
153
+ cur_args_json = json.dumps(cur_arguments)
154
+ sent = len(
155
+ self.streamed_args_for_tool[self.current_tool_id])
156
+ argument_diff = cur_args_json[sent:]
157
+
158
+ logger.debug("got arguments diff: %s", argument_diff)
159
+ delta = DeltaMessage(tool_calls=[
160
+ DeltaToolCall(index=self.current_tool_id,
161
+ function=DeltaFunctionCall(
162
+ arguments=argument_diff).
163
+ model_dump(exclude_none=True))
164
+ ])
165
+ self.streamed_args_for_tool[
166
+ self.current_tool_id] += argument_diff
167
+
168
+ # re-set stuff pertaining to progress in the current tool
169
+ self.current_tool_id = len(tool_call_arr) - 1
170
+ self.current_tool_name_sent = False
171
+ self.streamed_args_for_tool.append("")
172
+ logger.debug("starting on new tool %d", self.current_tool_id)
173
+ return delta
174
+
175
+ # if the current tool name hasn't been sent, send if available
176
+ # - otherwise send nothing
177
+ elif not self.current_tool_name_sent:
178
+ function_name = current_tool_call.get("name")
179
+ if function_name:
180
+
181
+ delta = DeltaMessage(tool_calls=[
182
+ DeltaToolCall(index=self.current_tool_id,
183
+ type="function",
184
+ id=f"chatcmpl-tool-{random_uuid()}",
185
+ function=DeltaFunctionCall(
186
+ name=function_name).model_dump(
187
+ exclude_none=True))
188
+ ])
189
+ self.current_tool_name_sent = True
190
+
191
+ # now we know we're on the same tool call and we're streaming
192
+ # arguments
193
+ else:
194
+ cur_arguments = current_tool_call.get("arguments")
195
+
196
+ if cur_arguments:
197
+ sent = len(
198
+ self.streamed_args_for_tool[self.current_tool_id])
199
+ cur_args_json = json.dumps(cur_arguments)
200
+ prev_arguments = self.prev_tool_call_arr[
201
+ self.current_tool_id].get("arguments")
202
+
203
+ argument_diff = None
204
+ if is_complete[self.current_tool_id]:
205
+ argument_diff = cur_args_json[sent:]
206
+ elif prev_arguments:
207
+ prev_args_json = json.dumps(prev_arguments)
208
+ if cur_args_json != prev_args_json:
209
+ prefix = find_common_prefix(
210
+ prev_args_json, cur_args_json)
211
+ argument_diff = prefix[sent:]
212
+
213
+ if argument_diff is not None:
214
+ delta = DeltaMessage(tool_calls=[
215
+ DeltaToolCall(index=self.current_tool_id,
216
+ function=DeltaFunctionCall(
217
+ arguments=argument_diff).
218
+ model_dump(exclude_none=True))
219
+ ])
220
+ self.streamed_args_for_tool[
221
+ self.current_tool_id] += argument_diff
222
+
223
+ self.prev_tool_call_arr = tool_call_arr
224
+ return delta
225
+
226
+ except Exception as e:
227
+ logger.error("Error trying to handle streaming tool call: %s", e)
228
+ logger.debug(
229
+ "Skipping chunk as a result of tool streaming extraction "
230
+ "error")
231
+ return None
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from typing import Dict, List, Sequence, Union
6
+
7
+ import partial_json_parser
8
+ from partial_json_parser.core.options import Allow
9
+
10
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
11
+ DeltaFunctionCall, DeltaMessage,
12
+ DeltaToolCall,
13
+ ExtractedToolCallInformation,
14
+ FunctionCall, ToolCall)
15
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
16
+ ToolParser, ToolParserManager)
17
+ from vllm.logger import init_logger
18
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
19
+ from vllm.utils import random_uuid
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ @ToolParserManager.register_module("hermes")
25
+ class Hermes2ProToolParser(ToolParser):
26
+
27
+ def __init__(self, tokenizer: AnyTokenizer):
28
+ super().__init__(tokenizer)
29
+
30
+ if isinstance(self.model_tokenizer, MistralTokenizer):
31
+ logger.error(
32
+ "Detected Mistral tokenizer when using a Hermes model")
33
+ self.model_tokenizer = self.model_tokenizer.tokenizer
34
+
35
+ self.current_tool_name_sent: bool = False
36
+ self.prev_tool_call_arr: List[Dict] = []
37
+ self.current_tool_id: int = -1
38
+ self.streamed_args_for_tool: List[str] = [
39
+ ] # map what has been streamed for each tool so far to a list
40
+
41
+ self.tool_call_start_token: str = "<tool_call>"
42
+ self.tool_call_end_token: str = "</tool_call>"
43
+
44
+ self.tool_call_regex = re.compile(
45
+ r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
46
+ self.scratch_pad_regex = re.compile(
47
+ r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
48
+
49
+ if not self.model_tokenizer:
50
+ raise ValueError(
51
+ "The model tokenizer must be passed to the ToolParser "
52
+ "constructor during construction.")
53
+ self.tool_call_start_token_id = self.vocab.get(
54
+ self.tool_call_start_token)
55
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
56
+ if (self.tool_call_start_token_id is None
57
+ or self.tool_call_end_token_id is None):
58
+ raise RuntimeError(
59
+ "Hermes 2 Pro Tool parser could not locate tool call start/end "
60
+ "tokens in the tokenizer!")
61
+
62
+ def extract_tool_calls(
63
+ self,
64
+ model_output: str,
65
+ request: ChatCompletionRequest,
66
+ ) -> ExtractedToolCallInformation:
67
+
68
+ # sanity check; avoid unnecessary processing
69
+ if self.tool_call_start_token not in model_output:
70
+ return ExtractedToolCallInformation(tools_called=False,
71
+ tool_calls=[],
72
+ content=model_output)
73
+
74
+ else:
75
+
76
+ try:
77
+ # there are two possible captures - between tags, or between a
78
+ # tag and end-of-string so the result of
79
+ # findall is an array of tuples where one is a function call and
80
+ # the other is None
81
+ function_call_tuples = (
82
+ self.tool_call_regex.findall(model_output))
83
+
84
+ # load the JSON, and then use it to build the Function and
85
+ # Tool Call
86
+ raw_function_calls = [
87
+ json.loads(match[0] if match[0] else match[1])
88
+ for match in function_call_tuples
89
+ ]
90
+ tool_calls = [
91
+ ToolCall(
92
+ type="function",
93
+ function=FunctionCall(
94
+ name=function_call["name"],
95
+ # function call args are JSON but as a string
96
+ arguments=json.dumps(function_call["arguments"],
97
+ ensure_ascii=False)))
98
+ for function_call in raw_function_calls
99
+ ]
100
+
101
+ content = model_output[:model_output.
102
+ find(self.tool_call_start_token)]
103
+ return ExtractedToolCallInformation(
104
+ tools_called=True,
105
+ tool_calls=tool_calls,
106
+ content=content if content else None)
107
+
108
+ except Exception:
109
+ logger.exception(
110
+ "Error in extracting tool call from response.")
111
+ return ExtractedToolCallInformation(tools_called=False,
112
+ tool_calls=[],
113
+ content=model_output)
114
+
115
+ def extract_tool_calls_streaming(
116
+ self,
117
+ previous_text: str,
118
+ current_text: str,
119
+ delta_text: str,
120
+ previous_token_ids: Sequence[int],
121
+ current_token_ids: Sequence[int],
122
+ delta_token_ids: Sequence[int],
123
+ request: ChatCompletionRequest,
124
+ ) -> Union[DeltaMessage, None]:
125
+
126
+ logger.debug("delta_text: %s", delta_text)
127
+ logger.debug("delta_token_ids: %s", delta_token_ids)
128
+ # check to see if we should be streaming a tool call - is there a
129
+ if self.tool_call_start_token_id not in current_token_ids:
130
+ logger.debug("No tool call tokens found!")
131
+ return DeltaMessage(content=delta_text)
132
+
133
+ try:
134
+
135
+ # figure out where we are in the parsing by counting tool call
136
+ # start & end tags
137
+ prev_tool_start_count = previous_token_ids.count(
138
+ self.tool_call_start_token_id)
139
+ prev_tool_end_count = previous_token_ids.count(
140
+ self.tool_call_end_token_id)
141
+ cur_tool_start_count = current_token_ids.count(
142
+ self.tool_call_start_token_id)
143
+ cur_tool_end_count = current_token_ids.count(
144
+ self.tool_call_end_token_id)
145
+ tool_call_portion = None
146
+ text_portion = None
147
+
148
+ # case: if we're generating text, OR rounding out a tool call
149
+ if (cur_tool_start_count == cur_tool_end_count
150
+ and prev_tool_end_count == cur_tool_end_count
151
+ and self.tool_call_end_token not in delta_text):
152
+ logger.debug("Generating text content! skipping tool parsing.")
153
+ return DeltaMessage(content=delta_text)
154
+
155
+ if self.tool_call_end_token in delta_text:
156
+ logger.debug("tool_call_end_token in delta_text")
157
+ full_text = current_text + delta_text
158
+ tool_call_portion = full_text.split(
159
+ self.tool_call_start_token)[-1].split(
160
+ self.tool_call_end_token)[0].rstrip()
161
+ delta_text = delta_text.split(
162
+ self.tool_call_end_token)[0].rstrip()
163
+ text_portion = delta_text.split(
164
+ self.tool_call_end_token)[-1].lstrip()
165
+
166
+ # case: if tool open & close tag counts don't match, we're doing
167
+ # imaginary "else" block here
168
+ # something with tools with this diff.
169
+ # flags for partial JSON parting. exported constants from
170
+ # "Allow" are handled via BIT MASK
171
+ flags = Allow.ALL if self.current_tool_name_sent \
172
+ else Allow.ALL & ~Allow.STR
173
+
174
+ # case -- we're starting a new tool call
175
+ if (cur_tool_start_count > cur_tool_end_count
176
+ and cur_tool_start_count > prev_tool_start_count):
177
+ if len(delta_token_ids) > 1:
178
+ tool_call_portion = current_text.split(
179
+ self.tool_call_start_token)[-1]
180
+ else:
181
+ tool_call_portion = None
182
+ delta = None
183
+
184
+ text_portion = None
185
+
186
+ # set cursors and state appropriately
187
+ self.current_tool_id += 1
188
+ self.current_tool_name_sent = False
189
+ self.streamed_args_for_tool.append("")
190
+ logger.debug("Starting on a new tool %s", self.current_tool_id)
191
+
192
+ # case -- we're updating an existing tool call
193
+ elif (cur_tool_start_count > cur_tool_end_count
194
+ and cur_tool_start_count == prev_tool_start_count):
195
+
196
+ # get the portion of the text that's the tool call
197
+ tool_call_portion = current_text.split(
198
+ self.tool_call_start_token)[-1]
199
+ text_portion = None
200
+
201
+ # case -- the current tool call is being closed.
202
+ elif (cur_tool_start_count == cur_tool_end_count
203
+ and cur_tool_end_count >= prev_tool_end_count):
204
+ if (self.prev_tool_call_arr is None
205
+ or len(self.prev_tool_call_arr) == 0):
206
+ logger.debug(
207
+ "attempting to close tool call, but no tool call")
208
+ return None
209
+ diff = self.prev_tool_call_arr[self.current_tool_id].get(
210
+ "arguments")
211
+ if diff:
212
+ diff = diff.encode('utf-8').decode(
213
+ 'unicode_escape') if diff is str else diff
214
+ if ('"}' not in delta_text):
215
+ return None
216
+ end_loc = delta_text.rindex('"}')
217
+ diff = delta_text[:end_loc] + '"}'
218
+ logger.debug(
219
+ "Finishing tool and found diff that had not "
220
+ "been streamed yet: %s", diff)
221
+ self.streamed_args_for_tool[self.current_tool_id] \
222
+ += diff
223
+ return DeltaMessage(tool_calls=[
224
+ DeltaToolCall(index=self.current_tool_id,
225
+ function=DeltaFunctionCall(
226
+ arguments=diff).model_dump(
227
+ exclude_none=True))
228
+ ])
229
+
230
+ # case -- otherwise we're just generating text
231
+ else:
232
+ text = delta_text.replace(self.tool_call_start_token, "")
233
+ text = text.replace(self.tool_call_end_token, "")
234
+ delta = DeltaMessage(tool_calls=[], content=text)
235
+ return delta
236
+
237
+ try:
238
+
239
+ current_tool_call = partial_json_parser.loads(
240
+ tool_call_portion or "{}",
241
+ flags) if tool_call_portion else None
242
+ logger.debug("Parsed tool call %s", current_tool_call)
243
+ except partial_json_parser.core.exceptions.MalformedJSON:
244
+ logger.debug('not enough tokens to parse into JSON yet')
245
+ return None
246
+ except json.decoder.JSONDecodeError:
247
+ logger.debug("unable to parse JSON")
248
+ return None
249
+
250
+ # case - we haven't sent the tool name yet. If it's available, send
251
+ # it. otherwise, wait until it's available.
252
+ if not self.current_tool_name_sent:
253
+ if (current_tool_call is None):
254
+ return None
255
+ function_name: Union[str, None] = current_tool_call.get("name")
256
+ if function_name:
257
+ self.current_tool_name_sent = True
258
+ return DeltaMessage(tool_calls=[
259
+ DeltaToolCall(index=self.current_tool_id,
260
+ type="function",
261
+ id=f"chatcmpl-tool-{random_uuid()}",
262
+ function=DeltaFunctionCall(
263
+ name=function_name).model_dump(
264
+ exclude_none=True))
265
+ ])
266
+ else:
267
+ return None
268
+ # case -- otherwise, send the tool call delta
269
+
270
+ # if the tool call portion is None, send the delta as text
271
+ if tool_call_portion is None:
272
+ # if there's text but not tool calls, send that -
273
+ # otherwise None to skip chunk
274
+ delta = DeltaMessage(content=delta_text) \
275
+ if text_portion is not None else None
276
+ return delta
277
+
278
+ # now, the nitty-gritty of tool calls
279
+ # now we have the portion to parse as tool call.
280
+
281
+ logger.debug("Trying to parse current tool call with ID %s",
282
+ self.current_tool_id)
283
+
284
+ # if we're starting a new tool call, push an empty object in as
285
+ # a placeholder for the arguments
286
+ if len(self.prev_tool_call_arr) <= self.current_tool_id:
287
+ self.prev_tool_call_arr.append({})
288
+
289
+ # main logic for tool parsing here - compare prev. partially-parsed
290
+ # JSON to the current partially-parsed JSON
291
+ prev_arguments = (
292
+ self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
293
+ cur_arguments = current_tool_call.get("arguments")
294
+
295
+ logger.debug("diffing old arguments: %s", prev_arguments)
296
+ logger.debug("against new ones: %s", cur_arguments)
297
+
298
+ # case -- no arguments have been created yet. skip sending a delta.
299
+ if not cur_arguments and not prev_arguments:
300
+ logger.debug("Skipping text %s - no arguments", delta_text)
301
+ delta = None
302
+
303
+ # case -- prev arguments are defined, but non are now.
304
+ # probably impossible, but not a fatal error - just keep going
305
+ elif not cur_arguments and prev_arguments:
306
+ logger.error("should be impossible to have arguments reset "
307
+ "mid-call. skipping streaming anything.")
308
+ delta = None
309
+
310
+ # case -- we now have the first info about arguments available from
311
+ # autocompleting the JSON
312
+ elif cur_arguments and not prev_arguments:
313
+
314
+ cur_arguments_json = json.dumps(cur_arguments,
315
+ ensure_ascii=False)
316
+ logger.debug("finding %s in %s", delta_text,
317
+ cur_arguments_json)
318
+
319
+ # get the location where previous args differ from current
320
+ if (delta_text not in cur_arguments_json[:-2]):
321
+ return None
322
+ args_delta_start_loc = cur_arguments_json[:-2]. \
323
+ rindex(delta_text) + \
324
+ len(delta_text)
325
+
326
+ # use that to find the actual delta
327
+ arguments_delta = cur_arguments_json[:args_delta_start_loc]
328
+ logger.debug("First tokens in arguments received: %s",
329
+ arguments_delta)
330
+
331
+ delta = DeltaMessage(tool_calls=[
332
+ DeltaToolCall(index=self.current_tool_id,
333
+ function=DeltaFunctionCall(
334
+ arguments=arguments_delta).model_dump(
335
+ exclude_none=True))
336
+ ])
337
+ self.streamed_args_for_tool[self.current_tool_id] \
338
+ += arguments_delta
339
+
340
+ # last case -- we have an update to existing arguments.
341
+ elif cur_arguments and prev_arguments:
342
+ if isinstance(delta_text, str) and len(delta_text.rstrip(
343
+ )) >= 1 and delta_text.rstrip()[-1] == '}':
344
+ delta_text = delta_text.rstrip()[:-1]
345
+
346
+ logger.debug("got diff %s", delta_text)
347
+
348
+ delta = DeltaMessage(tool_calls=[
349
+ DeltaToolCall(index=self.current_tool_id,
350
+ function=DeltaFunctionCall(
351
+ arguments=delta_text).model_dump(
352
+ exclude_none=True))
353
+ ])
354
+ self.streamed_args_for_tool[self.current_tool_id] \
355
+ += delta_text
356
+
357
+ # handle saving the state for the current tool into
358
+ # the "prev" list for use in diffing for the next iteration
359
+ if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
360
+ self.prev_tool_call_arr[self.current_tool_id] = \
361
+ current_tool_call
362
+ else:
363
+ self.prev_tool_call_arr.append(current_tool_call)
364
+
365
+ return delta
366
+
367
+ except Exception:
368
+ logger.exception("Error trying to handle streaming tool call.")
369
+ return None # do not stream a delta. skip this token ID.
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ from typing import Dict, Sequence, Union
5
+
6
+ import partial_json_parser
7
+ from partial_json_parser.core.options import Allow
8
+
9
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
10
+ DeltaFunctionCall, DeltaMessage,
11
+ DeltaToolCall,
12
+ ExtractedToolCallInformation,
13
+ FunctionCall, ToolCall)
14
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
15
+ ToolParser, ToolParserManager)
16
+ from vllm.entrypoints.openai.tool_parsers.utils import (
17
+ extract_intermediate_diff)
18
+ from vllm.logger import init_logger
19
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
20
+ from vllm.utils import random_uuid
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ @ToolParserManager.register_module(["internlm"])
26
+ class Internlm2ToolParser(ToolParser):
27
+
28
+ def __init__(self, tokenizer: AnyTokenizer):
29
+ super().__init__(tokenizer)
30
+ self.position = 0
31
+
32
+ def adjust_request(
33
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
34
+ if request.tools and request.tool_choice != 'none':
35
+ # do not skip special tokens because internlm use the special
36
+ # tokens to indicated the start and end of the tool calls
37
+ # information.
38
+ request.skip_special_tokens = False
39
+ return request
40
+
41
+ def get_argments(self, obj):
42
+ if "parameters" in obj:
43
+ return obj.get("parameters")
44
+ elif "arguments" in obj:
45
+ return obj.get("arguments")
46
+ return None
47
+
48
+ def extract_tool_calls_streaming(
49
+ self,
50
+ previous_text: str,
51
+ current_text: str,
52
+ delta_text: str,
53
+ previous_token_ids: Sequence[int],
54
+ current_token_ids: Sequence[int],
55
+ delta_token_ids: Sequence[int],
56
+ request: ChatCompletionRequest,
57
+ ) -> Union[DeltaMessage, None]:
58
+ if '<|action_start|>' not in current_text:
59
+ self.position = len(current_text)
60
+ return DeltaMessage(content=delta_text)
61
+ # if the tool call is sended, return a empty delta message
62
+ # to make sure the finish_reason will be send correctly.
63
+ if self.current_tool_id > 0:
64
+ return DeltaMessage(content='')
65
+
66
+ last_pos = self.position
67
+ if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
68
+ return None
69
+
70
+ new_delta = current_text[last_pos:]
71
+ text, action = new_delta.split('<|action_start|><|plugin|>')
72
+
73
+ if len(text) > 0:
74
+ self.position = self.position + len(text)
75
+ return DeltaMessage(content=text)
76
+
77
+ action = action.strip()
78
+ action = action.split('<|action_end|>'.strip())[0]
79
+
80
+ # bit mask flags for partial JSON parsing. If the name hasn't been
81
+ # sent yet, don't allow sending
82
+ # an incomplete string since OpenAI only ever (as far as I have
83
+ # seen) allows sending the entire tool/ function name at once.
84
+ flags = Allow.ALL if self.current_tool_name_sent \
85
+ else Allow.ALL & ~Allow.STR
86
+
87
+ try:
88
+ parsable_arr = action
89
+
90
+ # tool calls are generated in an object in inernlm2
91
+ # it's not support parallel tool calls
92
+ try:
93
+ tool_call_arr: Dict = partial_json_parser.loads(
94
+ parsable_arr, flags)
95
+ except partial_json_parser.core.exceptions.MalformedJSON:
96
+ logger.debug('not enough tokens to parse into JSON yet')
97
+ return None
98
+
99
+ # if the current tool name hasn't been sent, send if available
100
+ # - otherwise send nothing
101
+ if not self.current_tool_name_sent:
102
+ function_name = tool_call_arr.get("name")
103
+ if function_name:
104
+ self.current_tool_id = self.current_tool_id + 1
105
+ delta = DeltaMessage(tool_calls=[
106
+ DeltaToolCall(index=self.current_tool_id,
107
+ type="function",
108
+ id=f"chatcmpl-tool-{random_uuid()}",
109
+ function=DeltaFunctionCall(
110
+ name=function_name).model_dump(
111
+ exclude_none=True))
112
+ ])
113
+ self.current_tool_name_sent = True
114
+ self.streamed_args_for_tool.append("")
115
+ else:
116
+ delta = None
117
+ # now we know we're on the same tool call and we're streaming
118
+ # arguments
119
+ else:
120
+ prev_arguments = self.get_argments(
121
+ self.prev_tool_call_arr[self.current_tool_id])
122
+ cur_arguments = self.get_argments(tool_call_arr)
123
+
124
+ # not arguments generated
125
+ if not cur_arguments and not prev_arguments:
126
+ delta = None
127
+ # will never happen
128
+ elif not cur_arguments and prev_arguments:
129
+ logger.error(
130
+ "INVARIANT - impossible to have arguments reset "
131
+ "mid-arguments")
132
+ delta = None
133
+ # first time to get parameters
134
+ elif cur_arguments and not prev_arguments:
135
+ cur_arguments_json = json.dumps(cur_arguments)
136
+
137
+ arguments_delta = cur_arguments_json[:cur_arguments_json.
138
+ index(delta_text) +
139
+ len(delta_text)]
140
+ delta = DeltaMessage(tool_calls=[
141
+ DeltaToolCall(index=self.current_tool_id,
142
+ function=DeltaFunctionCall(
143
+ arguments=arguments_delta).
144
+ model_dump(exclude_none=True))
145
+ ])
146
+ self.streamed_args_for_tool[
147
+ self.current_tool_id] += arguments_delta
148
+ # both prev and cur parameters, send the increase parameters
149
+ elif cur_arguments and prev_arguments:
150
+ cur_args_json = json.dumps(cur_arguments)
151
+ prev_args_json = json.dumps(prev_arguments)
152
+
153
+ argument_diff = extract_intermediate_diff(
154
+ cur_args_json, prev_args_json)
155
+
156
+ delta = DeltaMessage(tool_calls=[
157
+ DeltaToolCall(index=self.current_tool_id,
158
+ function=DeltaFunctionCall(
159
+ arguments=argument_diff).model_dump(
160
+ exclude_none=True))
161
+ ])
162
+ self.streamed_args_for_tool[
163
+ self.current_tool_id] += argument_diff
164
+
165
+ # check to see if the name is defined and has been sent. if so,
166
+ # stream the name - otherwise keep waiting
167
+ # finish by setting old and returning None as base case
168
+ tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
169
+ self.prev_tool_call_arr = [tool_call_arr]
170
+ return delta
171
+ except Exception:
172
+ logger.exception("Error trying to handle streaming tool call.")
173
+ logger.debug(
174
+ "Skipping chunk as a result of tool streaming extraction "
175
+ "error")
176
+ return None
177
+
178
+ def extract_tool_calls(
179
+ self,
180
+ model_output: str,
181
+ request: ChatCompletionRequest,
182
+ ) -> ExtractedToolCallInformation:
183
+ text = model_output
184
+ tools = request.tools
185
+ if '<|action_start|><|plugin|>' in text:
186
+ text, action = text.split('<|action_start|><|plugin|>')
187
+ action = action.split('<|action_end|>'.strip())[0]
188
+ action = action[action.find('{'):]
189
+ action_dict = json.loads(action)
190
+ name, parameters = action_dict['name'], json.dumps(
191
+ action_dict.get('parameters', action_dict.get('arguments',
192
+ {})))
193
+
194
+ if not tools or name not in [t.function.name for t in tools]:
195
+ ExtractedToolCallInformation(tools_called=False,
196
+ tool_calls=[],
197
+ content=text)
198
+
199
+ tool_calls = [
200
+ ToolCall(
201
+ function=FunctionCall(name=name, arguments=parameters))
202
+ ]
203
+ return ExtractedToolCallInformation(
204
+ tools_called=True,
205
+ tool_calls=tool_calls,
206
+ content=text if len(text) > 0 else None)
207
+
208
+ return ExtractedToolCallInformation(tools_called=False,
209
+ tool_calls=[],
210
+ content=text)
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import ast
4
+ import json
5
+ import re
6
+ from typing import Any, Sequence, Tuple, Union
7
+
8
+ from transformers import PreTrainedTokenizerBase
9
+
10
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
11
+ DeltaFunctionCall, DeltaMessage,
12
+ DeltaToolCall,
13
+ ExtractedToolCallInformation,
14
+ FunctionCall, ToolCall)
15
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
16
+ ToolParser, ToolParserManager)
17
+ from vllm.logger import init_logger
18
+
19
+ logger = init_logger(__name__)
20
+
21
+
22
+ class _UnexpectedAstError(Exception):
23
+ pass
24
+
25
+
26
+ @ToolParserManager.register_module("pythonic")
27
+ class PythonicToolParser(ToolParser):
28
+ """
29
+ Tool call parser for models that produce tool calls in a pythonic style,
30
+ such as Llama 3.2 models.
31
+
32
+ Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
33
+ """
34
+ # TODO(mdepinet): Possible future improvements:
35
+ # 1. Support text + tools separated by either <|python_tag|> or \n\n
36
+ # 2. Support tools outside of a list (or separated by a semicolon).
37
+ # This depends on item 1 for consistent streaming.
38
+ # Neither of these are necessary for e.g. ToolACE, but both would help make
39
+ # Llama3.2 models more reliable.
40
+
41
+ TOOL_CALL_REGEX = re.compile(
42
+ r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
43
+ re.DOTALL)
44
+
45
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
46
+ super().__init__(tokenizer)
47
+
48
+ # Rename for readability. This is NOT a tool id.
49
+ @property
50
+ def current_tool_index(self) -> int:
51
+ return self.current_tool_id
52
+
53
+ @current_tool_index.setter
54
+ def current_tool_index(self, value: int) -> None:
55
+ self.current_tool_id = value
56
+
57
+ def extract_tool_calls(
58
+ self, model_output: str,
59
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
60
+ """
61
+ Extract the tool calls from a complete model response.
62
+ """
63
+
64
+ if not (self.TOOL_CALL_REGEX.match(model_output)):
65
+ return ExtractedToolCallInformation(tools_called=False,
66
+ tool_calls=[],
67
+ content=model_output)
68
+
69
+ try:
70
+ module = ast.parse(model_output)
71
+ parsed = getattr(module.body[0], "value", None)
72
+ if isinstance(parsed, ast.List) and all(
73
+ isinstance(e, ast.Call) for e in parsed.elts):
74
+ return ExtractedToolCallInformation(
75
+ tools_called=True,
76
+ tool_calls=[
77
+ _handle_single_tool(e) # type: ignore
78
+ for e in parsed.elts
79
+ ],
80
+ content=None)
81
+ else:
82
+ raise _UnexpectedAstError(
83
+ "Tool output must be a list of function calls")
84
+ except Exception:
85
+ logger.exception("Error in extracting tool call from response.")
86
+ # Treat as regular text
87
+ return ExtractedToolCallInformation(tools_called=False,
88
+ tool_calls=[],
89
+ content=model_output)
90
+
91
+ def extract_tool_calls_streaming(
92
+ self,
93
+ previous_text: str,
94
+ current_text: str,
95
+ delta_text: str,
96
+ previous_token_ids: Sequence[int],
97
+ current_token_ids: Sequence[int],
98
+ delta_token_ids: Sequence[int],
99
+ request: ChatCompletionRequest,
100
+ ) -> Union[DeltaMessage, None]:
101
+
102
+ if not current_text.startswith("["):
103
+ return DeltaMessage(content=delta_text)
104
+
105
+ try:
106
+ valid_and_added_text = _make_valid_python(current_text)
107
+ if valid_and_added_text is None:
108
+ return None
109
+ valid_text, added_text = valid_and_added_text
110
+
111
+ module = ast.parse(valid_text)
112
+ parsed = getattr(module.body[0], "value", None)
113
+ if not isinstance(parsed, ast.List) or not all(
114
+ isinstance(e, ast.Call) for e in parsed.elts):
115
+ raise _UnexpectedAstError(
116
+ "Tool output must be a list of function calls")
117
+ tool_calls = [
118
+ _handle_single_tool(e) # type: ignore
119
+ for e in parsed.elts
120
+ ]
121
+
122
+ tool_deltas = []
123
+ for index, new_call in enumerate(tool_calls):
124
+ if index < self.current_tool_index:
125
+ continue
126
+
127
+ self.current_tool_index = index
128
+ if len(self.streamed_args_for_tool) == index:
129
+ self.streamed_args_for_tool.append("")
130
+
131
+ new_call_complete = index < len(
132
+ tool_calls) - 1 or ")]" not in added_text
133
+ if new_call_complete:
134
+ self.current_tool_index += 1
135
+
136
+ withheld_suffix = (added_text[:-2]
137
+ if not new_call_complete else "")
138
+ if not new_call_complete and added_text[-2] == ")":
139
+ # Function call is incomplete. Withhold the closing bracket.
140
+ withheld_suffix = withheld_suffix + "}"
141
+ # Strings get single quotes in the model-produced string.
142
+ # JSON requires double quotes.
143
+ withheld_suffix = withheld_suffix.replace("'", '"')
144
+ delta = _compute_tool_delta(self.streamed_args_for_tool[index],
145
+ new_call, index, withheld_suffix)
146
+
147
+ if delta is not None:
148
+ tool_deltas.append(delta)
149
+ if (delta.function is not None
150
+ and delta.function.arguments is not None):
151
+ self.streamed_args_for_tool[
152
+ index] += delta.function.arguments
153
+
154
+ # HACK: serving_chat.py inspects the internal state of tool parsers
155
+ # when determining it's final streaming delta, automatically
156
+ # adding autocompleted JSON.
157
+ # These two lines avoid that nonsense while ensuring finish_reason
158
+ # is set to tool_calls when at least one tool is called.
159
+ if tool_deltas and not self.prev_tool_call_arr:
160
+ self.prev_tool_call_arr = [{"arguments": {}}]
161
+
162
+ if tool_deltas:
163
+ return DeltaMessage(tool_calls=tool_deltas)
164
+ elif not added_text and self.current_tool_id > 0:
165
+ # Return an empty DeltaMessage once the tool calls are all done
166
+ # so that finish_reason gets set.
167
+ return DeltaMessage(content='')
168
+ else:
169
+ return None
170
+ except Exception:
171
+ logger.exception("Error trying to handle streaming tool call.")
172
+ logger.debug(
173
+ "Skipping chunk as a result of tool streaming extraction "
174
+ "error")
175
+ return None
176
+
177
+
178
+ def _get_parameter_value(val: ast.expr) -> Any:
179
+ if isinstance(val, ast.Constant):
180
+ return val.value
181
+ elif isinstance(val, ast.Dict):
182
+ if not all(isinstance(k, ast.Constant) for k in val.keys):
183
+ raise _UnexpectedAstError(
184
+ "Dict tool call arguments must have literal keys")
185
+ return {
186
+ k.value: _get_parameter_value(v) # type: ignore
187
+ for k, v in zip(val.keys, val.values)
188
+ }
189
+ elif isinstance(val, ast.List):
190
+ return [_get_parameter_value(v) for v in val.elts]
191
+ else:
192
+ raise _UnexpectedAstError("Tool call arguments must be literals")
193
+
194
+
195
+ def _handle_single_tool(call: ast.Call) -> ToolCall:
196
+ if not isinstance(call.func, ast.Name):
197
+ raise _UnexpectedAstError("Invalid tool call name")
198
+ function_name = call.func.id
199
+ arguments = {}
200
+ for keyword in call.keywords:
201
+ arguments[keyword.arg] = _get_parameter_value(keyword.value)
202
+ return ToolCall(type="function",
203
+ function=FunctionCall(name=function_name,
204
+ arguments=json.dumps(arguments)))
205
+
206
+
207
+ def _make_valid_python(text: str) -> Union[Tuple[str, str], None]:
208
+ bracket_stack = []
209
+ for index, char in enumerate(text):
210
+ if char in {"[", "(", "{"}:
211
+ bracket_stack.append(char)
212
+ elif char == "]":
213
+ if not bracket_stack or bracket_stack.pop() != "[":
214
+ raise _UnexpectedAstError("Mismatched square brackets")
215
+ elif char == ")":
216
+ if not bracket_stack or bracket_stack.pop() != "(":
217
+ raise _UnexpectedAstError("Mismatched parentheses")
218
+ elif char == "}":
219
+ if not bracket_stack or bracket_stack.pop() != "{":
220
+ raise _UnexpectedAstError("Mismatched curly braces")
221
+ elif char in {"'", '"'}:
222
+ if bracket_stack and bracket_stack[-1] == char:
223
+ if index > 0 and text[index - 1] == "\\":
224
+ # Treat an escaped quote as a regular character
225
+ pass
226
+ else:
227
+ bracket_stack.pop()
228
+ elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
229
+ # Double quote within a single quote string or vice versa.
230
+ pass
231
+ else:
232
+ bracket_stack.append(char)
233
+
234
+ text = text.rstrip()
235
+ if text.endswith("=") or text.endswith(":"):
236
+ # Since we have no type information for this property/parameter value,
237
+ # we can't fill in a valid value.
238
+ return None
239
+ if bracket_stack and bracket_stack[-1] == "{":
240
+ trailing_dict_text = text[:text.rfind("{")]
241
+ num_keys = trailing_dict_text.count(":")
242
+ num_values = trailing_dict_text.count(",")
243
+ if num_keys <= num_values:
244
+ return None # Incomplete property name within parameter value
245
+ if bracket_stack and bracket_stack[-1] == "(":
246
+ trailing_params_text = text[:text.rfind("(")]
247
+ num_full_param_names = trailing_params_text.count("=")
248
+ num_full_param_values = trailing_params_text.count(",")
249
+ if num_full_param_names <= num_full_param_values:
250
+ return None # Incomplete parameter name
251
+ if text.endswith(","):
252
+ text = text[:-1]
253
+ if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
254
+ "[") and not text.endswith(")"):
255
+ return None # Incomplete function name
256
+
257
+ added_text = ""
258
+ for char in reversed(bracket_stack):
259
+ if char == "[":
260
+ added_text += "]"
261
+ elif char == "(":
262
+ added_text += ")"
263
+ elif char == "{":
264
+ added_text += "}"
265
+ elif char == "'":
266
+ added_text += "'"
267
+ elif char == '"':
268
+ added_text += '"'
269
+
270
+ return text + added_text, added_text
271
+
272
+
273
+ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
274
+ index: int,
275
+ withheld_suffix: str) -> Union[DeltaToolCall, None]:
276
+ new_call_args = new_call.function.arguments
277
+ if withheld_suffix:
278
+ assert new_call_args.endswith(withheld_suffix)
279
+ new_call_args = new_call_args[:-len(withheld_suffix)]
280
+ if not previously_sent_args:
281
+ return DeltaToolCall(id=new_call.id,
282
+ index=index,
283
+ function=DeltaFunctionCall(
284
+ name=new_call.function.name,
285
+ arguments=new_call_args,
286
+ ))
287
+
288
+ arg_diff = new_call_args[len(previously_sent_args):]
289
+ return DeltaToolCall(
290
+ id="", index=index, function=DeltaFunctionCall(
291
+ arguments=arg_diff)) if arg_diff else None
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
4
+ from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
5
+
6
+ __all__ = [
7
+ "PunicaWrapperBase",
8
+ "get_punica_wrapper",
9
+ ]
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (431 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc ADDED
Binary file (1.25 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc ADDED
Binary file (7.45 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Based on:
4
+ Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
5
+ Punica: Multi-Tenant LoRA Serving.
6
+ https://arxiv.org/abs/2310.18547
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
11
+
12
+ import torch
13
+
14
+ from .utils import compute_meta, convert_mapping
15
+
16
+ if TYPE_CHECKING:
17
+ # avoid circuit import
18
+ from vllm.lora.layers import LoRAMapping
19
+ from vllm.lora.models import LongContextLoRAContext
20
+
21
+
22
+ class PunicaWrapperABC(ABC):
23
+ """
24
+ PunicaWrapper ABC.
25
+ """
26
+
27
+ @abstractmethod
28
+ def update_metadata(
29
+ self,
30
+ mapping: "LoRAMapping",
31
+ lora_index_to_id: List[Optional[int]],
32
+ max_loras: int,
33
+ vocab_size: int,
34
+ extra_vocab_size: int,
35
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
36
+ **kwargs,
37
+ ) -> None:
38
+ """
39
+ Update the lora-related metadata
40
+ """
41
+ raise NotImplementedError
42
+
43
+ @abstractmethod
44
+ def add_shrink(
45
+ self,
46
+ y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
47
+ x: torch.Tensor,
48
+ lora_a_stacked: Tuple[torch.Tensor, ...],
49
+ scale: float,
50
+ **kwargs,
51
+ ) -> None:
52
+ """
53
+ Performs GEMM for multiple slices of lora_a.
54
+ """
55
+
56
+ raise NotImplementedError
57
+
58
+ @abstractmethod
59
+ def add_expand(
60
+ self,
61
+ y: torch.Tensor,
62
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
63
+ lora_b_stacked: Tuple[torch.Tensor, ...],
64
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
65
+ output_slices: Tuple[int, ...],
66
+ offset_start: int = 0,
67
+ add_inputs=True,
68
+ **kwargs,
69
+ ) -> None:
70
+ """
71
+ Performs GEMM and bias addition for multiple slices of lora_b.
72
+ """
73
+ raise NotImplementedError
74
+
75
+ @abstractmethod
76
+ def add_lora_embedding(
77
+ self,
78
+ y: torch.Tensor,
79
+ x: torch.Tensor,
80
+ lora_b_stacked: torch.Tensor,
81
+ add_inputs: bool = True,
82
+ **kwargs,
83
+ ) -> None:
84
+ """
85
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA,
86
+ and this layer only requires the expand operation.
87
+ """
88
+ raise NotImplementedError
89
+
90
+ @abstractmethod
91
+ def add_lora_linear(self,
92
+ y: torch.Tensor,
93
+ x: torch.Tensor,
94
+ lora_a_stacked: Tuple[torch.Tensor, ...],
95
+ lora_b_stacked: Tuple[torch.Tensor, ...],
96
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
97
+ scale: float,
98
+ output_slices: Tuple[int, ...],
99
+ *,
100
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
101
+ **kwargs) -> None:
102
+ """
103
+ Applicable to linear-related lora.
104
+ """
105
+
106
+ raise NotImplementedError
107
+
108
+ @abstractmethod
109
+ def add_lora_logits(self,
110
+ y: torch.Tensor,
111
+ x: torch.Tensor,
112
+ lora_a_stacked: torch.Tensor,
113
+ lora_b_stacked: torch.Tensor,
114
+ scale,
115
+ *,
116
+ buffer: Optional[torch.Tensor] = None,
117
+ **kwargs) -> None:
118
+ """
119
+ Applies lora specifically for LogitsProcessorWithLoRA.
120
+ """
121
+ raise NotImplementedError
122
+
123
+
124
+ class PunicaWrapperBase(PunicaWrapperABC):
125
+ """
126
+ PunicaWrapperBase is designed to manage and provide metadata for the punica
127
+ kernel. The main function is to maintain the state information for
128
+ Multi-LoRA, and to provide the interface for the punica.
129
+ """
130
+
131
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
132
+ device: Union[torch.device, str], **kwargs):
133
+ self._token_lora_indices = torch.empty(max_num_batched_tokens,
134
+ dtype=torch.long,
135
+ device=device)
136
+ self._sampler_indices = torch.empty(max_num_batched_tokens,
137
+ dtype=torch.long,
138
+ device=device)
139
+ self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
140
+ dtype=torch.long,
141
+ device=device)
142
+ self._embeddings_indices = torch.empty(2,
143
+ max_num_batched_tokens,
144
+ dtype=torch.long,
145
+ device=device)
146
+ self._long_lora_indices = torch.empty(max_num_batched_tokens,
147
+ dtype=torch.long,
148
+ device=device)
149
+
150
+ # 5 is the number of indicies tensors.
151
+ # base_indices, sampler_indices, sampler_indices_padded,
152
+ # embeddings_indices,long_lora_indices
153
+ self.indices_len: List[Optional[int]] = [None] * 5
154
+ # these attributes are the information required for sgmv kernel
155
+ self._seq_start_locs = torch.empty(max_batches,
156
+ dtype=torch.long,
157
+ device=device)
158
+ self._seq_lengths = torch.empty(max_batches,
159
+ dtype=torch.long,
160
+ device=device)
161
+ self._lora_indices_per_batch = torch.empty(max_batches,
162
+ dtype=torch.long,
163
+ device=device)
164
+ self.device: torch.device = device
165
+ self.max_length: int = 0
166
+ self.token_nums: int = 0
167
+ self.batch_size: int = -1
168
+ self.is_prefill = False
169
+ self.no_lora = False
170
+
171
+ def _update_base_metadata(
172
+ self,
173
+ mapping: "LoRAMapping",
174
+ lora_index_to_id: List[Optional[int]],
175
+ max_loras: int,
176
+ vocab_size: int,
177
+ extra_vocab_size: int,
178
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
179
+ ):
180
+ (
181
+ base_indices,
182
+ sampler_indices,
183
+ sampler_indices_padded,
184
+ embeddings_indices,
185
+ long_lora_offsets_tensor,
186
+ indices_len,
187
+ ) = convert_mapping(
188
+ mapping,
189
+ lora_index_to_id,
190
+ max_loras,
191
+ vocab_size,
192
+ extra_vocab_size,
193
+ self.device,
194
+ long_lora_context,
195
+ )
196
+ self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
197
+ self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
198
+ self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
199
+ sampler_indices_padded)
200
+ self._embeddings_indices[:embeddings_indices.
201
+ shape[0], :embeddings_indices.shape[1]].copy_(
202
+ embeddings_indices)
203
+ if long_lora_offsets_tensor is not None:
204
+ self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
205
+ long_lora_offsets_tensor)
206
+ else:
207
+ self._long_lora_indices.zero_()
208
+ self.indices_len[:] = indices_len
209
+
210
+ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
211
+
212
+ (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
213
+ batch_size, max_length, token_nums,
214
+ no_lora) = compute_meta(token_lora_tensor)
215
+
216
+ self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
217
+ b_seq_start_tensor)
218
+ self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
219
+ self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
220
+ lora_indices_tensor)
221
+ self.batch_size = batch_size
222
+ self.max_length = max_length
223
+ self.token_nums = token_nums
224
+ self.no_lora = no_lora
225
+
226
+ def _apply_bias(
227
+ self,
228
+ indices: torch.Tensor,
229
+ output: torch.Tensor,
230
+ output_slices: Tuple[int, ...],
231
+ lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
232
+ ):
233
+ """Applies bias to output
234
+
235
+ Input shapes:
236
+ lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
237
+ indices: (batch_size)
238
+ output: (batch_size, q_slice_size + 2*kv_slice_size)
239
+ output_slices: n-1 element tuple of (slice_size...),
240
+ where n is number of slices
241
+ """
242
+ org_output = output
243
+ output = output.view(-1, output.shape[-1])
244
+ indices = indices.view(-1)
245
+
246
+ offset_left = 0
247
+ for slice_idx, slice in enumerate(output_slices):
248
+ bias = lora_bias_stacked[slice_idx]
249
+ if bias is not None:
250
+ bias = bias.view(-1, bias.shape[-1])
251
+ bias = bias[indices]
252
+ bias[indices == -1] = 0
253
+ output[:, offset_left:offset_left + slice] += bias
254
+ offset_left += slice
255
+
256
+ return output.view_as(org_output)
257
+
258
+ @property
259
+ def prefill_metadata(
260
+ self
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
262
+ """
263
+ This property provides a convenient way to access the necessary
264
+ metadata for prefill-related kernel computations.
265
+ 1. seq_start_locs: Tensor of sequence start positions.
266
+ 2. seq_lengths: Tensor of sequence lengths.
267
+ 3. lora_indices_per_batch: Tensor of lora indices, and an index of
268
+ -1 means no lora should be applied.
269
+ 4. batch_size: Batch size after clustering identical lora indices.
270
+ 5. max_length: The maximum sequence length in the batch.
271
+ 6. token_nums: The token numbers in the batch.
272
+ """
273
+ return (self._seq_start_locs[:self.batch_size],
274
+ self._seq_lengths[:self.batch_size],
275
+ self._lora_indices_per_batch[:self.batch_size],
276
+ self.batch_size, self.max_length, self.token_nums)
277
+
278
+ @property
279
+ def token_lora_indices(self) -> torch.Tensor:
280
+ """
281
+ This property provides the lora indices corresponding to each token
282
+ in the batch. An index of -1 means no lora should be applied.
283
+ """
284
+ token_lora_len = self.indices_len[0]
285
+ return self._token_lora_indices[:token_lora_len]
286
+
287
+ @property
288
+ def sampler_indices(self) -> torch.Tensor:
289
+ """
290
+ This property is used to access the lora indices specifically for
291
+ LogitsProcessorWithLoRA.
292
+ """
293
+ sampler_indices_len = self.indices_len[1]
294
+ return self._sampler_indices[:sampler_indices_len]
295
+
296
+ @property
297
+ def sampler_indices_padded(self) -> torch.Tensor:
298
+ """
299
+ This property provides access to padded sampler indices.
300
+ """
301
+ indices_padded_len = self.indices_len[2]
302
+ return self._sampler_indices_padded[:indices_padded_len]
303
+
304
+ @property
305
+ def embeddings_indices(self) -> torch.Tensor:
306
+ """
307
+ This property provides access to the indices used for lora embeddings,
308
+ specifically for VocabParallelEmbeddingWithLoRA.
309
+ """
310
+ embeddings_indices_len = self.indices_len[3]
311
+ return self._embeddings_indices[:, :embeddings_indices_len]
312
+
313
+ @property
314
+ def long_lora_indices(self) -> torch.Tensor:
315
+ """
316
+ This property provides access to the indices used for long context
317
+ lora, specifically for LinearScalingRotaryEmbeddingWithLora.
318
+ """
319
+ long_lora_len = self.indices_len[4]
320
+ return self._long_lora_indices[:long_lora_len]
321
+
322
+ def update_metadata(
323
+ self,
324
+ mapping: "LoRAMapping",
325
+ lora_index_to_id: List[Optional[int]],
326
+ max_loras: int,
327
+ vocab_size: int,
328
+ extra_vocab_size: int,
329
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
330
+ **kwargs):
331
+
332
+ self._update_base_metadata(mapping, lora_index_to_id, max_loras,
333
+ vocab_size, extra_vocab_size,
334
+ long_lora_context)
335
+ if mapping.is_prefill:
336
+ # Update metadata required for prefill-related operators.
337
+ self._update_prefill_metada(self.token_lora_indices)
338
+ self.is_prefill = True
339
+ else:
340
+ self.is_prefill = False
341
+
342
+ @abstractmethod
343
+ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
344
+ x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
345
+ scale: float, **kwargs) -> None:
346
+ """
347
+ Performs GEMM for multiple slices of lora_a.
348
+
349
+ Semantics:
350
+ for i in range(len(lora_a_stacked)):
351
+ y[i] += (x @ lora_a_stacked[i]) * scale
352
+
353
+ Args:
354
+ y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
355
+ x (torch.Tensor): Input tensor
356
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
357
+ scale (float): Scaling factor for the operation
358
+
359
+ """
360
+ # TODO: implement it based on torch ops
361
+ raise NotImplementedError
362
+
363
+ @abstractmethod
364
+ def add_expand(self,
365
+ y: torch.Tensor,
366
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
367
+ lora_b_stacked: Tuple[torch.Tensor, ...],
368
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
369
+ output_slices: Tuple[int, ...],
370
+ offset_start: int = 0,
371
+ add_inputs=True,
372
+ **kwargs) -> None:
373
+ """
374
+ Performs GEMM and bias addition for multiple slices of lora_b.
375
+
376
+ Semantics:
377
+ offset = offset_start
378
+ for i in range(len(lora_b_stacked)):
379
+ slice = output_slices[i]
380
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
381
+ lora_bias_stacked[i]
382
+ offset += slice
383
+
384
+ Args:
385
+ y (torch.Tensor): Output tensor.
386
+ x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
387
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
388
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
389
+ bias's weight
390
+ output_slices (Tuple[int, ...]): Every slice's size
391
+ offset_start (int): The starting position of y, defaults to 0
392
+ add_inputs (bool): Defaults to True.
393
+
394
+ """
395
+ # TODO: implement it based on torch ops
396
+ raise NotImplementedError
397
+
398
+ @abstractmethod
399
+ def add_lora_embedding(self,
400
+ y: torch.Tensor,
401
+ x: torch.Tensor,
402
+ lora_b_stacked: torch.Tensor,
403
+ add_inputs: bool = True,
404
+ **kwargs) -> None:
405
+ """
406
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
407
+ and this layer only requires the expand operation.
408
+ Semantics:
409
+ y += x @ lora_b_stacked
410
+
411
+ Args:
412
+ y (torch.Tensor): Output tensor.
413
+ x (torch.Tensor): Input tensor.
414
+ lora_b_stacked (torch.Tensor): lora_b's weights.
415
+ add_inputs (bool): Default to True.
416
+ """
417
+ # TODO: implement it based on torch ops
418
+ raise NotImplementedError
419
+
420
+ @abstractmethod
421
+ def add_lora_linear(self,
422
+ y: torch.Tensor,
423
+ x: torch.Tensor,
424
+ lora_a_stacked: Tuple[torch.Tensor, ...],
425
+ lora_b_stacked: Tuple[torch.Tensor, ...],
426
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
427
+ scale: float,
428
+ output_slices: Tuple[int, ...],
429
+ *,
430
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
431
+ **kwargs) -> None:
432
+ """
433
+ Applicable to linear-related lora.
434
+
435
+ Semantics:
436
+ for i in range(len(lora_a_stacked)):
437
+ y[i] += (
438
+ x[i].unsqueeze(0)
439
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
440
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
441
+ * scale
442
+ ).squeeze(0)+lora_bias_stacked[i]
443
+
444
+ Args:
445
+ y (torch.Tensor): Output tensor. Will be changed in-place.
446
+ x (torch.Tensor): Input tensor
447
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
448
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
449
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
450
+ scale (float): Scaling factor.
451
+ output_slices (Tuple[int, ...]): Every slice's size.
452
+ buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
453
+ """
454
+ # TODO: implement it based on torch ops
455
+ raise NotImplementedError
456
+
457
+ @abstractmethod
458
+ def add_lora_logits(self,
459
+ y: torch.Tensor,
460
+ x: torch.Tensor,
461
+ lora_a_stacked: torch.Tensor,
462
+ lora_b_stacked: torch.Tensor,
463
+ scale,
464
+ *,
465
+ buffer: Optional[torch.Tensor] = None,
466
+ **kwargs) -> None:
467
+ """
468
+ Applies lora specifically for LogitsProcessorWithLoRA.
469
+
470
+ Semantics:
471
+ buffer = (x @ lora_a_stacked) * scale
472
+ y += buffer @ lora_b_stacked
473
+
474
+ Args:
475
+ y (torch.Tensor): Output tensor.
476
+ x (torch.Tensor): Input tensor.
477
+ lora_a_stacked (torch.Tensor): lora_a's weights.
478
+ lora_b_stacked (torch.Tensor):lora_b's weights.
479
+ scale (float): Scaling factor.
480
+ buffer (Optional[torch.Tensor]):Default to None.
481
+ """
482
+ # TODO: implement it based on torch ops
483
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Callable, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
8
+ bgmv_shrink, sgmv_expand,
9
+ sgmv_expand_slice, sgmv_shrink)
10
+
11
+ from .punica_base import PunicaWrapperBase
12
+
13
+
14
+ # The platforms that are compatible with the PyTorch-native implementation can
15
+ # inherit this class
16
+ class PunicaWrapperCPU(PunicaWrapperBase):
17
+ """
18
+ PunicaWrapperCPU is designed to manage and provide metadata for the punica
19
+ kernel. The main function is to maintain the state information for
20
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
21
+ """
22
+
23
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
24
+ device: Union[torch.device, str], **kwargs):
25
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
26
+ device)
27
+
28
+ def _shrink_prefill(
29
+ self,
30
+ y: torch.Tensor,
31
+ x: torch.Tensor,
32
+ w_t_all: torch.Tensor,
33
+ scale: float,
34
+ ):
35
+ #No LoRA request, so return directly
36
+ if self.no_lora:
37
+ return
38
+ sgmv_shrink(
39
+ x,
40
+ w_t_all,
41
+ y,
42
+ *self.prefill_metadata,
43
+ scale,
44
+ )
45
+
46
+ def _shrink_decode(
47
+ self,
48
+ y: torch.Tensor,
49
+ x: torch.Tensor,
50
+ w_t_all: torch.Tensor,
51
+ scale: float,
52
+ ):
53
+ bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
54
+
55
+ def _expand_prefill(
56
+ self,
57
+ y: torch.Tensor,
58
+ x: torch.Tensor,
59
+ w_t_all: torch.Tensor,
60
+ add_inputs: bool,
61
+ ):
62
+ #No LoRA request, so return directly
63
+ if self.no_lora:
64
+ return
65
+ sgmv_expand(
66
+ x,
67
+ w_t_all,
68
+ y,
69
+ *self.prefill_metadata,
70
+ add_inputs,
71
+ )
72
+
73
+ def _expand_decode(
74
+ self,
75
+ y: torch.Tensor,
76
+ x: torch.Tensor,
77
+ w_t_all: torch.Tensor,
78
+ add_inputs: bool,
79
+ ):
80
+ bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
81
+
82
+ def _expand_slice_prefill(
83
+ self,
84
+ y: torch.Tensor,
85
+ x: torch.Tensor,
86
+ w_t_all: torch.Tensor,
87
+ y_offset: int,
88
+ y_slice_size: int,
89
+ add_inputs: bool,
90
+ ):
91
+ #No LoRA request, so return directly
92
+ if self.no_lora:
93
+ return
94
+ sgmv_expand_slice(
95
+ x,
96
+ w_t_all,
97
+ y,
98
+ *self.prefill_metadata,
99
+ y_offset,
100
+ y_slice_size,
101
+ add_inputs,
102
+ )
103
+
104
+ def _expand_slice_decode(
105
+ self,
106
+ y: torch.Tensor,
107
+ x: torch.Tensor,
108
+ w_t_all: torch.Tensor,
109
+ y_offset: int,
110
+ y_slice_size: int,
111
+ add_inputs: bool,
112
+ ):
113
+ bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
114
+ y_slice_size, add_inputs)
115
+
116
+ def _apply_expand(
117
+ self,
118
+ y: torch.Tensor,
119
+ x: torch.Tensor,
120
+ w_t_all: torch.Tensor,
121
+ y_offset: int,
122
+ y_slice_size: int,
123
+ add_inputs: bool = True,
124
+ ):
125
+ """
126
+ Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
127
+ computation, which is suitable for the
128
+ GEMM of lora'b.
129
+ """
130
+
131
+ expand_slice_fun: Callable = (self._expand_slice_prefill
132
+ if self.is_prefill else
133
+ self._expand_slice_decode)
134
+ expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
135
+
136
+ def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
137
+ w_t_all: torch.Tensor, scale: float):
138
+ """
139
+ Perform the ` y+=x@w_t_all` computation, which is suitable for the
140
+ GEMM of lora'a.
141
+ When `is_prefill is` true, it indicates that it is currently the
142
+ prefill stage, and the `_shrink_prefill` function should be called.
143
+ Otherwise, it is the decode stage, and the _shrink_decode function
144
+ should be called.
145
+ """
146
+ y_org = y
147
+ y = y.view(-1, y.shape[-1])
148
+ shrink_fun: Callable = (self._shrink_prefill
149
+ if self.is_prefill else self._shrink_decode)
150
+ shrink_fun(y, x, w_t_all, scale)
151
+ y = y.view_as(y_org)
152
+
153
+ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
154
+ x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
155
+ scale: float, **kwargs):
156
+ """
157
+ Performs GEMM for multiple slices of lora_a.
158
+ When `is_prefill is` true, it indicates that it is currently the
159
+ prefill stage, and the `_shrink_prefill` function should be called.
160
+ Otherwise, it is the decode stage, and the _shrink_decode function
161
+ should be called.
162
+
163
+ Semantics:
164
+ for i in range(len(lora_a_stacked)):
165
+ y[i] += (x @ lora_a_stacked[i]) * scale
166
+
167
+ Args:
168
+ y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
169
+ x (torch.Tensor): Input tensor
170
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
171
+ scale (float): Scaling factor for the operation
172
+ """
173
+
174
+ x = x.view(-1, x.shape[-1])
175
+ # TODO fuse these kernels
176
+ for slice_idx in range(len(lora_a_stacked)):
177
+ self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
178
+ scale)
179
+
180
+ def add_expand(self,
181
+ y: torch.Tensor,
182
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
183
+ lora_b_stacked: Tuple[torch.Tensor, ...],
184
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
185
+ output_slices: Tuple[int, ...],
186
+ offset_start: int = 0,
187
+ add_inputs=True,
188
+ **kwargs) -> None:
189
+ """
190
+ Performs GEMM and bias addition for multiple slices of lora_b.
191
+
192
+ Semantics:
193
+ for i in range(len(lora_b_stacked)):
194
+ slice = output_slices[i]
195
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
196
+ lora_bias_stacked[i]
197
+ offset += slice
198
+
199
+ Args:
200
+ y (torch.Tensor): Output tensor.
201
+ x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
202
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
203
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
204
+ bias's weight
205
+ output_slices (Tuple[int, ...]): Every slice's size
206
+ add_inputs (bool): Defaults to True.
207
+ """
208
+ y_org = y
209
+ y = y.view(-1, y.shape[-1])
210
+ offset_left = offset_start
211
+ if lora_bias_stacked is not None:
212
+ self._apply_bias(self.token_lora_indices, y, output_slices,
213
+ lora_bias_stacked)
214
+ for slice_idx in range(len(lora_b_stacked)):
215
+ self._apply_expand(
216
+ y,
217
+ x[slice_idx],
218
+ lora_b_stacked[slice_idx],
219
+ offset_left,
220
+ output_slices[slice_idx],
221
+ add_inputs=add_inputs,
222
+ )
223
+ offset_left += output_slices[slice_idx]
224
+ y = y.view_as(y_org)
225
+
226
+ def add_lora_embedding(self,
227
+ y: torch.Tensor,
228
+ x: torch.Tensor,
229
+ lora_b_stacked: torch.Tensor,
230
+ add_inputs: bool = True,
231
+ **kwargs) -> None:
232
+ """
233
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
234
+
235
+ Semantics:
236
+ y += x @ lora_b_stacked
237
+
238
+ Args:
239
+ y (torch.Tensor): Output tensor.
240
+ x (torch.Tensor): Input tensor.
241
+ lora_b_stacked (torch.Tensor): lora_b's weights.
242
+ add_inputs (bool): Default to True.
243
+ """
244
+
245
+ # Embedding layer only need expand op
246
+ expand_fun: Callable = (self._expand_prefill
247
+ if self.is_prefill else self._expand_decode)
248
+ expand_fun(y, x, lora_b_stacked, add_inputs)
249
+
250
+ def add_lora_linear(self,
251
+ y: torch.Tensor,
252
+ x: torch.Tensor,
253
+ lora_a_stacked: Tuple[torch.Tensor, ...],
254
+ lora_b_stacked: Tuple[torch.Tensor, ...],
255
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
256
+ scale: float,
257
+ output_slices: Tuple[int, ...],
258
+ *,
259
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
260
+ **kwargs) -> None:
261
+ """
262
+ Applicable to linear-related lora.
263
+
264
+ Semantics:
265
+ for i in range(len(lora_a_stacked)):
266
+ y[i] += (
267
+ x[i].unsqueeze(0)
268
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
269
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
270
+ * scale
271
+ ).squeeze(0)+lora_bias_stacked[i]
272
+
273
+ Args:
274
+ y (torch.Tensor): Output tensor. Will be changed in-place.
275
+ x (torch.Tensor): Input tensor
276
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
277
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
278
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
279
+ scale (float): Scaling factor.
280
+ output_slices (Tuple[int, ...]): Every slice's size.
281
+ buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
282
+ """
283
+
284
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
285
+ if lora_bias_stacked is not None:
286
+ assert len(lora_bias_stacked) == len(output_slices)
287
+ y = self._apply_bias(self.token_lora_indices, y, output_slices,
288
+ lora_bias_stacked)
289
+
290
+ if buffer is None:
291
+ r = lora_b_stacked[0].size(-1)
292
+ # We set the buffer to be float32 by default, consistent with the
293
+ # triton op
294
+ buffer = tuple(
295
+ torch.zeros(
296
+ (x.size(0), r), dtype=torch.float32, device=x.device)
297
+ for _ in range(len(output_slices)))
298
+ self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
299
+ self.add_expand(y,
300
+ buffer,
301
+ lora_b_stacked,
302
+ None,
303
+ output_slices,
304
+ add_inputs=True,
305
+ **kwargs)
306
+
307
+ def add_lora_logits(self,
308
+ y: torch.Tensor,
309
+ x: torch.Tensor,
310
+ lora_a_stacked: torch.Tensor,
311
+ lora_b_stacked: torch.Tensor,
312
+ scale,
313
+ *,
314
+ buffer: Optional[torch.Tensor] = None,
315
+ **kwargs) -> None:
316
+ """
317
+ Applies lora specifically for LogitsProcessorWithLoRA.
318
+
319
+ Semantics:
320
+ buffer = (x @ lora_a_stacked) * scale
321
+ y += buffer @ lora_b_stacked
322
+
323
+ Args:
324
+ y (torch.Tensor): Output tensor.
325
+ x (torch.Tensor): Input tensor.
326
+ lora_a_stacked (torch.Tensor): lora_a's weights.
327
+ lora_b_stacked (torch.Tensor):lora_b's weights.
328
+ scale (float): Scaling factor.
329
+ buffer (Optional[torch.Tensor]):Default to None.
330
+ """
331
+ y_org = y
332
+ y = y.view(-1, y.shape[-1])
333
+ x = x.view(-1, x.shape[-1])
334
+ r = lora_b_stacked.size(-1)
335
+ if buffer is None:
336
+ # We set the buffer to be float32 by default, consistent with the
337
+ # triton op
338
+ buffer = torch.zeros((x.size(0), r),
339
+ dtype=torch.float32,
340
+ device=x.device)
341
+ # LogitsProcessorWithLoRA always using bgmv.
342
+ bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
343
+ bgmv_expand(buffer,
344
+ lora_b_stacked,
345
+ y,
346
+ self.sampler_indices,
347
+ add_inputs=True)
348
+ y = y.view_as(y_org)
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Optional, Tuple, Union, final
4
+
5
+ import torch
6
+ from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
7
+ dispatch_bgmv_linear)
8
+
9
+ from .punica_base import PunicaWrapperBase
10
+
11
+
12
+ @final
13
+ class PunicaWrapperHPU(PunicaWrapperBase):
14
+
15
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
16
+ device: Union[torch.device, str], **kwargs):
17
+ # Increasing max_num_batched_tokens by 3x to handle increase in
18
+ # tensor size due to padding.
19
+ PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
20
+ max_batches, device)
21
+
22
+ def add_lora_embedding(self,
23
+ y: torch.Tensor,
24
+ x: torch.Tensor,
25
+ lora_b_stacked: torch.Tensor,
26
+ add_inputs: bool = True,
27
+ **kwargs) -> None:
28
+ dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
29
+
30
+ def add_lora_linear(self,
31
+ y: torch.Tensor,
32
+ x: torch.Tensor,
33
+ lora_a_stacked: Tuple[torch.Tensor, ...],
34
+ lora_b_stacked: Tuple[torch.Tensor, ...],
35
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
36
+ scale: float,
37
+ output_slices: Tuple[int, ...],
38
+ *,
39
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
40
+ **kwargs) -> None:
41
+ y_org = y
42
+ x = x.view(-1, x.shape[-1])
43
+ y = y.view(-1, y.shape[-1])
44
+ offset_left = 0
45
+
46
+ for slice_idx in range(len(output_slices)):
47
+ dispatch_bgmv_linear(
48
+ y[:, offset_left:offset_left + output_slices[slice_idx]], x,
49
+ lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
50
+ offset_left += output_slices[slice_idx]
51
+ y = y.view_as(y_org)
52
+
53
+ def add_lora_logits(self,
54
+ y: torch.Tensor,
55
+ x: torch.Tensor,
56
+ lora_a_stacked: torch.Tensor,
57
+ lora_b_stacked: torch.Tensor,
58
+ scale,
59
+ *,
60
+ buffer: Optional[torch.Tensor] = None,
61
+ **kwargs) -> None:
62
+ y_org = y
63
+ y = y.view(-1, y.shape[-1])
64
+ x = x.view(-1, x.shape[-1])
65
+ dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
66
+ y = y.view_as(y_org)
67
+
68
+ def add_shrink(
69
+ self,
70
+ y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
71
+ x: torch.Tensor,
72
+ lora_a_stacked: Tuple[torch.Tensor, ...],
73
+ scale: float,
74
+ **kwargs,
75
+ ) -> None:
76
+ raise NotImplementedError
77
+
78
+ def add_expand(
79
+ self,
80
+ y: torch.Tensor,
81
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
82
+ lora_b_stacked: Tuple[torch.Tensor, ...],
83
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
84
+ output_slices: Tuple[int, ...],
85
+ offset_start: int = 0,
86
+ add_inputs=True,
87
+ **kwargs,
88
+ ) -> None:
89
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.logger import init_logger
4
+ from vllm.platforms import current_platform
5
+ from vllm.utils import resolve_obj_by_qualname
6
+
7
+ from .punica_base import PunicaWrapperBase
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
13
+ punica_wrapper_qualname = current_platform.get_punica_wrapper()
14
+ punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
15
+ punica_wrapper = punica_wrapper_cls(*args, **kwargs)
16
+ assert punica_wrapper is not None, \
17
+ "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
18
+ logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
19
+ ".")
20
+ return punica_wrapper
.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ if TYPE_CHECKING:
8
+ # avoid circuit import
9
+ from vllm.lora.layers import LoRAMapping
10
+ from vllm.lora.models import LongContextLoRAContext
11
+
12
+
13
+ def compute_meta(
14
+ token_lora_tensor: torch.Tensor
15
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
16
+ """
17
+ Get the information required for the sgmv kernel. With the features:
18
+ 1. If consecutive requests in the batch use the same LoRA, this function
19
+ will combine them into a single request, improving sgmv kernel inference
20
+ performance.
21
+ 2. At the beginning of each prefill stage inference, recalculations are
22
+ needed based on the input, but only once.
23
+ """
24
+
25
+ lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
26
+ token_lora_tensor, return_counts=True)
27
+ cum_result = torch.cumsum(seq_length_tensor, dim=0)
28
+ b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
29
+ b_seq_start_tensor[1:].copy_(cum_result[:-1])
30
+ max_length = seq_length_tensor.max().item()
31
+ token_nums = seq_length_tensor.sum().item()
32
+ batch_size = lora_indices_tensor.size(0)
33
+ no_lora = False
34
+ # -1 means no lora should be applied. Use `no_lora` to determine whether
35
+ # the current step requires LoRA. If LoRA is not needed, the prefill stage
36
+ # does not need to launch the triton kernel, which can improve performance
37
+ if batch_size == 1 and lora_indices_tensor == -1:
38
+ no_lora = True
39
+ return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
40
+ batch_size, max_length, token_nums, no_lora)
41
+
42
+
43
+ # TODO see if this can be vectorized
44
+ def convert_mapping(
45
+ mapping: "LoRAMapping",
46
+ lora_index_to_id: List[Optional[int]],
47
+ max_loras: int,
48
+ vocab_size: int,
49
+ extra_vocab_size: int,
50
+ device: torch.device,
51
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
53
+ Optional[torch.Tensor], List[int]]:
54
+ """Converts LoRAMapping to index tensors.
55
+
56
+ Args:
57
+ mapping: LoRAMapping mapping rows in a batch to LoRA ids.
58
+ lora_index_to_id: List mapping LoRA ids to LoRA indices.
59
+ max_loras: Maximum number of LoRAs.
60
+ vocab_size: Model vocab size.
61
+ extra_vocab_size: Extra vocab size each LoRA can have.
62
+ long_lora_context: Passed if there are long context lora in a batch.
63
+
64
+ Returns:
65
+ A tuple of tensors:
66
+ base_indices: Tensor of shape [batch_size] mapping batch rows to
67
+ LoRA indices.
68
+ sampler_indices: Tensor of shape [batch_size] mapping requests to
69
+ LoRA indices for sampler. For generation, this will be the
70
+ same as base_indicies. For prefill, this will map requests
71
+ to LoRA indices.
72
+ sampler_indices_padded: Tensor of shape [batch_size] mapping
73
+ requests to LoRA indices for sampler with padding.
74
+ Same as sampler_indicies, but -1 is replaced with
75
+ max_loras.
76
+ embeddings_indices: Tensor of shape [2, batch_size] mapping
77
+ requests to embedding indices. First row is for embeddings
78
+ added by the LoRAs, second row is for the LoRA.lora_a
79
+ embeddings.
80
+ long_lora_indices: Tensor of shape [batch_size] mapping
81
+ requests to RoPE offsets and rot dims for long LoRAs.
82
+ None if long context lora doesn't exist.
83
+ indices_len: List of lengths of the above tensors. It contains
84
+ (base_indices, sampler_indices, sampler_indices_padded,
85
+ embeddings_indices, long_lora_indices).
86
+ """
87
+ index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
88
+ embedding_indices = index_mapping_indices.copy()
89
+ lora_indices = index_mapping_indices.copy()
90
+ long_lora_offsets: Optional[torch.Tensor] = None
91
+ if long_lora_context:
92
+ long_lora_offsets = torch.zeros(len(index_mapping_indices),
93
+ device=device,
94
+ dtype=torch.long)
95
+ prompt_mapping: List[int] = [
96
+ lora_index_to_id.index(x) if x > 0 else -1
97
+ for x in mapping.prompt_mapping
98
+ ]
99
+ lora_idx = None
100
+ for i in range(len(index_mapping_indices)):
101
+ # TODO index can be slow. optimize
102
+ lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
103
+ if index_mapping_indices[i] > 0 else -1)
104
+ embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
105
+ lora_indices[i] = lora_idx
106
+ if long_lora_context:
107
+ assert long_lora_offsets is not None
108
+ lora_offset: int = long_lora_context.offsets_by_lora_id.get(
109
+ index_mapping_indices[i], 0)
110
+ long_lora_offsets[i] = lora_offset
111
+
112
+ indices_list: List[Union[List[int], torch.Tensor]] = [
113
+ index_mapping_indices,
114
+ lora_indices,
115
+ embedding_indices,
116
+ ]
117
+ if long_lora_context:
118
+ assert long_lora_offsets is not None
119
+ indices_list.append(long_lora_offsets)
120
+ indices = torch.tensor(indices_list, dtype=torch.long, device=device)
121
+ prompt_mapping_tensor = torch.tensor(prompt_mapping,
122
+ dtype=torch.long,
123
+ device=device)
124
+ embeddings_indices = torch.stack([
125
+ indices[2] * extra_vocab_size,
126
+ indices[2] * (vocab_size + extra_vocab_size),
127
+ ])
128
+ embeddings_indices[embeddings_indices == -1] = max_loras - 1
129
+ base_indices = indices[1]
130
+ sampler_indices = prompt_mapping_tensor
131
+ sampler_indices_padded = sampler_indices.clone()
132
+ sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
133
+ sampler_indices_padded = torch.arange(
134
+ 0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
135
+ sampler_indices_padded * len(sampler_indices_padded))
136
+ long_lora_indices = None
137
+ long_lora_indices_len: Optional[int] = None
138
+ if long_lora_context:
139
+ long_lora_indices = indices[3]
140
+ long_lora_indices_len = long_lora_indices.shape[-1]
141
+ # Contain length of indices tensors. Used to index into each tensor.
142
+ indices_len = [
143
+ base_indices.shape[-1],
144
+ sampler_indices.shape[-1],
145
+ sampler_indices_padded.shape[-1],
146
+ embeddings_indices.shape[-1],
147
+ ]
148
+ if long_lora_indices_len is not None:
149
+ indices_len.append(long_lora_indices_len)
150
+ else:
151
+ # If long_lora doesn't exist,append None
152
+ indices_len.append(None)
153
+
154
+ return (
155
+ base_indices,
156
+ sampler_indices,
157
+ sampler_indices_padded,
158
+ embeddings_indices,
159
+ long_lora_indices,
160
+ indices_len,
161
+ )
.venv/lib/python3.11/site-packages/vllm/v1/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (180 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc ADDED
Binary file (4.76 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc ADDED
Binary file (1.57 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc ADDED
Binary file (8.36 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc ADDED
Binary file (842 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention layer with FlashAttention."""
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Type
5
+
6
+ import numpy as np
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
+ AttentionMetadata, AttentionType)
13
+ from vllm.envs import VLLM_FLASH_ATTN_VERSION
14
+ from vllm.logger import init_logger
15
+ from vllm.platforms import current_platform
16
+ from vllm.utils import cdiv
17
+ from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
18
+ flash_attn_varlen_func,
19
+ is_fa_version_supported)
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class FlashAttentionBackend(AttentionBackend):
25
+
26
+ accept_output_buffer: bool = True
27
+
28
+ @staticmethod
29
+ def get_supported_head_sizes() -> List[int]:
30
+ return [32, 64, 96, 128, 160, 192, 224, 256]
31
+
32
+ @staticmethod
33
+ def get_name() -> str:
34
+ return "FLASH_ATTN_VLLM_V1"
35
+
36
+ @staticmethod
37
+ def get_impl_cls() -> Type["FlashAttentionImpl"]:
38
+ return FlashAttentionImpl
39
+
40
+ @staticmethod
41
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
42
+ return FlashAttentionMetadata
43
+
44
+ @staticmethod
45
+ def get_kv_cache_shape(
46
+ num_blocks: int,
47
+ block_size: int,
48
+ num_kv_heads: int,
49
+ head_size: int,
50
+ ) -> Tuple[int, ...]:
51
+ if block_size % 16 != 0:
52
+ raise ValueError("Block size must be a multiple of 16.")
53
+ return (2, num_blocks, block_size, num_kv_heads, head_size)
54
+
55
+ @staticmethod
56
+ def use_cascade_attention(*args, **kwargs) -> bool:
57
+ return use_cascade_attention(*args, **kwargs)
58
+
59
+
60
+ @dataclass
61
+ class FlashAttentionMetadata:
62
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
63
+ # |---------- N-1 iteration --------|
64
+ # |---------------- N iteration ---------------------|
65
+ # |- tokenA -|......................|-- newTokens ---|
66
+ # |---------- context_len ----------|
67
+ # |-------------------- seq_len ---------------------|
68
+ # |-- query_len ---|
69
+
70
+ num_actual_tokens: int # Number of tokens excluding padding.
71
+ max_query_len: int
72
+ query_start_loc: torch.Tensor
73
+ max_seq_len: int
74
+ seq_lens: torch.Tensor
75
+ block_table: torch.Tensor
76
+ slot_mapping: torch.Tensor
77
+
78
+ # For cascade attention.
79
+ use_cascade: bool
80
+ common_prefix_len: int
81
+ cu_prefix_query_lens: Optional[torch.Tensor]
82
+ prefix_kv_lens: Optional[torch.Tensor]
83
+ suffix_kv_lens: Optional[torch.Tensor]
84
+
85
+ # For logging.
86
+ num_input_tokens: int = 0 # Number of tokens including padding.
87
+
88
+
89
+ class FlashAttentionImpl(AttentionImpl):
90
+
91
+ def __init__(
92
+ self,
93
+ num_heads: int,
94
+ head_size: int,
95
+ scale: float,
96
+ num_kv_heads: int,
97
+ alibi_slopes: Optional[List[float]],
98
+ sliding_window: Optional[int],
99
+ kv_cache_dtype: str,
100
+ blocksparse_params: Optional[Dict[str, Any]] = None,
101
+ logits_soft_cap: Optional[float] = None,
102
+ attn_type: AttentionType = AttentionType.DECODER,
103
+ ) -> None:
104
+ if blocksparse_params is not None:
105
+ raise ValueError(
106
+ "FlashAttention does not support block-sparse attention.")
107
+ self.num_heads = num_heads
108
+ self.head_size = head_size
109
+ self.scale = float(scale)
110
+ self.num_kv_heads = num_kv_heads
111
+ if alibi_slopes is not None:
112
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
113
+ self.alibi_slopes = alibi_slopes
114
+ if sliding_window is None:
115
+ self.sliding_window = (-1, -1)
116
+ else:
117
+ self.sliding_window = (sliding_window - 1, 0)
118
+ self.kv_cache_dtype = kv_cache_dtype
119
+ if logits_soft_cap is None:
120
+ # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
121
+ logits_soft_cap = 0
122
+ self.logits_soft_cap = logits_soft_cap
123
+
124
+ assert self.num_heads % self.num_kv_heads == 0
125
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
126
+
127
+ support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
128
+ if head_size not in support_head_sizes:
129
+ raise ValueError(
130
+ f"Head size {head_size} is not supported by FlashAttention. "
131
+ f"Supported head sizes are: {support_head_sizes}.")
132
+
133
+ if attn_type != AttentionType.DECODER:
134
+ raise NotImplementedError("Encoder self-attention and "
135
+ "encoder/decoder cross-attention "
136
+ "are not implemented for "
137
+ "FlashAttentionImpl")
138
+
139
+ # if hopper default to FA3, otherwise stick to FA2 for now
140
+ # TODO(lucas): profile FA3 on ampere to see if it makes sense to
141
+ # use FA3 as default for both
142
+ if current_platform.get_device_capability()[0] >= 9:
143
+ self.fa_version = 3 if is_fa_version_supported(3) else 2
144
+ else:
145
+ self.fa_version = 2
146
+
147
+ if VLLM_FLASH_ATTN_VERSION is not None:
148
+ assert VLLM_FLASH_ATTN_VERSION in [2, 3]
149
+ self.fa_version = VLLM_FLASH_ATTN_VERSION
150
+
151
+ if not is_fa_version_supported(self.fa_version):
152
+ logger.error("Cannot use FA version %d is not supported due to %s",
153
+ self.fa_version,
154
+ fa_version_unsupported_reason(self.fa_version))
155
+
156
+ assert is_fa_version_supported(self.fa_version)
157
+
158
+ def forward(
159
+ self,
160
+ layer: torch.nn.Module,
161
+ query: torch.Tensor,
162
+ key: torch.Tensor,
163
+ value: torch.Tensor,
164
+ kv_cache: torch.Tensor,
165
+ attn_metadata: FlashAttentionMetadata,
166
+ output: Optional[torch.Tensor] = None,
167
+ ) -> torch.Tensor:
168
+ """Forward pass with FlashAttention.
169
+
170
+ Args:
171
+ query: shape = [num_tokens, num_heads, head_size]
172
+ key: shape = [num_tokens, num_kv_heads, head_size]
173
+ value: shape = [num_tokens, num_kv_heads, head_size]
174
+ kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
175
+ attn_metadata: Metadata for attention.
176
+ Returns:
177
+ shape = [num_tokens, num_heads * head_size]
178
+ """
179
+ assert output is not None, "Output tensor must be provided."
180
+
181
+ if attn_metadata is None:
182
+ # Profiling run.
183
+ return output
184
+
185
+ # IMPORTANT!
186
+ # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
187
+ # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
188
+ # in this method. For example, `view` and `slice` (or `[:n]`) operations
189
+ # are surprisingly slow even in the case they do not invoke any GPU ops.
190
+ # Minimize the PyTorch ops in this method as much as possible.
191
+ # Whenever making a change in this method, please benchmark the
192
+ # performance to make sure it does not introduce any overhead.
193
+
194
+ num_actual_tokens = attn_metadata.num_actual_tokens
195
+ # Reshape the input keys and values and store them in the cache.
196
+ # NOTE(woosuk): Here, key and value are padded while slot_mapping is
197
+ # not padded. However, we don't need to do key[:num_actual_tokens] and
198
+ # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
199
+ # the slot_mapping's shape to determine the number of actual tokens.
200
+ key_cache, value_cache = kv_cache.unbind(0)
201
+ torch.ops._C_cache_ops.reshape_and_cache_flash(
202
+ key,
203
+ value,
204
+ key_cache,
205
+ value_cache,
206
+ attn_metadata.slot_mapping,
207
+ self.kv_cache_dtype,
208
+ layer._k_scale,
209
+ layer._v_scale,
210
+ )
211
+
212
+ # Compute attention and update output up to `num_actual_tokens`.
213
+ if not attn_metadata.use_cascade:
214
+ # Regular attention (common case).
215
+ flash_attn_varlen_func(
216
+ q=query[:num_actual_tokens],
217
+ k=key_cache,
218
+ v=value_cache,
219
+ out=output[:num_actual_tokens],
220
+ cu_seqlens_q=attn_metadata.query_start_loc,
221
+ max_seqlen_q=attn_metadata.max_query_len,
222
+ seqused_k=attn_metadata.seq_lens,
223
+ max_seqlen_k=attn_metadata.max_seq_len,
224
+ softmax_scale=self.scale,
225
+ causal=True,
226
+ alibi_slopes=self.alibi_slopes,
227
+ window_size=self.sliding_window,
228
+ block_table=attn_metadata.block_table,
229
+ softcap=self.logits_soft_cap,
230
+ fa_version=self.fa_version,
231
+ )
232
+ return output
233
+
234
+ # Cascade attention (rare case).
235
+ cascade_attention(
236
+ output[:num_actual_tokens],
237
+ query[:num_actual_tokens],
238
+ key_cache,
239
+ value_cache,
240
+ cu_query_lens=attn_metadata.query_start_loc,
241
+ max_query_len=attn_metadata.max_query_len,
242
+ cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
243
+ prefix_kv_lens=attn_metadata.prefix_kv_lens,
244
+ suffix_kv_lens=attn_metadata.suffix_kv_lens,
245
+ max_kv_len=attn_metadata.max_seq_len,
246
+ softmax_scale=self.scale,
247
+ alibi_slopes=self.alibi_slopes,
248
+ sliding_window=self.sliding_window,
249
+ logits_soft_cap=self.logits_soft_cap,
250
+ block_table=attn_metadata.block_table,
251
+ common_prefix_len=attn_metadata.common_prefix_len,
252
+ fa_version=self.fa_version,
253
+ )
254
+ return output
255
+
256
+
257
+ def use_cascade_attention(
258
+ common_prefix_len: int,
259
+ query_lens: np.ndarray,
260
+ num_query_heads: int,
261
+ num_kv_heads: int,
262
+ use_alibi: bool,
263
+ use_sliding_window: bool,
264
+ num_sms: int,
265
+ ) -> bool:
266
+ """Decide whether to use cascade attention.
267
+
268
+ This function 1) checks whether cascade attention is supported with the
269
+ given configuration, and 2) heuristically decides whether using cascade
270
+ attention can improve performance.
271
+ """
272
+ # Too short common prefix. Probably not worth using cascade attention.
273
+ # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
274
+ # NOTE(woosuk): This is the common case. We should return False as soon as
275
+ # possible to avoid any unnecessary computation.
276
+ if common_prefix_len < 256:
277
+ return False
278
+ # Cascade attention is currently not supported with these variants.
279
+ if use_alibi or use_sliding_window:
280
+ return False
281
+ # Too few queries. Probably not worth using cascade attention.
282
+ # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
283
+ num_reqs = len(query_lens)
284
+ if num_reqs < 8:
285
+ return False
286
+
287
+ # Heuristics to decide whether using cascade attention is beneficial.
288
+ # 1. When FlashDecoding is not used for normal attention, cascade attention
289
+ # is likely to be faster since it saves memory bandwidth.
290
+ num_queries_per_kv = num_query_heads // num_kv_heads
291
+ # The criteria for using FlashDecoding can be found in the following link:
292
+ # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
293
+ use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
294
+ and not use_alibi and np.all(query_lens == 1))
295
+ if not use_flash_decoding:
296
+ # Use cascade attention.
297
+ return True
298
+
299
+ # 2. When FlashDecoding is used for normal attention, it is not clear
300
+ # whether cascade attention is beneficial, because FlashDecoding can
301
+ # launch more CTAs than cascade attention.
302
+ # We use a simple performance model to compare the two methods.
303
+ # NOTE(woosuk): The performance model is very rough and may not be
304
+ # accurate.
305
+ num_tokens = num_reqs
306
+ # NOTE(woosuk): These are default tile sizes. flash-attn might use
307
+ # different tile sizes (e.g., 64 or 256) depending on the configuration.
308
+ q_tile_size = 128
309
+ kv_tile_size = 128
310
+ num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
311
+
312
+ cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
313
+ cascade_waves = cdiv(cascade_ctas, num_sms)
314
+ cascade_time = cascade_waves * num_prefix_tiles
315
+
316
+ flash_decoding_ctas = (num_reqs * num_kv_heads *
317
+ cdiv(num_queries_per_kv, q_tile_size))
318
+ flash_decoding_ctas *= num_prefix_tiles
319
+ flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
320
+
321
+ # Use cascade attention if it is faster than FlashDecoding.
322
+ return cascade_time < flash_decoding_time
323
+
324
+
325
+ def cascade_attention(
326
+ output: torch.Tensor,
327
+ query: torch.Tensor,
328
+ key_cache: torch.Tensor,
329
+ value_cache: torch.Tensor,
330
+ cu_query_lens: torch.Tensor,
331
+ max_query_len: int,
332
+ cu_prefix_query_lens: torch.Tensor,
333
+ prefix_kv_lens: torch.Tensor,
334
+ suffix_kv_lens: torch.Tensor,
335
+ max_kv_len: int,
336
+ softmax_scale: float,
337
+ alibi_slopes: Optional[torch.Tensor],
338
+ sliding_window: Tuple[int, int],
339
+ logits_soft_cap: float,
340
+ block_table: torch.Tensor,
341
+ common_prefix_len: int,
342
+ fa_version: int,
343
+ ) -> torch.Tensor:
344
+ assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
345
+ # TODO: Support sliding window.
346
+ assert sliding_window == (-1, -1), (
347
+ "Cascade attention does not support sliding window.")
348
+
349
+ num_tokens = query.shape[0]
350
+ block_size = key_cache.shape[-3]
351
+ assert common_prefix_len % block_size == 0
352
+ num_common_kv_blocks = common_prefix_len // block_size
353
+ assert num_common_kv_blocks > 0
354
+
355
+ # Process shared prefix.
356
+ prefix_output, prefix_lse = flash_attn_varlen_func(
357
+ q=query,
358
+ k=key_cache,
359
+ v=value_cache,
360
+ cu_seqlens_q=cu_prefix_query_lens,
361
+ seqused_k=prefix_kv_lens,
362
+ max_seqlen_q=num_tokens,
363
+ max_seqlen_k=common_prefix_len,
364
+ softmax_scale=softmax_scale,
365
+ causal=False,
366
+ window_size=sliding_window,
367
+ block_table=block_table[:1],
368
+ softcap=logits_soft_cap,
369
+ return_softmax_lse=True,
370
+ fa_version=fa_version,
371
+ )
372
+
373
+ # Process suffix per query.
374
+ suffix_output, suffix_lse = flash_attn_varlen_func(
375
+ q=query,
376
+ k=key_cache,
377
+ v=value_cache,
378
+ cu_seqlens_q=cu_query_lens,
379
+ seqused_k=suffix_kv_lens,
380
+ max_seqlen_q=max_query_len,
381
+ max_seqlen_k=max_kv_len - common_prefix_len,
382
+ softmax_scale=softmax_scale,
383
+ causal=True,
384
+ window_size=sliding_window,
385
+ block_table=block_table[:, num_common_kv_blocks:],
386
+ softcap=logits_soft_cap,
387
+ return_softmax_lse=True,
388
+ fa_version=fa_version,
389
+ )
390
+
391
+ # Merge prefix and suffix outputs, and store the result in output.
392
+ merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
393
+ suffix_lse)
394
+
395
+
396
+ def merge_attn_states(
397
+ output: torch.Tensor,
398
+ prefix_output: torch.Tensor,
399
+ prefix_lse: torch.Tensor,
400
+ suffix_output: torch.Tensor,
401
+ suffix_lse: torch.Tensor,
402
+ ) -> None:
403
+ num_tokens = output.shape[0]
404
+ num_query_heads = output.shape[1]
405
+ head_size = output.shape[2]
406
+ padded_head_size = triton.next_power_of_2(head_size)
407
+
408
+ # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
409
+ merge_attn_states_kernel[(num_tokens, num_query_heads)](
410
+ output,
411
+ prefix_output,
412
+ prefix_lse,
413
+ suffix_output,
414
+ suffix_lse,
415
+ head_size,
416
+ padded_head_size,
417
+ )
418
+
419
+
420
+ @triton.jit
421
+ def merge_attn_states_kernel(
422
+ output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
423
+ prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
424
+ prefix_lse, # [NUM_HEADS, NUM_TOKENS]
425
+ suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
426
+ suffix_lse, # [NUM_HEADS, NUM_TOKENS]
427
+ HEAD_SIZE: tl.constexpr,
428
+ PADDED_HEAD_SIZE: tl.constexpr,
429
+ ):
430
+ token_idx = tl.program_id(0)
431
+ num_tokens = tl.num_programs(0)
432
+ head_idx = tl.program_id(1)
433
+ num_heads = tl.num_programs(1)
434
+
435
+ p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
436
+ s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
437
+ max_lse = tl.maximum(p_lse, s_lse)
438
+ p_lse = p_lse - max_lse
439
+ s_lse = s_lse - max_lse
440
+
441
+ head_arange = tl.arange(0, PADDED_HEAD_SIZE)
442
+ head_mask = head_arange < HEAD_SIZE
443
+ p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
444
+ head_idx * HEAD_SIZE + head_arange,
445
+ mask=head_mask)
446
+ s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
447
+ head_idx * HEAD_SIZE + head_arange,
448
+ mask=head_mask)
449
+
450
+ # NOTE(woosuk): Be careful with the numerical stability.
451
+ # We should compute the scale first, and then multiply it with the output.
452
+ # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
453
+ p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
454
+ s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
455
+ out = p_out * p_scale + s_out * s_scale
456
+ tl.store(output + token_idx * num_heads * HEAD_SIZE +
457
+ head_idx * HEAD_SIZE + head_arange,
458
+ out,
459
+ mask=head_mask)
.venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (185 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc ADDED
Binary file (6.78 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import TYPE_CHECKING, Dict, List, Set, Tuple
4
+
5
+ from vllm.logger import init_logger
6
+ from vllm.multimodal import MULTIMODAL_REGISTRY
7
+ from vllm.v1.request import Request
8
+
9
+ if TYPE_CHECKING:
10
+ from vllm.config import ModelConfig, SchedulerConfig
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ class EncoderCacheManager:
16
+
17
+ def __init__(self, cache_size: int):
18
+ self.cache_size = cache_size
19
+ self.num_free_slots = cache_size
20
+ # req_id -> cached input ids
21
+ self.cached: Dict[str, Set[int]] = {}
22
+ # List of [req_id, input_id]
23
+ self.freed: List[Tuple[str, int]] = []
24
+
25
+ def has_cache(self, request: Request, input_id: int) -> bool:
26
+ req_id = request.request_id
27
+ return req_id in self.cached and input_id in self.cached[req_id]
28
+
29
+ def can_allocate(self, request: Request, input_id: int) -> bool:
30
+ num_tokens = request.get_num_encoder_tokens(input_id)
31
+ return num_tokens <= self.num_free_slots
32
+
33
+ def allocate(self, request: Request, input_id: int) -> None:
34
+ req_id = request.request_id
35
+ if req_id not in self.cached:
36
+ self.cached[req_id] = set()
37
+ self.cached[req_id].add(input_id)
38
+ self.num_free_slots -= request.get_num_encoder_tokens(input_id)
39
+
40
+ def get_cached_input_ids(self, request: Request) -> Set[int]:
41
+ return self.cached.get(request.request_id, set())
42
+
43
+ def free_encoder_input(self, request: Request, input_id: int) -> None:
44
+ """Free a single encoder input id for the request."""
45
+ req_id = request.request_id
46
+ if req_id not in self.cached:
47
+ return
48
+
49
+ self.cached[req_id].discard(input_id)
50
+ if len(self.cached[req_id]) == 0:
51
+ del self.cached[req_id]
52
+ self.num_free_slots += request.get_num_encoder_tokens(input_id)
53
+ self.freed.append((req_id, input_id))
54
+
55
+ def free(self, request: Request) -> None:
56
+ """Free all cached input ids for the request."""
57
+ input_ids = self.get_cached_input_ids(request)
58
+ for input_id in input_ids:
59
+ self.free_encoder_input(request, input_id)
60
+
61
+ def get_freed_ids(self) -> List[Tuple[str, int]]:
62
+ freed = self.freed
63
+ self.freed = []
64
+ return freed
65
+
66
+
67
+ def compute_encoder_budget(
68
+ model_config: "ModelConfig",
69
+ scheduler_config: "SchedulerConfig",
70
+ ) -> Tuple[int, int]:
71
+ """Compute the encoder cache budget based on the model and scheduler
72
+ configurations.
73
+
74
+ Args:
75
+ model_config: Model configuration.
76
+ scheduler_config: Scheduler configuration.
77
+
78
+ Returns:
79
+ - Compute budget for encoder execution, in unit of number of tokens
80
+ in the input sequence.
81
+ - Space budget for encoder cache size, in unit of number of tokens
82
+ in the input sequence.
83
+ """
84
+
85
+ if not model_config.is_multimodal_model:
86
+ return 0, 0
87
+
88
+ # TODO: handle encoder-decoder models once we support them.
89
+ (
90
+ encoder_compute_budget,
91
+ encoder_cache_size,
92
+ ) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
93
+
94
+ return encoder_compute_budget, encoder_cache_size
95
+
96
+
97
+ def _compute_encoder_budget_multimodal(
98
+ model_config: "ModelConfig",
99
+ scheduler_config: "SchedulerConfig",
100
+ ) -> Tuple[int, int]:
101
+ """Compute the encoder cache budget based on the model and scheduler
102
+ configurations for a multimodal model.
103
+
104
+ Args:
105
+ model_config: Model configuration.
106
+ scheduler_config: Scheduler configuration.
107
+
108
+ Returns:
109
+ - Compute budget for encoder execution, in unit of number of tokens
110
+ in the input sequence.
111
+ - Space budget for encoder cache size, in unit of number of tokens
112
+ in the input sequence.
113
+ """
114
+
115
+ max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
116
+ model_config)
117
+
118
+ if not max_tokens_by_modality_dict:
119
+ logger.warning(
120
+ "All non-text modalities supported by the model have been "
121
+ "explicitly disabled via limit_mm_per_prompt. Encoder cache will "
122
+ "not be initialized.")
123
+ return 0, 0
124
+
125
+ _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
126
+ key=lambda item: item[1])
127
+
128
+ encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
129
+ max_tokens_per_mm_item)
130
+ encoder_cache_size = max(scheduler_config.encoder_cache_size,
131
+ max_tokens_per_mm_item)
132
+
133
+ return encoder_compute_budget, encoder_cache_size
.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from collections import defaultdict
4
+ from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
5
+
6
+ from vllm.logger import init_logger
7
+ from vllm.utils import cdiv
8
+ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
9
+ KVCacheBlock,
10
+ generate_block_hash_extra_keys,
11
+ hash_block_tokens,
12
+ hash_request_tokens)
13
+ from vllm.v1.request import Request, RequestStatus
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ class KVCacheManager:
19
+
20
+ def __init__(
21
+ self,
22
+ block_size: int,
23
+ num_gpu_blocks: int,
24
+ max_model_len: int,
25
+ sliding_window: Optional[int] = None,
26
+ enable_caching: bool = True,
27
+ num_preallocate_tokens: int = 64,
28
+ ) -> None:
29
+ self.block_size = block_size
30
+ self.num_gpu_blocks = num_gpu_blocks
31
+ self.max_model_len = max_model_len
32
+ self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
33
+ self.sliding_window = sliding_window
34
+ self.enable_caching = enable_caching
35
+ # NOTE(woosuk): To avoid frequent block allocation, we preallocate some
36
+ # blocks for each request. For example, when a request reaches the end
37
+ # of its block table, we preallocate N blocks in advance. This way, we
38
+ # reduce the overhead of updating free_block_ids and ref_cnts for each
39
+ # request every step (at the cost of some memory waste).
40
+ # NOTE(woosuk): This is different from the "lookahead" slots since this
41
+ # does not guarantee that the request always has N empty blocks. After
42
+ # the request gets N empty blocks, it starts to use the blocks without
43
+ # further allocation. When it uses up all the N empty blocks, it gets
44
+ # N new empty blocks.
45
+ self.num_preallocate_tokens = num_preallocate_tokens
46
+ self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
47
+
48
+ # A Block pool of all kv-cache blocks.
49
+ self.block_pool: List[KVCacheBlock] = [
50
+ KVCacheBlock(idx) for idx in range(num_gpu_blocks)
51
+ ]
52
+ # Free block queue that constructs and manipulates a doubly linked
53
+ # list of free blocks (including eviction candidates when caching is
54
+ # enabled).
55
+ self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
56
+
57
+ # {block_hash: {block ID: block}}. A cached block is
58
+ # a full block with a block hash that can be used for prefix caching.
59
+ # The cached block may be used by running requests or in the
60
+ # free_block_queue that could potentially be evicted.
61
+ # NOTE: We currently don't de-duplicate the blocks in the cache,
62
+ # meaning that if a block becomes full and is cached, we don't check
63
+ # if there is already an identical block in the cache. This is because
64
+ # we want to make sure the allocated block IDs won't change so that
65
+ # block tables are append-only.
66
+ self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
67
+ int, KVCacheBlock]] = defaultdict(dict)
68
+
69
+ # Mapping from request ID to blocks to track the blocks allocated
70
+ # for each request, so that we can free the blocks when the request
71
+ # is finished.
72
+ self.req_to_blocks: DefaultDict[str,
73
+ List[KVCacheBlock]] = defaultdict(list)
74
+
75
+ @property
76
+ def usage(self) -> float:
77
+ return 1.0 - (self.free_block_queue.num_free_blocks /
78
+ self.num_gpu_blocks)
79
+
80
+ def get_computed_blocks(
81
+ self, request: Request) -> Tuple[List[KVCacheBlock], int]:
82
+ """Get the computed (cached) blocks for the request.
83
+ Note that the computed blocks must be full.
84
+
85
+ Args:
86
+ request: The request to get the computed blocks.
87
+
88
+ Returns:
89
+ A tuple containing:
90
+ - A list of blocks that are computed for the request.
91
+ - The number of computed tokens.
92
+ """
93
+ if not self.enable_caching:
94
+ # Prefix caching is disabled.
95
+ return [], 0
96
+
97
+ computed_blocks = []
98
+
99
+ # The block hashes for the request may already be computed
100
+ # if the request was preempted and resumed.
101
+ if not request.kv_block_hashes:
102
+ request.set_kv_block_hashes(
103
+ hash_request_tokens(self.block_size, request))
104
+ block_hashes = request.kv_block_hashes
105
+
106
+ for block_hash in block_hashes:
107
+ # block_hashes is a chain of block hashes. If a block hash is not
108
+ # in the cached_block_hash_to_id, the following block hashes are
109
+ # not computed yet for sure.
110
+ if cached_block := self._get_cached_block(block_hash):
111
+ computed_blocks.append(cached_block)
112
+ else:
113
+ break
114
+
115
+ # NOTE(woosuk): Since incomplete blocks are not eligible for
116
+ # sharing, `num_computed_tokens` is always a multiple of
117
+ # `block_size`.
118
+ num_computed_tokens = len(computed_blocks) * self.block_size
119
+ return computed_blocks, num_computed_tokens
120
+
121
+ def allocate_slots(
122
+ self,
123
+ request: Request,
124
+ num_tokens: int,
125
+ new_computed_blocks: Optional[List[KVCacheBlock]] = None
126
+ ) -> Optional[List[KVCacheBlock]]:
127
+ """Add slots for a request with new tokens to append.
128
+
129
+ Args:
130
+ request: The request to allocate slots.
131
+ num_tokens: The number of tokens to allocate. Note that this does
132
+ not include the tokens that have already been computed.
133
+ new_computed_blocks: A list of new computed blocks just hitting the
134
+ prefix caching.
135
+
136
+ Blocks layout:
137
+ -----------------------------------------------------------------------
138
+ | < computed > | < new computed > | < new > | < pre-allocated > |
139
+ -----------------------------------------------------------------------
140
+ | < required > |
141
+ --------------------------------------------------
142
+ | < full > |
143
+ ------------------------------------------------
144
+ | <new full> |
145
+ --------------
146
+ The following *_blocks are illustrated in this layout.
147
+
148
+ Returns:
149
+ A list of new allocated blocks.
150
+ """
151
+ if num_tokens == 0:
152
+ raise ValueError("num_tokens must be greater than 0")
153
+
154
+ new_computed_blocks = new_computed_blocks or []
155
+
156
+ # The number of computed tokens is the number of computed tokens plus
157
+ # the new prefix caching hits
158
+ num_computed_tokens = (request.num_computed_tokens +
159
+ len(new_computed_blocks) * self.block_size)
160
+ num_required_blocks = cdiv(num_computed_tokens + num_tokens,
161
+ self.block_size)
162
+ req_blocks = self.req_to_blocks[request.request_id]
163
+ num_new_blocks = (num_required_blocks - len(req_blocks) -
164
+ len(new_computed_blocks))
165
+
166
+ # If a computed block of a request is an eviction candidate (in the
167
+ # free queue and ref_cnt == 0), it cannot be counted as a free block
168
+ # when allocating this request.
169
+ num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
170
+ if blk.ref_cnt == 0)
171
+ if (num_new_blocks > self.free_block_queue.num_free_blocks -
172
+ num_evictable_computed_blocks):
173
+ # Cannot allocate new blocks
174
+ return None
175
+
176
+ # Touch the computed blocks to make sure they won't be evicted.
177
+ if self.enable_caching:
178
+ self._touch(new_computed_blocks)
179
+ else:
180
+ assert not new_computed_blocks, (
181
+ "Computed blocks should be empty when "
182
+ "prefix caching is disabled")
183
+
184
+ # Append the new computed blocks to the request blocks until now to
185
+ # avoid the case where the new blocks cannot be allocated.
186
+ req_blocks.extend(new_computed_blocks)
187
+
188
+ # Start to handle new blocks
189
+
190
+ if num_new_blocks <= 0:
191
+ # No new block is needed.
192
+ new_blocks = []
193
+ else:
194
+ # Get new blocks from the free block pool considering
195
+ # preallocated blocks.
196
+ num_new_blocks = min(
197
+ num_new_blocks + self.num_preallocate_blocks,
198
+ self.free_block_queue.num_free_blocks,
199
+ # Should not exceed the maximum number of blocks per request.
200
+ # This is especially because the block table has the shape
201
+ # [..., max_num_blocks_per_req].
202
+ # TODO(woosuk): Check and reject requests if
203
+ # num_prompt_tokens + max_tokens > max_model_len.
204
+ self.max_num_blocks_per_req - len(req_blocks),
205
+ )
206
+ assert num_new_blocks > 0
207
+
208
+ # Concatenate the computed block IDs and the new block IDs.
209
+ new_blocks = self._get_new_blocks(num_new_blocks)
210
+ req_blocks.extend(new_blocks)
211
+
212
+ if not self.enable_caching:
213
+ return new_blocks
214
+
215
+ # NOTE(rickyx): We are assuming the `num_tokens` are actual
216
+ # tokens rather than lookahead slots (e.g. for speculative decoding).
217
+ # TODO(rickyx): When supporting speculative decoding, we will need to
218
+ # differentiate between them so that we can know how many blocks are
219
+ # full after appending the actual tokens.
220
+ num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
221
+ num_computed_full_blocks = num_computed_tokens // self.block_size
222
+ new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
223
+ if new_full_blocks:
224
+ self._cache_full_blocks(
225
+ request=request,
226
+ blk_start_idx=num_computed_full_blocks,
227
+ # The new full blocks are the full blocks that are not computed.
228
+ full_blocks=new_full_blocks,
229
+ prev_block=(req_blocks[num_computed_full_blocks - 1]
230
+ if num_computed_full_blocks > 0 else None))
231
+
232
+ return new_blocks
233
+
234
+ def free(self, request: Request) -> None:
235
+ """Free the blocks allocated for the request.
236
+ When caching is enabled, we free the blocks in reverse order so that
237
+ the tail blocks are evicted first.
238
+
239
+ Args:
240
+ request: The request to free the blocks.
241
+ """
242
+ # Default to [] in case a request is freed (aborted) before alloc.
243
+ blocks = self.req_to_blocks.pop(request.request_id, [])
244
+ ordered_blocks: Iterable[KVCacheBlock] = blocks
245
+ if self.enable_caching:
246
+ # Free blocks in reverse order so that the tail blocks are
247
+ # freed first.
248
+ ordered_blocks = reversed(blocks)
249
+
250
+ for block in ordered_blocks:
251
+ block.decr_ref()
252
+ if block.ref_cnt == 0:
253
+ self.free_block_queue.append(block)
254
+
255
+ def reset_prefix_cache(self) -> bool:
256
+ """Reset prefix cache. This function may be used in RLHF
257
+ flows to invalid prefix caching after the weights are updated,
258
+ or used for resetting prefix caching status for benchmarking.
259
+
260
+ Returns:
261
+ bool: True if the prefix cache is successfully reset,
262
+ False otherwise.
263
+ """
264
+ num_used_blocks = (self.num_gpu_blocks -
265
+ self.free_block_queue.num_free_blocks)
266
+ if num_used_blocks > 0:
267
+ logger.warning(
268
+ "Failed to reset prefix cache because some "
269
+ "blocks (%d) are not freed yet", num_used_blocks)
270
+ return False
271
+
272
+ # Remove all hashes so that no new blocks will hit.
273
+ self.cached_block_hash_to_block = defaultdict(dict)
274
+
275
+ # Remove all hashes from all blocks.
276
+ for block in self.block_pool:
277
+ block.reset_hash()
278
+
279
+ logger.info("Successfully reset prefix cache")
280
+ return True
281
+
282
+ def get_num_common_prefix_blocks(
283
+ self,
284
+ request: Request,
285
+ num_running_requests: int,
286
+ ) -> int:
287
+ """Calculate the number of common prefix blocks shared by all requests
288
+ in the RUNNING state.
289
+
290
+ The function determines this by selecting any request and iterating
291
+ through its blocks. A block is considered a common prefix block if its
292
+ `ref_cnt` equals the total number of requests in the RUNNING state.
293
+
294
+ NOTE(woosuk): The number of requests in the RUNNING state is **greater
295
+ than or equal to** the number of requests scheduled in the current step.
296
+ This is because the RUNNING state only indicates that:
297
+ 1. The request has not yet finished, and
298
+ 2. The request holds its blocks unfreed.
299
+
300
+ While all scheduled requests must be in the RUNNING state, the inverse
301
+ is not necessarily true. There may be RUNNING requests that are not
302
+ scheduled in the current step. As of 1/1/2025, the scheduler does not
303
+ allow this case, but it is possible in the future, as we allow more
304
+ flexible scheduling.
305
+
306
+ This can result in an edge case where the number of common prefix blocks
307
+ is 0, even though all scheduled requests share a common prefix. This
308
+ occurs because there may be unscheduled RUNNING requests that do not
309
+ share the common prefix. Currently, this case cannot be easily detected,
310
+ so the function returns 0 in such cases.
311
+
312
+ Args:
313
+ request: Any request in the RUNNING state, used to identify the
314
+ common prefix blocks.
315
+ num_running_requests: The total number of requests in the RUNNING
316
+ state. This can be different from the number of scheduled
317
+ requests in the current step.
318
+
319
+ Returns:
320
+ int: The number of common prefix blocks.
321
+ """
322
+ assert request.status == RequestStatus.RUNNING
323
+ blocks = self.req_to_blocks[request.request_id]
324
+ num_common_blocks = 0
325
+ for block in blocks:
326
+ if block.ref_cnt == num_running_requests:
327
+ num_common_blocks += 1
328
+ else:
329
+ break
330
+ return num_common_blocks
331
+
332
+ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
333
+ """Get new blocks from the free block pool.
334
+
335
+ Note that we do not check block cache in this function.
336
+
337
+ Args:
338
+ num_blocks: The number of blocks to allocate.
339
+
340
+ Returns:
341
+ A list of new block.
342
+ """
343
+ if num_blocks > self.free_block_queue.num_free_blocks:
344
+ raise ValueError(
345
+ f"Cannot get {num_blocks} free blocks from the pool")
346
+
347
+ ret: List[KVCacheBlock] = []
348
+ idx = 0
349
+ while idx < num_blocks:
350
+ # First allocate blocks.
351
+ curr_block = self.free_block_queue.popleft()
352
+ assert curr_block.ref_cnt == 0
353
+
354
+ # If the block is cached, evict it.
355
+ if self.enable_caching:
356
+ self._maybe_evict_cached_block(curr_block)
357
+
358
+ curr_block.incr_ref()
359
+ ret.append(curr_block)
360
+ idx += 1
361
+
362
+ return ret
363
+
364
+ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
365
+ """
366
+ If a block is cached in `cached_block_hash_to_block`, we reset its hash
367
+ metadata and evict it from the cache.
368
+
369
+ Args:
370
+ block: The block to evict.
371
+
372
+ Returns:
373
+ True if the block is evicted, False otherwise.
374
+ """
375
+ block_hash = block.block_hash
376
+ if block_hash and block_hash in self.cached_block_hash_to_block:
377
+ block.reset_hash()
378
+ del self.cached_block_hash_to_block[block_hash][block.block_id]
379
+
380
+ if len(self.cached_block_hash_to_block[block_hash]) == 0:
381
+ del self.cached_block_hash_to_block[block_hash]
382
+
383
+ return True
384
+ return False
385
+
386
+ def _get_cached_block(self,
387
+ block_hash: BlockHashType) -> Optional[KVCacheBlock]:
388
+ """Get a cached block by the block hash, or None if cache miss.
389
+ If there are duplicated blocks, we return the first block in the cache.
390
+
391
+ Args:
392
+ block_hash: The hash value of the block.
393
+
394
+ Returns:
395
+ The cached block if it exists, or None.
396
+ """
397
+ if block_hash in self.cached_block_hash_to_block:
398
+ first_block_id = list(
399
+ self.cached_block_hash_to_block[block_hash].keys())[0]
400
+ return self.cached_block_hash_to_block[block_hash][first_block_id]
401
+ return None
402
+
403
+ def _touch(self, blocks: List[KVCacheBlock]) -> None:
404
+ """Touch a block increases its reference count by 1, and may remove
405
+ the block from the free queue. This is used when a block is hit by
406
+ another request with the same prefix.
407
+
408
+ Args:
409
+ blocks: A list of blocks to touch.
410
+ """
411
+ for block in blocks:
412
+ # ref_cnt=0 means this block is in the free list (i.e. eviction
413
+ # candidate), so remove it.
414
+ if block.ref_cnt == 0:
415
+ self.free_block_queue.remove(block)
416
+ block.incr_ref()
417
+
418
+ def _cache_full_blocks(
419
+ self,
420
+ request: Request,
421
+ blk_start_idx: int,
422
+ full_blocks: List[KVCacheBlock],
423
+ prev_block: Optional[KVCacheBlock],
424
+ ) -> None:
425
+ """Cache a list of full blocks for prefix caching.
426
+
427
+ This function takes a list of blocks that will have their block hash
428
+ metadata to be updated and cached. Given a request, it computes the
429
+ block hashes for the blocks starting from `blk_start_idx` to the end
430
+ of the request's full blocks, updating the metadata for each block
431
+ and caching them in the `cached_block_hash_to_block`.
432
+
433
+ Args:
434
+ request: The request to cache the blocks.
435
+ blk_start_idx: The index of the first block in the request's blocks
436
+ to cache.
437
+ full_blocks: The list of blocks to update hash metadata.
438
+ prev_block: The previous block in the chain.
439
+ """
440
+ num_cached_block_hashes = len(request.kv_block_hashes)
441
+
442
+ # Update the new blocks with the block hashes through the chain.
443
+ prev_block_hash_value = None
444
+ if prev_block is not None:
445
+ # Previous block must have a block hash because it must be
446
+ # a full, cached block.
447
+ assert prev_block.block_hash is not None
448
+ prev_block_hash_value = prev_block.block_hash.hash_value
449
+
450
+ # Find the first uncached block. This case should only happen when
451
+ # speculative decoding is used.
452
+ offset = 0
453
+ for blk in full_blocks:
454
+ if blk.block_hash is None:
455
+ break
456
+ else:
457
+ prev_block_hash_value = blk.block_hash.hash_value
458
+ offset += 1
459
+ else:
460
+ # All blocks are cached.
461
+ return
462
+
463
+ for i, blk in enumerate(full_blocks[offset:]):
464
+ blk_idx = blk_start_idx + offset + i
465
+ assert blk.block_hash is None
466
+
467
+ if blk_idx < num_cached_block_hashes:
468
+ # The block hash may already be computed in
469
+ # "get_computed_blocks" if the tokens are not generated by
470
+ # this request (either the prompt tokens or the previously
471
+ # generated tokens with preemption). In this case we simply
472
+ # reuse the block hash.
473
+ block_hash = request.kv_block_hashes[blk_idx]
474
+ else:
475
+ # Otherwise compute the block hash and cache it in the request
476
+ # in case it will be preempted in the future.
477
+ start_token_idx = blk_idx * self.block_size
478
+ end_token_idx = (blk_idx + 1) * self.block_size
479
+ block_tokens = request.all_token_ids[
480
+ start_token_idx:end_token_idx]
481
+ assert len(block_tokens) == self.block_size, (
482
+ f"Expected {self.block_size} tokens, got "
483
+ f"{len(block_tokens)} at {blk_idx}th block for request "
484
+ f"{request.request_id}({request})")
485
+
486
+ # Generate extra keys for multi-modal inputs. Note that since
487
+ # we reach to this branch only when the block is completed with
488
+ # generated tokens, we only need to consider the last mm input.
489
+ extra_keys, _ = generate_block_hash_extra_keys(
490
+ request, start_token_idx, end_token_idx, -1)
491
+
492
+ # Compute the hash of the current block.
493
+ block_hash = hash_block_tokens(prev_block_hash_value,
494
+ block_tokens, extra_keys)
495
+ request.append_kv_block_hashes(block_hash)
496
+
497
+ # Update and added the full block to the cache.
498
+ blk.block_hash = block_hash
499
+ self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
500
+ prev_block_hash_value = block_hash.hash_value
.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """KV-Cache Utilities."""
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass
5
+ from typing import Any, List, NamedTuple, Optional, Tuple
6
+
7
+ from vllm.config import VllmConfig
8
+ from vllm.logger import init_logger
9
+ from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
10
+ KVCacheTensor)
11
+ from vllm.v1.request import Request
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class BlockHashType(NamedTuple):
17
+ """Hash value of a block (int), the token IDs in the block, and extra keys.
18
+ We keep a tuple of token IDs and extra keys to reduce the likelihood of
19
+ hash collisions when the hash value is the same. But please note that
20
+ hash collisions can still theoretically occur, albeit with an extremely
21
+ low probability.
22
+ """
23
+ # Hash value of the block in an integer.
24
+ hash_value: int
25
+ # Token IDs in the block.
26
+ token_ids: Tuple[int, ...]
27
+ # Extra keys for the block.
28
+ extra_keys: Optional[Any] = None
29
+
30
+
31
+ @dataclass
32
+ class KVCacheBlock:
33
+ """KV-cache block metadata."""
34
+ # Block ID, ranging from 0 to num_gpu_blocks - 1.
35
+ block_id: int
36
+ # Reference count.
37
+ ref_cnt: int = 0
38
+ # The hash of the block composed of (block hash, tuple of token IDs).
39
+ # It is only available when the block is full.
40
+ _block_hash: Optional[BlockHashType] = None
41
+
42
+ # Used to construct a doubly linked list for free blocks.
43
+ # These two attributes should only be manipulated by FreeKVCacheBlockQueue.
44
+ prev_free_block: Optional["KVCacheBlock"] = None
45
+ next_free_block: Optional["KVCacheBlock"] = None
46
+
47
+ def incr_ref(self):
48
+ self.ref_cnt += 1
49
+
50
+ def decr_ref(self):
51
+ self.ref_cnt -= 1
52
+
53
+ @property
54
+ def block_hash(self) -> Optional[BlockHashType]:
55
+ return self._block_hash
56
+
57
+ @block_hash.setter
58
+ def block_hash(self, block_hash: BlockHashType):
59
+ assert self.block_hash is None, (
60
+ "The block already has a hash. This should not happen.")
61
+ self._block_hash = block_hash
62
+
63
+ def reset_hash(self):
64
+ """Reset the block hash when the block is evicted."""
65
+ self._block_hash = None
66
+
67
+
68
+ class FreeKVCacheBlockQueue:
69
+ """This class organizes a list of KVCacheBlock objects to a doubly linked
70
+ list of free blocks. We implement this class instead of using Python
71
+ builtin deque to support removing a block in the middle of the queue
72
+ in O(1) time. To close the performance gap to the builtin deque which is
73
+ implemented in C++, this class does not allocate any Python objects when
74
+ manipulating the linked list. Instead, this class manipulates the
75
+ prev_free_block and next_free_block attributes of the given blocks.
76
+
77
+ The queue is ordered by block ID in the beginning. When a block is allocated
78
+ and then freed, it will be appended back with the eviction order:
79
+ 1. The least recent used block is at the front (LRU).
80
+ 2. If two blocks have the same last accessed time (allocated by the
81
+ same sequence), the one with more hash tokens (the tail of a block
82
+ chain) is at the front.
83
+ Note that we maintain this order by reversing the block order when free
84
+ blocks of a request. This operation is outside of this class.
85
+
86
+ Args:
87
+ blocks: A list of KVCacheBlock objects.
88
+ """
89
+
90
+ def __init__(self, blocks: List[KVCacheBlock]) -> None:
91
+ self.num_free_blocks = len(blocks)
92
+
93
+ # Initialize the doubly linked list of free blocks.
94
+ self.free_list_head: Optional[KVCacheBlock] = blocks[0]
95
+ self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
96
+ for i in range(self.num_free_blocks):
97
+ if i > 0:
98
+ blocks[i].prev_free_block = blocks[i - 1]
99
+ if i < self.num_free_blocks - 1:
100
+ blocks[i].next_free_block = blocks[i + 1]
101
+
102
+ def popleft(self) -> KVCacheBlock:
103
+ """Pop the first free block and reduce num_free_blocks by 1.
104
+
105
+ Returns:
106
+ The first free block.
107
+ """
108
+ if not self.free_list_head:
109
+ raise ValueError("No free blocks available")
110
+
111
+ block = self.free_list_head
112
+ self.remove(block)
113
+ return block
114
+
115
+ def remove(self, block: KVCacheBlock) -> None:
116
+ """Remove a block in the free list and reduce num_free_blocks by 1.
117
+
118
+ Args:
119
+ block: The block to remove.
120
+ """
121
+ if block.prev_free_block is not None:
122
+ # Link the previous block to the next block.
123
+ block.prev_free_block.next_free_block = block.next_free_block
124
+ if block.next_free_block is not None:
125
+ # Link the next block to the previous block.
126
+ block.next_free_block.prev_free_block = block.prev_free_block
127
+
128
+ if block == self.free_list_head:
129
+ # Update the head if the block is the head.
130
+ self.free_list_head = block.next_free_block
131
+ if block == self.free_list_tail:
132
+ # Update the tail if the block is the tail.
133
+ self.free_list_tail = block.prev_free_block
134
+
135
+ # Remove the block from the linked list.
136
+ block.prev_free_block = block.next_free_block = None
137
+ self.num_free_blocks -= 1
138
+
139
+ def append(self, block: KVCacheBlock) -> None:
140
+ """Put a block back into the free list and increase
141
+ num_free_blocks by 1.
142
+
143
+ Args:
144
+ block: The block to append.
145
+ """
146
+ if self.free_list_tail is not None:
147
+ # Link the last block to the new block.
148
+ self.free_list_tail.next_free_block = block
149
+ block.prev_free_block = self.free_list_tail
150
+ self.free_list_tail = block
151
+ else:
152
+ # The free list is empty.
153
+ assert self.free_list_head is None
154
+ self.free_list_head = self.free_list_tail = block
155
+
156
+ block.next_free_block = None
157
+ self.num_free_blocks += 1
158
+
159
+ def get_all_free_blocks(self) -> List[KVCacheBlock]:
160
+ """Get all free blocks in the free list. Mainly used for testing.
161
+
162
+ Returns:
163
+ A list of free blocks.
164
+ """
165
+ ret = []
166
+ curr_block = self.free_list_head
167
+ while curr_block is not None:
168
+ ret.append(curr_block)
169
+ curr_block = curr_block.next_free_block
170
+ return ret
171
+
172
+
173
+ def generate_block_hash_extra_keys(
174
+ request: Request, start_token_idx: int, end_token_idx: int,
175
+ start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
176
+ """Generate extra keys for the block hash. The extra keys can come from
177
+ the multi-modal inputs and request specific metadata (e.g., LoRA ID).
178
+ For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
179
+ indicate a mm input contained in the block and its starting offset in
180
+ the block tokens.
181
+
182
+ Args:
183
+ request: The request object.
184
+ start_token_idx: The start token index of the block.
185
+ end_token_idx: The end token index of the block.
186
+ start_mm_idx: The start multi-modal index of the block.
187
+
188
+ Returns:
189
+ A tuple of extra keys and the next multi-modal index.
190
+ """
191
+
192
+ mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
193
+ if not mm_positions:
194
+ return None, start_mm_idx
195
+
196
+ if mm_positions and len(mm_positions) != len(mm_hashes):
197
+ raise ValueError(
198
+ "The number of multi-modal positions and hashes must match. This "
199
+ "is likely because you do not enable MM preprocessor hashing. "
200
+ "Please set disable_mm_preprocessor_cache=False.")
201
+
202
+ # Note that we assume mm_positions is sorted by offset.
203
+ # We do not need to check all mm inputs if the start token index is out of
204
+ # range. This usually happens in the late prefill phase and decoding phase.
205
+ if mm_positions[-1]["offset"] + mm_positions[-1][
206
+ "length"] < start_token_idx:
207
+ return None, start_mm_idx
208
+
209
+ # Support start_mm_idx == -1 to indicate the last mm input.
210
+ if start_mm_idx < 0:
211
+ assert -start_mm_idx <= len(mm_positions)
212
+ start_mm_idx = len(mm_positions) + start_mm_idx
213
+
214
+ extra_keys = []
215
+ curr_mm_idx = start_mm_idx
216
+ while mm_positions and curr_mm_idx < len(mm_positions):
217
+ assert mm_hashes[curr_mm_idx] is not None
218
+ offset = mm_positions[curr_mm_idx]["offset"]
219
+ length = mm_positions[curr_mm_idx]["length"]
220
+ if end_token_idx > offset:
221
+ if start_token_idx > offset + length:
222
+ # This block has passed the current mm input.
223
+ curr_mm_idx += 1
224
+ continue
225
+
226
+ # The block contains the current mm input.
227
+ extra_keys.append(mm_hashes[curr_mm_idx])
228
+
229
+ if end_token_idx >= offset + length:
230
+ # If this block contains the end of the current mm input,
231
+ # move to the next mm input as this block may also contain
232
+ # the next mm input.
233
+ curr_mm_idx += 1
234
+ else:
235
+ # Otherwise this block is done with mm inputs.
236
+ break
237
+ else:
238
+ # This block has not reached the current mm input.
239
+ break
240
+ return tuple(extra_keys), curr_mm_idx
241
+
242
+
243
+ def hash_block_tokens(
244
+ parent_block_hash: Optional[int],
245
+ curr_block_token_ids: Sequence[int],
246
+ extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
247
+ """Computes a hash value corresponding to the contents of a block and
248
+ the contents of the preceding block(s). The hash value is used for
249
+ prefix caching. We use LRU cache for this function to avoid recomputing
250
+ hash values for the same block contents.
251
+
252
+ TODO: Support arbitrary metadata so that we could support more
253
+ features such as LoRA adapter.
254
+
255
+ Args:
256
+ parent_block_hash: The hash of the parent block. None
257
+ if this is the first block.
258
+ curr_block_token_ids: A list of token ids in the current
259
+ block. The current block is assumed to be full.
260
+ extra_keys: Extra keys for the block.
261
+
262
+ Returns:
263
+ The hash value of the block and the token ids in the block.
264
+ The entire tuple is used as the hash key of the block.
265
+ """
266
+ if not parent_block_hash:
267
+ # Note that we use 'None' as a string here instead of None because
268
+ # as of Python 3.12, hash(None) returns a constant predictable value.
269
+ # This could possibly make it easier to find and exploit hash
270
+ # collisions. 'None' as a string will be hashed differently per process,
271
+ # but consistently within the same process. This is the same as the
272
+ # behavior of None prior to Python 3.12.
273
+ parent_block_hash = hash('None')
274
+
275
+ curr_block_token_ids_tuple = tuple(curr_block_token_ids)
276
+ return BlockHashType(
277
+ hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
278
+ curr_block_token_ids_tuple, extra_keys)
279
+
280
+
281
+ def hash_request_tokens(block_size: int,
282
+ request: Request) -> List[BlockHashType]:
283
+ """Computes hash values of a chain of blocks given a sequence of
284
+ token IDs. The hash value is used for prefix caching.
285
+
286
+ Args:
287
+ block_size: The size of each block.
288
+ request: The request object.
289
+
290
+ Returns:
291
+ The list of computed hash values.
292
+ """
293
+ token_ids = request.all_token_ids
294
+ mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
295
+ if mm_positions and len(mm_positions) != len(mm_hashes):
296
+ raise ValueError(
297
+ "The number of multi-modal positions and hashes must match.")
298
+
299
+ # TODO: Extend this to support other features such as LoRA.
300
+ need_extra_keys = bool(mm_positions)
301
+ extra_keys = None
302
+ curr_mm_idx = 0
303
+
304
+ ret = []
305
+ parent_block_hash_value = None
306
+ for start in range(0, len(token_ids), block_size):
307
+ end = start + block_size
308
+ block_token_ids = token_ids[start:end]
309
+ # Do not hash the block if it is not full.
310
+ if len(block_token_ids) < block_size:
311
+ break
312
+
313
+ # Add extra keys if the block is a multi-modal block.
314
+ if need_extra_keys:
315
+ extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
316
+ request, start, end, curr_mm_idx)
317
+
318
+ block_hash = hash_block_tokens(parent_block_hash_value,
319
+ block_token_ids, extra_keys)
320
+ ret.append(block_hash)
321
+ parent_block_hash_value = block_hash.hash_value
322
+ return ret
323
+
324
+
325
+ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
326
+ kv_cache_spec: KVCacheSpec,
327
+ available_memory: int):
328
+ """
329
+ Checks whether `available_memory` is enough for the KV cache to hold at
330
+ least one request with the model's max_model_len.
331
+
332
+ Args:
333
+ vllm_config: The global VllmConfig
334
+ kv_cache_spec: The kv cache spec of the model
335
+ available_memory: Memory available for KV cache in bytes.
336
+
337
+ Raises:
338
+ ValueError: If there is not enough memory available for the KV cache.
339
+ """
340
+
341
+ if available_memory <= 0:
342
+ raise ValueError("No available memory for the cache blocks. "
343
+ "Try increasing `gpu_memory_utilization` when "
344
+ "initializing the engine.")
345
+
346
+ max_model_len = vllm_config.model_config.max_model_len
347
+ needed_memory = 0
348
+ for layer_spec in kv_cache_spec.values():
349
+ needed_memory += layer_spec.bytes_for_tokens(max_model_len)
350
+
351
+ if needed_memory > available_memory:
352
+ raise ValueError(
353
+ f"To serve at least one request with the models's max seq len "
354
+ f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
355
+ f"cache is needed, which is larger than the available KV cache "
356
+ f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
357
+ f"increasing `gpu_memory_utilization` or decreasing "
358
+ f"`max_model_len` when initializing the engine.")
359
+
360
+
361
+ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
362
+ """
363
+ Whether all layers in the given KVCacheSpec have the same type of KV cache.
364
+
365
+ Args:
366
+ kv_cache_spec: The KVCacheSpec of the model
367
+
368
+ Returns:
369
+ True if all layers have the same type, False otherwise.
370
+ """
371
+
372
+ layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
373
+ return len(layer_keys) == 1
374
+
375
+
376
+ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
377
+ kv_cache_spec: KVCacheSpec,
378
+ available_memory: int) -> KVCacheConfig:
379
+ """
380
+ Generates the KV cache configuration for a model with one type of KV cache.
381
+ Divide the available memory equally among all layers.
382
+
383
+ Args:
384
+ vllm_config: The global VllmConfig
385
+ kv_cache_spec: The kv cache spec of the model
386
+ available_memory: Memory available for KV cache in bytes.
387
+
388
+ Returns:
389
+ The generated KVCacheConfig
390
+ """
391
+
392
+ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
393
+ assert len(page_sizes) == 1
394
+ page_size = page_sizes.pop()
395
+
396
+ num_blocks = int(available_memory // page_size // len(kv_cache_spec))
397
+ num_blocks = max(num_blocks, 0)
398
+
399
+ if vllm_config.cache_config.num_gpu_blocks_override is not None:
400
+ num_gpu_blocks_override = \
401
+ vllm_config.cache_config.num_gpu_blocks_override
402
+ logger.info(
403
+ "Overriding num_gpu_blocks=%d with "
404
+ "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
405
+ num_blocks = num_gpu_blocks_override
406
+
407
+ logger.info("# GPU blocks: %d", num_blocks)
408
+ max_concurrency = (num_blocks * vllm_config.cache_config.block_size /
409
+ vllm_config.model_config.max_model_len)
410
+ logger.info("Maximum concurrency for %s tokens per request: %.2fx",
411
+ vllm_config.model_config.max_model_len, max_concurrency)
412
+
413
+ per_layer_size = page_size * num_blocks
414
+
415
+ kv_cache_config = KVCacheConfig(
416
+ num_blocks=num_blocks,
417
+ tensors={
418
+ layer_name: KVCacheTensor(size=per_layer_size)
419
+ for layer_name in kv_cache_spec
420
+ },
421
+ groups=[[layer_name for layer_name in kv_cache_spec]],
422
+ kv_cache_spec=kv_cache_spec)
423
+ return kv_cache_config
424
+
425
+
426
+ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
427
+ available_memory: int) -> KVCacheConfig:
428
+ """
429
+ Generates the KV cache configuration for a model
430
+ TODO: support hybrid models with more than one type of KV cache.
431
+
432
+ Args:
433
+ vllm_config: The global VllmConfig
434
+ kv_cache_spec: The kv cache spec of the model
435
+ available_memory: Memory available for KV cache in bytes.
436
+
437
+ Returns:
438
+ The generated KVCacheConfig
439
+ """
440
+ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
441
+ if is_kv_cache_type_uniform(kv_cache_spec):
442
+ # KV cache of all layers are the same, which is true for most models.
443
+ # Allocate the same amount of memory for each layer.
444
+ return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
445
+ available_memory)
446
+ else:
447
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from collections import deque
4
+ from dataclasses import dataclass
5
+ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
6
+ Tuple, Union)
7
+
8
+ from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
9
+ from vllm.logger import init_logger
10
+ from vllm.sampling_params import SamplingParams
11
+ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
12
+ compute_encoder_budget)
13
+ from vllm.v1.core.kv_cache_manager import KVCacheManager
14
+ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
15
+ from vllm.v1.metrics.stats import SchedulerStats
16
+ from vllm.v1.outputs import ModelRunnerOutput
17
+ from vllm.v1.request import Request, RequestStatus
18
+
19
+ if TYPE_CHECKING:
20
+ from vllm.multimodal import MultiModalKwargs
21
+ from vllm.multimodal.base import PlaceholderRange
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ class Scheduler:
27
+
28
+ def __init__(
29
+ self,
30
+ scheduler_config: SchedulerConfig,
31
+ model_config: ModelConfig,
32
+ cache_config: CacheConfig,
33
+ lora_config: Optional[LoRAConfig],
34
+ ) -> None:
35
+ self.scheduler_config = scheduler_config
36
+ self.cache_config = cache_config
37
+ self.lora_config = lora_config
38
+ # TODO: Support LoRA.
39
+ assert lora_config is None, "V1 does not support LoRA yet."
40
+
41
+ # Scheduling constraints.
42
+ self.max_num_running_reqs = self.scheduler_config.max_num_seqs
43
+ self.max_num_scheduled_tokens = \
44
+ self.scheduler_config.max_num_batched_tokens
45
+ self.max_model_len = self.scheduler_config.max_model_len
46
+
47
+ num_gpu_blocks = cache_config.num_gpu_blocks
48
+ assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
49
+ # Create the KV cache manager.
50
+ self.kv_cache_manager = KVCacheManager(
51
+ block_size=self.cache_config.block_size,
52
+ num_gpu_blocks=num_gpu_blocks,
53
+ max_model_len=self.max_model_len,
54
+ sliding_window=self.cache_config.sliding_window,
55
+ enable_caching=self.cache_config.enable_prefix_caching)
56
+ self.block_size = self.cache_config.block_size
57
+
58
+ # req_id -> Request
59
+ self.requests: Dict[str, Request] = {}
60
+ # Priority queues for requests.
61
+ self.waiting: Deque[Request] = deque()
62
+ self.running: List[Request] = []
63
+
64
+ # The request IDs that are finished in between the previous and the
65
+ # current steps. This is used to notify the workers about the finished
66
+ # requests so that they can free the cached states for those requests.
67
+ # This is flushed at the end of each scheduling step.
68
+ self.finished_req_ids: Set[str] = set()
69
+
70
+ # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
71
+ # them at each scheduling step.
72
+ # Request id -> CachedRequestData
73
+ self._cached_reqs_data: Dict[str, CachedRequestData] = {}
74
+
75
+ # Encoder-related.
76
+ # Calculate encoder cache size if applicable
77
+ # NOTE: For now we use the same budget for both compute and space.
78
+ # This can be changed when we make encoder cache for embedding caching
79
+ # across requests.
80
+ encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
81
+ model_config=model_config,
82
+ scheduler_config=scheduler_config,
83
+ )
84
+
85
+ # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
86
+ # projector if needed). Currently, we assume that the encoder also
87
+ # has the Transformer architecture (e.g., ViT).
88
+ self.max_num_encoder_input_tokens = encoder_compute_budget
89
+ # NOTE: For the models without encoder (e.g., text-only models),
90
+ # the encoder cache will not be initialized because cache size is 0
91
+ # for these models.
92
+ self.encoder_cache_manager = EncoderCacheManager(
93
+ cache_size=encoder_cache_size)
94
+
95
+ def schedule(self) -> "SchedulerOutput":
96
+ # NOTE(woosuk) on the scheduling algorithm:
97
+ # There's no "decoding phase" nor "prefill phase" in the scheduler.
98
+ # Each request just has the num_computed_tokens and num_tokens,
99
+ # which is equal to len(prompt_token_ids) + len(output_token_ids).
100
+ # At each step, the scheduler tries to assign tokens to the requests
101
+ # so that each request's num_computed_tokens can catch up its
102
+ # num_tokens. This is general enough to cover chunked prefills,
103
+ # prefix caching, and the "jump decoding" optimization in the future.
104
+
105
+ scheduled_new_reqs: List[Request] = []
106
+ scheduled_resumed_reqs: List[Request] = []
107
+ scheduled_running_reqs: List[Request] = []
108
+ preempted_reqs: List[Request] = []
109
+
110
+ req_to_new_block_ids: Dict[str, List[int]] = {}
111
+ num_scheduled_tokens: Dict[str, int] = {}
112
+ token_budget = self.max_num_scheduled_tokens
113
+ # Encoder-related.
114
+ scheduled_encoder_inputs: Dict[str, List[int]] = {}
115
+ encoder_budget = self.max_num_encoder_input_tokens
116
+
117
+ # First, schedule the RUNNING requests.
118
+ req_index = 0
119
+ while req_index < len(self.running) and token_budget > 0:
120
+ request = self.running[req_index]
121
+ num_new_tokens = request.num_tokens - request.num_computed_tokens
122
+ num_new_tokens = min(num_new_tokens, token_budget)
123
+ assert num_new_tokens > 0
124
+
125
+ # Schedule encoder inputs.
126
+ encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
127
+ self._try_schedule_encoder_inputs(request,
128
+ request.num_computed_tokens,
129
+ num_new_tokens,
130
+ encoder_budget))
131
+ if num_new_tokens == 0:
132
+ # The request cannot be scheduled because the encoder budget
133
+ # or the encoder cache is exhausted.
134
+ # NOTE(woosuk): Here, by doing `continue` instead of `break`,
135
+ # we do not strictly follow the FCFS scheduling policy and
136
+ # allow the lower-priority requests to be scheduled.
137
+ req_index += 1
138
+ continue
139
+
140
+ while True:
141
+ new_blocks = self.kv_cache_manager.allocate_slots(
142
+ request, num_new_tokens)
143
+ if new_blocks is None:
144
+ # The request cannot be scheduled.
145
+ # Preempt the lowest-priority request.
146
+ preempted_req = self.running.pop()
147
+ self.kv_cache_manager.free(preempted_req)
148
+ preempted_req.status = RequestStatus.PREEMPTED
149
+ preempted_req.num_computed_tokens = 0
150
+
151
+ self.waiting.appendleft(preempted_req)
152
+ preempted_reqs.append(preempted_req)
153
+ if preempted_req == request:
154
+ # No more request to preempt.
155
+ can_schedule = False
156
+ break
157
+ else:
158
+ # The request can be scheduled.
159
+ can_schedule = True
160
+ break
161
+ if not can_schedule:
162
+ break
163
+ assert new_blocks is not None
164
+
165
+ # Schedule the request.
166
+ scheduled_running_reqs.append(request)
167
+ req_to_new_block_ids[request.request_id] = [
168
+ b.block_id for b in new_blocks
169
+ ]
170
+ num_scheduled_tokens[request.request_id] = num_new_tokens
171
+ token_budget -= num_new_tokens
172
+ req_index += 1
173
+
174
+ # Encoder-related.
175
+ if encoder_inputs_to_schedule:
176
+ scheduled_encoder_inputs[request.request_id] = (
177
+ encoder_inputs_to_schedule)
178
+ # Allocate the encoder cache.
179
+ for i in encoder_inputs_to_schedule:
180
+ self.encoder_cache_manager.allocate(request, i)
181
+ encoder_budget = new_encoder_budget
182
+
183
+ # Next, schedule the WAITING requests.
184
+ if not preempted_reqs:
185
+ while self.waiting and token_budget > 0:
186
+ if len(self.running) == self.max_num_running_reqs:
187
+ break
188
+
189
+ request = self.waiting[0]
190
+ # Get already-cached tokens.
191
+ computed_blocks, num_computed_tokens = \
192
+ self.kv_cache_manager.get_computed_blocks(request)
193
+ # Number of tokens to be scheduled.
194
+ # We use `request.num_tokens` instead of
195
+ # `request.num_prompt_tokens` to consider the resumed requests,
196
+ # which have output tokens.
197
+ num_new_tokens = request.num_tokens - num_computed_tokens
198
+ if num_new_tokens == 0:
199
+ # This happens when prompt length is divisible by the block
200
+ # size and all blocks are cached. Now we force to recompute
201
+ # the last block. Note that we have to re-compute an entire
202
+ # block because allocate_slots() assumes num_computed_tokens
203
+ # is always a multiple of the block size. This limitation
204
+ # can potentially be removed in the future to slightly
205
+ # improve the performance.
206
+ num_computed_tokens -= self.block_size
207
+ num_new_tokens = self.block_size
208
+ computed_blocks.pop()
209
+ num_new_tokens = min(num_new_tokens, token_budget)
210
+ assert num_new_tokens > 0
211
+
212
+ # Schedule encoder inputs.
213
+ (encoder_inputs_to_schedule, num_new_tokens,
214
+ new_encoder_budget) = self._try_schedule_encoder_inputs(
215
+ request, num_computed_tokens, num_new_tokens,
216
+ encoder_budget)
217
+ if num_new_tokens == 0:
218
+ # The request cannot be scheduled.
219
+ break
220
+
221
+ new_blocks = self.kv_cache_manager.allocate_slots(
222
+ request, num_new_tokens, computed_blocks)
223
+ if new_blocks is None:
224
+ # The request cannot be scheduled.
225
+ break
226
+
227
+ self.waiting.popleft()
228
+ self.running.append(request)
229
+ if request.status == RequestStatus.WAITING:
230
+ scheduled_new_reqs.append(request)
231
+ elif request.status == RequestStatus.PREEMPTED:
232
+ scheduled_resumed_reqs.append(request)
233
+ else:
234
+ raise RuntimeError(
235
+ f"Invalid request status: {request.status}")
236
+
237
+ req_to_new_block_ids[request.request_id] = [
238
+ b.block_id for b in computed_blocks + new_blocks
239
+ ]
240
+ num_scheduled_tokens[request.request_id] = num_new_tokens
241
+ token_budget -= num_new_tokens
242
+ request.status = RequestStatus.RUNNING
243
+ request.num_computed_tokens = num_computed_tokens
244
+
245
+ # Encoder-related.
246
+ if encoder_inputs_to_schedule:
247
+ scheduled_encoder_inputs[request.request_id] = (
248
+ encoder_inputs_to_schedule)
249
+ # Allocate the encoder cache.
250
+ for i in encoder_inputs_to_schedule:
251
+ self.encoder_cache_manager.allocate(request, i)
252
+ encoder_budget = new_encoder_budget
253
+
254
+ # Check if the scheduling constraints are satisfied.
255
+ total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
256
+ assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
257
+ assert token_budget >= 0
258
+ assert len(self.running) <= self.max_num_running_reqs
259
+ # Since some requests in the RUNNING queue may not be scheduled in
260
+ # this step, the total number of scheduled requests can be smaller than
261
+ # len(self.running).
262
+ assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
263
+ len(scheduled_running_reqs) <= len(self.running))
264
+
265
+ # Get the longest common prefix among all requests in the running queue.
266
+ # This can be potentially used for cascade attention.
267
+ num_common_prefix_blocks = 0
268
+ if self.running:
269
+ any_request = self.running[0]
270
+ num_common_prefix_blocks = (
271
+ self.kv_cache_manager.get_num_common_prefix_blocks(
272
+ any_request, len(self.running)))
273
+
274
+ # Construct the scheduler output.
275
+ new_reqs_data = [
276
+ NewRequestData.from_request(req,
277
+ req_to_new_block_ids[req.request_id],
278
+ req.num_computed_tokens)
279
+ for req in scheduled_new_reqs
280
+ ]
281
+ resumed_reqs_data = [
282
+ self._make_cached_request_data(
283
+ req,
284
+ req_to_new_block_ids[req.request_id],
285
+ req.num_computed_tokens,
286
+ resumed_from_preemption=True,
287
+ ) for req in scheduled_resumed_reqs
288
+ ]
289
+ running_reqs_data = [
290
+ self._make_cached_request_data(
291
+ req,
292
+ req_to_new_block_ids[req.request_id],
293
+ req.num_computed_tokens,
294
+ resumed_from_preemption=False,
295
+ ) for req in scheduled_running_reqs
296
+ ]
297
+ scheduler_output = SchedulerOutput(
298
+ scheduled_new_reqs=new_reqs_data,
299
+ scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
300
+ num_scheduled_tokens=num_scheduled_tokens,
301
+ total_num_scheduled_tokens=total_num_scheduled_tokens,
302
+ scheduled_encoder_inputs=scheduled_encoder_inputs,
303
+ num_common_prefix_blocks=num_common_prefix_blocks,
304
+ # finished_req_ids is an existing state in the scheduler,
305
+ # instead of being newly scheduled in this step.
306
+ # It contains the request IDs that are finished in between
307
+ # the previous and the current steps.
308
+ finished_req_ids=self.finished_req_ids,
309
+ free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
310
+ )
311
+
312
+ self.finished_req_ids = set()
313
+ return scheduler_output
314
+
315
+ def _make_cached_request_data(
316
+ self,
317
+ request: Request,
318
+ new_block_ids: List[int],
319
+ num_computed_tokens: int,
320
+ resumed_from_preemption: bool,
321
+ ) -> "CachedRequestData":
322
+ # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
323
+ # them at each scheduling step.
324
+ if request.request_id in self._cached_reqs_data:
325
+ req_data = self._cached_reqs_data[request.request_id]
326
+ req_data.resumed_from_preemption = resumed_from_preemption
327
+ req_data.new_block_ids = new_block_ids
328
+ req_data.num_computed_tokens = num_computed_tokens
329
+ else:
330
+ req_data = CachedRequestData.from_request(request,
331
+ resumed_from_preemption,
332
+ new_block_ids,
333
+ num_computed_tokens)
334
+ self._cached_reqs_data[request.request_id] = req_data
335
+ return req_data
336
+
337
+ def _try_schedule_encoder_inputs(
338
+ self,
339
+ request: Request,
340
+ num_computed_tokens: int,
341
+ num_new_tokens: int,
342
+ encoder_budget: int,
343
+ ) -> Tuple[List[int], int, int]:
344
+ """
345
+ Determine which encoder inputs need to be scheduled in the current step,
346
+ and update `num_new_tokens` and encoder token budget accordingly.
347
+
348
+ An encoder input will be scheduled if:
349
+ - Its output tokens overlap with the range of tokens being computed
350
+ in this step, i.e.,
351
+ [num_computed_tokens, num_computed_tokens + num_new_tokens).
352
+ - It is not already computed and stored in the encoder cache.
353
+ - There is sufficient encoder token budget to process it.
354
+ - The encoder cache has space to store it.
355
+
356
+ If an encoder input cannot be scheduled due to cache or budget
357
+ limitations, the method adjusts `num_new_tokens` to schedule only the
358
+ decoder tokens up to just before the unschedulable encoder input.
359
+ """
360
+ if not request.has_encoder_inputs():
361
+ return [], num_new_tokens, encoder_budget
362
+
363
+ encoder_inputs_to_schedule: List[int] = []
364
+ mm_positions = request.mm_positions
365
+ assert mm_positions is not None
366
+ assert len(mm_positions) > 0
367
+ for i, pos_info in enumerate(mm_positions):
368
+ start_pos = pos_info["offset"]
369
+ num_encoder_tokens = pos_info["length"]
370
+
371
+ # The encoder output is needed if the two ranges overlap:
372
+ # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
373
+ # [start_pos, start_pos + num_encoder_tokens)
374
+ if start_pos >= num_computed_tokens + num_new_tokens:
375
+ # The encoder input is not needed in this step.
376
+ break
377
+ if start_pos + num_encoder_tokens <= num_computed_tokens:
378
+ # The encoder input is already computed and stored
379
+ # in the decoder's KV cache.
380
+ continue
381
+
382
+ if self.encoder_cache_manager.has_cache(request, i):
383
+ # The encoder input is already computed and cached.
384
+ continue
385
+ if (not self.encoder_cache_manager.can_allocate(request, i)
386
+ or num_encoder_tokens > encoder_budget):
387
+ # The encoder cache is full or the encoder budget is exhausted.
388
+ # NOTE(woosuk): We assume that the encoder input tokens should
389
+ # be processed altogether, as the encoder usually uses
390
+ # bidirectional attention.
391
+ if num_computed_tokens < start_pos:
392
+ # We only schedule the decoder tokens just before the
393
+ # encoder input.
394
+ num_new_tokens = start_pos - num_computed_tokens
395
+ else:
396
+ # Because of prefix caching, num_computed_tokens is greater
397
+ # than start_pos even though its encoder input is not
398
+ # available. In this case, we can't schedule any token for
399
+ # the request in this step.
400
+ num_new_tokens = 0
401
+ break
402
+
403
+ encoder_budget -= num_encoder_tokens
404
+ encoder_inputs_to_schedule.append(i)
405
+ return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
406
+
407
+ def update_from_output(
408
+ self,
409
+ scheduler_output: "SchedulerOutput",
410
+ model_runner_output: "ModelRunnerOutput",
411
+ ) -> EngineCoreOutputs:
412
+ # NOTE(woosuk): This method doesn't consider speculative decoding.
413
+ sampled_token_ids = model_runner_output.sampled_token_ids
414
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens
415
+ new_running: List[Request] = []
416
+ outputs: List[EngineCoreOutput] = []
417
+
418
+ # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
419
+ # loop can be a performance bottleneck. We should do our best to avoid
420
+ # expensive operations inside the loop.
421
+ for request in self.running:
422
+ req_id = request.request_id
423
+ num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
424
+ if num_tokens_scheduled == 0:
425
+ # The request was not scheduled in this step.
426
+ new_running.append(request)
427
+ continue
428
+
429
+ request.num_computed_tokens += num_tokens_scheduled
430
+ # When the request's num_computed_tokens catches up its num_tokens,
431
+ # the request generates output tokens. Otherwise, we ignore the
432
+ # sampler output for the request.
433
+ assert request.num_computed_tokens <= request.num_tokens
434
+
435
+ cached_encoder_input_ids = (
436
+ self.encoder_cache_manager.get_cached_input_ids(request))
437
+ # OPTIMIZATION: Avoid list(set) if the set is empty.
438
+ if cached_encoder_input_ids:
439
+ for input_id in list(cached_encoder_input_ids):
440
+ start_pos = request.mm_positions[input_id]["offset"]
441
+ num_tokens = request.mm_positions[input_id]["length"]
442
+ if start_pos + num_tokens <= request.num_computed_tokens:
443
+ # The encoder output is already processed and stored
444
+ # in the decoder's KV cache.
445
+ self.encoder_cache_manager.free_encoder_input(
446
+ request, input_id)
447
+
448
+ if request.num_computed_tokens == request.num_tokens:
449
+ req_index = model_runner_output.req_id_to_index[req_id]
450
+ # NOTE(woosuk): Currently, we assume that each request
451
+ # generates at most one token at each step.
452
+ token_id = sampled_token_ids[req_index]
453
+ request.append_output_token_ids(token_id)
454
+ num_new_tokens = 1
455
+ # TODO: Update the KV cache manager for prefix caching.
456
+
457
+ # Check for stop and update request state.
458
+ # This must be called before we make the EngineCoreOutput.
459
+ stopped = self._check_stop(request)
460
+ if stopped:
461
+ self._free_request(request)
462
+
463
+ # Add EngineCoreOutput for this Request.
464
+ output = EngineCoreOutput(
465
+ request_id=req_id,
466
+ new_token_ids=request.output_token_ids[-num_new_tokens:],
467
+ finished=request.is_finished(),
468
+ finish_reason=request.get_finished_reason(),
469
+ stop_reason=request.stop_reason)
470
+ outputs.append(output)
471
+
472
+ # Breakout of the loop.
473
+ if stopped:
474
+ continue
475
+
476
+ new_running.append(request)
477
+ self.running = new_running
478
+ return EngineCoreOutputs(
479
+ outputs=outputs,
480
+ scheduler_stats=self.make_stats(),
481
+ )
482
+
483
+ def _check_stop(self, request: Request) -> bool:
484
+ if (request.num_tokens >= self.max_model_len
485
+ or request.num_output_tokens >= request.max_tokens):
486
+ request.status = RequestStatus.FINISHED_LENGTH_CAPPED
487
+ return True
488
+
489
+ sampling_params = request.sampling_params
490
+ last_token_id = request.output_token_ids[-1]
491
+ if (not sampling_params.ignore_eos
492
+ and last_token_id == request.eos_token_id):
493
+ request.status = RequestStatus.FINISHED_STOPPED
494
+ return True
495
+
496
+ if last_token_id in (sampling_params.stop_token_ids or ()):
497
+ request.status = RequestStatus.FINISHED_STOPPED
498
+ request.stop_reason = last_token_id
499
+ return True
500
+ return False
501
+
502
+ def add_request(self, request: Request) -> None:
503
+ self.waiting.append(request)
504
+ self.requests[request.request_id] = request
505
+
506
+ def finish_requests(
507
+ self,
508
+ request_ids: Union[str, Iterable[str]],
509
+ finished_status: RequestStatus,
510
+ ) -> None:
511
+ """Handles the finish signal from outside the scheduler.
512
+
513
+ For example, the API server can abort a request when the client
514
+ disconnects.
515
+ """
516
+ assert RequestStatus.is_finished(finished_status)
517
+ if isinstance(request_ids, str):
518
+ request_ids = (request_ids, )
519
+ request_ids = set(request_ids)
520
+
521
+ for req_id in request_ids:
522
+ request = self.requests.get(req_id)
523
+ if request is None:
524
+ # Invalid request ID.
525
+ continue
526
+
527
+ if request.status == RequestStatus.RUNNING:
528
+ self.running.remove(request)
529
+ else:
530
+ self.waiting.remove(request)
531
+ request.status = finished_status
532
+ self._free_request(request)
533
+
534
+ def _free_request(self, request: Request) -> None:
535
+ assert request.is_finished()
536
+ self.kv_cache_manager.free(request)
537
+ self.encoder_cache_manager.free(request)
538
+ self._cached_reqs_data.pop(request.request_id, None)
539
+ del self.requests[request.request_id]
540
+ self.finished_req_ids.add(request.request_id)
541
+
542
+ def get_num_unfinished_requests(self) -> int:
543
+ return len(self.waiting) + len(self.running)
544
+
545
+ def has_unfinished_requests(self) -> bool:
546
+ return self.get_num_unfinished_requests() > 0
547
+
548
+ def reset_prefix_cache(self) -> bool:
549
+ return self.kv_cache_manager.reset_prefix_cache()
550
+
551
+ def make_stats(self) -> SchedulerStats:
552
+ return SchedulerStats(
553
+ num_running_reqs=len(self.running),
554
+ num_waiting_reqs=len(self.waiting),
555
+ gpu_cache_usage=self.kv_cache_manager.usage,
556
+ )
557
+
558
+
559
+ @dataclass
560
+ class NewRequestData:
561
+
562
+ req_id: str
563
+ prompt_token_ids: List[int]
564
+ prompt: Optional[str]
565
+ mm_inputs: List["MultiModalKwargs"]
566
+ mm_hashes: List[str]
567
+ mm_positions: List["PlaceholderRange"]
568
+ sampling_params: SamplingParams
569
+ block_ids: List[int]
570
+ num_computed_tokens: int
571
+
572
+ @classmethod
573
+ def from_request(
574
+ cls,
575
+ request: Request,
576
+ block_ids: List[int],
577
+ num_computed_tokens: int,
578
+ ) -> "NewRequestData":
579
+ return cls(
580
+ req_id=request.request_id,
581
+ prompt_token_ids=request.prompt_token_ids,
582
+ prompt=request.prompt,
583
+ mm_inputs=request.mm_inputs,
584
+ mm_hashes=request.mm_hashes,
585
+ mm_positions=request.mm_positions,
586
+ sampling_params=request.sampling_params,
587
+ block_ids=block_ids,
588
+ num_computed_tokens=num_computed_tokens,
589
+ )
590
+
591
+
592
+ @dataclass
593
+ class CachedRequestData:
594
+
595
+ req_id: str
596
+ # If resumed_from_preemption is False, new_block_ids will be appended to
597
+ # the request's block IDs. If True, new_block_ids will be used as the
598
+ # request's block IDs instead of appending to the existing block IDs.
599
+ resumed_from_preemption: bool
600
+ new_block_ids: List[int]
601
+ num_computed_tokens: int
602
+
603
+ @classmethod
604
+ def from_request(
605
+ cls,
606
+ request: Request,
607
+ resumed_from_preemption: bool,
608
+ new_block_ids: List[int],
609
+ num_computed_tokens: int,
610
+ ) -> "CachedRequestData":
611
+ return cls(
612
+ req_id=request.request_id,
613
+ resumed_from_preemption=resumed_from_preemption,
614
+ new_block_ids=new_block_ids,
615
+ num_computed_tokens=num_computed_tokens,
616
+ )
617
+
618
+
619
+ @dataclass
620
+ class SchedulerOutput:
621
+
622
+ scheduled_new_reqs: List[NewRequestData]
623
+ scheduled_cached_reqs: List[CachedRequestData]
624
+
625
+ num_scheduled_tokens: Dict[str, int]
626
+ total_num_scheduled_tokens: int
627
+ scheduled_encoder_inputs: Dict[str, List[int]]
628
+ num_common_prefix_blocks: int
629
+
630
+ finished_req_ids: Set[str]
631
+ free_encoder_input_ids: List[Tuple[str, int]]
.venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc ADDED
Binary file (4.57 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Type
4
+
5
+ from vllm.config import VllmConfig
6
+ from vllm.executor.executor_base import ExecutorBase
7
+ from vllm.executor.ray_distributed_executor import ( # noqa
8
+ RayDistributedExecutor as RayDistributedExecutorV0)
9
+ from vllm.executor.uniproc_executor import ( # noqa
10
+ ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
11
+ from vllm.executor.uniproc_executor import ( # noqa
12
+ UniProcExecutor as UniProcExecutorV0)
13
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
14
+ from vllm.v1.outputs import ModelRunnerOutput
15
+
16
+
17
+ class Executor(ExecutorBase):
18
+ """
19
+ Abstract class for v1 executors, mainly define some methods for v1.
20
+ For methods shared by v0 and v1, define them in ExecutorBase"""
21
+
22
+ @staticmethod
23
+ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
24
+ executor_class: Type[Executor]
25
+ parallel_config = vllm_config.parallel_config
26
+ distributed_executor_backend = (
27
+ parallel_config.distributed_executor_backend)
28
+ if distributed_executor_backend is None:
29
+ # If the user does not specify the distributed executor backend,
30
+ # we will choose the backend based on the world size.
31
+ if parallel_config.world_size > 1:
32
+ distributed_executor_backend = "mp"
33
+ else:
34
+ distributed_executor_backend = "uni"
35
+
36
+ if distributed_executor_backend == "ray":
37
+ executor_class = RayDistributedExecutor
38
+ elif distributed_executor_backend == "mp":
39
+ from vllm.v1.executor.multiproc_executor import MultiprocExecutor
40
+ executor_class = MultiprocExecutor
41
+ elif distributed_executor_backend == "uni":
42
+ executor_class = UniProcExecutor
43
+ elif distributed_executor_backend == "external_launcher":
44
+ # TODO: make v1 scheduling deterministic
45
+ # to support external launcher
46
+ executor_class = ExecutorWithExternalLauncher
47
+ else:
48
+ raise ValueError("Unknown distributed executor backend: "
49
+ f"{distributed_executor_backend}")
50
+ return executor_class
51
+
52
+ def initialize(self, kv_cache_config: KVCacheConfig) -> None:
53
+ """
54
+ Initialize the KV caches and begin the model execution loop of the
55
+ underlying workers.
56
+ """
57
+ self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
58
+ self.collective_rpc("compile_or_warm_up_model")
59
+
60
+ def determine_available_memory(self) -> int: # in bytes
61
+ output = self.collective_rpc("determine_available_memory")
62
+ # Since we use a shared centralized controller, we take the minimum
63
+ # memory size across all workers to make sure all the memory
64
+ # operators can be applied to all workers.
65
+ return min(output)
66
+
67
+ def get_kv_cache_spec(self) -> KVCacheSpec:
68
+ output = self.collective_rpc("get_kv_cache_spec")
69
+ for x in output:
70
+ assert x == output[0]
71
+ return output[0]
72
+
73
+ def execute_model(
74
+ self,
75
+ scheduler_output,
76
+ ) -> ModelRunnerOutput:
77
+ output = self.collective_rpc("execute_model",
78
+ args=(scheduler_output, ))
79
+ return output[0]
80
+
81
+ def profile(self, is_start: bool = True):
82
+ self.collective_rpc("profile", args=(is_start, ))
83
+
84
+
85
+ class UniProcExecutor(UniProcExecutorV0, Executor):
86
+ pass
87
+
88
+
89
+ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
90
+ pass
91
+
92
+
93
+ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
94
+ pass