diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..80354d69b50afe64a38df6532af78e10cc6858fa
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
+from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
+
+__all__ = [
+ "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
+]
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76a288d2682124cf8137c644ea5d91b8bbc217d9
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e1fee820f81efd899377c804f6ca490c44f7c84
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ed067b2e1de1e3c8853a6eb0680be94ec84945a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5df7e47446b7acc74cd77c4f905f87a95581716
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
@@ -0,0 +1,160 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from functools import cached_property
+from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import import_from_path, is_list_of
+
+logger = init_logger(__name__)
+
+
+class ReasoningParser:
+ """
+ Abstract reasoning parser class that should not be used directly.
+ Provided and methods should be used in derived classes.
+
+ It is used to extract reasoning content from the model output.
+ """
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ self.model_tokenizer = tokenizer
+
+ @cached_property
+ def vocab(self) -> Dict[str, int]:
+ # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
+ # whereas all tokenizers have .get_vocab()
+ return self.model_tokenizer.get_vocab()
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest
+ ) -> Tuple[Optional[str], Optional[str]]:
+ """
+ Extract reasoning content from a complete model-generated string.
+
+ Used for non-streaming responses where we have the entire model response
+ available before sending to the client.
+
+ Parameters:
+ model_output: str
+ The model-generated string to extract reasoning content from.
+
+ request: ChatCompletionRequest
+ The request object that was used to generate the model_output.
+
+ Returns:
+ Tuple[Optional[str], Optional[str]]
+ A tuple containing the reasoning content and the content.
+ """
+
+ raise NotImplementedError(
+ "AbstractReasoningParser.extract_reasoning_calls "
+ "has not been implemented!")
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Instance method that should be implemented for extracting reasoning
+ from an incomplete response; for use when handling reasoning calls and
+ streaming. Has to be an instance method because it requires state -
+ the current tokens/diffs, but also the information about what has
+ previously been parsed and extracted (see constructor)
+ """
+ raise NotImplementedError(
+ "AbstractReasoningParser.extract_reasoning_content_streaming "
+ "has not been implemented!")
+
+
+class ReasoningParserManager:
+ reasoning_parsers: Dict[str, Type] = {}
+
+ @classmethod
+ def get_reasoning_parser(cls, name) -> Type:
+ """
+ Get reasoning parser by name which is registered by `register_module`.
+
+ Raise a KeyError exception if the name is not registered.
+ """
+ if name in cls.reasoning_parsers:
+ return cls.reasoning_parsers[name]
+
+ raise KeyError(f"reasoning helper: '{name}' not found in "
+ "reasoning_parsers")
+
+ @classmethod
+ def _register_module(cls,
+ module: Type,
+ module_name: Optional[Union[str, List[str]]] = None,
+ force: bool = True) -> None:
+ if not issubclass(module, ReasoningParser):
+ raise TypeError("module must be subclass of ReasoningParser, "
+ f"but got {type(module)}")
+ if module_name is None:
+ module_name = module.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in cls.reasoning_parsers:
+ existed_module = cls.reasoning_parsers[name]
+ raise KeyError(f"{name} is already registered "
+ f"at {existed_module.__module__}")
+ cls.reasoning_parsers[name] = module
+
+ @classmethod
+ def register_module(
+ cls,
+ name: Optional[Union[str, List[str]]] = None,
+ force: bool = True,
+ module: Union[Type, None] = None) -> Union[type, Callable]:
+ """
+ Register module with the given name or name list. it can be used as a
+ decoder(with module as None) or normal function(with module as not
+ None).
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str)
+ or is_list_of(name, str)):
+ raise TypeError(
+ "name must be None, an instance of str, or a sequence of str, "
+ f"but got {type(name)}")
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ cls._register_module(module=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(module):
+ cls._register_module(module=module, module_name=name, force=force)
+ return module
+
+ return _register
+
+ @classmethod
+ def import_reasoning_parser(cls, plugin_path: str) -> None:
+ """
+ Import a user-defined reasoning parser by the path
+ of the reasoning parser define file.
+ """
+ module_name = os.path.splitext(os.path.basename(plugin_path))[0]
+
+ try:
+ import_from_path(module_name, plugin_path)
+ except Exception:
+ logger.exception("Failed to load module '%s' from %s.",
+ module_name, plugin_path)
+ return
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c19888d4540137fb7d07150720b0ad5f5e849b9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
@@ -0,0 +1,135 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+from typing import Optional, Sequence, Tuple, Union
+
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage)
+from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
+ ReasoningParser, ReasoningParserManager)
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@ReasoningParserManager.register_module("deepseek_r1")
+class DeepSeekR1ReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for DeepSeek R1 model.
+
+ The DeepSeek R1 model uses ... tokens to denote reasoning
+ text. This parser extracts the reasoning content from the model output.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
+ super().__init__(tokenizer)
+ self.think_start_token = ""
+ self.think_end_token = ""
+
+ self.reasoning_regex = re.compile(
+ rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ReasoningParser "
+ "constructor during construction.")
+
+ self.think_start_token_id = self.vocab.get(self.think_start_token)
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
+ if (self.think_start_token_id is None
+ or self.think_end_token_id is None):
+ raise RuntimeError(
+ "DeepSeek R1 reasoning parser could not locate think start/end "
+ "tokens in the tokenizer!")
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract reasoning content from a delta message.
+ Handles streaming output where previous + delta = current.
+ Uses token IDs for faster processing.
+ For text abcxyz:
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+ """
+ # Skip single special tokens
+ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
+ self.think_start_token_id, self.think_end_token_id
+ ]):
+ return None
+
+ if self.think_start_token_id in previous_token_ids:
+ if self.think_end_token_id in delta_token_ids:
+ # in previous, in delta,
+ # extract reasoning content
+ end_index = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[:end_index]
+ content = delta_text[end_index + len(self.think_end_token):]
+ return DeltaMessage(reasoning_content=reasoning_content,
+ content=content if content else None)
+ elif self.think_end_token_id in previous_token_ids:
+ # in previous, in previous,
+ # reasoning content continues
+ return DeltaMessage(content=delta_text)
+ else:
+ # in previous, no in previous or delta,
+ # reasoning content continues
+ return DeltaMessage(reasoning_content=delta_text)
+ elif self.think_start_token_id in delta_token_ids:
+ logger.info(delta_text)
+ if self.think_end_token_id in delta_token_ids:
+ # in delta, in delta, extract reasoning content
+ start_index = delta_text.find(self.think_start_token)
+ end_index = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[start_index +
+ len(self.think_start_token
+ ):end_index]
+ content = delta_text[end_index + len(self.think_end_token):]
+ return DeltaMessage(reasoning_content=reasoning_content,
+ content=content if content else None)
+ else:
+ # in delta, no in delta,
+ # reasoning content continues
+ return DeltaMessage(reasoning_content=delta_text)
+ else:
+ # No in previous or delta, reasoning content continues.
+ return DeltaMessage(content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest
+ ) -> Tuple[Optional[str], Optional[str]]:
+
+ # Check if the model output contains the tokens.
+ if (self.think_start_token not in model_output
+ or self.think_end_token not in model_output):
+ return None, model_output
+ else:
+ # Use a regex to find the reasoning content
+ reasoning_content = self.reasoning_regex.findall(model_output)[0]
+
+ # Remove the reasoning content from the model output
+ # Although deepseek's token is always at the
+ # beginning of the line, we cannot guarantee that the
+ # other models will follow this convention.
+ # Therefore, we need to add :start_index.
+ start_index = model_output.find(self.think_start_token)
+ if start_index != -1:
+ end_index = start_index + len(
+ f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
+ )
+ model_output = model_output[:start_index] + \
+ model_output[end_index:]
+
+ if len(model_output) == 0:
+ return reasoning_content, None
+
+ return reasoning_content, model_output
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..002bf173883086f80bedcd61477ce9a0501e28fc
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
@@ -0,0 +1,253 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+import re
+from json import JSONDecoder
+from typing import Dict, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
+ find_common_prefix,
+ is_complete_json,
+ partial_json_loads)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("granite-20b-fc")
+class Granite20bFCToolParser(ToolParser):
+ """
+ Tool call parser for the granite-20b-functioncalling model intended
+ for use with the examples/tool_chat_template_granite20b_fc.jinja
+ template.
+
+ Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
+ are all set
+ """
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ self.bot_token = ""
+ self.tool_start_token = self.bot_token
+ self.tool_call_regex = re.compile(r"\s*")
+
+ def extract_tool_calls(
+ self, model_output: str,
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ if self.tool_start_token not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ dec = JSONDecoder()
+ try:
+ matches = list(self.tool_call_regex.finditer(model_output))
+ logger.debug("Found %d tool call matches", len(matches))
+
+ raw_function_calls = []
+
+ for i, match in enumerate(matches):
+ # position after the tag
+ start_of_json = match.end()
+ # end_index == the start of the next function call
+ # (if exists)
+ next_function_call_start = (matches[i + 1].start() if i +
+ 1 < len(matches) else None)
+
+ raw_function_calls.append(
+ dec.raw_decode(
+ model_output[start_of_json:next_function_call_start])
+ [0])
+
+ logger.debug("Extracted %d tool calls", len(raw_function_calls))
+ tool_calls = [
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=function_call["name"],
+ # function call args are JSON but as a string
+ arguments=json.dumps(function_call["arguments"]),
+ ),
+ ) for function_call in raw_function_calls
+ ]
+
+ content = model_output[:model_output.find(self.bot_token)]
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None,
+ )
+
+ except Exception as e:
+ logger.error("Error in extracting tool call from response %s", e)
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ if len(current_text) < len(
+ self.bot_token) and self.bot_token.startswith(current_text):
+ return None
+
+ if not current_text.startswith(self.bot_token):
+ return DeltaMessage(content=delta_text)
+
+ # bit mask flags for partial JSON parsing. If the name hasn't been
+ # sent yet, don't allow sending
+ # an incomplete string since OpenAI only ever (as far as I have
+ # seen) allows sending the entire tool/ function name at once.
+ flags = Allow.ALL if self.current_tool_name_sent \
+ else Allow.ALL & ~Allow.STR
+ try:
+ tool_call_arr = []
+ is_complete = []
+ try:
+ start_idx = len(self.bot_token)
+ start_idx = consume_space(start_idx, current_text)
+
+ while start_idx < len(current_text):
+ (obj,
+ end_idx) = partial_json_loads(current_text[start_idx:],
+ flags)
+ is_complete.append(
+ is_complete_json(current_text[start_idx:start_idx +
+ end_idx]))
+ start_idx += end_idx
+ start_idx = consume_space(start_idx, current_text)
+ start_idx += len(self.bot_token)
+ start_idx = consume_space(start_idx, current_text)
+ tool_call_arr.append(obj)
+ except partial_json_parser.core.exceptions.MalformedJSON:
+ logger.debug('not enough tokens to parse into JSON yet')
+ return None
+
+ # select as the current tool call the one we're on the state at
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
+ if len(tool_call_arr) > 0 else {}
+
+ # case -- if no tokens have been streamed for the tool, e.g.
+ # only the array brackets, stream nothing
+ if len(tool_call_arr) == 0:
+ return None
+
+ # case: we are starting a new tool in the array
+ # -> array has > 0 length AND length has moved past cursor
+ elif (len(tool_call_arr) > 0
+ and len(tool_call_arr) > self.current_tool_id + 1):
+
+ # if we're moving on to a new call, first make sure we
+ # haven't missed anything in the previous one that was
+ # auto-generated due to JSON completions, but wasn't
+ # streamed to the client yet.
+ if self.current_tool_id >= 0:
+ cur_arguments = current_tool_call.get("arguments")
+ if cur_arguments:
+ cur_args_json = json.dumps(cur_arguments)
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ argument_diff = cur_args_json[sent:]
+
+ logger.debug("got arguments diff: %s", argument_diff)
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+ else:
+ delta = None
+ else:
+ delta = None
+ # re-set stuff pertaining to progress in the current tool
+ self.current_tool_id = len(tool_call_arr) - 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ logger.debug("starting on new tool %d", self.current_tool_id)
+ return delta
+
+ # if the current tool name hasn't been sent, send if available
+ # - otherwise send nothing
+ elif not self.current_tool_name_sent:
+ function_name = current_tool_call.get("name")
+ if function_name:
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True))
+ ])
+ self.current_tool_name_sent = True
+ else:
+ delta = None
+
+ # now we know we're on the same tool call and we're streaming
+ # arguments
+ else:
+ cur_arguments = current_tool_call.get("arguments")
+ delta = None
+
+ if cur_arguments:
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ cur_args_json = json.dumps(cur_arguments)
+ prev_arguments = self.prev_tool_call_arr[
+ self.current_tool_id].get("arguments")
+
+ argument_diff = None
+ if is_complete[self.current_tool_id]:
+ argument_diff = cur_args_json[sent:]
+ elif prev_arguments:
+ prev_args_json = json.dumps(prev_arguments)
+ if cur_args_json != prev_args_json:
+
+ prefix = find_common_prefix(
+ prev_args_json, cur_args_json)
+ argument_diff = prefix[sent:]
+
+ if argument_diff is not None:
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+
+ self.prev_tool_call_arr = tool_call_arr
+ return delta
+
+ except Exception as e:
+ logger.error("Error trying to handle streaming tool call: %s", e)
+ logger.debug(
+ "Skipping chunk as a result of tool streaming extraction "
+ "error")
+ return None
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..c948ed78f503b9bb9f760463846e5d459f11c21b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
@@ -0,0 +1,231 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from typing import Dict, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
+ find_common_prefix,
+ is_complete_json,
+ partial_json_loads)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("granite")
+class GraniteToolParser(ToolParser):
+ """
+ Tool call parser for the granite 3.0 models. Intended
+ for use with the examples/tool_chat_template_granite.jinja
+ template.
+
+ Used when --enable-auto-tool-choice --tool-call-parser granite
+ are all set
+ """
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+ # for granite 3.0, the token `<|tool_call|>`
+ self.bot_token = "<|tool_call|>"
+ # for granite 3.1, the string ``
+ self.bot_string = ""
+
+ def extract_tool_calls(
+ self, model_output: str,
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ stripped = model_output.strip()\
+ .removeprefix(self.bot_token)\
+ .removeprefix(self.bot_string)\
+ .lstrip()
+ if not stripped or stripped[0] != '[':
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+ try:
+ raw_function_calls = json.loads(stripped)
+ if not isinstance(raw_function_calls, list):
+ raise Exception(
+ f"Expected dict or list, got {type(raw_function_calls)}")
+
+ logger.debug("Extracted %d tool calls", len(raw_function_calls))
+ tool_calls = [
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=function_call["name"],
+ # function call args are JSON but as a string
+ arguments=json.dumps(function_call["arguments"]),
+ ),
+ ) for function_call in raw_function_calls
+ ]
+
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=None,
+ )
+
+ except Exception as e:
+ logger.error("Error in extracting tool call from response %s", e)
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ start_idx = consume_space(0, current_text)
+ if current_text[start_idx:].startswith(self.bot_token):
+ start_idx = consume_space(start_idx + len(self.bot_token),
+ current_text)
+ if current_text[start_idx:].startswith(self.bot_string):
+ start_idx = consume_space(start_idx + len(self.bot_string),
+ current_text)
+ if not current_text or start_idx >= len(current_text)\
+ or current_text[start_idx] != '[':
+ return DeltaMessage(content=delta_text)
+
+ # bit mask flags for partial JSON parsing. If the name hasn't been
+ # sent yet, don't allow sending
+ # an incomplete string since OpenAI only ever (as far as I have
+ # seen) allows sending the entire tool/ function name at once.
+ flags = Allow.ALL if self.current_tool_name_sent \
+ else Allow.ALL & ~Allow.STR
+ try:
+ tool_call_arr = None
+ is_complete = None
+ try:
+ tool_calls, end_idx = partial_json_loads(
+ current_text[start_idx:], flags)
+ if type(tool_calls) is list:
+ tool_call_arr = tool_calls
+ else:
+ return DeltaMessage(content=delta_text)
+
+ is_complete = [True] * len(tool_calls)
+ if not is_complete_json(
+ current_text[start_idx:start_idx + end_idx]):
+ is_complete[-1] = False
+ except partial_json_parser.core.exceptions.MalformedJSON:
+ logger.debug('not enough tokens to parse into JSON yet')
+ return None
+
+ # case -- if no tokens have been streamed for the tool, e.g.
+ # only the array brackets, stream nothing
+ if not tool_call_arr:
+ return None
+
+ # select as the current tool call the one we're on the state at
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id]
+
+ delta = None
+ # case: we are starting a new tool in the array
+ # -> array has > 0 length AND length has moved past cursor
+ if len(tool_call_arr) > self.current_tool_id + 1:
+
+ # if we're moving on to a new call, first make sure we
+ # haven't missed anything in the previous one that was
+ # auto-generated due to JSON completions, but wasn't
+ # streamed to the client yet.
+ if self.current_tool_id >= 0:
+ cur_arguments = current_tool_call.get("arguments")
+ if cur_arguments:
+ cur_args_json = json.dumps(cur_arguments)
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ argument_diff = cur_args_json[sent:]
+
+ logger.debug("got arguments diff: %s", argument_diff)
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+
+ # re-set stuff pertaining to progress in the current tool
+ self.current_tool_id = len(tool_call_arr) - 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ logger.debug("starting on new tool %d", self.current_tool_id)
+ return delta
+
+ # if the current tool name hasn't been sent, send if available
+ # - otherwise send nothing
+ elif not self.current_tool_name_sent:
+ function_name = current_tool_call.get("name")
+ if function_name:
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True))
+ ])
+ self.current_tool_name_sent = True
+
+ # now we know we're on the same tool call and we're streaming
+ # arguments
+ else:
+ cur_arguments = current_tool_call.get("arguments")
+
+ if cur_arguments:
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ cur_args_json = json.dumps(cur_arguments)
+ prev_arguments = self.prev_tool_call_arr[
+ self.current_tool_id].get("arguments")
+
+ argument_diff = None
+ if is_complete[self.current_tool_id]:
+ argument_diff = cur_args_json[sent:]
+ elif prev_arguments:
+ prev_args_json = json.dumps(prev_arguments)
+ if cur_args_json != prev_args_json:
+ prefix = find_common_prefix(
+ prev_args_json, cur_args_json)
+ argument_diff = prefix[sent:]
+
+ if argument_diff is not None:
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+
+ self.prev_tool_call_arr = tool_call_arr
+ return delta
+
+ except Exception as e:
+ logger.error("Error trying to handle streaming tool call: %s", e)
+ logger.debug(
+ "Skipping chunk as a result of tool streaming extraction "
+ "error")
+ return None
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..4841b28703ee3beff672150f465577d57a17251b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
@@ -0,0 +1,369 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+import re
+from typing import Dict, List, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("hermes")
+class Hermes2ProToolParser(ToolParser):
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ if isinstance(self.model_tokenizer, MistralTokenizer):
+ logger.error(
+ "Detected Mistral tokenizer when using a Hermes model")
+ self.model_tokenizer = self.model_tokenizer.tokenizer
+
+ self.current_tool_name_sent: bool = False
+ self.prev_tool_call_arr: List[Dict] = []
+ self.current_tool_id: int = -1
+ self.streamed_args_for_tool: List[str] = [
+ ] # map what has been streamed for each tool so far to a list
+
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+
+ self.tool_call_regex = re.compile(
+ r"(.*?)|(.*)", re.DOTALL)
+ self.scratch_pad_regex = re.compile(
+ r"(.*?)", re.DOTALL)
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolParser "
+ "constructor during construction.")
+ self.tool_call_start_token_id = self.vocab.get(
+ self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+ if (self.tool_call_start_token_id is None
+ or self.tool_call_end_token_id is None):
+ raise RuntimeError(
+ "Hermes 2 Pro Tool parser could not locate tool call start/end "
+ "tokens in the tokenizer!")
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+
+ # sanity check; avoid unnecessary processing
+ if self.tool_call_start_token not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ else:
+
+ try:
+ # there are two possible captures - between tags, or between a
+ # tag and end-of-string so the result of
+ # findall is an array of tuples where one is a function call and
+ # the other is None
+ function_call_tuples = (
+ self.tool_call_regex.findall(model_output))
+
+ # load the JSON, and then use it to build the Function and
+ # Tool Call
+ raw_function_calls = [
+ json.loads(match[0] if match[0] else match[1])
+ for match in function_call_tuples
+ ]
+ tool_calls = [
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=function_call["name"],
+ # function call args are JSON but as a string
+ arguments=json.dumps(function_call["arguments"],
+ ensure_ascii=False)))
+ for function_call in raw_function_calls
+ ]
+
+ content = model_output[:model_output.
+ find(self.tool_call_start_token)]
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None)
+
+ except Exception:
+ logger.exception(
+ "Error in extracting tool call from response.")
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ logger.debug("delta_text: %s", delta_text)
+ logger.debug("delta_token_ids: %s", delta_token_ids)
+ # check to see if we should be streaming a tool call - is there a
+ if self.tool_call_start_token_id not in current_token_ids:
+ logger.debug("No tool call tokens found!")
+ return DeltaMessage(content=delta_text)
+
+ try:
+
+ # figure out where we are in the parsing by counting tool call
+ # start & end tags
+ prev_tool_start_count = previous_token_ids.count(
+ self.tool_call_start_token_id)
+ prev_tool_end_count = previous_token_ids.count(
+ self.tool_call_end_token_id)
+ cur_tool_start_count = current_token_ids.count(
+ self.tool_call_start_token_id)
+ cur_tool_end_count = current_token_ids.count(
+ self.tool_call_end_token_id)
+ tool_call_portion = None
+ text_portion = None
+
+ # case: if we're generating text, OR rounding out a tool call
+ if (cur_tool_start_count == cur_tool_end_count
+ and prev_tool_end_count == cur_tool_end_count
+ and self.tool_call_end_token not in delta_text):
+ logger.debug("Generating text content! skipping tool parsing.")
+ return DeltaMessage(content=delta_text)
+
+ if self.tool_call_end_token in delta_text:
+ logger.debug("tool_call_end_token in delta_text")
+ full_text = current_text + delta_text
+ tool_call_portion = full_text.split(
+ self.tool_call_start_token)[-1].split(
+ self.tool_call_end_token)[0].rstrip()
+ delta_text = delta_text.split(
+ self.tool_call_end_token)[0].rstrip()
+ text_portion = delta_text.split(
+ self.tool_call_end_token)[-1].lstrip()
+
+ # case: if tool open & close tag counts don't match, we're doing
+ # imaginary "else" block here
+ # something with tools with this diff.
+ # flags for partial JSON parting. exported constants from
+ # "Allow" are handled via BIT MASK
+ flags = Allow.ALL if self.current_tool_name_sent \
+ else Allow.ALL & ~Allow.STR
+
+ # case -- we're starting a new tool call
+ if (cur_tool_start_count > cur_tool_end_count
+ and cur_tool_start_count > prev_tool_start_count):
+ if len(delta_token_ids) > 1:
+ tool_call_portion = current_text.split(
+ self.tool_call_start_token)[-1]
+ else:
+ tool_call_portion = None
+ delta = None
+
+ text_portion = None
+
+ # set cursors and state appropriately
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ logger.debug("Starting on a new tool %s", self.current_tool_id)
+
+ # case -- we're updating an existing tool call
+ elif (cur_tool_start_count > cur_tool_end_count
+ and cur_tool_start_count == prev_tool_start_count):
+
+ # get the portion of the text that's the tool call
+ tool_call_portion = current_text.split(
+ self.tool_call_start_token)[-1]
+ text_portion = None
+
+ # case -- the current tool call is being closed.
+ elif (cur_tool_start_count == cur_tool_end_count
+ and cur_tool_end_count >= prev_tool_end_count):
+ if (self.prev_tool_call_arr is None
+ or len(self.prev_tool_call_arr) == 0):
+ logger.debug(
+ "attempting to close tool call, but no tool call")
+ return None
+ diff = self.prev_tool_call_arr[self.current_tool_id].get(
+ "arguments")
+ if diff:
+ diff = diff.encode('utf-8').decode(
+ 'unicode_escape') if diff is str else diff
+ if ('"}' not in delta_text):
+ return None
+ end_loc = delta_text.rindex('"}')
+ diff = delta_text[:end_loc] + '"}'
+ logger.debug(
+ "Finishing tool and found diff that had not "
+ "been streamed yet: %s", diff)
+ self.streamed_args_for_tool[self.current_tool_id] \
+ += diff
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=diff).model_dump(
+ exclude_none=True))
+ ])
+
+ # case -- otherwise we're just generating text
+ else:
+ text = delta_text.replace(self.tool_call_start_token, "")
+ text = text.replace(self.tool_call_end_token, "")
+ delta = DeltaMessage(tool_calls=[], content=text)
+ return delta
+
+ try:
+
+ current_tool_call = partial_json_parser.loads(
+ tool_call_portion or "{}",
+ flags) if tool_call_portion else None
+ logger.debug("Parsed tool call %s", current_tool_call)
+ except partial_json_parser.core.exceptions.MalformedJSON:
+ logger.debug('not enough tokens to parse into JSON yet')
+ return None
+ except json.decoder.JSONDecodeError:
+ logger.debug("unable to parse JSON")
+ return None
+
+ # case - we haven't sent the tool name yet. If it's available, send
+ # it. otherwise, wait until it's available.
+ if not self.current_tool_name_sent:
+ if (current_tool_call is None):
+ return None
+ function_name: Union[str, None] = current_tool_call.get("name")
+ if function_name:
+ self.current_tool_name_sent = True
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True))
+ ])
+ else:
+ return None
+ # case -- otherwise, send the tool call delta
+
+ # if the tool call portion is None, send the delta as text
+ if tool_call_portion is None:
+ # if there's text but not tool calls, send that -
+ # otherwise None to skip chunk
+ delta = DeltaMessage(content=delta_text) \
+ if text_portion is not None else None
+ return delta
+
+ # now, the nitty-gritty of tool calls
+ # now we have the portion to parse as tool call.
+
+ logger.debug("Trying to parse current tool call with ID %s",
+ self.current_tool_id)
+
+ # if we're starting a new tool call, push an empty object in as
+ # a placeholder for the arguments
+ if len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+
+ # main logic for tool parsing here - compare prev. partially-parsed
+ # JSON to the current partially-parsed JSON
+ prev_arguments = (
+ self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
+ cur_arguments = current_tool_call.get("arguments")
+
+ logger.debug("diffing old arguments: %s", prev_arguments)
+ logger.debug("against new ones: %s", cur_arguments)
+
+ # case -- no arguments have been created yet. skip sending a delta.
+ if not cur_arguments and not prev_arguments:
+ logger.debug("Skipping text %s - no arguments", delta_text)
+ delta = None
+
+ # case -- prev arguments are defined, but non are now.
+ # probably impossible, but not a fatal error - just keep going
+ elif not cur_arguments and prev_arguments:
+ logger.error("should be impossible to have arguments reset "
+ "mid-call. skipping streaming anything.")
+ delta = None
+
+ # case -- we now have the first info about arguments available from
+ # autocompleting the JSON
+ elif cur_arguments and not prev_arguments:
+
+ cur_arguments_json = json.dumps(cur_arguments,
+ ensure_ascii=False)
+ logger.debug("finding %s in %s", delta_text,
+ cur_arguments_json)
+
+ # get the location where previous args differ from current
+ if (delta_text not in cur_arguments_json[:-2]):
+ return None
+ args_delta_start_loc = cur_arguments_json[:-2]. \
+ rindex(delta_text) + \
+ len(delta_text)
+
+ # use that to find the actual delta
+ arguments_delta = cur_arguments_json[:args_delta_start_loc]
+ logger.debug("First tokens in arguments received: %s",
+ arguments_delta)
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=arguments_delta).model_dump(
+ exclude_none=True))
+ ])
+ self.streamed_args_for_tool[self.current_tool_id] \
+ += arguments_delta
+
+ # last case -- we have an update to existing arguments.
+ elif cur_arguments and prev_arguments:
+ if isinstance(delta_text, str) and len(delta_text.rstrip(
+ )) >= 1 and delta_text.rstrip()[-1] == '}':
+ delta_text = delta_text.rstrip()[:-1]
+
+ logger.debug("got diff %s", delta_text)
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=delta_text).model_dump(
+ exclude_none=True))
+ ])
+ self.streamed_args_for_tool[self.current_tool_id] \
+ += delta_text
+
+ # handle saving the state for the current tool into
+ # the "prev" list for use in diffing for the next iteration
+ if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
+ self.prev_tool_call_arr[self.current_tool_id] = \
+ current_tool_call
+ else:
+ self.prev_tool_call_arr.append(current_tool_call)
+
+ return delta
+
+ except Exception:
+ logger.exception("Error trying to handle streaming tool call.")
+ return None # do not stream a delta. skip this token ID.
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9215e7979bf534303ada53e5e7b8c9b54a89c08
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
@@ -0,0 +1,210 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from typing import Dict, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.entrypoints.openai.tool_parsers.utils import (
+ extract_intermediate_diff)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module(["internlm"])
+class Internlm2ToolParser(ToolParser):
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+ self.position = 0
+
+ def adjust_request(
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
+ if request.tools and request.tool_choice != 'none':
+ # do not skip special tokens because internlm use the special
+ # tokens to indicated the start and end of the tool calls
+ # information.
+ request.skip_special_tokens = False
+ return request
+
+ def get_argments(self, obj):
+ if "parameters" in obj:
+ return obj.get("parameters")
+ elif "arguments" in obj:
+ return obj.get("arguments")
+ return None
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+ if '<|action_start|>' not in current_text:
+ self.position = len(current_text)
+ return DeltaMessage(content=delta_text)
+ # if the tool call is sended, return a empty delta message
+ # to make sure the finish_reason will be send correctly.
+ if self.current_tool_id > 0:
+ return DeltaMessage(content='')
+
+ last_pos = self.position
+ if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
+ return None
+
+ new_delta = current_text[last_pos:]
+ text, action = new_delta.split('<|action_start|><|plugin|>')
+
+ if len(text) > 0:
+ self.position = self.position + len(text)
+ return DeltaMessage(content=text)
+
+ action = action.strip()
+ action = action.split('<|action_end|>'.strip())[0]
+
+ # bit mask flags for partial JSON parsing. If the name hasn't been
+ # sent yet, don't allow sending
+ # an incomplete string since OpenAI only ever (as far as I have
+ # seen) allows sending the entire tool/ function name at once.
+ flags = Allow.ALL if self.current_tool_name_sent \
+ else Allow.ALL & ~Allow.STR
+
+ try:
+ parsable_arr = action
+
+ # tool calls are generated in an object in inernlm2
+ # it's not support parallel tool calls
+ try:
+ tool_call_arr: Dict = partial_json_parser.loads(
+ parsable_arr, flags)
+ except partial_json_parser.core.exceptions.MalformedJSON:
+ logger.debug('not enough tokens to parse into JSON yet')
+ return None
+
+ # if the current tool name hasn't been sent, send if available
+ # - otherwise send nothing
+ if not self.current_tool_name_sent:
+ function_name = tool_call_arr.get("name")
+ if function_name:
+ self.current_tool_id = self.current_tool_id + 1
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True))
+ ])
+ self.current_tool_name_sent = True
+ self.streamed_args_for_tool.append("")
+ else:
+ delta = None
+ # now we know we're on the same tool call and we're streaming
+ # arguments
+ else:
+ prev_arguments = self.get_argments(
+ self.prev_tool_call_arr[self.current_tool_id])
+ cur_arguments = self.get_argments(tool_call_arr)
+
+ # not arguments generated
+ if not cur_arguments and not prev_arguments:
+ delta = None
+ # will never happen
+ elif not cur_arguments and prev_arguments:
+ logger.error(
+ "INVARIANT - impossible to have arguments reset "
+ "mid-arguments")
+ delta = None
+ # first time to get parameters
+ elif cur_arguments and not prev_arguments:
+ cur_arguments_json = json.dumps(cur_arguments)
+
+ arguments_delta = cur_arguments_json[:cur_arguments_json.
+ index(delta_text) +
+ len(delta_text)]
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=arguments_delta).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += arguments_delta
+ # both prev and cur parameters, send the increase parameters
+ elif cur_arguments and prev_arguments:
+ cur_args_json = json.dumps(cur_arguments)
+ prev_args_json = json.dumps(prev_arguments)
+
+ argument_diff = extract_intermediate_diff(
+ cur_args_json, prev_args_json)
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).model_dump(
+ exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+
+ # check to see if the name is defined and has been sent. if so,
+ # stream the name - otherwise keep waiting
+ # finish by setting old and returning None as base case
+ tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
+ self.prev_tool_call_arr = [tool_call_arr]
+ return delta
+ except Exception:
+ logger.exception("Error trying to handle streaming tool call.")
+ logger.debug(
+ "Skipping chunk as a result of tool streaming extraction "
+ "error")
+ return None
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ text = model_output
+ tools = request.tools
+ if '<|action_start|><|plugin|>' in text:
+ text, action = text.split('<|action_start|><|plugin|>')
+ action = action.split('<|action_end|>'.strip())[0]
+ action = action[action.find('{'):]
+ action_dict = json.loads(action)
+ name, parameters = action_dict['name'], json.dumps(
+ action_dict.get('parameters', action_dict.get('arguments',
+ {})))
+
+ if not tools or name not in [t.function.name for t in tools]:
+ ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=text)
+
+ tool_calls = [
+ ToolCall(
+ function=FunctionCall(name=name, arguments=parameters))
+ ]
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=text if len(text) > 0 else None)
+
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=text)
diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c282b5c2605a6cc8a9c85a377cf7f31c8aab967
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
@@ -0,0 +1,291 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import ast
+import json
+import re
+from typing import Any, Sequence, Tuple, Union
+
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class _UnexpectedAstError(Exception):
+ pass
+
+
+@ToolParserManager.register_module("pythonic")
+class PythonicToolParser(ToolParser):
+ """
+ Tool call parser for models that produce tool calls in a pythonic style,
+ such as Llama 3.2 models.
+
+ Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
+ """
+ # TODO(mdepinet): Possible future improvements:
+ # 1. Support text + tools separated by either <|python_tag|> or \n\n
+ # 2. Support tools outside of a list (or separated by a semicolon).
+ # This depends on item 1 for consistent streaming.
+ # Neither of these are necessary for e.g. ToolACE, but both would help make
+ # Llama3.2 models more reliable.
+
+ TOOL_CALL_REGEX = re.compile(
+ r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
+ re.DOTALL)
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
+ super().__init__(tokenizer)
+
+ # Rename for readability. This is NOT a tool id.
+ @property
+ def current_tool_index(self) -> int:
+ return self.current_tool_id
+
+ @current_tool_index.setter
+ def current_tool_index(self, value: int) -> None:
+ self.current_tool_id = value
+
+ def extract_tool_calls(
+ self, model_output: str,
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ """
+ Extract the tool calls from a complete model response.
+ """
+
+ if not (self.TOOL_CALL_REGEX.match(model_output)):
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ try:
+ module = ast.parse(model_output)
+ parsed = getattr(module.body[0], "value", None)
+ if isinstance(parsed, ast.List) and all(
+ isinstance(e, ast.Call) for e in parsed.elts):
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=[
+ _handle_single_tool(e) # type: ignore
+ for e in parsed.elts
+ ],
+ content=None)
+ else:
+ raise _UnexpectedAstError(
+ "Tool output must be a list of function calls")
+ except Exception:
+ logger.exception("Error in extracting tool call from response.")
+ # Treat as regular text
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ if not current_text.startswith("["):
+ return DeltaMessage(content=delta_text)
+
+ try:
+ valid_and_added_text = _make_valid_python(current_text)
+ if valid_and_added_text is None:
+ return None
+ valid_text, added_text = valid_and_added_text
+
+ module = ast.parse(valid_text)
+ parsed = getattr(module.body[0], "value", None)
+ if not isinstance(parsed, ast.List) or not all(
+ isinstance(e, ast.Call) for e in parsed.elts):
+ raise _UnexpectedAstError(
+ "Tool output must be a list of function calls")
+ tool_calls = [
+ _handle_single_tool(e) # type: ignore
+ for e in parsed.elts
+ ]
+
+ tool_deltas = []
+ for index, new_call in enumerate(tool_calls):
+ if index < self.current_tool_index:
+ continue
+
+ self.current_tool_index = index
+ if len(self.streamed_args_for_tool) == index:
+ self.streamed_args_for_tool.append("")
+
+ new_call_complete = index < len(
+ tool_calls) - 1 or ")]" not in added_text
+ if new_call_complete:
+ self.current_tool_index += 1
+
+ withheld_suffix = (added_text[:-2]
+ if not new_call_complete else "")
+ if not new_call_complete and added_text[-2] == ")":
+ # Function call is incomplete. Withhold the closing bracket.
+ withheld_suffix = withheld_suffix + "}"
+ # Strings get single quotes in the model-produced string.
+ # JSON requires double quotes.
+ withheld_suffix = withheld_suffix.replace("'", '"')
+ delta = _compute_tool_delta(self.streamed_args_for_tool[index],
+ new_call, index, withheld_suffix)
+
+ if delta is not None:
+ tool_deltas.append(delta)
+ if (delta.function is not None
+ and delta.function.arguments is not None):
+ self.streamed_args_for_tool[
+ index] += delta.function.arguments
+
+ # HACK: serving_chat.py inspects the internal state of tool parsers
+ # when determining it's final streaming delta, automatically
+ # adding autocompleted JSON.
+ # These two lines avoid that nonsense while ensuring finish_reason
+ # is set to tool_calls when at least one tool is called.
+ if tool_deltas and not self.prev_tool_call_arr:
+ self.prev_tool_call_arr = [{"arguments": {}}]
+
+ if tool_deltas:
+ return DeltaMessage(tool_calls=tool_deltas)
+ elif not added_text and self.current_tool_id > 0:
+ # Return an empty DeltaMessage once the tool calls are all done
+ # so that finish_reason gets set.
+ return DeltaMessage(content='')
+ else:
+ return None
+ except Exception:
+ logger.exception("Error trying to handle streaming tool call.")
+ logger.debug(
+ "Skipping chunk as a result of tool streaming extraction "
+ "error")
+ return None
+
+
+def _get_parameter_value(val: ast.expr) -> Any:
+ if isinstance(val, ast.Constant):
+ return val.value
+ elif isinstance(val, ast.Dict):
+ if not all(isinstance(k, ast.Constant) for k in val.keys):
+ raise _UnexpectedAstError(
+ "Dict tool call arguments must have literal keys")
+ return {
+ k.value: _get_parameter_value(v) # type: ignore
+ for k, v in zip(val.keys, val.values)
+ }
+ elif isinstance(val, ast.List):
+ return [_get_parameter_value(v) for v in val.elts]
+ else:
+ raise _UnexpectedAstError("Tool call arguments must be literals")
+
+
+def _handle_single_tool(call: ast.Call) -> ToolCall:
+ if not isinstance(call.func, ast.Name):
+ raise _UnexpectedAstError("Invalid tool call name")
+ function_name = call.func.id
+ arguments = {}
+ for keyword in call.keywords:
+ arguments[keyword.arg] = _get_parameter_value(keyword.value)
+ return ToolCall(type="function",
+ function=FunctionCall(name=function_name,
+ arguments=json.dumps(arguments)))
+
+
+def _make_valid_python(text: str) -> Union[Tuple[str, str], None]:
+ bracket_stack = []
+ for index, char in enumerate(text):
+ if char in {"[", "(", "{"}:
+ bracket_stack.append(char)
+ elif char == "]":
+ if not bracket_stack or bracket_stack.pop() != "[":
+ raise _UnexpectedAstError("Mismatched square brackets")
+ elif char == ")":
+ if not bracket_stack or bracket_stack.pop() != "(":
+ raise _UnexpectedAstError("Mismatched parentheses")
+ elif char == "}":
+ if not bracket_stack or bracket_stack.pop() != "{":
+ raise _UnexpectedAstError("Mismatched curly braces")
+ elif char in {"'", '"'}:
+ if bracket_stack and bracket_stack[-1] == char:
+ if index > 0 and text[index - 1] == "\\":
+ # Treat an escaped quote as a regular character
+ pass
+ else:
+ bracket_stack.pop()
+ elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
+ # Double quote within a single quote string or vice versa.
+ pass
+ else:
+ bracket_stack.append(char)
+
+ text = text.rstrip()
+ if text.endswith("=") or text.endswith(":"):
+ # Since we have no type information for this property/parameter value,
+ # we can't fill in a valid value.
+ return None
+ if bracket_stack and bracket_stack[-1] == "{":
+ trailing_dict_text = text[:text.rfind("{")]
+ num_keys = trailing_dict_text.count(":")
+ num_values = trailing_dict_text.count(",")
+ if num_keys <= num_values:
+ return None # Incomplete property name within parameter value
+ if bracket_stack and bracket_stack[-1] == "(":
+ trailing_params_text = text[:text.rfind("(")]
+ num_full_param_names = trailing_params_text.count("=")
+ num_full_param_values = trailing_params_text.count(",")
+ if num_full_param_names <= num_full_param_values:
+ return None # Incomplete parameter name
+ if text.endswith(","):
+ text = text[:-1]
+ if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
+ "[") and not text.endswith(")"):
+ return None # Incomplete function name
+
+ added_text = ""
+ for char in reversed(bracket_stack):
+ if char == "[":
+ added_text += "]"
+ elif char == "(":
+ added_text += ")"
+ elif char == "{":
+ added_text += "}"
+ elif char == "'":
+ added_text += "'"
+ elif char == '"':
+ added_text += '"'
+
+ return text + added_text, added_text
+
+
+def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
+ index: int,
+ withheld_suffix: str) -> Union[DeltaToolCall, None]:
+ new_call_args = new_call.function.arguments
+ if withheld_suffix:
+ assert new_call_args.endswith(withheld_suffix)
+ new_call_args = new_call_args[:-len(withheld_suffix)]
+ if not previously_sent_args:
+ return DeltaToolCall(id=new_call.id,
+ index=index,
+ function=DeltaFunctionCall(
+ name=new_call.function.name,
+ arguments=new_call_args,
+ ))
+
+ arg_diff = new_call_args[len(previously_sent_args):]
+ return DeltaToolCall(
+ id="", index=index, function=DeltaFunctionCall(
+ arguments=arg_diff)) if arg_diff else None
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..915fc6623398e2aa2ff67723aa3770d35b4aa1db
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
+from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
+
+__all__ = [
+ "PunicaWrapperBase",
+ "get_punica_wrapper",
+]
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8fcac7ee56f284efe9c5f7b35c4c9780554e66b2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49ef3b15f0064b8f6131b947e45659006426b81e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d37cbc3d89ee519acca8e7283a229c0d670745b6
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35e257a075e948d1eb0a311ed593481a99a86c7f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ec82ad99ee725e3babf55d6bc4152029e2d406e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a2282ae9accd5ca26691f64d4e4d7ca6dce0d49
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py
@@ -0,0 +1,483 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Based on:
+Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
+Punica: Multi-Tenant LoRA Serving.
+https://arxiv.org/abs/2310.18547
+"""
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch
+
+from .utils import compute_meta, convert_mapping
+
+if TYPE_CHECKING:
+ # avoid circuit import
+ from vllm.lora.layers import LoRAMapping
+ from vllm.lora.models import LongContextLoRAContext
+
+
+class PunicaWrapperABC(ABC):
+ """
+ PunicaWrapper ABC.
+ """
+
+ @abstractmethod
+ def update_metadata(
+ self,
+ mapping: "LoRAMapping",
+ lora_index_to_id: List[Optional[int]],
+ max_loras: int,
+ vocab_size: int,
+ extra_vocab_size: int,
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Update the lora-related metadata
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_shrink(
+ self,
+ y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ x: torch.Tensor,
+ lora_a_stacked: Tuple[torch.Tensor, ...],
+ scale: float,
+ **kwargs,
+ ) -> None:
+ """
+ Performs GEMM for multiple slices of lora_a.
+ """
+
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_expand(
+ self,
+ y: torch.Tensor,
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ output_slices: Tuple[int, ...],
+ offset_start: int = 0,
+ add_inputs=True,
+ **kwargs,
+ ) -> None:
+ """
+ Performs GEMM and bias addition for multiple slices of lora_b.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_lora_embedding(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ add_inputs: bool = True,
+ **kwargs,
+ ) -> None:
+ """
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA,
+ and this layer only requires the expand operation.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_lora_linear(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: Tuple[torch.Tensor, ...],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ scale: float,
+ output_slices: Tuple[int, ...],
+ *,
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
+ **kwargs) -> None:
+ """
+ Applicable to linear-related lora.
+ """
+
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_lora_logits(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ scale,
+ *,
+ buffer: Optional[torch.Tensor] = None,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for LogitsProcessorWithLoRA.
+ """
+ raise NotImplementedError
+
+
+class PunicaWrapperBase(PunicaWrapperABC):
+ """
+ PunicaWrapperBase is designed to manage and provide metadata for the punica
+ kernel. The main function is to maintain the state information for
+ Multi-LoRA, and to provide the interface for the punica.
+ """
+
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
+ device: Union[torch.device, str], **kwargs):
+ self._token_lora_indices = torch.empty(max_num_batched_tokens,
+ dtype=torch.long,
+ device=device)
+ self._sampler_indices = torch.empty(max_num_batched_tokens,
+ dtype=torch.long,
+ device=device)
+ self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
+ dtype=torch.long,
+ device=device)
+ self._embeddings_indices = torch.empty(2,
+ max_num_batched_tokens,
+ dtype=torch.long,
+ device=device)
+ self._long_lora_indices = torch.empty(max_num_batched_tokens,
+ dtype=torch.long,
+ device=device)
+
+ # 5 is the number of indicies tensors.
+ # base_indices, sampler_indices, sampler_indices_padded,
+ # embeddings_indices,long_lora_indices
+ self.indices_len: List[Optional[int]] = [None] * 5
+ # these attributes are the information required for sgmv kernel
+ self._seq_start_locs = torch.empty(max_batches,
+ dtype=torch.long,
+ device=device)
+ self._seq_lengths = torch.empty(max_batches,
+ dtype=torch.long,
+ device=device)
+ self._lora_indices_per_batch = torch.empty(max_batches,
+ dtype=torch.long,
+ device=device)
+ self.device: torch.device = device
+ self.max_length: int = 0
+ self.token_nums: int = 0
+ self.batch_size: int = -1
+ self.is_prefill = False
+ self.no_lora = False
+
+ def _update_base_metadata(
+ self,
+ mapping: "LoRAMapping",
+ lora_index_to_id: List[Optional[int]],
+ max_loras: int,
+ vocab_size: int,
+ extra_vocab_size: int,
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
+ ):
+ (
+ base_indices,
+ sampler_indices,
+ sampler_indices_padded,
+ embeddings_indices,
+ long_lora_offsets_tensor,
+ indices_len,
+ ) = convert_mapping(
+ mapping,
+ lora_index_to_id,
+ max_loras,
+ vocab_size,
+ extra_vocab_size,
+ self.device,
+ long_lora_context,
+ )
+ self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
+ self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
+ self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
+ sampler_indices_padded)
+ self._embeddings_indices[:embeddings_indices.
+ shape[0], :embeddings_indices.shape[1]].copy_(
+ embeddings_indices)
+ if long_lora_offsets_tensor is not None:
+ self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
+ long_lora_offsets_tensor)
+ else:
+ self._long_lora_indices.zero_()
+ self.indices_len[:] = indices_len
+
+ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
+
+ (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
+ batch_size, max_length, token_nums,
+ no_lora) = compute_meta(token_lora_tensor)
+
+ self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
+ b_seq_start_tensor)
+ self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
+ self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
+ lora_indices_tensor)
+ self.batch_size = batch_size
+ self.max_length = max_length
+ self.token_nums = token_nums
+ self.no_lora = no_lora
+
+ def _apply_bias(
+ self,
+ indices: torch.Tensor,
+ output: torch.Tensor,
+ output_slices: Tuple[int, ...],
+ lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
+ ):
+ """Applies bias to output
+
+ Input shapes:
+ lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
+ indices: (batch_size)
+ output: (batch_size, q_slice_size + 2*kv_slice_size)
+ output_slices: n-1 element tuple of (slice_size...),
+ where n is number of slices
+ """
+ org_output = output
+ output = output.view(-1, output.shape[-1])
+ indices = indices.view(-1)
+
+ offset_left = 0
+ for slice_idx, slice in enumerate(output_slices):
+ bias = lora_bias_stacked[slice_idx]
+ if bias is not None:
+ bias = bias.view(-1, bias.shape[-1])
+ bias = bias[indices]
+ bias[indices == -1] = 0
+ output[:, offset_left:offset_left + slice] += bias
+ offset_left += slice
+
+ return output.view_as(org_output)
+
+ @property
+ def prefill_metadata(
+ self
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
+ """
+ This property provides a convenient way to access the necessary
+ metadata for prefill-related kernel computations.
+ 1. seq_start_locs: Tensor of sequence start positions.
+ 2. seq_lengths: Tensor of sequence lengths.
+ 3. lora_indices_per_batch: Tensor of lora indices, and an index of
+ -1 means no lora should be applied.
+ 4. batch_size: Batch size after clustering identical lora indices.
+ 5. max_length: The maximum sequence length in the batch.
+ 6. token_nums: The token numbers in the batch.
+ """
+ return (self._seq_start_locs[:self.batch_size],
+ self._seq_lengths[:self.batch_size],
+ self._lora_indices_per_batch[:self.batch_size],
+ self.batch_size, self.max_length, self.token_nums)
+
+ @property
+ def token_lora_indices(self) -> torch.Tensor:
+ """
+ This property provides the lora indices corresponding to each token
+ in the batch. An index of -1 means no lora should be applied.
+ """
+ token_lora_len = self.indices_len[0]
+ return self._token_lora_indices[:token_lora_len]
+
+ @property
+ def sampler_indices(self) -> torch.Tensor:
+ """
+ This property is used to access the lora indices specifically for
+ LogitsProcessorWithLoRA.
+ """
+ sampler_indices_len = self.indices_len[1]
+ return self._sampler_indices[:sampler_indices_len]
+
+ @property
+ def sampler_indices_padded(self) -> torch.Tensor:
+ """
+ This property provides access to padded sampler indices.
+ """
+ indices_padded_len = self.indices_len[2]
+ return self._sampler_indices_padded[:indices_padded_len]
+
+ @property
+ def embeddings_indices(self) -> torch.Tensor:
+ """
+ This property provides access to the indices used for lora embeddings,
+ specifically for VocabParallelEmbeddingWithLoRA.
+ """
+ embeddings_indices_len = self.indices_len[3]
+ return self._embeddings_indices[:, :embeddings_indices_len]
+
+ @property
+ def long_lora_indices(self) -> torch.Tensor:
+ """
+ This property provides access to the indices used for long context
+ lora, specifically for LinearScalingRotaryEmbeddingWithLora.
+ """
+ long_lora_len = self.indices_len[4]
+ return self._long_lora_indices[:long_lora_len]
+
+ def update_metadata(
+ self,
+ mapping: "LoRAMapping",
+ lora_index_to_id: List[Optional[int]],
+ max_loras: int,
+ vocab_size: int,
+ extra_vocab_size: int,
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
+ **kwargs):
+
+ self._update_base_metadata(mapping, lora_index_to_id, max_loras,
+ vocab_size, extra_vocab_size,
+ long_lora_context)
+ if mapping.is_prefill:
+ # Update metadata required for prefill-related operators.
+ self._update_prefill_metada(self.token_lora_indices)
+ self.is_prefill = True
+ else:
+ self.is_prefill = False
+
+ @abstractmethod
+ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
+ scale: float, **kwargs) -> None:
+ """
+ Performs GEMM for multiple slices of lora_a.
+
+ Semantics:
+ for i in range(len(lora_a_stacked)):
+ y[i] += (x @ lora_a_stacked[i]) * scale
+
+ Args:
+ y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
+ x (torch.Tensor): Input tensor
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
+ scale (float): Scaling factor for the operation
+
+ """
+ # TODO: implement it based on torch ops
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_expand(self,
+ y: torch.Tensor,
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ output_slices: Tuple[int, ...],
+ offset_start: int = 0,
+ add_inputs=True,
+ **kwargs) -> None:
+ """
+ Performs GEMM and bias addition for multiple slices of lora_b.
+
+ Semantics:
+ offset = offset_start
+ for i in range(len(lora_b_stacked)):
+ slice = output_slices[i]
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
+ lora_bias_stacked[i]
+ offset += slice
+
+ Args:
+ y (torch.Tensor): Output tensor.
+ x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
+ bias's weight
+ output_slices (Tuple[int, ...]): Every slice's size
+ offset_start (int): The starting position of y, defaults to 0
+ add_inputs (bool): Defaults to True.
+
+ """
+ # TODO: implement it based on torch ops
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_lora_embedding(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ add_inputs: bool = True,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
+ and this layer only requires the expand operation.
+ Semantics:
+ y += x @ lora_b_stacked
+
+ Args:
+ y (torch.Tensor): Output tensor.
+ x (torch.Tensor): Input tensor.
+ lora_b_stacked (torch.Tensor): lora_b's weights.
+ add_inputs (bool): Default to True.
+ """
+ # TODO: implement it based on torch ops
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_lora_linear(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: Tuple[torch.Tensor, ...],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ scale: float,
+ output_slices: Tuple[int, ...],
+ *,
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
+ **kwargs) -> None:
+ """
+ Applicable to linear-related lora.
+
+ Semantics:
+ for i in range(len(lora_a_stacked)):
+ y[i] += (
+ x[i].unsqueeze(0)
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
+ * scale
+ ).squeeze(0)+lora_bias_stacked[i]
+
+ Args:
+ y (torch.Tensor): Output tensor. Will be changed in-place.
+ x (torch.Tensor): Input tensor
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
+ scale (float): Scaling factor.
+ output_slices (Tuple[int, ...]): Every slice's size.
+ buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
+ """
+ # TODO: implement it based on torch ops
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_lora_logits(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ scale,
+ *,
+ buffer: Optional[torch.Tensor] = None,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for LogitsProcessorWithLoRA.
+
+ Semantics:
+ buffer = (x @ lora_a_stacked) * scale
+ y += buffer @ lora_b_stacked
+
+ Args:
+ y (torch.Tensor): Output tensor.
+ x (torch.Tensor): Input tensor.
+ lora_a_stacked (torch.Tensor): lora_a's weights.
+ lora_b_stacked (torch.Tensor):lora_b's weights.
+ scale (float): Scaling factor.
+ buffer (Optional[torch.Tensor]):Default to None.
+ """
+ # TODO: implement it based on torch ops
+ raise NotImplementedError
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..29428f4cfff3175e782618cbd16c727d56771798
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py
@@ -0,0 +1,348 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+
+from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
+ bgmv_shrink, sgmv_expand,
+ sgmv_expand_slice, sgmv_shrink)
+
+from .punica_base import PunicaWrapperBase
+
+
+# The platforms that are compatible with the PyTorch-native implementation can
+# inherit this class
+class PunicaWrapperCPU(PunicaWrapperBase):
+ """
+ PunicaWrapperCPU is designed to manage and provide metadata for the punica
+ kernel. The main function is to maintain the state information for
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
+ """
+
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
+ device: Union[torch.device, str], **kwargs):
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
+ device)
+
+ def _shrink_prefill(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ scale: float,
+ ):
+ #No LoRA request, so return directly
+ if self.no_lora:
+ return
+ sgmv_shrink(
+ x,
+ w_t_all,
+ y,
+ *self.prefill_metadata,
+ scale,
+ )
+
+ def _shrink_decode(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ scale: float,
+ ):
+ bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
+
+ def _expand_prefill(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ add_inputs: bool,
+ ):
+ #No LoRA request, so return directly
+ if self.no_lora:
+ return
+ sgmv_expand(
+ x,
+ w_t_all,
+ y,
+ *self.prefill_metadata,
+ add_inputs,
+ )
+
+ def _expand_decode(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ add_inputs: bool,
+ ):
+ bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
+
+ def _expand_slice_prefill(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ y_offset: int,
+ y_slice_size: int,
+ add_inputs: bool,
+ ):
+ #No LoRA request, so return directly
+ if self.no_lora:
+ return
+ sgmv_expand_slice(
+ x,
+ w_t_all,
+ y,
+ *self.prefill_metadata,
+ y_offset,
+ y_slice_size,
+ add_inputs,
+ )
+
+ def _expand_slice_decode(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ y_offset: int,
+ y_slice_size: int,
+ add_inputs: bool,
+ ):
+ bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
+ y_slice_size, add_inputs)
+
+ def _apply_expand(
+ self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ y_offset: int,
+ y_slice_size: int,
+ add_inputs: bool = True,
+ ):
+ """
+ Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
+ computation, which is suitable for the
+ GEMM of lora'b.
+ """
+
+ expand_slice_fun: Callable = (self._expand_slice_prefill
+ if self.is_prefill else
+ self._expand_slice_decode)
+ expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
+
+ def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
+ w_t_all: torch.Tensor, scale: float):
+ """
+ Perform the ` y+=x@w_t_all` computation, which is suitable for the
+ GEMM of lora'a.
+ When `is_prefill is` true, it indicates that it is currently the
+ prefill stage, and the `_shrink_prefill` function should be called.
+ Otherwise, it is the decode stage, and the _shrink_decode function
+ should be called.
+ """
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ shrink_fun: Callable = (self._shrink_prefill
+ if self.is_prefill else self._shrink_decode)
+ shrink_fun(y, x, w_t_all, scale)
+ y = y.view_as(y_org)
+
+ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
+ scale: float, **kwargs):
+ """
+ Performs GEMM for multiple slices of lora_a.
+ When `is_prefill is` true, it indicates that it is currently the
+ prefill stage, and the `_shrink_prefill` function should be called.
+ Otherwise, it is the decode stage, and the _shrink_decode function
+ should be called.
+
+ Semantics:
+ for i in range(len(lora_a_stacked)):
+ y[i] += (x @ lora_a_stacked[i]) * scale
+
+ Args:
+ y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
+ x (torch.Tensor): Input tensor
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
+ scale (float): Scaling factor for the operation
+ """
+
+ x = x.view(-1, x.shape[-1])
+ # TODO fuse these kernels
+ for slice_idx in range(len(lora_a_stacked)):
+ self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
+ scale)
+
+ def add_expand(self,
+ y: torch.Tensor,
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ output_slices: Tuple[int, ...],
+ offset_start: int = 0,
+ add_inputs=True,
+ **kwargs) -> None:
+ """
+ Performs GEMM and bias addition for multiple slices of lora_b.
+
+ Semantics:
+ for i in range(len(lora_b_stacked)):
+ slice = output_slices[i]
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
+ lora_bias_stacked[i]
+ offset += slice
+
+ Args:
+ y (torch.Tensor): Output tensor.
+ x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
+ bias's weight
+ output_slices (Tuple[int, ...]): Every slice's size
+ add_inputs (bool): Defaults to True.
+ """
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ offset_left = offset_start
+ if lora_bias_stacked is not None:
+ self._apply_bias(self.token_lora_indices, y, output_slices,
+ lora_bias_stacked)
+ for slice_idx in range(len(lora_b_stacked)):
+ self._apply_expand(
+ y,
+ x[slice_idx],
+ lora_b_stacked[slice_idx],
+ offset_left,
+ output_slices[slice_idx],
+ add_inputs=add_inputs,
+ )
+ offset_left += output_slices[slice_idx]
+ y = y.view_as(y_org)
+
+ def add_lora_embedding(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ add_inputs: bool = True,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
+
+ Semantics:
+ y += x @ lora_b_stacked
+
+ Args:
+ y (torch.Tensor): Output tensor.
+ x (torch.Tensor): Input tensor.
+ lora_b_stacked (torch.Tensor): lora_b's weights.
+ add_inputs (bool): Default to True.
+ """
+
+ # Embedding layer only need expand op
+ expand_fun: Callable = (self._expand_prefill
+ if self.is_prefill else self._expand_decode)
+ expand_fun(y, x, lora_b_stacked, add_inputs)
+
+ def add_lora_linear(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: Tuple[torch.Tensor, ...],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ scale: float,
+ output_slices: Tuple[int, ...],
+ *,
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
+ **kwargs) -> None:
+ """
+ Applicable to linear-related lora.
+
+ Semantics:
+ for i in range(len(lora_a_stacked)):
+ y[i] += (
+ x[i].unsqueeze(0)
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
+ * scale
+ ).squeeze(0)+lora_bias_stacked[i]
+
+ Args:
+ y (torch.Tensor): Output tensor. Will be changed in-place.
+ x (torch.Tensor): Input tensor
+ lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
+ lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
+ lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
+ scale (float): Scaling factor.
+ output_slices (Tuple[int, ...]): Every slice's size.
+ buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
+ """
+
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
+ if lora_bias_stacked is not None:
+ assert len(lora_bias_stacked) == len(output_slices)
+ y = self._apply_bias(self.token_lora_indices, y, output_slices,
+ lora_bias_stacked)
+
+ if buffer is None:
+ r = lora_b_stacked[0].size(-1)
+ # We set the buffer to be float32 by default, consistent with the
+ # triton op
+ buffer = tuple(
+ torch.zeros(
+ (x.size(0), r), dtype=torch.float32, device=x.device)
+ for _ in range(len(output_slices)))
+ self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
+ self.add_expand(y,
+ buffer,
+ lora_b_stacked,
+ None,
+ output_slices,
+ add_inputs=True,
+ **kwargs)
+
+ def add_lora_logits(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ scale,
+ *,
+ buffer: Optional[torch.Tensor] = None,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for LogitsProcessorWithLoRA.
+
+ Semantics:
+ buffer = (x @ lora_a_stacked) * scale
+ y += buffer @ lora_b_stacked
+
+ Args:
+ y (torch.Tensor): Output tensor.
+ x (torch.Tensor): Input tensor.
+ lora_a_stacked (torch.Tensor): lora_a's weights.
+ lora_b_stacked (torch.Tensor):lora_b's weights.
+ scale (float): Scaling factor.
+ buffer (Optional[torch.Tensor]):Default to None.
+ """
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ x = x.view(-1, x.shape[-1])
+ r = lora_b_stacked.size(-1)
+ if buffer is None:
+ # We set the buffer to be float32 by default, consistent with the
+ # triton op
+ buffer = torch.zeros((x.size(0), r),
+ dtype=torch.float32,
+ device=x.device)
+ # LogitsProcessorWithLoRA always using bgmv.
+ bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
+ bgmv_expand(buffer,
+ lora_b_stacked,
+ y,
+ self.sampler_indices,
+ add_inputs=True)
+ y = y.view_as(y_org)
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e1bfab3f5136ab17732b2578d864fe3e0043d7
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py
@@ -0,0 +1,89 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional, Tuple, Union, final
+
+import torch
+from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
+ dispatch_bgmv_linear)
+
+from .punica_base import PunicaWrapperBase
+
+
+@final
+class PunicaWrapperHPU(PunicaWrapperBase):
+
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
+ device: Union[torch.device, str], **kwargs):
+ # Increasing max_num_batched_tokens by 3x to handle increase in
+ # tensor size due to padding.
+ PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
+ max_batches, device)
+
+ def add_lora_embedding(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ add_inputs: bool = True,
+ **kwargs) -> None:
+ dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
+
+ def add_lora_linear(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: Tuple[torch.Tensor, ...],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ scale: float,
+ output_slices: Tuple[int, ...],
+ *,
+ buffer: Optional[Tuple[torch.Tensor, ...]] = None,
+ **kwargs) -> None:
+ y_org = y
+ x = x.view(-1, x.shape[-1])
+ y = y.view(-1, y.shape[-1])
+ offset_left = 0
+
+ for slice_idx in range(len(output_slices)):
+ dispatch_bgmv_linear(
+ y[:, offset_left:offset_left + output_slices[slice_idx]], x,
+ lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
+ offset_left += output_slices[slice_idx]
+ y = y.view_as(y_org)
+
+ def add_lora_logits(self,
+ y: torch.Tensor,
+ x: torch.Tensor,
+ lora_a_stacked: torch.Tensor,
+ lora_b_stacked: torch.Tensor,
+ scale,
+ *,
+ buffer: Optional[torch.Tensor] = None,
+ **kwargs) -> None:
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ x = x.view(-1, x.shape[-1])
+ dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
+ y = y.view_as(y_org)
+
+ def add_shrink(
+ self,
+ y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ x: torch.Tensor,
+ lora_a_stacked: Tuple[torch.Tensor, ...],
+ scale: float,
+ **kwargs,
+ ) -> None:
+ raise NotImplementedError
+
+ def add_expand(
+ self,
+ y: torch.Tensor,
+ x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
+ lora_b_stacked: Tuple[torch.Tensor, ...],
+ lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
+ output_slices: Tuple[int, ...],
+ offset_start: int = 0,
+ add_inputs=True,
+ **kwargs,
+ ) -> None:
+ raise NotImplementedError
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad5d4b788ec435a970f35d1625e127e3d3812bcd
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.utils import resolve_obj_by_qualname
+
+from .punica_base import PunicaWrapperBase
+
+logger = init_logger(__name__)
+
+
+def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
+ punica_wrapper_qualname = current_platform.get_punica_wrapper()
+ punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
+ punica_wrapper = punica_wrapper_cls(*args, **kwargs)
+ assert punica_wrapper is not None, \
+ "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
+ logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
+ ".")
+ return punica_wrapper
diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc2d27c597f20c8a5aa79e87b96b8f76e9979bc
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py
@@ -0,0 +1,161 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch
+
+if TYPE_CHECKING:
+ # avoid circuit import
+ from vllm.lora.layers import LoRAMapping
+ from vllm.lora.models import LongContextLoRAContext
+
+
+def compute_meta(
+ token_lora_tensor: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
+ """
+ Get the information required for the sgmv kernel. With the features:
+ 1. If consecutive requests in the batch use the same LoRA, this function
+ will combine them into a single request, improving sgmv kernel inference
+ performance.
+ 2. At the beginning of each prefill stage inference, recalculations are
+ needed based on the input, but only once.
+ """
+
+ lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
+ token_lora_tensor, return_counts=True)
+ cum_result = torch.cumsum(seq_length_tensor, dim=0)
+ b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
+ b_seq_start_tensor[1:].copy_(cum_result[:-1])
+ max_length = seq_length_tensor.max().item()
+ token_nums = seq_length_tensor.sum().item()
+ batch_size = lora_indices_tensor.size(0)
+ no_lora = False
+ # -1 means no lora should be applied. Use `no_lora` to determine whether
+ # the current step requires LoRA. If LoRA is not needed, the prefill stage
+ # does not need to launch the triton kernel, which can improve performance
+ if batch_size == 1 and lora_indices_tensor == -1:
+ no_lora = True
+ return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
+ batch_size, max_length, token_nums, no_lora)
+
+
+# TODO see if this can be vectorized
+def convert_mapping(
+ mapping: "LoRAMapping",
+ lora_index_to_id: List[Optional[int]],
+ max_loras: int,
+ vocab_size: int,
+ extra_vocab_size: int,
+ device: torch.device,
+ long_lora_context: Optional["LongContextLoRAContext"] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
+ Optional[torch.Tensor], List[int]]:
+ """Converts LoRAMapping to index tensors.
+
+ Args:
+ mapping: LoRAMapping mapping rows in a batch to LoRA ids.
+ lora_index_to_id: List mapping LoRA ids to LoRA indices.
+ max_loras: Maximum number of LoRAs.
+ vocab_size: Model vocab size.
+ extra_vocab_size: Extra vocab size each LoRA can have.
+ long_lora_context: Passed if there are long context lora in a batch.
+
+ Returns:
+ A tuple of tensors:
+ base_indices: Tensor of shape [batch_size] mapping batch rows to
+ LoRA indices.
+ sampler_indices: Tensor of shape [batch_size] mapping requests to
+ LoRA indices for sampler. For generation, this will be the
+ same as base_indicies. For prefill, this will map requests
+ to LoRA indices.
+ sampler_indices_padded: Tensor of shape [batch_size] mapping
+ requests to LoRA indices for sampler with padding.
+ Same as sampler_indicies, but -1 is replaced with
+ max_loras.
+ embeddings_indices: Tensor of shape [2, batch_size] mapping
+ requests to embedding indices. First row is for embeddings
+ added by the LoRAs, second row is for the LoRA.lora_a
+ embeddings.
+ long_lora_indices: Tensor of shape [batch_size] mapping
+ requests to RoPE offsets and rot dims for long LoRAs.
+ None if long context lora doesn't exist.
+ indices_len: List of lengths of the above tensors. It contains
+ (base_indices, sampler_indices, sampler_indices_padded,
+ embeddings_indices, long_lora_indices).
+ """
+ index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
+ embedding_indices = index_mapping_indices.copy()
+ lora_indices = index_mapping_indices.copy()
+ long_lora_offsets: Optional[torch.Tensor] = None
+ if long_lora_context:
+ long_lora_offsets = torch.zeros(len(index_mapping_indices),
+ device=device,
+ dtype=torch.long)
+ prompt_mapping: List[int] = [
+ lora_index_to_id.index(x) if x > 0 else -1
+ for x in mapping.prompt_mapping
+ ]
+ lora_idx = None
+ for i in range(len(index_mapping_indices)):
+ # TODO index can be slow. optimize
+ lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
+ if index_mapping_indices[i] > 0 else -1)
+ embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
+ lora_indices[i] = lora_idx
+ if long_lora_context:
+ assert long_lora_offsets is not None
+ lora_offset: int = long_lora_context.offsets_by_lora_id.get(
+ index_mapping_indices[i], 0)
+ long_lora_offsets[i] = lora_offset
+
+ indices_list: List[Union[List[int], torch.Tensor]] = [
+ index_mapping_indices,
+ lora_indices,
+ embedding_indices,
+ ]
+ if long_lora_context:
+ assert long_lora_offsets is not None
+ indices_list.append(long_lora_offsets)
+ indices = torch.tensor(indices_list, dtype=torch.long, device=device)
+ prompt_mapping_tensor = torch.tensor(prompt_mapping,
+ dtype=torch.long,
+ device=device)
+ embeddings_indices = torch.stack([
+ indices[2] * extra_vocab_size,
+ indices[2] * (vocab_size + extra_vocab_size),
+ ])
+ embeddings_indices[embeddings_indices == -1] = max_loras - 1
+ base_indices = indices[1]
+ sampler_indices = prompt_mapping_tensor
+ sampler_indices_padded = sampler_indices.clone()
+ sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
+ sampler_indices_padded = torch.arange(
+ 0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
+ sampler_indices_padded * len(sampler_indices_padded))
+ long_lora_indices = None
+ long_lora_indices_len: Optional[int] = None
+ if long_lora_context:
+ long_lora_indices = indices[3]
+ long_lora_indices_len = long_lora_indices.shape[-1]
+ # Contain length of indices tensors. Used to index into each tensor.
+ indices_len = [
+ base_indices.shape[-1],
+ sampler_indices.shape[-1],
+ sampler_indices_padded.shape[-1],
+ embeddings_indices.shape[-1],
+ ]
+ if long_lora_indices_len is not None:
+ indices_len.append(long_lora_indices_len)
+ else:
+ # If long_lora doesn't exist,append None
+ indices_len.append(None)
+
+ return (
+ base_indices,
+ sampler_indices,
+ sampler_indices_padded,
+ embeddings_indices,
+ long_lora_indices,
+ indices_len,
+ )
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f7fd323d020b42e7204c9401bbc4efad189d96d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b8e0fe2e27d36a6fa957f4c2756541c54b02a8a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8bf496dbe933abe40bb4ba3f9e820b133028d9a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..158d653ed7fc8784d1e267d508f718f288dc7874
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cf3b59c57980942abb3548df723f10555e3d4c0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74ebbd8250d645b616fa48a83b9a63761d2c2750
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83e8f94d4e26c0adda6039a49dd2ff818447c231
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23a658bc6705719367fddd5f70625a617675f258
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34e6b16b6700b685dbfaf5569f47f27a6385ee97
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..837d7faf43708dbc2ece2eaa60c7283293f5f7c1
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py
@@ -0,0 +1,459 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Attention layer with FlashAttention."""
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import numpy as np
+import torch
+import triton
+import triton.language as tl
+
+from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
+ AttentionMetadata, AttentionType)
+from vllm.envs import VLLM_FLASH_ATTN_VERSION
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.utils import cdiv
+from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
+ flash_attn_varlen_func,
+ is_fa_version_supported)
+
+logger = init_logger(__name__)
+
+
+class FlashAttentionBackend(AttentionBackend):
+
+ accept_output_buffer: bool = True
+
+ @staticmethod
+ def get_supported_head_sizes() -> List[int]:
+ return [32, 64, 96, 128, 160, 192, 224, 256]
+
+ @staticmethod
+ def get_name() -> str:
+ return "FLASH_ATTN_VLLM_V1"
+
+ @staticmethod
+ def get_impl_cls() -> Type["FlashAttentionImpl"]:
+ return FlashAttentionImpl
+
+ @staticmethod
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
+ return FlashAttentionMetadata
+
+ @staticmethod
+ def get_kv_cache_shape(
+ num_blocks: int,
+ block_size: int,
+ num_kv_heads: int,
+ head_size: int,
+ ) -> Tuple[int, ...]:
+ if block_size % 16 != 0:
+ raise ValueError("Block size must be a multiple of 16.")
+ return (2, num_blocks, block_size, num_kv_heads, head_size)
+
+ @staticmethod
+ def use_cascade_attention(*args, **kwargs) -> bool:
+ return use_cascade_attention(*args, **kwargs)
+
+
+@dataclass
+class FlashAttentionMetadata:
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
+ # |---------- N-1 iteration --------|
+ # |---------------- N iteration ---------------------|
+ # |- tokenA -|......................|-- newTokens ---|
+ # |---------- context_len ----------|
+ # |-------------------- seq_len ---------------------|
+ # |-- query_len ---|
+
+ num_actual_tokens: int # Number of tokens excluding padding.
+ max_query_len: int
+ query_start_loc: torch.Tensor
+ max_seq_len: int
+ seq_lens: torch.Tensor
+ block_table: torch.Tensor
+ slot_mapping: torch.Tensor
+
+ # For cascade attention.
+ use_cascade: bool
+ common_prefix_len: int
+ cu_prefix_query_lens: Optional[torch.Tensor]
+ prefix_kv_lens: Optional[torch.Tensor]
+ suffix_kv_lens: Optional[torch.Tensor]
+
+ # For logging.
+ num_input_tokens: int = 0 # Number of tokens including padding.
+
+
+class FlashAttentionImpl(AttentionImpl):
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: int,
+ alibi_slopes: Optional[List[float]],
+ sliding_window: Optional[int],
+ kv_cache_dtype: str,
+ blocksparse_params: Optional[Dict[str, Any]] = None,
+ logits_soft_cap: Optional[float] = None,
+ attn_type: AttentionType = AttentionType.DECODER,
+ ) -> None:
+ if blocksparse_params is not None:
+ raise ValueError(
+ "FlashAttention does not support block-sparse attention.")
+ self.num_heads = num_heads
+ self.head_size = head_size
+ self.scale = float(scale)
+ self.num_kv_heads = num_kv_heads
+ if alibi_slopes is not None:
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+ self.alibi_slopes = alibi_slopes
+ if sliding_window is None:
+ self.sliding_window = (-1, -1)
+ else:
+ self.sliding_window = (sliding_window - 1, 0)
+ self.kv_cache_dtype = kv_cache_dtype
+ if logits_soft_cap is None:
+ # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
+ logits_soft_cap = 0
+ self.logits_soft_cap = logits_soft_cap
+
+ assert self.num_heads % self.num_kv_heads == 0
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+ support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
+ if head_size not in support_head_sizes:
+ raise ValueError(
+ f"Head size {head_size} is not supported by FlashAttention. "
+ f"Supported head sizes are: {support_head_sizes}.")
+
+ if attn_type != AttentionType.DECODER:
+ raise NotImplementedError("Encoder self-attention and "
+ "encoder/decoder cross-attention "
+ "are not implemented for "
+ "FlashAttentionImpl")
+
+ # if hopper default to FA3, otherwise stick to FA2 for now
+ # TODO(lucas): profile FA3 on ampere to see if it makes sense to
+ # use FA3 as default for both
+ if current_platform.get_device_capability()[0] >= 9:
+ self.fa_version = 3 if is_fa_version_supported(3) else 2
+ else:
+ self.fa_version = 2
+
+ if VLLM_FLASH_ATTN_VERSION is not None:
+ assert VLLM_FLASH_ATTN_VERSION in [2, 3]
+ self.fa_version = VLLM_FLASH_ATTN_VERSION
+
+ if not is_fa_version_supported(self.fa_version):
+ logger.error("Cannot use FA version %d is not supported due to %s",
+ self.fa_version,
+ fa_version_unsupported_reason(self.fa_version))
+
+ assert is_fa_version_supported(self.fa_version)
+
+ def forward(
+ self,
+ layer: torch.nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ kv_cache: torch.Tensor,
+ attn_metadata: FlashAttentionMetadata,
+ output: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Forward pass with FlashAttention.
+
+ Args:
+ query: shape = [num_tokens, num_heads, head_size]
+ key: shape = [num_tokens, num_kv_heads, head_size]
+ value: shape = [num_tokens, num_kv_heads, head_size]
+ kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
+ attn_metadata: Metadata for attention.
+ Returns:
+ shape = [num_tokens, num_heads * head_size]
+ """
+ assert output is not None, "Output tensor must be provided."
+
+ if attn_metadata is None:
+ # Profiling run.
+ return output
+
+ # IMPORTANT!
+ # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
+ # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
+ # in this method. For example, `view` and `slice` (or `[:n]`) operations
+ # are surprisingly slow even in the case they do not invoke any GPU ops.
+ # Minimize the PyTorch ops in this method as much as possible.
+ # Whenever making a change in this method, please benchmark the
+ # performance to make sure it does not introduce any overhead.
+
+ num_actual_tokens = attn_metadata.num_actual_tokens
+ # Reshape the input keys and values and store them in the cache.
+ # NOTE(woosuk): Here, key and value are padded while slot_mapping is
+ # not padded. However, we don't need to do key[:num_actual_tokens] and
+ # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
+ # the slot_mapping's shape to determine the number of actual tokens.
+ key_cache, value_cache = kv_cache.unbind(0)
+ torch.ops._C_cache_ops.reshape_and_cache_flash(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ attn_metadata.slot_mapping,
+ self.kv_cache_dtype,
+ layer._k_scale,
+ layer._v_scale,
+ )
+
+ # Compute attention and update output up to `num_actual_tokens`.
+ if not attn_metadata.use_cascade:
+ # Regular attention (common case).
+ flash_attn_varlen_func(
+ q=query[:num_actual_tokens],
+ k=key_cache,
+ v=value_cache,
+ out=output[:num_actual_tokens],
+ cu_seqlens_q=attn_metadata.query_start_loc,
+ max_seqlen_q=attn_metadata.max_query_len,
+ seqused_k=attn_metadata.seq_lens,
+ max_seqlen_k=attn_metadata.max_seq_len,
+ softmax_scale=self.scale,
+ causal=True,
+ alibi_slopes=self.alibi_slopes,
+ window_size=self.sliding_window,
+ block_table=attn_metadata.block_table,
+ softcap=self.logits_soft_cap,
+ fa_version=self.fa_version,
+ )
+ return output
+
+ # Cascade attention (rare case).
+ cascade_attention(
+ output[:num_actual_tokens],
+ query[:num_actual_tokens],
+ key_cache,
+ value_cache,
+ cu_query_lens=attn_metadata.query_start_loc,
+ max_query_len=attn_metadata.max_query_len,
+ cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
+ prefix_kv_lens=attn_metadata.prefix_kv_lens,
+ suffix_kv_lens=attn_metadata.suffix_kv_lens,
+ max_kv_len=attn_metadata.max_seq_len,
+ softmax_scale=self.scale,
+ alibi_slopes=self.alibi_slopes,
+ sliding_window=self.sliding_window,
+ logits_soft_cap=self.logits_soft_cap,
+ block_table=attn_metadata.block_table,
+ common_prefix_len=attn_metadata.common_prefix_len,
+ fa_version=self.fa_version,
+ )
+ return output
+
+
+def use_cascade_attention(
+ common_prefix_len: int,
+ query_lens: np.ndarray,
+ num_query_heads: int,
+ num_kv_heads: int,
+ use_alibi: bool,
+ use_sliding_window: bool,
+ num_sms: int,
+) -> bool:
+ """Decide whether to use cascade attention.
+
+ This function 1) checks whether cascade attention is supported with the
+ given configuration, and 2) heuristically decides whether using cascade
+ attention can improve performance.
+ """
+ # Too short common prefix. Probably not worth using cascade attention.
+ # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
+ # NOTE(woosuk): This is the common case. We should return False as soon as
+ # possible to avoid any unnecessary computation.
+ if common_prefix_len < 256:
+ return False
+ # Cascade attention is currently not supported with these variants.
+ if use_alibi or use_sliding_window:
+ return False
+ # Too few queries. Probably not worth using cascade attention.
+ # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
+ num_reqs = len(query_lens)
+ if num_reqs < 8:
+ return False
+
+ # Heuristics to decide whether using cascade attention is beneficial.
+ # 1. When FlashDecoding is not used for normal attention, cascade attention
+ # is likely to be faster since it saves memory bandwidth.
+ num_queries_per_kv = num_query_heads // num_kv_heads
+ # The criteria for using FlashDecoding can be found in the following link:
+ # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
+ use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
+ and not use_alibi and np.all(query_lens == 1))
+ if not use_flash_decoding:
+ # Use cascade attention.
+ return True
+
+ # 2. When FlashDecoding is used for normal attention, it is not clear
+ # whether cascade attention is beneficial, because FlashDecoding can
+ # launch more CTAs than cascade attention.
+ # We use a simple performance model to compare the two methods.
+ # NOTE(woosuk): The performance model is very rough and may not be
+ # accurate.
+ num_tokens = num_reqs
+ # NOTE(woosuk): These are default tile sizes. flash-attn might use
+ # different tile sizes (e.g., 64 or 256) depending on the configuration.
+ q_tile_size = 128
+ kv_tile_size = 128
+ num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
+
+ cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
+ cascade_waves = cdiv(cascade_ctas, num_sms)
+ cascade_time = cascade_waves * num_prefix_tiles
+
+ flash_decoding_ctas = (num_reqs * num_kv_heads *
+ cdiv(num_queries_per_kv, q_tile_size))
+ flash_decoding_ctas *= num_prefix_tiles
+ flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
+
+ # Use cascade attention if it is faster than FlashDecoding.
+ return cascade_time < flash_decoding_time
+
+
+def cascade_attention(
+ output: torch.Tensor,
+ query: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ cu_query_lens: torch.Tensor,
+ max_query_len: int,
+ cu_prefix_query_lens: torch.Tensor,
+ prefix_kv_lens: torch.Tensor,
+ suffix_kv_lens: torch.Tensor,
+ max_kv_len: int,
+ softmax_scale: float,
+ alibi_slopes: Optional[torch.Tensor],
+ sliding_window: Tuple[int, int],
+ logits_soft_cap: float,
+ block_table: torch.Tensor,
+ common_prefix_len: int,
+ fa_version: int,
+) -> torch.Tensor:
+ assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
+ # TODO: Support sliding window.
+ assert sliding_window == (-1, -1), (
+ "Cascade attention does not support sliding window.")
+
+ num_tokens = query.shape[0]
+ block_size = key_cache.shape[-3]
+ assert common_prefix_len % block_size == 0
+ num_common_kv_blocks = common_prefix_len // block_size
+ assert num_common_kv_blocks > 0
+
+ # Process shared prefix.
+ prefix_output, prefix_lse = flash_attn_varlen_func(
+ q=query,
+ k=key_cache,
+ v=value_cache,
+ cu_seqlens_q=cu_prefix_query_lens,
+ seqused_k=prefix_kv_lens,
+ max_seqlen_q=num_tokens,
+ max_seqlen_k=common_prefix_len,
+ softmax_scale=softmax_scale,
+ causal=False,
+ window_size=sliding_window,
+ block_table=block_table[:1],
+ softcap=logits_soft_cap,
+ return_softmax_lse=True,
+ fa_version=fa_version,
+ )
+
+ # Process suffix per query.
+ suffix_output, suffix_lse = flash_attn_varlen_func(
+ q=query,
+ k=key_cache,
+ v=value_cache,
+ cu_seqlens_q=cu_query_lens,
+ seqused_k=suffix_kv_lens,
+ max_seqlen_q=max_query_len,
+ max_seqlen_k=max_kv_len - common_prefix_len,
+ softmax_scale=softmax_scale,
+ causal=True,
+ window_size=sliding_window,
+ block_table=block_table[:, num_common_kv_blocks:],
+ softcap=logits_soft_cap,
+ return_softmax_lse=True,
+ fa_version=fa_version,
+ )
+
+ # Merge prefix and suffix outputs, and store the result in output.
+ merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
+ suffix_lse)
+
+
+def merge_attn_states(
+ output: torch.Tensor,
+ prefix_output: torch.Tensor,
+ prefix_lse: torch.Tensor,
+ suffix_output: torch.Tensor,
+ suffix_lse: torch.Tensor,
+) -> None:
+ num_tokens = output.shape[0]
+ num_query_heads = output.shape[1]
+ head_size = output.shape[2]
+ padded_head_size = triton.next_power_of_2(head_size)
+
+ # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
+ merge_attn_states_kernel[(num_tokens, num_query_heads)](
+ output,
+ prefix_output,
+ prefix_lse,
+ suffix_output,
+ suffix_lse,
+ head_size,
+ padded_head_size,
+ )
+
+
+@triton.jit
+def merge_attn_states_kernel(
+ output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
+ prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
+ prefix_lse, # [NUM_HEADS, NUM_TOKENS]
+ suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
+ suffix_lse, # [NUM_HEADS, NUM_TOKENS]
+ HEAD_SIZE: tl.constexpr,
+ PADDED_HEAD_SIZE: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ num_tokens = tl.num_programs(0)
+ head_idx = tl.program_id(1)
+ num_heads = tl.num_programs(1)
+
+ p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
+ s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
+ max_lse = tl.maximum(p_lse, s_lse)
+ p_lse = p_lse - max_lse
+ s_lse = s_lse - max_lse
+
+ head_arange = tl.arange(0, PADDED_HEAD_SIZE)
+ head_mask = head_arange < HEAD_SIZE
+ p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
+ head_idx * HEAD_SIZE + head_arange,
+ mask=head_mask)
+ s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
+ head_idx * HEAD_SIZE + head_arange,
+ mask=head_mask)
+
+ # NOTE(woosuk): Be careful with the numerical stability.
+ # We should compute the scale first, and then multiply it with the output.
+ # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
+ p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
+ s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
+ out = p_out * p_scale + s_out * s_scale
+ tl.store(output + token_idx * num_heads * HEAD_SIZE +
+ head_idx * HEAD_SIZE + head_arange,
+ out,
+ mask=head_mask)
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae72f748fce6398f894b583a84bce5aebe4b532a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0c75b9c22950cfac9e9ec452c27717728366017
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e49d1dfd6d5b0f63c89be37bfc352f6994178ff
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e4da6e587877383733293274bf661ebde4dfb82
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a12e04dd4c53298069c8cb84b724ee46b584620
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..651bc01aa5cf665c46bb0def62694499ae79d793
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py
@@ -0,0 +1,133 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import TYPE_CHECKING, Dict, List, Set, Tuple
+
+from vllm.logger import init_logger
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.v1.request import Request
+
+if TYPE_CHECKING:
+ from vllm.config import ModelConfig, SchedulerConfig
+
+logger = init_logger(__name__)
+
+
+class EncoderCacheManager:
+
+ def __init__(self, cache_size: int):
+ self.cache_size = cache_size
+ self.num_free_slots = cache_size
+ # req_id -> cached input ids
+ self.cached: Dict[str, Set[int]] = {}
+ # List of [req_id, input_id]
+ self.freed: List[Tuple[str, int]] = []
+
+ def has_cache(self, request: Request, input_id: int) -> bool:
+ req_id = request.request_id
+ return req_id in self.cached and input_id in self.cached[req_id]
+
+ def can_allocate(self, request: Request, input_id: int) -> bool:
+ num_tokens = request.get_num_encoder_tokens(input_id)
+ return num_tokens <= self.num_free_slots
+
+ def allocate(self, request: Request, input_id: int) -> None:
+ req_id = request.request_id
+ if req_id not in self.cached:
+ self.cached[req_id] = set()
+ self.cached[req_id].add(input_id)
+ self.num_free_slots -= request.get_num_encoder_tokens(input_id)
+
+ def get_cached_input_ids(self, request: Request) -> Set[int]:
+ return self.cached.get(request.request_id, set())
+
+ def free_encoder_input(self, request: Request, input_id: int) -> None:
+ """Free a single encoder input id for the request."""
+ req_id = request.request_id
+ if req_id not in self.cached:
+ return
+
+ self.cached[req_id].discard(input_id)
+ if len(self.cached[req_id]) == 0:
+ del self.cached[req_id]
+ self.num_free_slots += request.get_num_encoder_tokens(input_id)
+ self.freed.append((req_id, input_id))
+
+ def free(self, request: Request) -> None:
+ """Free all cached input ids for the request."""
+ input_ids = self.get_cached_input_ids(request)
+ for input_id in input_ids:
+ self.free_encoder_input(request, input_id)
+
+ def get_freed_ids(self) -> List[Tuple[str, int]]:
+ freed = self.freed
+ self.freed = []
+ return freed
+
+
+def compute_encoder_budget(
+ model_config: "ModelConfig",
+ scheduler_config: "SchedulerConfig",
+) -> Tuple[int, int]:
+ """Compute the encoder cache budget based on the model and scheduler
+ configurations.
+
+ Args:
+ model_config: Model configuration.
+ scheduler_config: Scheduler configuration.
+
+ Returns:
+ - Compute budget for encoder execution, in unit of number of tokens
+ in the input sequence.
+ - Space budget for encoder cache size, in unit of number of tokens
+ in the input sequence.
+ """
+
+ if not model_config.is_multimodal_model:
+ return 0, 0
+
+ # TODO: handle encoder-decoder models once we support them.
+ (
+ encoder_compute_budget,
+ encoder_cache_size,
+ ) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
+
+ return encoder_compute_budget, encoder_cache_size
+
+
+def _compute_encoder_budget_multimodal(
+ model_config: "ModelConfig",
+ scheduler_config: "SchedulerConfig",
+) -> Tuple[int, int]:
+ """Compute the encoder cache budget based on the model and scheduler
+ configurations for a multimodal model.
+
+ Args:
+ model_config: Model configuration.
+ scheduler_config: Scheduler configuration.
+
+ Returns:
+ - Compute budget for encoder execution, in unit of number of tokens
+ in the input sequence.
+ - Space budget for encoder cache size, in unit of number of tokens
+ in the input sequence.
+ """
+
+ max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
+ model_config)
+
+ if not max_tokens_by_modality_dict:
+ logger.warning(
+ "All non-text modalities supported by the model have been "
+ "explicitly disabled via limit_mm_per_prompt. Encoder cache will "
+ "not be initialized.")
+ return 0, 0
+
+ _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
+ key=lambda item: item[1])
+
+ encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
+ max_tokens_per_mm_item)
+ encoder_cache_size = max(scheduler_config.encoder_cache_size,
+ max_tokens_per_mm_item)
+
+ return encoder_compute_budget, encoder_cache_size
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..de349ec12099931b81729d45274454e6b6f73c27
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py
@@ -0,0 +1,500 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from collections import defaultdict
+from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
+
+from vllm.logger import init_logger
+from vllm.utils import cdiv
+from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
+ KVCacheBlock,
+ generate_block_hash_extra_keys,
+ hash_block_tokens,
+ hash_request_tokens)
+from vllm.v1.request import Request, RequestStatus
+
+logger = init_logger(__name__)
+
+
+class KVCacheManager:
+
+ def __init__(
+ self,
+ block_size: int,
+ num_gpu_blocks: int,
+ max_model_len: int,
+ sliding_window: Optional[int] = None,
+ enable_caching: bool = True,
+ num_preallocate_tokens: int = 64,
+ ) -> None:
+ self.block_size = block_size
+ self.num_gpu_blocks = num_gpu_blocks
+ self.max_model_len = max_model_len
+ self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
+ self.sliding_window = sliding_window
+ self.enable_caching = enable_caching
+ # NOTE(woosuk): To avoid frequent block allocation, we preallocate some
+ # blocks for each request. For example, when a request reaches the end
+ # of its block table, we preallocate N blocks in advance. This way, we
+ # reduce the overhead of updating free_block_ids and ref_cnts for each
+ # request every step (at the cost of some memory waste).
+ # NOTE(woosuk): This is different from the "lookahead" slots since this
+ # does not guarantee that the request always has N empty blocks. After
+ # the request gets N empty blocks, it starts to use the blocks without
+ # further allocation. When it uses up all the N empty blocks, it gets
+ # N new empty blocks.
+ self.num_preallocate_tokens = num_preallocate_tokens
+ self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
+
+ # A Block pool of all kv-cache blocks.
+ self.block_pool: List[KVCacheBlock] = [
+ KVCacheBlock(idx) for idx in range(num_gpu_blocks)
+ ]
+ # Free block queue that constructs and manipulates a doubly linked
+ # list of free blocks (including eviction candidates when caching is
+ # enabled).
+ self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
+
+ # {block_hash: {block ID: block}}. A cached block is
+ # a full block with a block hash that can be used for prefix caching.
+ # The cached block may be used by running requests or in the
+ # free_block_queue that could potentially be evicted.
+ # NOTE: We currently don't de-duplicate the blocks in the cache,
+ # meaning that if a block becomes full and is cached, we don't check
+ # if there is already an identical block in the cache. This is because
+ # we want to make sure the allocated block IDs won't change so that
+ # block tables are append-only.
+ self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
+ int, KVCacheBlock]] = defaultdict(dict)
+
+ # Mapping from request ID to blocks to track the blocks allocated
+ # for each request, so that we can free the blocks when the request
+ # is finished.
+ self.req_to_blocks: DefaultDict[str,
+ List[KVCacheBlock]] = defaultdict(list)
+
+ @property
+ def usage(self) -> float:
+ return 1.0 - (self.free_block_queue.num_free_blocks /
+ self.num_gpu_blocks)
+
+ def get_computed_blocks(
+ self, request: Request) -> Tuple[List[KVCacheBlock], int]:
+ """Get the computed (cached) blocks for the request.
+ Note that the computed blocks must be full.
+
+ Args:
+ request: The request to get the computed blocks.
+
+ Returns:
+ A tuple containing:
+ - A list of blocks that are computed for the request.
+ - The number of computed tokens.
+ """
+ if not self.enable_caching:
+ # Prefix caching is disabled.
+ return [], 0
+
+ computed_blocks = []
+
+ # The block hashes for the request may already be computed
+ # if the request was preempted and resumed.
+ if not request.kv_block_hashes:
+ request.set_kv_block_hashes(
+ hash_request_tokens(self.block_size, request))
+ block_hashes = request.kv_block_hashes
+
+ for block_hash in block_hashes:
+ # block_hashes is a chain of block hashes. If a block hash is not
+ # in the cached_block_hash_to_id, the following block hashes are
+ # not computed yet for sure.
+ if cached_block := self._get_cached_block(block_hash):
+ computed_blocks.append(cached_block)
+ else:
+ break
+
+ # NOTE(woosuk): Since incomplete blocks are not eligible for
+ # sharing, `num_computed_tokens` is always a multiple of
+ # `block_size`.
+ num_computed_tokens = len(computed_blocks) * self.block_size
+ return computed_blocks, num_computed_tokens
+
+ def allocate_slots(
+ self,
+ request: Request,
+ num_tokens: int,
+ new_computed_blocks: Optional[List[KVCacheBlock]] = None
+ ) -> Optional[List[KVCacheBlock]]:
+ """Add slots for a request with new tokens to append.
+
+ Args:
+ request: The request to allocate slots.
+ num_tokens: The number of tokens to allocate. Note that this does
+ not include the tokens that have already been computed.
+ new_computed_blocks: A list of new computed blocks just hitting the
+ prefix caching.
+
+ Blocks layout:
+ -----------------------------------------------------------------------
+ | < computed > | < new computed > | < new > | < pre-allocated > |
+ -----------------------------------------------------------------------
+ | < required > |
+ --------------------------------------------------
+ | < full > |
+ ------------------------------------------------
+ | |
+ --------------
+ The following *_blocks are illustrated in this layout.
+
+ Returns:
+ A list of new allocated blocks.
+ """
+ if num_tokens == 0:
+ raise ValueError("num_tokens must be greater than 0")
+
+ new_computed_blocks = new_computed_blocks or []
+
+ # The number of computed tokens is the number of computed tokens plus
+ # the new prefix caching hits
+ num_computed_tokens = (request.num_computed_tokens +
+ len(new_computed_blocks) * self.block_size)
+ num_required_blocks = cdiv(num_computed_tokens + num_tokens,
+ self.block_size)
+ req_blocks = self.req_to_blocks[request.request_id]
+ num_new_blocks = (num_required_blocks - len(req_blocks) -
+ len(new_computed_blocks))
+
+ # If a computed block of a request is an eviction candidate (in the
+ # free queue and ref_cnt == 0), it cannot be counted as a free block
+ # when allocating this request.
+ num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
+ if blk.ref_cnt == 0)
+ if (num_new_blocks > self.free_block_queue.num_free_blocks -
+ num_evictable_computed_blocks):
+ # Cannot allocate new blocks
+ return None
+
+ # Touch the computed blocks to make sure they won't be evicted.
+ if self.enable_caching:
+ self._touch(new_computed_blocks)
+ else:
+ assert not new_computed_blocks, (
+ "Computed blocks should be empty when "
+ "prefix caching is disabled")
+
+ # Append the new computed blocks to the request blocks until now to
+ # avoid the case where the new blocks cannot be allocated.
+ req_blocks.extend(new_computed_blocks)
+
+ # Start to handle new blocks
+
+ if num_new_blocks <= 0:
+ # No new block is needed.
+ new_blocks = []
+ else:
+ # Get new blocks from the free block pool considering
+ # preallocated blocks.
+ num_new_blocks = min(
+ num_new_blocks + self.num_preallocate_blocks,
+ self.free_block_queue.num_free_blocks,
+ # Should not exceed the maximum number of blocks per request.
+ # This is especially because the block table has the shape
+ # [..., max_num_blocks_per_req].
+ # TODO(woosuk): Check and reject requests if
+ # num_prompt_tokens + max_tokens > max_model_len.
+ self.max_num_blocks_per_req - len(req_blocks),
+ )
+ assert num_new_blocks > 0
+
+ # Concatenate the computed block IDs and the new block IDs.
+ new_blocks = self._get_new_blocks(num_new_blocks)
+ req_blocks.extend(new_blocks)
+
+ if not self.enable_caching:
+ return new_blocks
+
+ # NOTE(rickyx): We are assuming the `num_tokens` are actual
+ # tokens rather than lookahead slots (e.g. for speculative decoding).
+ # TODO(rickyx): When supporting speculative decoding, we will need to
+ # differentiate between them so that we can know how many blocks are
+ # full after appending the actual tokens.
+ num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
+ num_computed_full_blocks = num_computed_tokens // self.block_size
+ new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
+ if new_full_blocks:
+ self._cache_full_blocks(
+ request=request,
+ blk_start_idx=num_computed_full_blocks,
+ # The new full blocks are the full blocks that are not computed.
+ full_blocks=new_full_blocks,
+ prev_block=(req_blocks[num_computed_full_blocks - 1]
+ if num_computed_full_blocks > 0 else None))
+
+ return new_blocks
+
+ def free(self, request: Request) -> None:
+ """Free the blocks allocated for the request.
+ When caching is enabled, we free the blocks in reverse order so that
+ the tail blocks are evicted first.
+
+ Args:
+ request: The request to free the blocks.
+ """
+ # Default to [] in case a request is freed (aborted) before alloc.
+ blocks = self.req_to_blocks.pop(request.request_id, [])
+ ordered_blocks: Iterable[KVCacheBlock] = blocks
+ if self.enable_caching:
+ # Free blocks in reverse order so that the tail blocks are
+ # freed first.
+ ordered_blocks = reversed(blocks)
+
+ for block in ordered_blocks:
+ block.decr_ref()
+ if block.ref_cnt == 0:
+ self.free_block_queue.append(block)
+
+ def reset_prefix_cache(self) -> bool:
+ """Reset prefix cache. This function may be used in RLHF
+ flows to invalid prefix caching after the weights are updated,
+ or used for resetting prefix caching status for benchmarking.
+
+ Returns:
+ bool: True if the prefix cache is successfully reset,
+ False otherwise.
+ """
+ num_used_blocks = (self.num_gpu_blocks -
+ self.free_block_queue.num_free_blocks)
+ if num_used_blocks > 0:
+ logger.warning(
+ "Failed to reset prefix cache because some "
+ "blocks (%d) are not freed yet", num_used_blocks)
+ return False
+
+ # Remove all hashes so that no new blocks will hit.
+ self.cached_block_hash_to_block = defaultdict(dict)
+
+ # Remove all hashes from all blocks.
+ for block in self.block_pool:
+ block.reset_hash()
+
+ logger.info("Successfully reset prefix cache")
+ return True
+
+ def get_num_common_prefix_blocks(
+ self,
+ request: Request,
+ num_running_requests: int,
+ ) -> int:
+ """Calculate the number of common prefix blocks shared by all requests
+ in the RUNNING state.
+
+ The function determines this by selecting any request and iterating
+ through its blocks. A block is considered a common prefix block if its
+ `ref_cnt` equals the total number of requests in the RUNNING state.
+
+ NOTE(woosuk): The number of requests in the RUNNING state is **greater
+ than or equal to** the number of requests scheduled in the current step.
+ This is because the RUNNING state only indicates that:
+ 1. The request has not yet finished, and
+ 2. The request holds its blocks unfreed.
+
+ While all scheduled requests must be in the RUNNING state, the inverse
+ is not necessarily true. There may be RUNNING requests that are not
+ scheduled in the current step. As of 1/1/2025, the scheduler does not
+ allow this case, but it is possible in the future, as we allow more
+ flexible scheduling.
+
+ This can result in an edge case where the number of common prefix blocks
+ is 0, even though all scheduled requests share a common prefix. This
+ occurs because there may be unscheduled RUNNING requests that do not
+ share the common prefix. Currently, this case cannot be easily detected,
+ so the function returns 0 in such cases.
+
+ Args:
+ request: Any request in the RUNNING state, used to identify the
+ common prefix blocks.
+ num_running_requests: The total number of requests in the RUNNING
+ state. This can be different from the number of scheduled
+ requests in the current step.
+
+ Returns:
+ int: The number of common prefix blocks.
+ """
+ assert request.status == RequestStatus.RUNNING
+ blocks = self.req_to_blocks[request.request_id]
+ num_common_blocks = 0
+ for block in blocks:
+ if block.ref_cnt == num_running_requests:
+ num_common_blocks += 1
+ else:
+ break
+ return num_common_blocks
+
+ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
+ """Get new blocks from the free block pool.
+
+ Note that we do not check block cache in this function.
+
+ Args:
+ num_blocks: The number of blocks to allocate.
+
+ Returns:
+ A list of new block.
+ """
+ if num_blocks > self.free_block_queue.num_free_blocks:
+ raise ValueError(
+ f"Cannot get {num_blocks} free blocks from the pool")
+
+ ret: List[KVCacheBlock] = []
+ idx = 0
+ while idx < num_blocks:
+ # First allocate blocks.
+ curr_block = self.free_block_queue.popleft()
+ assert curr_block.ref_cnt == 0
+
+ # If the block is cached, evict it.
+ if self.enable_caching:
+ self._maybe_evict_cached_block(curr_block)
+
+ curr_block.incr_ref()
+ ret.append(curr_block)
+ idx += 1
+
+ return ret
+
+ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
+ """
+ If a block is cached in `cached_block_hash_to_block`, we reset its hash
+ metadata and evict it from the cache.
+
+ Args:
+ block: The block to evict.
+
+ Returns:
+ True if the block is evicted, False otherwise.
+ """
+ block_hash = block.block_hash
+ if block_hash and block_hash in self.cached_block_hash_to_block:
+ block.reset_hash()
+ del self.cached_block_hash_to_block[block_hash][block.block_id]
+
+ if len(self.cached_block_hash_to_block[block_hash]) == 0:
+ del self.cached_block_hash_to_block[block_hash]
+
+ return True
+ return False
+
+ def _get_cached_block(self,
+ block_hash: BlockHashType) -> Optional[KVCacheBlock]:
+ """Get a cached block by the block hash, or None if cache miss.
+ If there are duplicated blocks, we return the first block in the cache.
+
+ Args:
+ block_hash: The hash value of the block.
+
+ Returns:
+ The cached block if it exists, or None.
+ """
+ if block_hash in self.cached_block_hash_to_block:
+ first_block_id = list(
+ self.cached_block_hash_to_block[block_hash].keys())[0]
+ return self.cached_block_hash_to_block[block_hash][first_block_id]
+ return None
+
+ def _touch(self, blocks: List[KVCacheBlock]) -> None:
+ """Touch a block increases its reference count by 1, and may remove
+ the block from the free queue. This is used when a block is hit by
+ another request with the same prefix.
+
+ Args:
+ blocks: A list of blocks to touch.
+ """
+ for block in blocks:
+ # ref_cnt=0 means this block is in the free list (i.e. eviction
+ # candidate), so remove it.
+ if block.ref_cnt == 0:
+ self.free_block_queue.remove(block)
+ block.incr_ref()
+
+ def _cache_full_blocks(
+ self,
+ request: Request,
+ blk_start_idx: int,
+ full_blocks: List[KVCacheBlock],
+ prev_block: Optional[KVCacheBlock],
+ ) -> None:
+ """Cache a list of full blocks for prefix caching.
+
+ This function takes a list of blocks that will have their block hash
+ metadata to be updated and cached. Given a request, it computes the
+ block hashes for the blocks starting from `blk_start_idx` to the end
+ of the request's full blocks, updating the metadata for each block
+ and caching them in the `cached_block_hash_to_block`.
+
+ Args:
+ request: The request to cache the blocks.
+ blk_start_idx: The index of the first block in the request's blocks
+ to cache.
+ full_blocks: The list of blocks to update hash metadata.
+ prev_block: The previous block in the chain.
+ """
+ num_cached_block_hashes = len(request.kv_block_hashes)
+
+ # Update the new blocks with the block hashes through the chain.
+ prev_block_hash_value = None
+ if prev_block is not None:
+ # Previous block must have a block hash because it must be
+ # a full, cached block.
+ assert prev_block.block_hash is not None
+ prev_block_hash_value = prev_block.block_hash.hash_value
+
+ # Find the first uncached block. This case should only happen when
+ # speculative decoding is used.
+ offset = 0
+ for blk in full_blocks:
+ if blk.block_hash is None:
+ break
+ else:
+ prev_block_hash_value = blk.block_hash.hash_value
+ offset += 1
+ else:
+ # All blocks are cached.
+ return
+
+ for i, blk in enumerate(full_blocks[offset:]):
+ blk_idx = blk_start_idx + offset + i
+ assert blk.block_hash is None
+
+ if blk_idx < num_cached_block_hashes:
+ # The block hash may already be computed in
+ # "get_computed_blocks" if the tokens are not generated by
+ # this request (either the prompt tokens or the previously
+ # generated tokens with preemption). In this case we simply
+ # reuse the block hash.
+ block_hash = request.kv_block_hashes[blk_idx]
+ else:
+ # Otherwise compute the block hash and cache it in the request
+ # in case it will be preempted in the future.
+ start_token_idx = blk_idx * self.block_size
+ end_token_idx = (blk_idx + 1) * self.block_size
+ block_tokens = request.all_token_ids[
+ start_token_idx:end_token_idx]
+ assert len(block_tokens) == self.block_size, (
+ f"Expected {self.block_size} tokens, got "
+ f"{len(block_tokens)} at {blk_idx}th block for request "
+ f"{request.request_id}({request})")
+
+ # Generate extra keys for multi-modal inputs. Note that since
+ # we reach to this branch only when the block is completed with
+ # generated tokens, we only need to consider the last mm input.
+ extra_keys, _ = generate_block_hash_extra_keys(
+ request, start_token_idx, end_token_idx, -1)
+
+ # Compute the hash of the current block.
+ block_hash = hash_block_tokens(prev_block_hash_value,
+ block_tokens, extra_keys)
+ request.append_kv_block_hashes(block_hash)
+
+ # Update and added the full block to the cache.
+ blk.block_hash = block_hash
+ self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
+ prev_block_hash_value = block_hash.hash_value
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0976ba8577b9ac4e02037c79d7e0914a6fc9675
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py
@@ -0,0 +1,447 @@
+# SPDX-License-Identifier: Apache-2.0
+"""KV-Cache Utilities."""
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import Any, List, NamedTuple, Optional, Tuple
+
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
+ KVCacheTensor)
+from vllm.v1.request import Request
+
+logger = init_logger(__name__)
+
+
+class BlockHashType(NamedTuple):
+ """Hash value of a block (int), the token IDs in the block, and extra keys.
+ We keep a tuple of token IDs and extra keys to reduce the likelihood of
+ hash collisions when the hash value is the same. But please note that
+ hash collisions can still theoretically occur, albeit with an extremely
+ low probability.
+ """
+ # Hash value of the block in an integer.
+ hash_value: int
+ # Token IDs in the block.
+ token_ids: Tuple[int, ...]
+ # Extra keys for the block.
+ extra_keys: Optional[Any] = None
+
+
+@dataclass
+class KVCacheBlock:
+ """KV-cache block metadata."""
+ # Block ID, ranging from 0 to num_gpu_blocks - 1.
+ block_id: int
+ # Reference count.
+ ref_cnt: int = 0
+ # The hash of the block composed of (block hash, tuple of token IDs).
+ # It is only available when the block is full.
+ _block_hash: Optional[BlockHashType] = None
+
+ # Used to construct a doubly linked list for free blocks.
+ # These two attributes should only be manipulated by FreeKVCacheBlockQueue.
+ prev_free_block: Optional["KVCacheBlock"] = None
+ next_free_block: Optional["KVCacheBlock"] = None
+
+ def incr_ref(self):
+ self.ref_cnt += 1
+
+ def decr_ref(self):
+ self.ref_cnt -= 1
+
+ @property
+ def block_hash(self) -> Optional[BlockHashType]:
+ return self._block_hash
+
+ @block_hash.setter
+ def block_hash(self, block_hash: BlockHashType):
+ assert self.block_hash is None, (
+ "The block already has a hash. This should not happen.")
+ self._block_hash = block_hash
+
+ def reset_hash(self):
+ """Reset the block hash when the block is evicted."""
+ self._block_hash = None
+
+
+class FreeKVCacheBlockQueue:
+ """This class organizes a list of KVCacheBlock objects to a doubly linked
+ list of free blocks. We implement this class instead of using Python
+ builtin deque to support removing a block in the middle of the queue
+ in O(1) time. To close the performance gap to the builtin deque which is
+ implemented in C++, this class does not allocate any Python objects when
+ manipulating the linked list. Instead, this class manipulates the
+ prev_free_block and next_free_block attributes of the given blocks.
+
+ The queue is ordered by block ID in the beginning. When a block is allocated
+ and then freed, it will be appended back with the eviction order:
+ 1. The least recent used block is at the front (LRU).
+ 2. If two blocks have the same last accessed time (allocated by the
+ same sequence), the one with more hash tokens (the tail of a block
+ chain) is at the front.
+ Note that we maintain this order by reversing the block order when free
+ blocks of a request. This operation is outside of this class.
+
+ Args:
+ blocks: A list of KVCacheBlock objects.
+ """
+
+ def __init__(self, blocks: List[KVCacheBlock]) -> None:
+ self.num_free_blocks = len(blocks)
+
+ # Initialize the doubly linked list of free blocks.
+ self.free_list_head: Optional[KVCacheBlock] = blocks[0]
+ self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
+ for i in range(self.num_free_blocks):
+ if i > 0:
+ blocks[i].prev_free_block = blocks[i - 1]
+ if i < self.num_free_blocks - 1:
+ blocks[i].next_free_block = blocks[i + 1]
+
+ def popleft(self) -> KVCacheBlock:
+ """Pop the first free block and reduce num_free_blocks by 1.
+
+ Returns:
+ The first free block.
+ """
+ if not self.free_list_head:
+ raise ValueError("No free blocks available")
+
+ block = self.free_list_head
+ self.remove(block)
+ return block
+
+ def remove(self, block: KVCacheBlock) -> None:
+ """Remove a block in the free list and reduce num_free_blocks by 1.
+
+ Args:
+ block: The block to remove.
+ """
+ if block.prev_free_block is not None:
+ # Link the previous block to the next block.
+ block.prev_free_block.next_free_block = block.next_free_block
+ if block.next_free_block is not None:
+ # Link the next block to the previous block.
+ block.next_free_block.prev_free_block = block.prev_free_block
+
+ if block == self.free_list_head:
+ # Update the head if the block is the head.
+ self.free_list_head = block.next_free_block
+ if block == self.free_list_tail:
+ # Update the tail if the block is the tail.
+ self.free_list_tail = block.prev_free_block
+
+ # Remove the block from the linked list.
+ block.prev_free_block = block.next_free_block = None
+ self.num_free_blocks -= 1
+
+ def append(self, block: KVCacheBlock) -> None:
+ """Put a block back into the free list and increase
+ num_free_blocks by 1.
+
+ Args:
+ block: The block to append.
+ """
+ if self.free_list_tail is not None:
+ # Link the last block to the new block.
+ self.free_list_tail.next_free_block = block
+ block.prev_free_block = self.free_list_tail
+ self.free_list_tail = block
+ else:
+ # The free list is empty.
+ assert self.free_list_head is None
+ self.free_list_head = self.free_list_tail = block
+
+ block.next_free_block = None
+ self.num_free_blocks += 1
+
+ def get_all_free_blocks(self) -> List[KVCacheBlock]:
+ """Get all free blocks in the free list. Mainly used for testing.
+
+ Returns:
+ A list of free blocks.
+ """
+ ret = []
+ curr_block = self.free_list_head
+ while curr_block is not None:
+ ret.append(curr_block)
+ curr_block = curr_block.next_free_block
+ return ret
+
+
+def generate_block_hash_extra_keys(
+ request: Request, start_token_idx: int, end_token_idx: int,
+ start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
+ """Generate extra keys for the block hash. The extra keys can come from
+ the multi-modal inputs and request specific metadata (e.g., LoRA ID).
+ For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
+ indicate a mm input contained in the block and its starting offset in
+ the block tokens.
+
+ Args:
+ request: The request object.
+ start_token_idx: The start token index of the block.
+ end_token_idx: The end token index of the block.
+ start_mm_idx: The start multi-modal index of the block.
+
+ Returns:
+ A tuple of extra keys and the next multi-modal index.
+ """
+
+ mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
+ if not mm_positions:
+ return None, start_mm_idx
+
+ if mm_positions and len(mm_positions) != len(mm_hashes):
+ raise ValueError(
+ "The number of multi-modal positions and hashes must match. This "
+ "is likely because you do not enable MM preprocessor hashing. "
+ "Please set disable_mm_preprocessor_cache=False.")
+
+ # Note that we assume mm_positions is sorted by offset.
+ # We do not need to check all mm inputs if the start token index is out of
+ # range. This usually happens in the late prefill phase and decoding phase.
+ if mm_positions[-1]["offset"] + mm_positions[-1][
+ "length"] < start_token_idx:
+ return None, start_mm_idx
+
+ # Support start_mm_idx == -1 to indicate the last mm input.
+ if start_mm_idx < 0:
+ assert -start_mm_idx <= len(mm_positions)
+ start_mm_idx = len(mm_positions) + start_mm_idx
+
+ extra_keys = []
+ curr_mm_idx = start_mm_idx
+ while mm_positions and curr_mm_idx < len(mm_positions):
+ assert mm_hashes[curr_mm_idx] is not None
+ offset = mm_positions[curr_mm_idx]["offset"]
+ length = mm_positions[curr_mm_idx]["length"]
+ if end_token_idx > offset:
+ if start_token_idx > offset + length:
+ # This block has passed the current mm input.
+ curr_mm_idx += 1
+ continue
+
+ # The block contains the current mm input.
+ extra_keys.append(mm_hashes[curr_mm_idx])
+
+ if end_token_idx >= offset + length:
+ # If this block contains the end of the current mm input,
+ # move to the next mm input as this block may also contain
+ # the next mm input.
+ curr_mm_idx += 1
+ else:
+ # Otherwise this block is done with mm inputs.
+ break
+ else:
+ # This block has not reached the current mm input.
+ break
+ return tuple(extra_keys), curr_mm_idx
+
+
+def hash_block_tokens(
+ parent_block_hash: Optional[int],
+ curr_block_token_ids: Sequence[int],
+ extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
+ """Computes a hash value corresponding to the contents of a block and
+ the contents of the preceding block(s). The hash value is used for
+ prefix caching. We use LRU cache for this function to avoid recomputing
+ hash values for the same block contents.
+
+ TODO: Support arbitrary metadata so that we could support more
+ features such as LoRA adapter.
+
+ Args:
+ parent_block_hash: The hash of the parent block. None
+ if this is the first block.
+ curr_block_token_ids: A list of token ids in the current
+ block. The current block is assumed to be full.
+ extra_keys: Extra keys for the block.
+
+ Returns:
+ The hash value of the block and the token ids in the block.
+ The entire tuple is used as the hash key of the block.
+ """
+ if not parent_block_hash:
+ # Note that we use 'None' as a string here instead of None because
+ # as of Python 3.12, hash(None) returns a constant predictable value.
+ # This could possibly make it easier to find and exploit hash
+ # collisions. 'None' as a string will be hashed differently per process,
+ # but consistently within the same process. This is the same as the
+ # behavior of None prior to Python 3.12.
+ parent_block_hash = hash('None')
+
+ curr_block_token_ids_tuple = tuple(curr_block_token_ids)
+ return BlockHashType(
+ hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
+ curr_block_token_ids_tuple, extra_keys)
+
+
+def hash_request_tokens(block_size: int,
+ request: Request) -> List[BlockHashType]:
+ """Computes hash values of a chain of blocks given a sequence of
+ token IDs. The hash value is used for prefix caching.
+
+ Args:
+ block_size: The size of each block.
+ request: The request object.
+
+ Returns:
+ The list of computed hash values.
+ """
+ token_ids = request.all_token_ids
+ mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
+ if mm_positions and len(mm_positions) != len(mm_hashes):
+ raise ValueError(
+ "The number of multi-modal positions and hashes must match.")
+
+ # TODO: Extend this to support other features such as LoRA.
+ need_extra_keys = bool(mm_positions)
+ extra_keys = None
+ curr_mm_idx = 0
+
+ ret = []
+ parent_block_hash_value = None
+ for start in range(0, len(token_ids), block_size):
+ end = start + block_size
+ block_token_ids = token_ids[start:end]
+ # Do not hash the block if it is not full.
+ if len(block_token_ids) < block_size:
+ break
+
+ # Add extra keys if the block is a multi-modal block.
+ if need_extra_keys:
+ extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
+ request, start, end, curr_mm_idx)
+
+ block_hash = hash_block_tokens(parent_block_hash_value,
+ block_token_ids, extra_keys)
+ ret.append(block_hash)
+ parent_block_hash_value = block_hash.hash_value
+ return ret
+
+
+def check_enough_kv_cache_memory(vllm_config: VllmConfig,
+ kv_cache_spec: KVCacheSpec,
+ available_memory: int):
+ """
+ Checks whether `available_memory` is enough for the KV cache to hold at
+ least one request with the model's max_model_len.
+
+ Args:
+ vllm_config: The global VllmConfig
+ kv_cache_spec: The kv cache spec of the model
+ available_memory: Memory available for KV cache in bytes.
+
+ Raises:
+ ValueError: If there is not enough memory available for the KV cache.
+ """
+
+ if available_memory <= 0:
+ raise ValueError("No available memory for the cache blocks. "
+ "Try increasing `gpu_memory_utilization` when "
+ "initializing the engine.")
+
+ max_model_len = vllm_config.model_config.max_model_len
+ needed_memory = 0
+ for layer_spec in kv_cache_spec.values():
+ needed_memory += layer_spec.bytes_for_tokens(max_model_len)
+
+ if needed_memory > available_memory:
+ raise ValueError(
+ f"To serve at least one request with the models's max seq len "
+ f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
+ f"cache is needed, which is larger than the available KV cache "
+ f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
+ f"increasing `gpu_memory_utilization` or decreasing "
+ f"`max_model_len` when initializing the engine.")
+
+
+def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
+ """
+ Whether all layers in the given KVCacheSpec have the same type of KV cache.
+
+ Args:
+ kv_cache_spec: The KVCacheSpec of the model
+
+ Returns:
+ True if all layers have the same type, False otherwise.
+ """
+
+ layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
+ return len(layer_keys) == 1
+
+
+def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
+ kv_cache_spec: KVCacheSpec,
+ available_memory: int) -> KVCacheConfig:
+ """
+ Generates the KV cache configuration for a model with one type of KV cache.
+ Divide the available memory equally among all layers.
+
+ Args:
+ vllm_config: The global VllmConfig
+ kv_cache_spec: The kv cache spec of the model
+ available_memory: Memory available for KV cache in bytes.
+
+ Returns:
+ The generated KVCacheConfig
+ """
+
+ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
+ assert len(page_sizes) == 1
+ page_size = page_sizes.pop()
+
+ num_blocks = int(available_memory // page_size // len(kv_cache_spec))
+ num_blocks = max(num_blocks, 0)
+
+ if vllm_config.cache_config.num_gpu_blocks_override is not None:
+ num_gpu_blocks_override = \
+ vllm_config.cache_config.num_gpu_blocks_override
+ logger.info(
+ "Overriding num_gpu_blocks=%d with "
+ "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
+ num_blocks = num_gpu_blocks_override
+
+ logger.info("# GPU blocks: %d", num_blocks)
+ max_concurrency = (num_blocks * vllm_config.cache_config.block_size /
+ vllm_config.model_config.max_model_len)
+ logger.info("Maximum concurrency for %s tokens per request: %.2fx",
+ vllm_config.model_config.max_model_len, max_concurrency)
+
+ per_layer_size = page_size * num_blocks
+
+ kv_cache_config = KVCacheConfig(
+ num_blocks=num_blocks,
+ tensors={
+ layer_name: KVCacheTensor(size=per_layer_size)
+ for layer_name in kv_cache_spec
+ },
+ groups=[[layer_name for layer_name in kv_cache_spec]],
+ kv_cache_spec=kv_cache_spec)
+ return kv_cache_config
+
+
+def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
+ available_memory: int) -> KVCacheConfig:
+ """
+ Generates the KV cache configuration for a model
+ TODO: support hybrid models with more than one type of KV cache.
+
+ Args:
+ vllm_config: The global VllmConfig
+ kv_cache_spec: The kv cache spec of the model
+ available_memory: Memory available for KV cache in bytes.
+
+ Returns:
+ The generated KVCacheConfig
+ """
+ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
+ if is_kv_cache_type_uniform(kv_cache_spec):
+ # KV cache of all layers are the same, which is true for most models.
+ # Allocate the same amount of memory for each layer.
+ return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
+ available_memory)
+ else:
+ raise NotImplementedError
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb5e83fe062747ee85abf83243938c3763c1a2eb
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py
@@ -0,0 +1,631 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from collections import deque
+from dataclasses import dataclass
+from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
+ Tuple, Union)
+
+from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
+from vllm.logger import init_logger
+from vllm.sampling_params import SamplingParams
+from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
+ compute_encoder_budget)
+from vllm.v1.core.kv_cache_manager import KVCacheManager
+from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
+from vllm.v1.metrics.stats import SchedulerStats
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.request import Request, RequestStatus
+
+if TYPE_CHECKING:
+ from vllm.multimodal import MultiModalKwargs
+ from vllm.multimodal.base import PlaceholderRange
+
+logger = init_logger(__name__)
+
+
+class Scheduler:
+
+ def __init__(
+ self,
+ scheduler_config: SchedulerConfig,
+ model_config: ModelConfig,
+ cache_config: CacheConfig,
+ lora_config: Optional[LoRAConfig],
+ ) -> None:
+ self.scheduler_config = scheduler_config
+ self.cache_config = cache_config
+ self.lora_config = lora_config
+ # TODO: Support LoRA.
+ assert lora_config is None, "V1 does not support LoRA yet."
+
+ # Scheduling constraints.
+ self.max_num_running_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_scheduled_tokens = \
+ self.scheduler_config.max_num_batched_tokens
+ self.max_model_len = self.scheduler_config.max_model_len
+
+ num_gpu_blocks = cache_config.num_gpu_blocks
+ assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
+ # Create the KV cache manager.
+ self.kv_cache_manager = KVCacheManager(
+ block_size=self.cache_config.block_size,
+ num_gpu_blocks=num_gpu_blocks,
+ max_model_len=self.max_model_len,
+ sliding_window=self.cache_config.sliding_window,
+ enable_caching=self.cache_config.enable_prefix_caching)
+ self.block_size = self.cache_config.block_size
+
+ # req_id -> Request
+ self.requests: Dict[str, Request] = {}
+ # Priority queues for requests.
+ self.waiting: Deque[Request] = deque()
+ self.running: List[Request] = []
+
+ # The request IDs that are finished in between the previous and the
+ # current steps. This is used to notify the workers about the finished
+ # requests so that they can free the cached states for those requests.
+ # This is flushed at the end of each scheduling step.
+ self.finished_req_ids: Set[str] = set()
+
+ # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
+ # them at each scheduling step.
+ # Request id -> CachedRequestData
+ self._cached_reqs_data: Dict[str, CachedRequestData] = {}
+
+ # Encoder-related.
+ # Calculate encoder cache size if applicable
+ # NOTE: For now we use the same budget for both compute and space.
+ # This can be changed when we make encoder cache for embedding caching
+ # across requests.
+ encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
+ model_config=model_config,
+ scheduler_config=scheduler_config,
+ )
+
+ # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
+ # projector if needed). Currently, we assume that the encoder also
+ # has the Transformer architecture (e.g., ViT).
+ self.max_num_encoder_input_tokens = encoder_compute_budget
+ # NOTE: For the models without encoder (e.g., text-only models),
+ # the encoder cache will not be initialized because cache size is 0
+ # for these models.
+ self.encoder_cache_manager = EncoderCacheManager(
+ cache_size=encoder_cache_size)
+
+ def schedule(self) -> "SchedulerOutput":
+ # NOTE(woosuk) on the scheduling algorithm:
+ # There's no "decoding phase" nor "prefill phase" in the scheduler.
+ # Each request just has the num_computed_tokens and num_tokens,
+ # which is equal to len(prompt_token_ids) + len(output_token_ids).
+ # At each step, the scheduler tries to assign tokens to the requests
+ # so that each request's num_computed_tokens can catch up its
+ # num_tokens. This is general enough to cover chunked prefills,
+ # prefix caching, and the "jump decoding" optimization in the future.
+
+ scheduled_new_reqs: List[Request] = []
+ scheduled_resumed_reqs: List[Request] = []
+ scheduled_running_reqs: List[Request] = []
+ preempted_reqs: List[Request] = []
+
+ req_to_new_block_ids: Dict[str, List[int]] = {}
+ num_scheduled_tokens: Dict[str, int] = {}
+ token_budget = self.max_num_scheduled_tokens
+ # Encoder-related.
+ scheduled_encoder_inputs: Dict[str, List[int]] = {}
+ encoder_budget = self.max_num_encoder_input_tokens
+
+ # First, schedule the RUNNING requests.
+ req_index = 0
+ while req_index < len(self.running) and token_budget > 0:
+ request = self.running[req_index]
+ num_new_tokens = request.num_tokens - request.num_computed_tokens
+ num_new_tokens = min(num_new_tokens, token_budget)
+ assert num_new_tokens > 0
+
+ # Schedule encoder inputs.
+ encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
+ self._try_schedule_encoder_inputs(request,
+ request.num_computed_tokens,
+ num_new_tokens,
+ encoder_budget))
+ if num_new_tokens == 0:
+ # The request cannot be scheduled because the encoder budget
+ # or the encoder cache is exhausted.
+ # NOTE(woosuk): Here, by doing `continue` instead of `break`,
+ # we do not strictly follow the FCFS scheduling policy and
+ # allow the lower-priority requests to be scheduled.
+ req_index += 1
+ continue
+
+ while True:
+ new_blocks = self.kv_cache_manager.allocate_slots(
+ request, num_new_tokens)
+ if new_blocks is None:
+ # The request cannot be scheduled.
+ # Preempt the lowest-priority request.
+ preempted_req = self.running.pop()
+ self.kv_cache_manager.free(preempted_req)
+ preempted_req.status = RequestStatus.PREEMPTED
+ preempted_req.num_computed_tokens = 0
+
+ self.waiting.appendleft(preempted_req)
+ preempted_reqs.append(preempted_req)
+ if preempted_req == request:
+ # No more request to preempt.
+ can_schedule = False
+ break
+ else:
+ # The request can be scheduled.
+ can_schedule = True
+ break
+ if not can_schedule:
+ break
+ assert new_blocks is not None
+
+ # Schedule the request.
+ scheduled_running_reqs.append(request)
+ req_to_new_block_ids[request.request_id] = [
+ b.block_id for b in new_blocks
+ ]
+ num_scheduled_tokens[request.request_id] = num_new_tokens
+ token_budget -= num_new_tokens
+ req_index += 1
+
+ # Encoder-related.
+ if encoder_inputs_to_schedule:
+ scheduled_encoder_inputs[request.request_id] = (
+ encoder_inputs_to_schedule)
+ # Allocate the encoder cache.
+ for i in encoder_inputs_to_schedule:
+ self.encoder_cache_manager.allocate(request, i)
+ encoder_budget = new_encoder_budget
+
+ # Next, schedule the WAITING requests.
+ if not preempted_reqs:
+ while self.waiting and token_budget > 0:
+ if len(self.running) == self.max_num_running_reqs:
+ break
+
+ request = self.waiting[0]
+ # Get already-cached tokens.
+ computed_blocks, num_computed_tokens = \
+ self.kv_cache_manager.get_computed_blocks(request)
+ # Number of tokens to be scheduled.
+ # We use `request.num_tokens` instead of
+ # `request.num_prompt_tokens` to consider the resumed requests,
+ # which have output tokens.
+ num_new_tokens = request.num_tokens - num_computed_tokens
+ if num_new_tokens == 0:
+ # This happens when prompt length is divisible by the block
+ # size and all blocks are cached. Now we force to recompute
+ # the last block. Note that we have to re-compute an entire
+ # block because allocate_slots() assumes num_computed_tokens
+ # is always a multiple of the block size. This limitation
+ # can potentially be removed in the future to slightly
+ # improve the performance.
+ num_computed_tokens -= self.block_size
+ num_new_tokens = self.block_size
+ computed_blocks.pop()
+ num_new_tokens = min(num_new_tokens, token_budget)
+ assert num_new_tokens > 0
+
+ # Schedule encoder inputs.
+ (encoder_inputs_to_schedule, num_new_tokens,
+ new_encoder_budget) = self._try_schedule_encoder_inputs(
+ request, num_computed_tokens, num_new_tokens,
+ encoder_budget)
+ if num_new_tokens == 0:
+ # The request cannot be scheduled.
+ break
+
+ new_blocks = self.kv_cache_manager.allocate_slots(
+ request, num_new_tokens, computed_blocks)
+ if new_blocks is None:
+ # The request cannot be scheduled.
+ break
+
+ self.waiting.popleft()
+ self.running.append(request)
+ if request.status == RequestStatus.WAITING:
+ scheduled_new_reqs.append(request)
+ elif request.status == RequestStatus.PREEMPTED:
+ scheduled_resumed_reqs.append(request)
+ else:
+ raise RuntimeError(
+ f"Invalid request status: {request.status}")
+
+ req_to_new_block_ids[request.request_id] = [
+ b.block_id for b in computed_blocks + new_blocks
+ ]
+ num_scheduled_tokens[request.request_id] = num_new_tokens
+ token_budget -= num_new_tokens
+ request.status = RequestStatus.RUNNING
+ request.num_computed_tokens = num_computed_tokens
+
+ # Encoder-related.
+ if encoder_inputs_to_schedule:
+ scheduled_encoder_inputs[request.request_id] = (
+ encoder_inputs_to_schedule)
+ # Allocate the encoder cache.
+ for i in encoder_inputs_to_schedule:
+ self.encoder_cache_manager.allocate(request, i)
+ encoder_budget = new_encoder_budget
+
+ # Check if the scheduling constraints are satisfied.
+ total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
+ assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
+ assert token_budget >= 0
+ assert len(self.running) <= self.max_num_running_reqs
+ # Since some requests in the RUNNING queue may not be scheduled in
+ # this step, the total number of scheduled requests can be smaller than
+ # len(self.running).
+ assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
+ len(scheduled_running_reqs) <= len(self.running))
+
+ # Get the longest common prefix among all requests in the running queue.
+ # This can be potentially used for cascade attention.
+ num_common_prefix_blocks = 0
+ if self.running:
+ any_request = self.running[0]
+ num_common_prefix_blocks = (
+ self.kv_cache_manager.get_num_common_prefix_blocks(
+ any_request, len(self.running)))
+
+ # Construct the scheduler output.
+ new_reqs_data = [
+ NewRequestData.from_request(req,
+ req_to_new_block_ids[req.request_id],
+ req.num_computed_tokens)
+ for req in scheduled_new_reqs
+ ]
+ resumed_reqs_data = [
+ self._make_cached_request_data(
+ req,
+ req_to_new_block_ids[req.request_id],
+ req.num_computed_tokens,
+ resumed_from_preemption=True,
+ ) for req in scheduled_resumed_reqs
+ ]
+ running_reqs_data = [
+ self._make_cached_request_data(
+ req,
+ req_to_new_block_ids[req.request_id],
+ req.num_computed_tokens,
+ resumed_from_preemption=False,
+ ) for req in scheduled_running_reqs
+ ]
+ scheduler_output = SchedulerOutput(
+ scheduled_new_reqs=new_reqs_data,
+ scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
+ num_scheduled_tokens=num_scheduled_tokens,
+ total_num_scheduled_tokens=total_num_scheduled_tokens,
+ scheduled_encoder_inputs=scheduled_encoder_inputs,
+ num_common_prefix_blocks=num_common_prefix_blocks,
+ # finished_req_ids is an existing state in the scheduler,
+ # instead of being newly scheduled in this step.
+ # It contains the request IDs that are finished in between
+ # the previous and the current steps.
+ finished_req_ids=self.finished_req_ids,
+ free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
+ )
+
+ self.finished_req_ids = set()
+ return scheduler_output
+
+ def _make_cached_request_data(
+ self,
+ request: Request,
+ new_block_ids: List[int],
+ num_computed_tokens: int,
+ resumed_from_preemption: bool,
+ ) -> "CachedRequestData":
+ # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
+ # them at each scheduling step.
+ if request.request_id in self._cached_reqs_data:
+ req_data = self._cached_reqs_data[request.request_id]
+ req_data.resumed_from_preemption = resumed_from_preemption
+ req_data.new_block_ids = new_block_ids
+ req_data.num_computed_tokens = num_computed_tokens
+ else:
+ req_data = CachedRequestData.from_request(request,
+ resumed_from_preemption,
+ new_block_ids,
+ num_computed_tokens)
+ self._cached_reqs_data[request.request_id] = req_data
+ return req_data
+
+ def _try_schedule_encoder_inputs(
+ self,
+ request: Request,
+ num_computed_tokens: int,
+ num_new_tokens: int,
+ encoder_budget: int,
+ ) -> Tuple[List[int], int, int]:
+ """
+ Determine which encoder inputs need to be scheduled in the current step,
+ and update `num_new_tokens` and encoder token budget accordingly.
+
+ An encoder input will be scheduled if:
+ - Its output tokens overlap with the range of tokens being computed
+ in this step, i.e.,
+ [num_computed_tokens, num_computed_tokens + num_new_tokens).
+ - It is not already computed and stored in the encoder cache.
+ - There is sufficient encoder token budget to process it.
+ - The encoder cache has space to store it.
+
+ If an encoder input cannot be scheduled due to cache or budget
+ limitations, the method adjusts `num_new_tokens` to schedule only the
+ decoder tokens up to just before the unschedulable encoder input.
+ """
+ if not request.has_encoder_inputs():
+ return [], num_new_tokens, encoder_budget
+
+ encoder_inputs_to_schedule: List[int] = []
+ mm_positions = request.mm_positions
+ assert mm_positions is not None
+ assert len(mm_positions) > 0
+ for i, pos_info in enumerate(mm_positions):
+ start_pos = pos_info["offset"]
+ num_encoder_tokens = pos_info["length"]
+
+ # The encoder output is needed if the two ranges overlap:
+ # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
+ # [start_pos, start_pos + num_encoder_tokens)
+ if start_pos >= num_computed_tokens + num_new_tokens:
+ # The encoder input is not needed in this step.
+ break
+ if start_pos + num_encoder_tokens <= num_computed_tokens:
+ # The encoder input is already computed and stored
+ # in the decoder's KV cache.
+ continue
+
+ if self.encoder_cache_manager.has_cache(request, i):
+ # The encoder input is already computed and cached.
+ continue
+ if (not self.encoder_cache_manager.can_allocate(request, i)
+ or num_encoder_tokens > encoder_budget):
+ # The encoder cache is full or the encoder budget is exhausted.
+ # NOTE(woosuk): We assume that the encoder input tokens should
+ # be processed altogether, as the encoder usually uses
+ # bidirectional attention.
+ if num_computed_tokens < start_pos:
+ # We only schedule the decoder tokens just before the
+ # encoder input.
+ num_new_tokens = start_pos - num_computed_tokens
+ else:
+ # Because of prefix caching, num_computed_tokens is greater
+ # than start_pos even though its encoder input is not
+ # available. In this case, we can't schedule any token for
+ # the request in this step.
+ num_new_tokens = 0
+ break
+
+ encoder_budget -= num_encoder_tokens
+ encoder_inputs_to_schedule.append(i)
+ return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
+
+ def update_from_output(
+ self,
+ scheduler_output: "SchedulerOutput",
+ model_runner_output: "ModelRunnerOutput",
+ ) -> EngineCoreOutputs:
+ # NOTE(woosuk): This method doesn't consider speculative decoding.
+ sampled_token_ids = model_runner_output.sampled_token_ids
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens
+ new_running: List[Request] = []
+ outputs: List[EngineCoreOutput] = []
+
+ # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
+ # loop can be a performance bottleneck. We should do our best to avoid
+ # expensive operations inside the loop.
+ for request in self.running:
+ req_id = request.request_id
+ num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
+ if num_tokens_scheduled == 0:
+ # The request was not scheduled in this step.
+ new_running.append(request)
+ continue
+
+ request.num_computed_tokens += num_tokens_scheduled
+ # When the request's num_computed_tokens catches up its num_tokens,
+ # the request generates output tokens. Otherwise, we ignore the
+ # sampler output for the request.
+ assert request.num_computed_tokens <= request.num_tokens
+
+ cached_encoder_input_ids = (
+ self.encoder_cache_manager.get_cached_input_ids(request))
+ # OPTIMIZATION: Avoid list(set) if the set is empty.
+ if cached_encoder_input_ids:
+ for input_id in list(cached_encoder_input_ids):
+ start_pos = request.mm_positions[input_id]["offset"]
+ num_tokens = request.mm_positions[input_id]["length"]
+ if start_pos + num_tokens <= request.num_computed_tokens:
+ # The encoder output is already processed and stored
+ # in the decoder's KV cache.
+ self.encoder_cache_manager.free_encoder_input(
+ request, input_id)
+
+ if request.num_computed_tokens == request.num_tokens:
+ req_index = model_runner_output.req_id_to_index[req_id]
+ # NOTE(woosuk): Currently, we assume that each request
+ # generates at most one token at each step.
+ token_id = sampled_token_ids[req_index]
+ request.append_output_token_ids(token_id)
+ num_new_tokens = 1
+ # TODO: Update the KV cache manager for prefix caching.
+
+ # Check for stop and update request state.
+ # This must be called before we make the EngineCoreOutput.
+ stopped = self._check_stop(request)
+ if stopped:
+ self._free_request(request)
+
+ # Add EngineCoreOutput for this Request.
+ output = EngineCoreOutput(
+ request_id=req_id,
+ new_token_ids=request.output_token_ids[-num_new_tokens:],
+ finished=request.is_finished(),
+ finish_reason=request.get_finished_reason(),
+ stop_reason=request.stop_reason)
+ outputs.append(output)
+
+ # Breakout of the loop.
+ if stopped:
+ continue
+
+ new_running.append(request)
+ self.running = new_running
+ return EngineCoreOutputs(
+ outputs=outputs,
+ scheduler_stats=self.make_stats(),
+ )
+
+ def _check_stop(self, request: Request) -> bool:
+ if (request.num_tokens >= self.max_model_len
+ or request.num_output_tokens >= request.max_tokens):
+ request.status = RequestStatus.FINISHED_LENGTH_CAPPED
+ return True
+
+ sampling_params = request.sampling_params
+ last_token_id = request.output_token_ids[-1]
+ if (not sampling_params.ignore_eos
+ and last_token_id == request.eos_token_id):
+ request.status = RequestStatus.FINISHED_STOPPED
+ return True
+
+ if last_token_id in (sampling_params.stop_token_ids or ()):
+ request.status = RequestStatus.FINISHED_STOPPED
+ request.stop_reason = last_token_id
+ return True
+ return False
+
+ def add_request(self, request: Request) -> None:
+ self.waiting.append(request)
+ self.requests[request.request_id] = request
+
+ def finish_requests(
+ self,
+ request_ids: Union[str, Iterable[str]],
+ finished_status: RequestStatus,
+ ) -> None:
+ """Handles the finish signal from outside the scheduler.
+
+ For example, the API server can abort a request when the client
+ disconnects.
+ """
+ assert RequestStatus.is_finished(finished_status)
+ if isinstance(request_ids, str):
+ request_ids = (request_ids, )
+ request_ids = set(request_ids)
+
+ for req_id in request_ids:
+ request = self.requests.get(req_id)
+ if request is None:
+ # Invalid request ID.
+ continue
+
+ if request.status == RequestStatus.RUNNING:
+ self.running.remove(request)
+ else:
+ self.waiting.remove(request)
+ request.status = finished_status
+ self._free_request(request)
+
+ def _free_request(self, request: Request) -> None:
+ assert request.is_finished()
+ self.kv_cache_manager.free(request)
+ self.encoder_cache_manager.free(request)
+ self._cached_reqs_data.pop(request.request_id, None)
+ del self.requests[request.request_id]
+ self.finished_req_ids.add(request.request_id)
+
+ def get_num_unfinished_requests(self) -> int:
+ return len(self.waiting) + len(self.running)
+
+ def has_unfinished_requests(self) -> bool:
+ return self.get_num_unfinished_requests() > 0
+
+ def reset_prefix_cache(self) -> bool:
+ return self.kv_cache_manager.reset_prefix_cache()
+
+ def make_stats(self) -> SchedulerStats:
+ return SchedulerStats(
+ num_running_reqs=len(self.running),
+ num_waiting_reqs=len(self.waiting),
+ gpu_cache_usage=self.kv_cache_manager.usage,
+ )
+
+
+@dataclass
+class NewRequestData:
+
+ req_id: str
+ prompt_token_ids: List[int]
+ prompt: Optional[str]
+ mm_inputs: List["MultiModalKwargs"]
+ mm_hashes: List[str]
+ mm_positions: List["PlaceholderRange"]
+ sampling_params: SamplingParams
+ block_ids: List[int]
+ num_computed_tokens: int
+
+ @classmethod
+ def from_request(
+ cls,
+ request: Request,
+ block_ids: List[int],
+ num_computed_tokens: int,
+ ) -> "NewRequestData":
+ return cls(
+ req_id=request.request_id,
+ prompt_token_ids=request.prompt_token_ids,
+ prompt=request.prompt,
+ mm_inputs=request.mm_inputs,
+ mm_hashes=request.mm_hashes,
+ mm_positions=request.mm_positions,
+ sampling_params=request.sampling_params,
+ block_ids=block_ids,
+ num_computed_tokens=num_computed_tokens,
+ )
+
+
+@dataclass
+class CachedRequestData:
+
+ req_id: str
+ # If resumed_from_preemption is False, new_block_ids will be appended to
+ # the request's block IDs. If True, new_block_ids will be used as the
+ # request's block IDs instead of appending to the existing block IDs.
+ resumed_from_preemption: bool
+ new_block_ids: List[int]
+ num_computed_tokens: int
+
+ @classmethod
+ def from_request(
+ cls,
+ request: Request,
+ resumed_from_preemption: bool,
+ new_block_ids: List[int],
+ num_computed_tokens: int,
+ ) -> "CachedRequestData":
+ return cls(
+ req_id=request.request_id,
+ resumed_from_preemption=resumed_from_preemption,
+ new_block_ids=new_block_ids,
+ num_computed_tokens=num_computed_tokens,
+ )
+
+
+@dataclass
+class SchedulerOutput:
+
+ scheduled_new_reqs: List[NewRequestData]
+ scheduled_cached_reqs: List[CachedRequestData]
+
+ num_scheduled_tokens: Dict[str, int]
+ total_num_scheduled_tokens: int
+ scheduled_encoder_inputs: Dict[str, List[int]]
+ num_common_prefix_blocks: int
+
+ finished_req_ids: Set[str]
+ free_encoder_input_ids: List[Tuple[str, int]]
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48c0457236bee710090a550d1e9ee8b35edfaf00
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4052f3fbbd50f8599924f75cb20169790f43b35d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f65c16b53bb69b4c9acf5fd75fc0fac79e96eff
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py b/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac10d43eb0d54d634af49964234cdcdc00c3c8d9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py
@@ -0,0 +1,94 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Type
+
+from vllm.config import VllmConfig
+from vllm.executor.executor_base import ExecutorBase
+from vllm.executor.ray_distributed_executor import ( # noqa
+ RayDistributedExecutor as RayDistributedExecutorV0)
+from vllm.executor.uniproc_executor import ( # noqa
+ ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
+from vllm.executor.uniproc_executor import ( # noqa
+ UniProcExecutor as UniProcExecutorV0)
+from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
+from vllm.v1.outputs import ModelRunnerOutput
+
+
+class Executor(ExecutorBase):
+ """
+ Abstract class for v1 executors, mainly define some methods for v1.
+ For methods shared by v0 and v1, define them in ExecutorBase"""
+
+ @staticmethod
+ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
+ executor_class: Type[Executor]
+ parallel_config = vllm_config.parallel_config
+ distributed_executor_backend = (
+ parallel_config.distributed_executor_backend)
+ if distributed_executor_backend is None:
+ # If the user does not specify the distributed executor backend,
+ # we will choose the backend based on the world size.
+ if parallel_config.world_size > 1:
+ distributed_executor_backend = "mp"
+ else:
+ distributed_executor_backend = "uni"
+
+ if distributed_executor_backend == "ray":
+ executor_class = RayDistributedExecutor
+ elif distributed_executor_backend == "mp":
+ from vllm.v1.executor.multiproc_executor import MultiprocExecutor
+ executor_class = MultiprocExecutor
+ elif distributed_executor_backend == "uni":
+ executor_class = UniProcExecutor
+ elif distributed_executor_backend == "external_launcher":
+ # TODO: make v1 scheduling deterministic
+ # to support external launcher
+ executor_class = ExecutorWithExternalLauncher
+ else:
+ raise ValueError("Unknown distributed executor backend: "
+ f"{distributed_executor_backend}")
+ return executor_class
+
+ def initialize(self, kv_cache_config: KVCacheConfig) -> None:
+ """
+ Initialize the KV caches and begin the model execution loop of the
+ underlying workers.
+ """
+ self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
+ self.collective_rpc("compile_or_warm_up_model")
+
+ def determine_available_memory(self) -> int: # in bytes
+ output = self.collective_rpc("determine_available_memory")
+ # Since we use a shared centralized controller, we take the minimum
+ # memory size across all workers to make sure all the memory
+ # operators can be applied to all workers.
+ return min(output)
+
+ def get_kv_cache_spec(self) -> KVCacheSpec:
+ output = self.collective_rpc("get_kv_cache_spec")
+ for x in output:
+ assert x == output[0]
+ return output[0]
+
+ def execute_model(
+ self,
+ scheduler_output,
+ ) -> ModelRunnerOutput:
+ output = self.collective_rpc("execute_model",
+ args=(scheduler_output, ))
+ return output[0]
+
+ def profile(self, is_start: bool = True):
+ self.collective_rpc("profile", args=(is_start, ))
+
+
+class UniProcExecutor(UniProcExecutorV0, Executor):
+ pass
+
+
+class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
+ pass
+
+
+class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
+ pass
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py b/.venv/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f07172d8cd9bc280740c1d92caa2e6ca8f0607
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py
@@ -0,0 +1,378 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import pickle
+import signal
+import sys
+import time
+import weakref
+from dataclasses import dataclass
+from enum import Enum, auto
+from functools import partial
+from multiprocessing.process import BaseProcess
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cloudpickle
+import psutil
+import zmq
+
+from vllm.config import VllmConfig
+from vllm.distributed import (destroy_distributed_environment,
+ destroy_model_parallel)
+from vllm.distributed.device_communicators.shm_broadcast import (Handle,
+ MessageQueue)
+from vllm.executor.multiproc_worker_utils import (
+ _add_prefix, set_multiprocessing_worker_envs)
+from vllm.logger import init_logger
+from vllm.utils import (get_distributed_init_method, get_mp_context,
+ get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
+from vllm.v1.executor.abstract import Executor
+from vllm.worker.worker_base import WorkerWrapperBase
+
+logger = init_logger(__name__)
+
+POLLING_TIMEOUT_MS = 5000
+POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
+
+
+class MultiprocExecutor(Executor):
+
+ def _init_executor(self) -> None:
+ # Call self.shutdown at exit to clean up
+ # and ensure workers will be terminated.
+ self._finalizer = weakref.finalize(self, self.shutdown)
+
+ # The child processes will send SIGUSR1 when unrecoverable
+ # errors happen.
+ def sigusr1_handler(signum, frame):
+ logger.fatal(
+ "MulitprocExecutor got fatal signal from worker processes, "
+ "shutting down. See stack trace above for root cause issue.")
+ # Propagate error up to parent process.
+ parent_process = psutil.Process().parent()
+ parent_process.send_signal(signal.SIGUSR1)
+ self.shutdown()
+
+ signal.signal(signal.SIGUSR1, sigusr1_handler)
+
+ self.world_size = self.parallel_config.world_size
+ tensor_parallel_size = self.parallel_config.tensor_parallel_size
+ assert self.world_size == tensor_parallel_size, (
+ f"world_size ({self.world_size}) must be equal to the "
+ f"tensor_parallel_size ({tensor_parallel_size}). "
+ f"Pipeline parallelism is not yet implemented in v1")
+
+ # Set multiprocessing envs that are common to V0 and V1
+ set_multiprocessing_worker_envs(self.parallel_config)
+
+ # Multiprocessing-based executor does not support multi-node setting.
+ # Since it only works for single node, we can use the loopback address
+ # 127.0.0.1 for communication.
+ distributed_init_method = get_distributed_init_method(
+ "127.0.0.1", get_open_port())
+
+ # Initialize worker and set up message queues for SchedulerOutputs
+ # and ModelRunnerOutputs
+ self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size)
+ scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
+
+ # Create workers
+ self.workers: List[WorkerProcHandle] = []
+ for rank in range(self.world_size):
+ worker = WorkerProc.make_worker_process(self.vllm_config, rank,
+ rank,
+ distributed_init_method,
+ scheduler_output_handle)
+ self.workers.append(worker)
+
+ # Ensure message queues are ready. Will deadlock if re-ordered
+ # Must be kept consistent with the WorkerProc
+ self.rpc_broadcast_mq.wait_until_ready()
+ for w in self.workers:
+ w.worker_response_mq.wait_until_ready()
+
+ def collective_rpc(self,
+ method: Union[str, Callable],
+ timeout: Optional[float] = None,
+ args: Tuple = (),
+ kwargs: Optional[Dict] = None) -> List[Any]:
+ start_time = time.monotonic()
+ kwargs = kwargs or {}
+
+ # NOTE: If the args are heterogeneous, then we pack them into a list,
+ # and unpack them in the method of every worker, because every worker
+ # knows their own rank.
+ try:
+ if isinstance(method, str):
+ send_method = method
+ else:
+ send_method = cloudpickle.dumps(
+ method, protocol=pickle.HIGHEST_PROTOCOL)
+ self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
+
+ responses = [None] * self.world_size
+ for w in self.workers:
+ dequeue_timeout = timeout - (time.monotonic() - start_time
+ ) if timeout is not None else None
+ status, result = w.worker_response_mq.dequeue(
+ timeout=dequeue_timeout)
+
+ if status != WorkerProc.ResponseStatus.SUCCESS:
+ if isinstance(result, Exception):
+ raise result
+ else:
+ raise RuntimeError("Worker failed")
+
+ responses[w.rank] = result
+
+ return responses
+ except TimeoutError as e:
+ raise TimeoutError(f"RPC call to {method} timed out.") from e
+ except Exception as e:
+ # Re-raise any other exceptions
+ raise e
+
+ def _ensure_worker_termination(self):
+ """Ensure that all worker processes are terminated. Assumes workers have
+ received termination requests. Waits for processing, then sends
+ termination and kill signals if needed."""
+
+ def wait_for_termination(procs, timeout):
+ if not time:
+ # If we are in late stage shutdown, the interpreter may replace
+ # `time` with `None`.
+ return all(not proc.is_alive() for proc in procs)
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ if all(not proc.is_alive() for proc in procs):
+ return True
+ time.sleep(0.1)
+ return False
+
+ # Send SIGTERM if still running
+ active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
+ for p in active_procs:
+ p.terminate()
+ if not wait_for_termination(active_procs, 4):
+ # Send SIGKILL if still running
+ active_procs = [p for p in active_procs if p.is_alive()]
+ for p in active_procs:
+ p.kill()
+
+ self._cleanup_sockets()
+
+ def _cleanup_sockets(self):
+ for w in self.workers:
+ # Remove the zmq ipc socket file
+ socket_path = w.ready_path.replace("ipc://", "")
+ if os and os.path.exists(socket_path):
+ os.remove(socket_path)
+
+ def shutdown(self):
+ """Properly shut down the executor and its workers"""
+ if getattr(self, 'shutting_down', False):
+ self.shutting_down = True
+ for w in self.workers:
+ w.worker_response_mq = None
+ self._ensure_worker_termination()
+
+ self.rpc_broadcast_mq = None
+
+ def check_health(self) -> None:
+ self.collective_rpc("check_health", timeout=10)
+ return
+
+
+@dataclass
+class WorkerProcHandle:
+ proc: BaseProcess
+ rank: int
+ ready_path: str
+ worker_response_mq: MessageQueue # The worker process writes to this MQ
+
+
+class WorkerProc:
+ """Wrapper that runs one Worker in a separate process."""
+
+ READY_STR = "READY"
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ local_rank: int,
+ rank: int,
+ distributed_init_method: str,
+ input_shm_handle: Handle,
+ ready_path: str,
+ ):
+ self.rank = rank
+ wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
+ # TODO: move `init_worker` to executor level as a collective rpc call
+ all_kwargs: List[Dict] = [
+ {} for _ in range(vllm_config.parallel_config.world_size)
+ ]
+ all_kwargs[rank] = {
+ "vllm_config": vllm_config,
+ "local_rank": local_rank,
+ "rank": rank,
+ "distributed_init_method": distributed_init_method,
+ }
+ wrapper.init_worker(all_kwargs)
+ self.worker = wrapper.worker
+
+ pid = os.getpid()
+ _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
+ _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
+
+ # Initialize MessageQueue for receiving SchedulerOutput
+ self.rpc_broadcast_mq = MessageQueue.create_from_handle(
+ input_shm_handle, self.worker.rank)
+
+ # Initializes a message queue for sending the model output
+ self.worker_response_mq = MessageQueue(1, 1)
+ worker_response_mq_handle = self.worker_response_mq.export_handle()
+
+ # Send Readiness signal to EngineCore process.
+ with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
+ payload = pickle.dumps(worker_response_mq_handle,
+ protocol=pickle.HIGHEST_PROTOCOL)
+ ready_socket.send_string(WorkerProc.READY_STR)
+ ready_socket.send(payload)
+
+ self.worker.init_device()
+ self.worker.load_model()
+
+ @staticmethod
+ def make_worker_process(
+ vllm_config: VllmConfig,
+ local_rank: int,
+ rank: int,
+ distributed_init_method: str,
+ input_shm_handle, # Receive SchedulerOutput
+ ) -> WorkerProcHandle:
+ context = get_mp_context()
+
+ # ZMQ path for worker to send ready message and shm_broadcast handle
+ # back to core process.
+ ready_path = get_open_zmq_ipc_path()
+
+ process_kwargs = {
+ "vllm_config": vllm_config,
+ "local_rank": local_rank,
+ "rank": rank,
+ "distributed_init_method": distributed_init_method,
+ "input_shm_handle": input_shm_handle,
+ "ready_path": ready_path,
+ }
+ # Run EngineCore busy loop in background process.
+ proc = context.Process(target=WorkerProc.worker_main,
+ kwargs=process_kwargs,
+ daemon=True)
+ proc.start()
+
+ # Wait for startup
+ worker_response_mq_handle = WorkerProc.wait_for_startup(
+ proc, ready_path)
+
+ worker_response_mq = MessageQueue.create_from_handle(
+ worker_response_mq_handle, 0)
+
+ return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
+
+ def shutdown(self):
+ self.rpc_broadcast_mq = None
+ self.worker_response_mq = None
+ destroy_model_parallel()
+ destroy_distributed_environment()
+
+ @staticmethod
+ def worker_main(*args, **kwargs):
+ """ Worker initialization and execution loops.
+ This runs a background process """
+
+ # Signal handler used for graceful termination.
+ # SystemExit exception is only raised once to allow this and worker
+ # processes to terminate without error
+ shutdown_requested = False
+
+ def signal_handler(signum, frame):
+ nonlocal shutdown_requested
+ if not shutdown_requested:
+ shutdown_requested = True
+ raise SystemExit()
+
+ # Either SIGTERM or SIGINT will terminate the worker
+ signal.signal(signal.SIGTERM, signal_handler)
+ signal.signal(signal.SIGINT, signal_handler)
+
+ worker = None
+ try:
+ worker = WorkerProc(*args, **kwargs)
+
+ # Ensure message queues are ready. Will deadlock if re-ordered.
+ # Must be kept consistent with the Executor
+ worker.rpc_broadcast_mq.wait_until_ready()
+ worker.worker_response_mq.wait_until_ready()
+
+ worker.worker_busy_loop()
+
+ except SystemExit:
+ logger.debug("Worker interrupted.")
+
+ except Exception:
+ # worker_busy_loop sends exceptions exceptons to Executor
+ # for shutdown, but if there is an error in startup or an
+ # error with IPC itself, we need to alert the parent.
+ psutil.Process().parent().send_signal(signal.SIGUSR1)
+ raise
+
+ finally:
+ # Clean up once worker exits busy loop
+ if worker is not None:
+ worker.shutdown()
+ worker = None
+
+ @staticmethod
+ def wait_for_startup(
+ proc: BaseProcess,
+ ready_path: str,
+ ) -> Optional[Handle]:
+ """Wait until the Worker is ready."""
+ with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
+
+ # Wait for Worker to send READY.
+ while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
+ logger.debug("Waiting for WorkerProc to startup.")
+
+ if not proc.is_alive():
+ raise RuntimeError("WorkerProc failed to start.")
+
+ message = socket.recv_string()
+ assert message == WorkerProc.READY_STR
+ handle_frame = socket.recv(copy=False)
+ handle = pickle.loads(handle_frame.buffer)
+ return handle
+
+ class ResponseStatus(Enum):
+ SUCCESS = auto()
+ FAILURE = auto()
+
+ def worker_busy_loop(self):
+ """Main busy loop for Multiprocessing Workers"""
+ while True:
+ method, args, kwargs = self.rpc_broadcast_mq.dequeue()
+
+ try:
+ if isinstance(method, str):
+ func = getattr(self.worker, method)
+ elif isinstance(method, bytes):
+ func = partial(cloudpickle.loads(method), self.worker)
+ output = func(*args, **kwargs)
+ except Exception as e:
+ self.worker_response_mq.enqueue(
+ (WorkerProc.ResponseStatus.FAILURE, e))
+ logger.exception("WorkerProc hit an exception: %s", exc_info=e)
+ continue
+
+ self.worker_response_mq.enqueue(
+ (WorkerProc.ResponseStatus.SUCCESS, output))
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/kv_cache_interface.py b/.venv/lib/python3.11/site-packages/vllm/v1/kv_cache_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..eddfb5949ebe65c3dd5f8ae72a8aad06ee818703
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/kv_cache_interface.py
@@ -0,0 +1,113 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass
+from typing import Dict, List
+
+import torch
+
+from vllm.logger import init_logger
+from vllm.utils import cdiv, get_dtype_size
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class KVCacheSpecBase:
+ """
+ A base class for specifying the KV cache format of one layer.
+ """
+
+ # number of tokens in a block
+ block_size: int
+
+ @property
+ def type_id(self) -> str:
+ """
+ The type identifier of this KV cache.
+ Return different strings for layers with different KV cache type (e.g.,
+ different number of tokens like full attention vs sliding window
+ attention, different KV cache size per token like layers with different
+ number of heads)
+
+ Returns:
+ The type identifier of this KV cache.
+ """
+ raise NotImplementedError
+
+ @property
+ def page_size_bytes(self) -> int:
+ """
+ The size of a page with `block_size` tokens in bytes.
+
+ Returns:
+ The page size
+ """
+ raise NotImplementedError
+
+ def bytes_for_tokens(self, num_tokens: int) -> int:
+ """
+ The KV cache size for `num_tokens` tokens in bytes. Returns the real
+ memory size after padding `num_tokens` to full blocks.
+
+ Returns:
+ The KV cache size
+ """
+ raise NotImplementedError
+
+
+@dataclass
+class FullAttentionSpec(KVCacheSpecBase):
+ num_kv_heads: int
+ head_size: int
+ dtype: torch.dtype
+
+ @property
+ def type_id(self) -> str:
+ return f"full_attention_{self.block_size}_{self.page_size_bytes}"
+
+ @property
+ def page_size_bytes(self) -> int:
+ return 2 * self.block_size * self.num_kv_heads * self.head_size \
+ * get_dtype_size(self.dtype)
+
+ def bytes_for_tokens(self, num_tokens: int) -> int:
+ return cdiv(num_tokens, self.block_size) * self.page_size_bytes
+
+
+KVCacheSpec = Dict[str, KVCacheSpecBase]
+
+
+@dataclass
+class KVCacheTensor:
+ """
+ A dataclass for specifying how the workers should initialize the KV cache
+ for a layer. Only contains the size of KV cache for that layer for now. Will
+ be extended to support multiple layers sharing the same memory pool.
+ """
+ size: int # The size of KV cache Tensor in bytes
+
+
+@dataclass
+class KVCacheConfig:
+ """
+ The KV cache configuration of a model.
+ """
+ """The number of KV cache blocks"""
+ num_blocks: int
+ """layer_name -> how to initialize KV cache for that layer"""
+ tensors: Dict[str, KVCacheTensor]
+ """
+ A list of kv-cache groups. Each group includes a set of layers with
+ the same kv-cache spec, and the total page_size of layers inside a group
+ is same across all groups (as the KVCacheManager only supports allocating
+ pages of the same size). For example:
+ 1. A model only uses full attention: one group with all layers in the model.
+ 2. (not implemented yet) A model with the same number of full attention
+ layers and sliding window attention layers: two groups, one for full
+ attention layers and one for sliding window attention layers.
+ 3. (not implemented yet) A model with 2 full attention layers and 4 sliding
+ window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
+ """
+ groups: List[List[str]]
+ """the KVCacheSpec of the model"""
+ kv_cache_spec: KVCacheSpec
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/outputs.py b/.venv/lib/python3.11/site-packages/vllm/v1/outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e82bffd7e5c9dfff0a077ecb9c34a3cad4c9c53
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/outputs.py
@@ -0,0 +1,41 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import torch
+
+
+@dataclass
+class SamplerOutput:
+
+ # [num_reqs]
+ sampled_token_ids: torch.Tensor
+
+ # [num_reqs, max_num_logprobs + 1]
+ logprob_token_ids: Optional[torch.Tensor]
+ # [num_reqs, max_num_logprobs + 1]
+ logprobs: Optional[torch.Tensor]
+
+ # TODO: Support prompt logprobs.
+ prompt_logprob_token_ids: Optional[torch.Tensor]
+ prompt_logprobs: Optional[torch.Tensor]
+
+
+# ModelRunnerOutput is serialized and sent to the scheduler process.
+# This is expensive for torch.Tensor so prefer to use List instead.
+@dataclass
+class ModelRunnerOutput:
+
+ # [num_reqs]
+ req_ids: List[str]
+ # req_id -> index
+ req_id_to_index: Dict[str, int]
+
+ # [num_reqs]
+ sampled_token_ids: List[int]
+
+ # [num_reqs, max_num_logprobs + 1]
+ logprob_token_ids_cpu: Optional[torch.Tensor]
+ # [num_reqs, max_num_logprobs + 1]
+ logprobs_cpu: Optional[torch.Tensor]
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/request.py b/.venv/lib/python3.11/site-packages/vllm/v1/request.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b39ea615d20f6a1cfb3c3b91e2d211008a85ba
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/request.py
@@ -0,0 +1,166 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import enum
+from typing import TYPE_CHECKING, List, Optional, Union
+
+from vllm.lora.request import LoRARequest
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import RequestMetrics
+from vllm.v1.engine import EngineCoreRequest, FinishReason
+from vllm.v1.utils import ConstantList
+
+if TYPE_CHECKING:
+ from vllm.multimodal import MultiModalKwargs
+ from vllm.multimodal.inputs import PlaceholderRange
+ from vllm.v1.core.kv_cache_utils import BlockHashType
+
+
+class Request:
+
+ def __init__(
+ self,
+ request_id: str,
+ prompt: Optional[str],
+ prompt_token_ids: List[int],
+ multi_modal_inputs: Optional[List["MultiModalKwargs"]],
+ multi_modal_hashes: Optional[List[str]],
+ multi_modal_placeholders: Optional[List["PlaceholderRange"]],
+ sampling_params: SamplingParams,
+ eos_token_id: Optional[int],
+ arrival_time: float,
+ lora_request: Optional[LoRARequest] = None,
+ ) -> None:
+ self.request_id = request_id
+ self.sampling_params = sampling_params
+ # Because of LoRA, the eos token id can be different for each request.
+ self.eos_token_id = eos_token_id
+ self.metrics = RequestMetrics(arrival_time=arrival_time,
+ last_token_time=arrival_time,
+ first_scheduled_time=None,
+ first_token_time=None,
+ time_in_queue=None)
+ self.lora_request = lora_request
+
+ self.status = RequestStatus.WAITING
+ self.stop_reason: Union[int, str, None] = None
+ assert sampling_params.max_tokens is not None
+ self.max_tokens = sampling_params.max_tokens
+
+ self.prompt = prompt
+ self.prompt_token_ids = prompt_token_ids
+ self.num_prompt_tokens = len(self.prompt_token_ids)
+ self._output_token_ids: List[int] = []
+ self._all_token_ids: List[int] = self.prompt_token_ids.copy()
+ self.num_computed_tokens = 0
+
+ # Multi-modal related
+ self.mm_positions = multi_modal_placeholders or []
+ self.mm_inputs = multi_modal_inputs or []
+ self.mm_hashes: List[str] = multi_modal_hashes or []
+
+ # Sanity check
+ assert len(self.mm_inputs) == len(self.mm_positions)
+ if self.mm_hashes:
+ assert len(self.mm_inputs) == len(self.mm_hashes)
+
+ # Cache the computed kv block hashes of the request to avoid
+ # recomputing.
+ self._kv_block_hashes: List[BlockHashType] = []
+ self.kv_block_hashes = ConstantList(self._kv_block_hashes)
+
+ # Read-only views
+ # Prevent directly appending to the these lists since
+ # they should also be updated simultaneously.
+ self.output_token_ids = ConstantList(self._output_token_ids)
+ self.all_token_ids = ConstantList(self._all_token_ids)
+
+ @classmethod
+ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
+ return cls(
+ request_id=request.request_id,
+ prompt=request.prompt,
+ prompt_token_ids=request.prompt_token_ids,
+ multi_modal_inputs=request.mm_inputs,
+ multi_modal_hashes=request.mm_hashes,
+ multi_modal_placeholders=request.mm_placeholders,
+ sampling_params=request.sampling_params,
+ eos_token_id=request.eos_token_id,
+ arrival_time=request.arrival_time,
+ lora_request=request.lora_request,
+ )
+
+ def append_output_token_ids(
+ self,
+ token_ids: Union[int, List[int]],
+ ) -> None:
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ self._output_token_ids.extend(token_ids)
+ self._all_token_ids.extend(token_ids)
+
+ @property
+ def num_tokens(self) -> int:
+ return len(self._all_token_ids)
+
+ @property
+ def num_output_tokens(self) -> int:
+ return len(self._output_token_ids)
+
+ def is_finished(self) -> bool:
+ return RequestStatus.is_finished(self.status)
+
+ def get_finished_reason(self) -> Union[FinishReason, None]:
+ return RequestStatus.get_finished_reason(self.status)
+
+ def has_encoder_inputs(self) -> bool:
+ return len(self.mm_inputs) > 0
+
+ @property
+ def num_encoder_inputs(self) -> int:
+ return len(self.mm_positions)
+
+ def get_num_encoder_tokens(self, input_id: int) -> int:
+ assert input_id < len(self.mm_positions)
+ num_tokens = self.mm_positions[input_id]["length"]
+ return num_tokens
+
+ def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
+ self._kv_block_hashes = value
+ self.kv_block_hashes = ConstantList(self._kv_block_hashes)
+
+ def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
+ self._kv_block_hashes.append(block_hash)
+
+
+class RequestStatus(enum.IntEnum):
+ """Status of a request."""
+ WAITING = 0
+ RUNNING = 1
+ PREEMPTED = 2
+ # Note: anything after PREEMPTED (2) will be considered
+ # as a finished status.
+ FINISHED_STOPPED = 3
+ FINISHED_LENGTH_CAPPED = 4
+ FINISHED_ABORTED = 5
+ FINISHED_IGNORED = 6
+
+ @staticmethod
+ def is_finished(status: "RequestStatus") -> bool:
+ return status > RequestStatus.PREEMPTED
+
+ @staticmethod
+ def get_finished_reason(
+ status: "RequestStatus") -> Union[FinishReason, None]:
+ return _FINISHED_REASON_MAP.get(status)
+
+
+# Mapping of finished statuses to their finish reasons.
+# NOTE: The ignored requests are the requests whose prompt lengths
+# are longer than the model's length cap. Therefore, the stop
+# reason should also be "length" as in OpenAI API.
+_FINISHED_REASON_MAP = {
+ RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
+ RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
+ RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
+ RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
+}
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e7683c151598a9f5317f12c93e204ecd21413b1
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/metadata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/metadata.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3658a6631dc04da622618de15c188352232940d4
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/metadata.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/sampler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60a763e4dff05ced5882b81c4f54e5407853afc0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/sampler.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/metadata.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e54de34548ddfe8a631d78979b91c83dffc2e9f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/metadata.py
@@ -0,0 +1,33 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set
+
+import torch
+
+
+@dataclass
+class SamplingMetadata:
+
+ temperature: torch.Tensor
+ all_greedy: bool
+ all_random: bool
+
+ top_p: torch.Tensor
+ top_k: torch.Tensor
+ no_top_p: bool
+ no_top_k: bool
+
+ generators: Dict[int, torch.Generator]
+
+ max_num_logprobs: int
+
+ no_penalties: bool
+ prompt_token_ids: Optional[torch.Tensor]
+ frequency_penalties: torch.Tensor
+ presence_penalties: torch.Tensor
+ repetition_penalties: torch.Tensor
+
+ output_token_ids: List[List[int]]
+ min_tokens: List[int]
+ stop_token_ids: List[Set[int]]
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99b6c0884ce0c82ad1f707dc07edccddef4aa63f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/penalties.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/penalties.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f1aa42895152c40804c6dc507e1fa412ccb5f4c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/penalties.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/topk_topp_sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/topk_topp_sampler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c87dec7d51f69cb463730b59c4217506cbc7b08
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/topk_topp_sampler.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/penalties.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/penalties.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba368b44ab9cc02c8cb281049f84bfd66705081a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/penalties.py
@@ -0,0 +1,61 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List, Set, Tuple
+
+import torch
+
+from vllm.model_executor.layers.utils import apply_penalties
+from vllm.utils import is_pin_memory_available, make_tensor_with_pad
+
+
+def apply_min_token_penalties(logits: torch.Tensor,
+ output_token_ids: List[List[int]],
+ stop_token_ids: List[Set[int]],
+ min_tokens: List[int]) -> None:
+ """
+ Applies minimum token penalty by setting the logits of the stop tokens
+ to -inf.
+ """
+ min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
+ for index, min_token in enumerate(min_tokens):
+ if len(output_token_ids[index]) < min_token:
+ for stop_token_id in stop_token_ids[index]:
+ min_tokens_logits_to_penalize.append((index, stop_token_id))
+ if min_tokens_logits_to_penalize:
+ logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
+
+
+def apply_all_penalties(
+ logits: torch.Tensor,
+ prompt_token_ids: torch.Tensor,
+ presence_penalties: torch.Tensor,
+ frequency_penalties: torch.Tensor,
+ repetition_penalties: torch.Tensor,
+ output_token_ids: List[List[int]],
+) -> torch.Tensor:
+ """
+ Applies presence, frequency and repetition penalties to the logits.
+ """
+ _, vocab_size = logits.shape
+ output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
+ logits.device)
+ return apply_penalties(logits, prompt_token_ids, output_tokens_t,
+ presence_penalties, frequency_penalties,
+ repetition_penalties)
+
+
+def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
+ device: torch.device) -> torch.Tensor:
+ """
+ Convert the different list data structures to tensors.
+ """
+ output_tokens_tensor = make_tensor_with_pad(
+ output_token_ids,
+ # Use the value of vocab_size as a pad since we don't have a
+ # token_id of this value.
+ pad=vocab_size,
+ device="cpu",
+ dtype=torch.int64,
+ pin_memory=is_pin_memory_available(),
+ )
+ return output_tokens_tensor.to(device, non_blocking=True)
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/topk_topp_sampler.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/topk_topp_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..27431001e3e7a2f0c78baa9d6c20900bb107fe3a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/topk_topp_sampler.py
@@ -0,0 +1,203 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Dict
+
+import torch
+import torch.nn as nn
+
+from vllm import envs
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+
+logger = init_logger(__name__)
+
+try:
+ import flashinfer.sampling
+ is_flashinfer_available = True
+except ImportError:
+ is_flashinfer_available = False
+
+
+class TopKTopPSampler(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ if current_platform.is_cuda:
+ if is_flashinfer_available:
+ if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
+ # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
+ # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
+ # default it is unused). For backward compatibility, we set
+ # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
+ # interpret it differently in V0 and V1 samplers: In V0,
+ # None means False, while in V1, None means True. This is
+ # why we use the condition
+ # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
+ logger.info("Using FlashInfer for top-p & top-k sampling.")
+ self.forward = self.forward_cuda
+ else:
+ logger.warning(
+ "FlashInfer is available, but it is not enabled. "
+ "Falling back to the PyTorch-native implementation of "
+ "top-p & top-k sampling. For the best performance, "
+ "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
+ self.forward = self.forward_native
+ else:
+ logger.warning(
+ "FlashInfer is not available. Falling back to the PyTorch-"
+ "native implementation of top-p & top-k sampling. For the "
+ "best performance, please install FlashInfer.")
+ self.forward = self.forward_native
+ else:
+ self.forward = self.forward_native
+
+ def forward_native(
+ self,
+ logits: torch.Tensor,
+ generators: Dict[int, torch.Generator],
+ no_top_k: bool,
+ k: torch.Tensor,
+ no_top_p: bool,
+ p: torch.Tensor,
+ ) -> torch.Tensor:
+ """PyTorch-native implementation of top-k and top-p sampling."""
+ logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
+ probs = logits.softmax(dim=-1, dtype=torch.float32)
+ return random_sample(probs, generators)
+
+ def forward_cuda(
+ self,
+ logits: torch.Tensor,
+ generators: Dict[int, torch.Generator],
+ no_top_k: bool,
+ k: torch.Tensor,
+ no_top_p: bool,
+ p: torch.Tensor,
+ ) -> torch.Tensor:
+ """More optimized implementation for top-k and top-p sampling."""
+ probs = logits.softmax(dim=-1, dtype=torch.float32)
+ if no_top_k and no_top_p:
+ # We prefer `random_sample` over `flashinfer_sample` when sorting is
+ # not needed. This is because `random_sample` does not require
+ # CPU-GPU synchronization while `flashinfer_sample` does.
+ return random_sample(probs, generators)
+ return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
+
+
+def apply_top_k_top_p(
+ logits: torch.Tensor,
+ no_top_k: bool,
+ k: torch.Tensor,
+ no_top_p: bool,
+ p: torch.Tensor,
+) -> torch.Tensor:
+ """Apply top-k and top-p masks to the logits.
+
+ This function sorts the logits tensor, which can be slow for large batches.
+ """
+ if no_top_k and no_top_p:
+ return logits
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
+
+ if not no_top_k:
+ # Apply top-k.
+ top_k_mask = logits_sort.size(1) - k.to(torch.long)
+ # Get all the top_k values.
+ top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
+ top_k_mask = logits_sort < top_k_mask
+ logits_sort.masked_fill_(top_k_mask, -float("inf"))
+
+ if not no_top_p:
+ # Apply top-p.
+ probs_sort = logits_sort.softmax(dim=-1)
+ probs_sum = probs_sort.cumsum(dim=-1)
+ top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
+ # at least one
+ top_p_mask[:, -1] = False
+ logits_sort.masked_fill_(top_p_mask, -float("inf"))
+
+ # Re-sort the probabilities.
+ logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
+ return logits
+
+
+def random_sample(
+ probs: torch.Tensor,
+ generators: Dict[int, torch.Generator],
+) -> torch.Tensor:
+ """Randomly sample from the probabilities.
+
+ We use this function instead of torch.multinomial because torch.multinomial
+ causes CPU-GPU synchronization.
+ """
+ q = torch.empty_like(probs)
+ # NOTE(woosuk): To batch-process the requests without their own seeds,
+ # which is the common case, we first assume that every request does
+ # not have its own seed. Then, we overwrite the values for the requests
+ # that have their own seeds.
+ if len(generators) != probs.shape[0]:
+ q.exponential_()
+ if generators:
+ # TODO(woosuk): This can be slow because we handle each request
+ # one by one. Optimize this.
+ for i, generator in generators.items():
+ q[i].exponential_(generator=generator)
+ return probs.div_(q).argmax(dim=-1).view(-1)
+
+
+def flashinfer_sample(
+ probs: torch.Tensor,
+ no_top_k: bool,
+ k: torch.Tensor,
+ no_top_p: bool,
+ p: torch.Tensor,
+ generators: Dict[int, torch.Generator],
+) -> torch.Tensor:
+ """Sample from the probabilities using FlashInfer.
+
+ Statistically, this function is equivalent to the `random_sample` function.
+ However, this function is faster because it avoids sorting the logits tensor
+ via rejection sampling.
+
+ NOTE: The outputs of this function do not necessarily match the outputs of
+ the `random_sample` function. It only guarantees that the outputs are
+ statistically equivalent.
+
+ NOTE: This function includes CPU-GPU synchronization, while `random_sample`
+ does not. Call this function at the end of the forward pass to minimize
+ the synchronization overhead.
+ """
+ assert not (no_top_k and no_top_p)
+ max_top_k_round = 32
+ batch_size = probs.shape[0]
+ uniform_samples = torch.empty((max_top_k_round, batch_size),
+ device=probs.device)
+ if len(generators) != batch_size:
+ uniform_samples.uniform_()
+ if generators:
+ for i, generator in generators.items():
+ uniform_samples[:, i].uniform_(generator=generator)
+
+ if no_top_k:
+ # Top-p only.
+ next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
+ probs, uniform_samples, p, deterministic=True)
+ elif no_top_p:
+ # Top-k only.
+ next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
+ probs, uniform_samples, k, deterministic=True)
+ else:
+ # Both top-k and top-p.
+ next_token_ids, success = (
+ flashinfer.sampling.top_k_top_p_sampling_from_probs(
+ probs, uniform_samples, k, p, deterministic=True))
+
+ # NOTE: CPU-GPU synchronization happens here.
+ if not success.all():
+ if not no_top_k:
+ probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
+ if not no_top_p:
+ probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
+ next_token_ids = flashinfer.sampling.sampling_from_probs(
+ probs, uniform_samples[0], deterministic=True)
+ return next_token_ids.view(-1)
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/sampler.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da7498e0dae5d671e81b9fcace11ed992a8478d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/sampler.py
@@ -0,0 +1,136 @@
+# SPDX-License-Identifier: Apache-2.0
+"""A layer that samples the next tokens from the model's outputs."""
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from vllm.v1.outputs import SamplerOutput
+from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.sample.ops.penalties import (apply_all_penalties,
+ apply_min_token_penalties)
+from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
+
+_SAMPLING_EPS = 1e-5
+
+
+class Sampler(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.topk_topp_sampler = TopKTopPSampler()
+
+ def forward(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> SamplerOutput:
+ needs_logprobs = sampling_metadata.max_num_logprobs > 0
+ if needs_logprobs:
+ # NOTE(woosuk): Use the original logits (before any penalties or
+ # temperature scaling) for the top-k logprobs.
+ # This is different from the V0 sampler, which uses the logits that
+ # is used for sampling (after penalties and temperature scaling).
+ # NOTE: We compute logprobs first because the below ops may
+ # modify the logits tensor in-place (and we don't want to clone
+ # the logits tensor for memory efficiency).
+ topk_logprobs, topk_indices = self.get_topk_logprobs(
+ logits, sampling_metadata)
+ else:
+ topk_logprobs = None
+ topk_indices = None
+
+ # Use float32 for the logits.
+ logits = logits.to(torch.float32)
+ # Apply penalties (e.g., min_tokens, freq_penalties).
+ logits = self.apply_penalties(logits, sampling_metadata)
+ # Apply temperature.
+ logits = self.apply_temperature(logits, sampling_metadata.temperature)
+ # Sample the next token.
+ sampled = self.sample(logits, sampling_metadata)
+ # Use int32 to reduce the tensor size.
+ sampled = sampled.to(torch.int32)
+
+ sampler_output = SamplerOutput(
+ sampled_token_ids=sampled,
+ logprob_token_ids=topk_indices,
+ logprobs=topk_logprobs,
+ prompt_logprob_token_ids=None,
+ prompt_logprobs=None,
+ )
+ return sampler_output
+
+ def apply_temperature(
+ self,
+ logits: torch.Tensor,
+ temp: torch.Tensor,
+ ) -> torch.Tensor:
+ # Avoid division by zero.
+ temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
+ # Use in-place division to avoid creating a new tensor.
+ logits.div_(temp.unsqueeze(dim=1))
+ return logits
+
+ def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
+ return logits.argmax(dim=-1).view(-1)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> torch.Tensor:
+ assert not (sampling_metadata.all_greedy
+ and sampling_metadata.all_random)
+ if sampling_metadata.all_greedy:
+ return self.greedy_sample(logits)
+
+ random_sampled = self.topk_topp_sampler(
+ logits,
+ sampling_metadata.generators,
+ sampling_metadata.no_top_k,
+ sampling_metadata.top_k,
+ sampling_metadata.no_top_p,
+ sampling_metadata.top_p,
+ )
+ if sampling_metadata.all_random:
+ return random_sampled
+
+ greedy_sampled = self.greedy_sample(logits)
+ sampled = torch.where(
+ sampling_metadata.temperature < _SAMPLING_EPS,
+ greedy_sampled,
+ random_sampled,
+ )
+ return sampled
+
+ def get_topk_logprobs(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
+ # FIXME: Mask the sampled token_id, get topk logprobs,
+ # and concatenate the topk with the sampled token_id.
+ topk_logprobs, topk_indices = torch.topk(
+ logprobs, sampling_metadata.max_num_logprobs, dim=-1)
+ # Use int32 to reduce the tensor size.
+ topk_indices = topk_indices.to(torch.int32)
+ return topk_logprobs, topk_indices
+
+ def apply_penalties(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> torch.Tensor:
+ apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
+ sampling_metadata.stop_token_ids,
+ sampling_metadata.min_tokens)
+ if not sampling_metadata.no_penalties:
+ assert sampling_metadata.prompt_token_ids is not None
+ logits = apply_all_penalties(
+ logits, sampling_metadata.prompt_token_ids,
+ sampling_metadata.presence_penalties,
+ sampling_metadata.frequency_penalties,
+ sampling_metadata.repetition_penalties,
+ sampling_metadata.output_token_ids)
+ return logits
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/serial_utils.py b/.venv/lib/python3.11/site-packages/vllm/v1/serial_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1791dfa2b6325f2f41c34cd68fa86152aa9a7c06
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/serial_utils.py
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import pickle
+
+
+class PickleEncoder:
+
+ def encode(self, obj):
+ return pickle.dumps(obj)
+
+ def decode(self, data):
+ return pickle.loads(data)
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a652930fd79c2041e96cfd05393790bea4d0992
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/common.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69377f715885e9b6b39d56cf04dd5286c7b8509b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/common.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/common.py b/.venv/lib/python3.11/site-packages/vllm/v1/stats/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d382638bffd881c9dbe3ef5ec5a55c6fb17d7d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/stats/common.py
@@ -0,0 +1,453 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import time
+from dataclasses import dataclass
+from dataclasses import field as dataclass_field
+from enum import IntEnum
+from typing import ClassVar, Dict, List, Optional, Set
+
+import msgspec
+from msgspec import field as msgspec_field
+
+from vllm.sampling_params import SamplingParams
+
+
+class RequestStatsUpdate(
+ msgspec.Struct, # type: ignore
+ array_like=True,
+ omit_defaults=True,
+ gc=False):
+ """
+ An update to the request stats.
+
+ This represents a stats update at a specific timestamp with metadata
+ associated with the update.
+
+ NOTE: since there might be multiple processes generating updates at
+ different parts of the engine (e.g. input processor, scheduler, engine core,
+ etc.), we use the monotonic timestamp to record the update to compute any
+ intervals, and explicit wall-clock timestamp should be used for timestamps.
+
+ WARNING: This assumes stats are generated in a single machine. If there are
+ potentially multiple machines, one should always generate the stats updates
+ on one single machine or use something else.
+ """
+
+ class Type(IntEnum):
+ """See `RequestStats` for the lifecycle of a request."""
+
+ # Request arrived at the engine frontend.
+ ARRIVED = 0
+ # Input processed by the input processor.
+ INPUT_PROCESSED = 1
+ # Queued on the engine core.
+ QUEUED = 2
+ # Scheduled running prefill by the scheduler.
+ # A request could be running a new prefill on the prompt tokens or
+ # a resumed prefill on the original prefill tokens + generated output
+ # tokens before preemption.
+ PREFILLING = 3
+ # Preempted by the scheduler.
+ PREEMPTED = 4
+ # Output token is generated by the engine core.
+ DECODING = 5
+ # Token detokenized by the detokenizer.
+ # We will record the timestamp for each output token, as well as the
+ # finish reason.
+ DETOKENIZED = 6
+ # Request finishes (or aborts).
+ FINISHED = 7
+
+ """
+ Valid state updates:
+ ARRIVED
+ │
+ ├──────► INPUT_PROCESSED ──────► QUEUED ──────► PREFILLING ◄────┐
+ │ │ │ │ │
+ │ │ │ ▼ │
+ │ │ │ -──► DECODING │
+ │ │ │ | │ │
+ │ │ │ | ▼ │
+ │ │ │ └─ DETOKENIZED │
+ │ │ │ │ │
+ │ │ │ ▼ │
+ │ ▼ ▼ PREEMPTED ◄──────┘
+ │ │ │ │
+ └──────────────┴───────────────────┴──────────────┴
+ │
+ ▼
+ FINISHED (All could go to FINISHED)
+ """
+ _VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = {
+ Type.ARRIVED: {
+ Type.INPUT_PROCESSED,
+ Type.FINISHED,
+ },
+ Type.INPUT_PROCESSED: {
+ Type.QUEUED,
+ Type.FINISHED,
+ },
+ Type.QUEUED: {
+ Type.PREFILLING,
+ Type.FINISHED,
+ },
+ Type.PREFILLING: {
+ Type.DECODING,
+ Type.PREEMPTED,
+ Type.FINISHED,
+ },
+ Type.DECODING: {
+ Type.DETOKENIZED,
+ Type.FINISHED,
+ },
+ Type.DETOKENIZED: {
+ Type.DECODING,
+ Type.PREEMPTED,
+ Type.FINISHED,
+ },
+ Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED},
+ Type.FINISHED: set(),
+ }
+
+ request_id: str
+
+ type: Type
+
+ # Timestamp when the update is recorded. This is used to record time
+ # intervals between events rather than wall clock time.
+ monotonic_ts_s: float = msgspec_field(
+ default_factory=lambda: time.monotonic())
+
+ ############################################################
+ # Metadata associated with the update.
+ ############################################################
+ # For input_processed. Metadata needed for stats logging.
+ num_prompt_tokens: Optional[int] = None
+ sampling_params: Optional[SamplingParams] = None
+
+ # For running.
+ # Number of tokens computed when scheduled to run.
+ num_computed_tokens: Optional[int] = None
+ # Number of cached tokens when scheduled to run.
+ num_cached_tokens: Optional[int] = None
+
+ # For decoded.
+ # The number of new output tokens generated.
+ num_new_tokens: Optional[int] = None
+
+ # For both detokenized and decoded.
+ # Finished reason.
+ finish_reason: Optional[str] = None
+
+ # Non-optional fields for each update type.
+ _REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = {
+ Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"],
+ Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"],
+ Type.DETOKENIZED: ["num_new_tokens"],
+ Type.FINISHED: ["finish_reason"],
+ }
+
+ def __post_init__(self):
+ required_fields = self._REQUIRED_FIELDS.get(self.type, [])
+ for field in required_fields:
+ if getattr(self, field) is None:
+ raise ValueError(
+ f"Field {field} is required for update type {self.type}.")
+
+ @staticmethod
+ def check_valid_update(
+ update: "RequestStatsUpdate",
+ last_update_type: Optional[Type],
+ last_updated_ts_s: Optional[float],
+ ):
+ if last_update_type is None:
+ assert update.type == RequestStatsUpdate.Type.ARRIVED
+ else:
+ valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[
+ last_update_type]
+ assert update.type in valid_cur_update_types, (
+ f"Invalid update type: {update.type} for last_update_type: "
+ f"{last_update_type}.")
+
+ if last_updated_ts_s is not None:
+ assert update.monotonic_ts_s >= last_updated_ts_s, (
+ "Update timestamp must be monotonically increasing, but "
+ f"last_updated_ts_s={last_updated_ts_s} and "
+ f"update.monotonic_ts_s={update.monotonic_ts_s}.")
+
+
+@dataclass
+class RequestStats:
+ """Stats associated with a request (`Request`)."""
+
+ ############################################################
+ # Metadata
+ ############################################################
+ request_id: str
+ sampling_params: Optional[SamplingParams] = None
+ num_prompt_tokens: Optional[int] = None
+
+ ############################################################
+ # Metrics and Stats
+ ############################################################
+ # Timestamp when the request was last updated.
+ last_updated_ts_s: Optional[float] = None
+
+ # Last update stats type.
+ last_update_type: Optional[RequestStatsUpdate.Type] = None
+
+ # Timestamp when the request arrived at the llm engine.
+ arrival_ts_s: Optional[float] = None
+
+ # Number of tokens cached. When part of the request prefix is cached,
+ # this will be set.
+ num_cached_tokens: int = 0
+
+ # Number of tokens computed.
+ num_computed_tokens: int = 0
+
+ # The timestamp when the request become waiting in the queue.
+ queued_ts_s: Optional[float] = None
+
+ # When the input processor is completed.
+ input_processor_end_ts_s: Optional[float] = None
+
+ # A sorted list of timestamps when the request was scheduled to prefill.
+ # This could be when:
+ # 1. the request is newly scheduled, so it's a new prefill.
+ # 2. the request was preempted and resumed. It is equivalent to running
+ # a prefill of the original prefill tokens + generated output tokens
+ # before preemption.
+ prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list)
+
+ # A list of timestamps when a token is decoded by the engine core.
+ decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list)
+
+ # A sorted list of timestamps for each output token.
+ output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list)
+
+ # First token's timestamp.
+ first_token_ts_s: Optional[float] = None
+
+ # TODO(rickyx): we need model runner to surface these.
+ model_forward_duration_s: float = 0.0
+ # Includes model forward, block/sync across workers, cpu-gpu sync time
+ # and sampling time.
+ model_execute_duration_s: float = 0.0
+
+ # A sorted list of timestamps when the request was preempted at the
+ # scheduler.
+ # TODO(rickyx): right now, we don't actually have a good high-level
+ # metric to measure the impact of preemption other than observation of
+ # large P99 TPOT. Ideally we could quantify the impact of preemption by
+ # measuring the number of tokens re-computed due to preemption.
+ preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list)
+
+ # Timestamp when the request was finished at the engine core.
+ finished_ts_s: Optional[float] = None
+
+ # Finish reason.
+ finish_reason: Optional[str] = None
+
+ ############################################################
+ # Derived properties.
+ ############################################################
+ @property
+ def prefill_ts_s(self) -> Optional[float]:
+ """The timestamp when the request started prefilling.
+ Since a request could be preempted in decoding and later resumed
+ to prefill the decoded tokens, we use the first prefill start timestamp.
+ """
+ return (self.prefill_start_ts_s_lst[0]
+ if self.prefill_start_ts_s_lst else None)
+
+ @property
+ def e2e_latency_s(self) -> Optional[float]:
+ if self.finished_ts_s is None or self.arrival_ts_s is None:
+ return None
+ assert self.finished_ts_s >= self.arrival_ts_s
+ return self.finished_ts_s - self.arrival_ts_s
+
+ @property
+ def queue_duration_s(self) -> Optional[float]:
+ """How long the request was waiting to run."""
+ if self.queued_ts_s is None or self.prefill_ts_s is None:
+ # Either not queued or not running yet.
+ return None
+ assert self.queued_ts_s <= self.prefill_ts_s
+ return self.prefill_ts_s - self.queued_ts_s
+
+ @property
+ def inference_latency_s(self) -> Optional[float]:
+ """How long the request was running inference
+ (prefill and decode)."""
+ if self.finished_ts_s is None or self.prefill_ts_s is None:
+ return None
+ assert self.finished_ts_s >= self.prefill_ts_s
+ return self.finished_ts_s - self.prefill_ts_s
+
+ @property
+ def first_token_latency_s(self) -> Optional[float]:
+ if self.first_token_ts_s is None or self.arrival_ts_s is None:
+ return None
+ assert self.first_token_ts_s >= self.arrival_ts_s
+ return self.first_token_ts_s - self.arrival_ts_s
+
+ @property
+ def prefill_latency_s(self) -> Optional[float]:
+ if self.first_token_ts_s is None or self.prefill_ts_s is None:
+ return None
+ assert self.first_token_ts_s >= self.prefill_ts_s
+ return self.first_token_ts_s - self.prefill_ts_s
+
+ @property
+ def decode_latency_s(self) -> Optional[float]:
+ if self.e2e_latency_s is None or self.first_token_latency_s is None:
+ return None
+ assert self.e2e_latency_s >= self.first_token_latency_s
+ return self.e2e_latency_s - self.first_token_latency_s
+
+ @property
+ def output_token_latency_s_lst(self) -> List[float]:
+ if len(self.output_token_ts_s_lst) == 0:
+ return []
+ latency_s_lst = []
+ for i in range(1, len(self.output_token_ts_s_lst)):
+ assert (self.output_token_ts_s_lst[i]
+ >= self.output_token_ts_s_lst[i - 1])
+ latency_s = (self.output_token_ts_s_lst[i] -
+ self.output_token_ts_s_lst[i - 1])
+ latency_s_lst.append(latency_s)
+ return latency_s_lst
+
+ @property
+ def num_output_tokens(self) -> int:
+ return len(self.output_token_ts_s_lst)
+
+ @property
+ def is_finished(self) -> bool:
+ return self.finished_ts_s is not None
+
+ def update_from(self, update: "RequestStatsUpdate"):
+ RequestStatsUpdate.check_valid_update(update, self.last_update_type,
+ self.last_updated_ts_s)
+ ts = update.monotonic_ts_s
+ self.last_updated_ts_s = ts
+ self.last_update_type = update.type
+ if update.type == RequestStatsUpdate.Type.ARRIVED:
+ self.arrival_ts_s = ts
+ elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED:
+ self.input_processor_end_ts_s = ts
+ self.sampling_params = update.sampling_params
+ self.num_prompt_tokens = update.num_prompt_tokens
+ elif update.type == RequestStatsUpdate.Type.QUEUED:
+ self.queued_ts_s = ts
+ elif update.type == RequestStatsUpdate.Type.PREFILLING:
+ self.prefill_start_ts_s_lst.append(ts)
+ self.num_cached_tokens = update.num_cached_tokens or 0
+ self.num_computed_tokens = update.num_computed_tokens or 0
+ elif update.type == RequestStatsUpdate.Type.PREEMPTED:
+ self._reset_for_preemption(ts)
+ elif update.type == RequestStatsUpdate.Type.DECODING:
+ self.decoding_ts_s_lst.append(ts)
+ elif update.type == RequestStatsUpdate.Type.DETOKENIZED:
+ self._record_detokenized_output(
+ ts,
+ update.num_new_tokens or 0,
+ )
+ elif update.type == RequestStatsUpdate.Type.FINISHED:
+ self.finished_ts_s = ts
+ self.finish_reason = update.finish_reason
+ else:
+ raise ValueError(f"Unknown update type: {update.type}")
+
+ def _record_detokenized_output(
+ self,
+ ts_s: float,
+ num_new_tokens: int,
+ ):
+ # Update if first output token is generated.
+ if len(self.output_token_ts_s_lst) == 0:
+ self.first_token_ts_s = ts_s
+ assert (
+ self.prefill_ts_s is not None
+ ), "Request must be running before generating output tokens."
+
+ # Some X new tokens were generated at the ts.
+ self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens)
+
+ def _reset_for_preemption(self, ts_s: float):
+ self.preempted_ts_s_lst.append(ts_s)
+ # Reset the computed tokens since it might restart the prefill.
+ self.num_computed_tokens = 0
+ # Cached token count might also change when resumed.
+ self.num_cached_tokens = 0
+ # These stats don't change since they happen before request running.
+ # - arrival_ts_s
+ # - input_processor_end_ts_s
+ # - sampling_params
+ # - num_prompt_tokens
+ # - first_token_ts_s
+ #
+ # These stats are accumulated over preemptions:
+ # - output_token_ts_s_lst
+ # - prefill_start_ts_s_lst (after preemption, it will prefill the
+ # original prefill tokens and any output tokens generated before
+ # preemption.)
+
+
+@dataclass
+class KVCacheStats:
+ # KV Cache Usage in %
+ gpu_cache_usage_sys: float = 0.0
+ gpu_prefix_cache_hit_rate: float = 0.0
+
+
+@dataclass
+class SchedulerStats:
+ """Stats associated with the scheduler."""
+
+ # Number of requests currently running.
+ num_running_reqs: int = 0
+ # Number of requests currently waiting.
+ num_waiting_reqs: int = 0
+
+ kv_cache_stats: KVCacheStats = dataclass_field(
+ default_factory=KVCacheStats)
+
+
+@dataclass
+class EngineCoreProcessStats:
+ """Stats associated with the engine core process."""
+
+ # Number of requests currently in the input queue. None if the engine core
+ # is not running in multiprocess mode.
+ input_queue_size: Optional[int] = None
+ # Number of outputs currently in the output queue. None if the engine core
+ # is not running in multiprocess mode.
+ output_queue_size: Optional[int] = None
+
+
+class EngineCoreStatsSnapshot(
+ msgspec.Struct, # type: ignore
+ array_like=True,
+ omit_defaults=True,
+ gc=False):
+ """
+ A snapshot of the EngineCore's current stats over a period of time.
+ """
+
+ # Snapshot of the scheduler stats.
+ scheduler_stats: SchedulerStats = msgspec_field(
+ default_factory=SchedulerStats)
+
+ # Per request stats updates.
+ requests_stats_updates: List[RequestStatsUpdate] = msgspec_field(
+ default_factory=list)
+
+ # Engine core's queue stats.
+ engine_core_process_stats: EngineCoreProcessStats = msgspec_field(
+ default_factory=EngineCoreProcessStats)
+
+ # TODO(rickyx): Add other components' stats,
+ # e.g. model runner/worker and etc.
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/utils.py b/.venv/lib/python3.11/site-packages/vllm/v1/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5494542c181d7843db9cbdf9051a1ad55229ae9f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/utils.py
@@ -0,0 +1,190 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import multiprocessing
+import os
+import weakref
+from collections import defaultdict
+from collections.abc import Sequence
+from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List,
+ Optional, TypeVar, Union, overload)
+
+import torch
+
+from vllm.logger import init_logger
+from vllm.model_executor.models.utils import extract_layer_index
+from vllm.utils import get_mp_context, kill_process_tree
+
+if TYPE_CHECKING:
+ from vllm.attention.layer import Attention
+
+logger = init_logger(__name__)
+
+T = TypeVar("T")
+
+
+class ConstantList(Generic[T], Sequence):
+
+ def __init__(self, x: List[T]) -> None:
+ self._x = x
+
+ def append(self, item):
+ raise Exception("Cannot append to a constant list")
+
+ def extend(self, item):
+ raise Exception("Cannot extend a constant list")
+
+ def insert(self, item):
+ raise Exception("Cannot insert into a constant list")
+
+ def pop(self, item):
+ raise Exception("Cannot pop from a constant list")
+
+ def remove(self, item):
+ raise Exception("Cannot remove from a constant list")
+
+ def clear(self):
+ raise Exception("Cannot clear a constant list")
+
+ def index(self,
+ item: T,
+ start: int = 0,
+ stop: Optional[int] = None) -> int:
+ return self._x.index(item, start,
+ stop if stop is not None else len(self._x))
+
+ @overload
+ def __getitem__(self, item: int) -> T:
+ ...
+
+ @overload
+ def __getitem__(self, s: slice, /) -> List[T]:
+ ...
+
+ def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
+ return self._x[item]
+
+ @overload
+ def __setitem__(self, item: int, value: T):
+ ...
+
+ @overload
+ def __setitem__(self, s: slice, value: T, /):
+ ...
+
+ def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
+ raise Exception("Cannot set item in a constant list")
+
+ def __delitem__(self, item):
+ raise Exception("Cannot delete item from a constant list")
+
+ def __iter__(self):
+ return iter(self._x)
+
+ def __contains__(self, item):
+ return item in self._x
+
+ def __len__(self):
+ return len(self._x)
+
+
+class BackgroundProcHandle:
+ """
+ Utility class to handle creation, readiness, and shutdown
+ of background processes used by the AsyncLLM and LLMEngine.
+ """
+
+ def __init__(
+ self,
+ input_path: str,
+ output_path: str,
+ process_name: str,
+ target_fn: Callable,
+ process_kwargs: Dict[Any, Any],
+ ):
+ context = get_mp_context()
+ reader, writer = context.Pipe(duplex=False)
+
+ assert ("ready_pipe" not in process_kwargs
+ and "input_path" not in process_kwargs
+ and "output_path" not in process_kwargs)
+ process_kwargs["ready_pipe"] = writer
+ process_kwargs["input_path"] = input_path
+ process_kwargs["output_path"] = output_path
+
+ # Run busy loop in background process.
+ self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
+ self._finalizer = weakref.finalize(self, shutdown, self.proc,
+ input_path, output_path)
+ self.proc.start()
+
+ # Wait for startup.
+ if reader.recv()["status"] != "READY":
+ raise RuntimeError(f"{process_name} initialization failed. "
+ "See root cause above.")
+
+ def shutdown(self):
+ self._finalizer()
+
+
+# Note(rob): shutdown function cannot be a bound method,
+# else the gc cannot collect the object.
+def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
+ # Shutdown the process.
+ if proc.is_alive():
+ proc.terminate()
+ proc.join(5)
+
+ if proc.is_alive():
+ kill_process_tree(proc.pid)
+
+ # Remove zmq ipc socket files.
+ ipc_sockets = [output_path, input_path]
+ for ipc_socket in ipc_sockets:
+ socket_file = ipc_socket.replace("ipc://", "")
+ if os and os.path.exists(socket_file):
+ os.remove(socket_file)
+
+
+def bind_kv_cache(
+ kv_caches: Dict[str, torch.Tensor],
+ forward_context: Dict[str, "Attention"],
+ runner_kv_caches: List[torch.Tensor],
+) -> None:
+ """
+ Bind the allocated KV cache to both ModelRunner and forward context so
+ that the KV cache can be used in the forward pass.
+
+ This function:
+ 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
+ kv_caches.
+ 2) Associates each attention layer in the `forward_context` with its
+ corresponding KV cache in kv_caches.
+
+ Args:
+ kv_caches: The allocated kv_caches with layer names as keys.
+ forward_context: The global forward context containing all Attention
+ layers with layer names as keys.
+ runner_kv_caches: The kv_cache declared by ModelRunner.
+ """
+ # Bind kv_caches to ModelRunner
+ assert len(runner_kv_caches) == 0
+
+ # Convert kv_caches dict to a list of tensors in the order of layer_index.
+ index2name = defaultdict(list)
+ for layer_name in kv_caches:
+ index2name[extract_layer_index(layer_name)].append(layer_name)
+
+ for layer_index in sorted(index2name.keys()):
+ layer_names = index2name[layer_index]
+ if len(layer_names) > 1:
+ # One typical case is encoder-decoder model, e.g., bart.
+ # The cross attention and self attention in the same decoder layer
+ # has different layer_name but the same layer_index.
+ raise NotImplementedError
+ layer_name = layer_names[0]
+ runner_kv_caches.append(kv_caches[layer_name])
+
+ # Bind kv_caches to forward context
+ for layer_name, kv_cache in kv_caches.items():
+ # NOTE: Use list because of v0 PP virtual engine.
+ forward_context[layer_name].kv_cache = [kv_cache]
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/worker/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/worker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/worker/block_table.py b/.venv/lib/python3.11/site-packages/vllm/v1/worker/block_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..f520ee9586c5c909a303d9470c8f16627480e68b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/worker/block_table.py
@@ -0,0 +1,82 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List
+
+import numpy as np
+import torch
+
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class BlockTable:
+
+ def __init__(
+ self,
+ max_num_reqs: int,
+ max_model_len: int,
+ max_num_blocks_per_req: int,
+ pin_memory: bool,
+ device: torch.device,
+ ):
+ self.max_num_reqs = max_num_reqs
+ self.max_model_len = max_model_len
+ self.max_num_blocks_per_req = max_num_blocks_per_req
+ self.pin_memory = pin_memory
+ self.device = device
+
+ self.block_table = torch.zeros(
+ (max_num_reqs, max_num_blocks_per_req),
+ device=self.device,
+ dtype=torch.int32,
+ )
+ self.block_table_cpu = torch.zeros(
+ (max_num_reqs, max_num_blocks_per_req),
+ device="cpu",
+ dtype=torch.int32,
+ pin_memory=pin_memory,
+ )
+ self.block_table_np = self.block_table_cpu.numpy()
+ self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
+
+ def append_row(
+ self,
+ row_idx: int,
+ start: int,
+ block_ids: List[int],
+ ) -> None:
+ if not block_ids:
+ return
+ num_blocks = len(block_ids)
+ self.block_table_np[row_idx, start:start + num_blocks] = block_ids
+ self.num_blocks_per_row[row_idx] = start + num_blocks
+
+ def add_row(self, row_idx: int, block_ids: List[int]) -> None:
+ self.append_row(row_idx, 0, block_ids)
+
+ def move_row(self, src: int, tgt: int) -> None:
+ num_blocks = self.num_blocks_per_row[src]
+ self.block_table_np[tgt, :num_blocks] = self.block_table_np[
+ src, :num_blocks]
+ self.num_blocks_per_row[tgt] = num_blocks
+
+ def commit(self, num_reqs: int) -> None:
+ self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
+ non_blocking=True)
+
+ def clear(self) -> None:
+ self.block_table.fill_(0)
+ self.block_table_cpu.fill_(0)
+
+ def get_device_tensor(self) -> torch.Tensor:
+ """Ruturns the device tensor of the block table."""
+ return self.block_table
+
+ def get_cpu_tensor(self) -> torch.Tensor:
+ """Returns the CPU tensor of the block table."""
+ return self.block_table_cpu
+
+ def get_numpy_array(self) -> np.ndarray:
+ """Returns the numpy array of the block table."""
+ return self.block_table_np
diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_input_batch.py b/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_input_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..39708f833fd58340a160eef20150c977ce17506f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_input_batch.py
@@ -0,0 +1,440 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Datastructures defining an input batch
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Set
+
+import numpy as np
+import torch
+
+from vllm.multimodal import MultiModalKwargs
+from vllm.sampling_params import SamplingParams, SamplingType
+from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.worker.block_table import BlockTable
+
+if TYPE_CHECKING:
+ from vllm.multimodal.inputs import PlaceholderRange
+
+
+@dataclass
+class CachedRequestState:
+
+ req_id: str
+ prompt_token_ids: List[int]
+ prompt: Optional[str]
+ mm_inputs: List[MultiModalKwargs]
+ mm_positions: List["PlaceholderRange"]
+ sampling_params: SamplingParams
+ generator: Optional[torch.Generator]
+
+ block_ids: List[int]
+ num_computed_tokens: int
+ output_token_ids: List[int]
+
+ mrope_positions: Optional[torch.Tensor] = None
+ mrope_position_delta: Optional[int] = None
+
+ @property
+ def num_tokens(self) -> int:
+ return len(self.prompt_token_ids) + len(self.output_token_ids)
+
+
+class InputBatch:
+
+ def __init__(
+ self,
+ max_num_reqs: int,
+ max_model_len: int,
+ max_num_blocks_per_req: int,
+ device: torch.device,
+ pin_memory: bool,
+ vocab_size: int,
+ ):
+ self.max_num_reqs = max_num_reqs
+ self.max_model_len = max_model_len
+ self.max_num_blocks_per_req = max_num_blocks_per_req
+ self.device = device
+ self.pin_memory = pin_memory
+ self.vocab_size = vocab_size
+
+ self.req_ids: List[Optional[str]] = [None] * max_num_reqs
+ self.req_id_to_index: Dict[str, int] = {}
+
+ # TODO(woosuk): This buffer could be too large if max_model_len is big.
+ # Find a way to reduce the CPU memory usage.
+ # This buffer is not directly transferred to the GPU, so it does not
+ # need to be pinned.
+ self.token_ids_cpu_tensor = torch.zeros(
+ (max_num_reqs, max_model_len),
+ device="cpu",
+ dtype=torch.int32,
+ pin_memory=False,
+ )
+ self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
+ self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
+ self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
+ self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
+
+ # Block table.
+ self.block_table = BlockTable(
+ max_num_reqs=max_num_reqs,
+ max_model_len=max_model_len,
+ max_num_blocks_per_req=max_num_blocks_per_req,
+ pin_memory=pin_memory,
+ device=device,
+ )
+
+ # Sampling-related.
+ self.temperature = torch.empty((max_num_reqs, ),
+ dtype=torch.float32,
+ device=device)
+ self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
+ dtype=torch.float32,
+ device="cpu",
+ pin_memory=pin_memory)
+ self.temperature_cpu = self.temperature_cpu_tensor.numpy()
+ self.greedy_reqs: Set[str] = set()
+ self.random_reqs: Set[str] = set()
+
+ self.top_p = torch.empty((max_num_reqs, ),
+ dtype=torch.float32,
+ device=device)
+ self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
+ dtype=torch.float32,
+ device="cpu",
+ pin_memory=pin_memory)
+ self.top_p_cpu = self.top_p_cpu_tensor.numpy()
+ self.top_p_reqs: Set[str] = set()
+
+ self.top_k = torch.empty((max_num_reqs, ),
+ dtype=torch.int32,
+ device=device)
+ self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=pin_memory)
+ self.top_k_cpu = self.top_k_cpu_tensor.numpy()
+ self.top_k_reqs: Set[str] = set()
+
+ # Frequency penalty related data structures
+ self.frequency_penalties = torch.empty((max_num_reqs, ),
+ dtype=torch.float,
+ device=device)
+ self.frequency_penalties_cpu_tensor = torch.empty(
+ (max_num_reqs, ),
+ dtype=torch.float,
+ device="cpu",
+ pin_memory=pin_memory)
+ self.frequency_penalties_cpu = \
+ self.frequency_penalties_cpu_tensor.numpy()
+ self.frequency_penalties_reqs: Set[str] = set()
+
+ # Presence penalty related data structures
+ self.presence_penalties = torch.empty((max_num_reqs, ),
+ dtype=torch.float,
+ device=device)
+ self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
+ dtype=torch.float,
+ device="cpu",
+ pin_memory=pin_memory)
+ self.presence_penalties_cpu = \
+ self.presence_penalties_cpu_tensor.numpy()
+ self.presence_penalties_reqs: Set[str] = set()
+
+ # Repetition penalty related data structures
+ self.repetition_penalties = torch.empty((max_num_reqs, ),
+ dtype=torch.float,
+ device=device)
+ self.repetition_penalties_cpu_tensor = torch.empty(
+ (max_num_reqs, ),
+ dtype=torch.float,
+ device="cpu",
+ pin_memory=pin_memory)
+ self.repetition_penalties_cpu = \
+ self.repetition_penalties_cpu_tensor.numpy()
+ self.repetition_penalties_reqs: Set[str] = set()
+
+ self.min_tokens: List[int] = [0] * max_num_reqs
+ self.stop_token_ids: List[Set[int]] = [
+ set() for _ in range(max_num_reqs)
+ ]
+ self.prompt_token_ids: Optional[torch.Tensor] = None
+
+ # req_index -> generator
+ # NOTE(woosuk): The indices of the requests that do not have their own
+ # generator should not be included in the dictionary.
+ self.generators: Dict[int, torch.Generator] = {}
+
+ self.num_logprobs: Dict[str, int] = {}
+ self.prompt_logprob_reqs: Set[str] = set()
+
+ def add_request(
+ self,
+ request: "CachedRequestState",
+ req_index: Optional[int] = None,
+ ) -> None:
+ if req_index is None:
+ req_index = self.num_reqs
+ assert req_index < self.max_num_reqs
+
+ req_id = request.req_id
+ self.req_ids[req_index] = req_id
+ self.req_id_to_index[req_id] = req_index
+
+ # Copy the prompt token ids and output token ids.
+ num_prompt_tokens = len(request.prompt_token_ids)
+ self.num_prompt_tokens[req_index] = num_prompt_tokens
+ self.token_ids_cpu[
+ req_index, :num_prompt_tokens] = request.prompt_token_ids
+ start_idx = num_prompt_tokens
+ end_idx = start_idx + len(request.output_token_ids)
+ self.token_ids_cpu[req_index,
+ start_idx:end_idx] = request.output_token_ids
+ self.num_tokens[req_index] = request.num_tokens
+
+ self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
+ self.block_table.add_row(req_index, request.block_ids)
+
+ sampling_params = request.sampling_params
+ self.temperature_cpu[req_index] = sampling_params.temperature
+ if sampling_params.sampling_type == SamplingType.GREEDY:
+ self.greedy_reqs.add(req_id)
+ else:
+ self.random_reqs.add(req_id)
+
+ self.top_p_cpu[req_index] = sampling_params.top_p
+ if sampling_params.top_p < 1:
+ self.top_p_reqs.add(req_id)
+ self.top_k_cpu[req_index] = sampling_params.top_k
+ if sampling_params.top_k > 0:
+ self.top_k_reqs.add(req_id)
+ self.frequency_penalties_cpu[req_index] = \
+ sampling_params.frequency_penalty
+ if sampling_params.frequency_penalty != 0.0:
+ self.frequency_penalties_reqs.add(req_id)
+ self.presence_penalties_cpu[req_index] = \
+ sampling_params.presence_penalty
+ if sampling_params.presence_penalty != 0.0:
+ self.presence_penalties_reqs.add(req_id)
+ self.repetition_penalties_cpu[req_index] = \
+ sampling_params.repetition_penalty
+ if sampling_params.repetition_penalty != 1.0:
+ self.repetition_penalties_reqs.add(req_id)
+ self.min_tokens[req_index] = sampling_params.min_tokens
+ self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
+
+ # NOTE(woosuk): self.generators should not include the requests that
+ # do not have their own generator.
+ if request.generator is not None:
+ self.generators[req_index] = request.generator
+
+ num_logprobs = sampling_params.logprobs
+ if num_logprobs is not None and num_logprobs > 0:
+ self.num_logprobs[req_id] = num_logprobs
+ if sampling_params.prompt_logprobs:
+ self.prompt_logprob_reqs.add(req_id)
+
+ def remove_request(self, req_id: str) -> Optional[int]:
+ req_index = self.req_id_to_index.pop(req_id, None)
+ if req_index is None:
+ return None
+ self.req_ids[req_index] = None
+
+ self.greedy_reqs.discard(req_id)
+ self.random_reqs.discard(req_id)
+ self.top_p_reqs.discard(req_id)
+ self.top_k_reqs.discard(req_id)
+ self.frequency_penalties_reqs.discard(req_id)
+ self.presence_penalties_reqs.discard(req_id)
+ self.repetition_penalties_reqs.discard(req_id)
+ self.generators.pop(req_index, None)
+ self.num_logprobs.pop(req_id, None)
+ self.prompt_logprob_reqs.discard(req_id)
+ return req_index
+
+ def clear(self) -> None:
+ self.req_ids = [None] * self.max_num_reqs
+ self.req_id_to_index.clear()
+ self.greedy_reqs.clear()
+ self.random_reqs.clear()
+ self.top_p_reqs.clear()
+ self.top_k_reqs.clear()
+ self.frequency_penalties_reqs.clear()
+ self.presence_penalties_reqs.clear()
+ self.repetition_penalties_reqs.clear()
+ self.generators.clear()
+ self.num_logprobs.clear()
+ self.prompt_logprob_reqs.clear()
+
+ def condense(self, empty_req_indices: List[int]) -> None:
+ if self.num_reqs == 0:
+ # The batched states are empty.
+ return
+
+ # NOTE(woosuk): This function assumes that the empty_req_indices
+ # is sorted in descending order.
+ last_req_index = self.num_reqs + len(empty_req_indices) - 1
+ while empty_req_indices:
+ # Find the largest non-empty index.
+ while last_req_index in empty_req_indices:
+ last_req_index -= 1
+
+ # Find the smallest empty index.
+ empty_index = empty_req_indices.pop()
+ if empty_index >= last_req_index:
+ break
+
+ # Swap the states.
+ req_id = self.req_ids[last_req_index]
+ assert req_id is not None
+ self.req_ids[empty_index] = req_id
+ self.req_ids[last_req_index] = None
+ self.req_id_to_index[req_id] = empty_index
+
+ num_tokens = self.num_tokens[last_req_index]
+ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
+ last_req_index, :num_tokens]
+ self.num_tokens[empty_index] = num_tokens
+ self.num_prompt_tokens[empty_index] = \
+ self.num_prompt_tokens[last_req_index]
+ self.num_computed_tokens_cpu[
+ empty_index] = self.num_computed_tokens_cpu[last_req_index]
+ self.block_table.move_row(last_req_index, empty_index)
+ self.temperature_cpu[empty_index] = self.temperature_cpu[
+ last_req_index]
+ self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
+ self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
+ self.frequency_penalties_cpu[empty_index] = \
+ self.frequency_penalties_cpu[last_req_index]
+ self.presence_penalties_cpu[empty_index] = \
+ self.presence_penalties_cpu[last_req_index]
+ self.repetition_penalties_cpu[empty_index] = \
+ self.repetition_penalties_cpu[last_req_index]
+ self.min_tokens[empty_index] = self.min_tokens[last_req_index]
+ self.stop_token_ids[empty_index] = \
+ self.stop_token_ids[last_req_index]
+ generator = self.generators.pop(last_req_index, None)
+ if generator is not None:
+ self.generators[empty_index] = generator
+
+ # Decrement last_req_index since it is now empty.
+ last_req_index -= 1
+
+ def make_sampling_metadata(
+ self,
+ req_id_output_token_ids: Dict[str, List[int]],
+ skip_copy: bool = False,
+ ) -> SamplingMetadata:
+ if not skip_copy:
+ self.temperature[:self.num_reqs].copy_(
+ self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
+ self.top_p[:self.num_reqs].copy_(
+ self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
+ self.top_k[:self.num_reqs].copy_(
+ self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
+ if not self.no_penalties:
+ # Since syncing these tensors is expensive only copy them
+ # if necessary i.e. if there are requests which require
+ # penalties to be applied during sampling.
+ self.frequency_penalties[:self.num_reqs].copy_(
+ self.frequency_penalties_cpu_tensor[:self.num_reqs],
+ non_blocking=True)
+ self.presence_penalties[:self.num_reqs].copy_(
+ self.presence_penalties_cpu_tensor[:self.num_reqs],
+ non_blocking=True)
+ self.repetition_penalties[:self.num_reqs].copy_(
+ self.repetition_penalties_cpu_tensor[:self.num_reqs],
+ non_blocking=True)
+ # The prompt tokens are used only for applying penalties during
+ # the sampling process. Hence copy these tensors only when
+ # there are requests which need penalties to be applied.
+ self.prompt_token_ids = self._make_prompt_token_ids_tensor()
+
+ output_token_ids: List[List[int]] = []
+
+ for req_id in self.req_ids[:self.num_reqs]:
+ assert req_id is not None
+ # Currently we create a tensor for output_token_ids from scratch
+ # at each step. However, for the penalties computation what we
+ # need is stats about the token ids present in the output. This
+ # stats can be maintained incrementally instead of computing it
+ # from scratch at each step.
+ # TODO - Replace this with incremental update to output token
+ # statistics.
+ output_token_ids.append(req_id_output_token_ids[req_id])
+
+ return SamplingMetadata(
+ temperature=self.temperature[:self.num_reqs],
+ all_greedy=self.all_greedy,
+ all_random=self.all_random,
+ top_p=self.top_p[:self.num_reqs],
+ top_k=self.top_k[:self.num_reqs],
+ no_top_p=self.no_top_p,
+ no_top_k=self.no_top_k,
+ generators=self.generators,
+ max_num_logprobs=self.max_num_logprobs,
+ prompt_token_ids=self.prompt_token_ids,
+ frequency_penalties=self.frequency_penalties[:self.num_reqs],
+ presence_penalties=self.presence_penalties[:self.num_reqs],
+ repetition_penalties=self.repetition_penalties[:self.num_reqs],
+ output_token_ids=output_token_ids,
+ min_tokens=self.min_tokens[:self.num_reqs],
+ stop_token_ids=self.stop_token_ids[:self.num_reqs],
+ no_penalties=self.no_penalties,
+ )
+
+ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
+ max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
+ prompt_token_ids_cpu_tensor = torch.empty(
+ (self.num_reqs, max_prompt_len),
+ device="cpu",
+ dtype=torch.int64,
+ pin_memory=self.pin_memory)
+ prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
+ prompt_token_ids[:] = (
+ self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
+ # Use the value of vocab_size as a pad since we don't have a
+ # token_id of this value.
+ for i in range(self.num_reqs):
+ prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
+ return prompt_token_ids_cpu_tensor.to(device=self.device,
+ non_blocking=True)
+
+ @property
+ def num_reqs(self) -> int:
+ return len(self.req_id_to_index)
+
+ @property
+ def all_greedy(self) -> bool:
+ return len(self.random_reqs) == 0
+
+ @property
+ def all_random(self) -> bool:
+ return len(self.greedy_reqs) == 0
+
+ @property
+ def no_top_p(self) -> bool:
+ return len(self.top_p_reqs) == 0
+
+ @property
+ def no_top_k(self) -> bool:
+ return len(self.top_k_reqs) == 0
+
+ @property
+ def no_penalties(self) -> bool:
+ return (len(self.presence_penalties_reqs) == 0
+ and len(self.frequency_penalties_reqs) == 0
+ and len(self.repetition_penalties_reqs) == 0)
+
+ @property
+ def max_num_logprobs(self) -> int:
+ return max(self.num_logprobs.values()) if self.num_logprobs else 0
+
+ @property
+ def no_logprob(self) -> bool:
+ return len(self.num_logprobs) == 0
+
+ @property
+ def no_prompt_logprob(self) -> bool:
+ return len(self.prompt_logprob_reqs) == 0