diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80354d69b50afe64a38df6532af78e10cc6858fa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser + +__all__ = [ + "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser" +] diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a288d2682124cf8137c644ea5d91b8bbc217d9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e1fee820f81efd899377c804f6ca490c44f7c84 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/abs_reasoning_parsers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed067b2e1de1e3c8853a6eb0680be94ec84945a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/__pycache__/deepseek_r1_reasoning_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..b5df7e47446b7acc74cd77c4f905f87a95581716 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from functools import cached_property +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import import_from_path, is_list_of + +logger = init_logger(__name__) + + +class ReasoningParser: + """ + Abstract reasoning parser class that should not be used directly. + Provided and methods should be used in derived classes. + + It is used to extract reasoning content from the model output. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> Dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from a complete model-generated string. + + Used for non-streaming responses where we have the entire model response + available before sending to the client. + + Parameters: + model_output: str + The model-generated string to extract reasoning content from. + + request: ChatCompletionRequest + The request object that was used to generate the model_output. + + Returns: + Tuple[Optional[str], Optional[str]] + A tuple containing the reasoning content and the content. + """ + + raise NotImplementedError( + "AbstractReasoningParser.extract_reasoning_calls " + "has not been implemented!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting reasoning + from an incomplete response; for use when handling reasoning calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractReasoningParser.extract_reasoning_content_streaming " + "has not been implemented!") + + +class ReasoningParserManager: + reasoning_parsers: Dict[str, Type] = {} + + @classmethod + def get_reasoning_parser(cls, name) -> Type: + """ + Get reasoning parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.reasoning_parsers: + return cls.reasoning_parsers[name] + + raise KeyError(f"reasoning helper: '{name}' not found in " + "reasoning_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ReasoningParser): + raise TypeError("module must be subclass of ReasoningParser, " + f"but got {type(module)}") + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.reasoning_parsers: + existed_module = cls.reasoning_parsers[name] + raise KeyError(f"{name} is already registered " + f"at {existed_module.__module__}") + cls.reasoning_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + "name must be None, an instance of str, or a sequence of str, " + f"but got {type(name)}") + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_reasoning_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined reasoning parser by the path + of the reasoning parser define file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, plugin_path) + return diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5c19888d4540137fb7d07150720b0ad5f5e849b9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Optional, Sequence, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( + ReasoningParser, ReasoningParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("deepseek_r1") +class DeepSeekR1ReasoningParser(ReasoningParser): + """ + Reasoning parser for DeepSeek R1 model. + + The DeepSeek R1 model uses ... tokens to denote reasoning + text. This parser extracts the reasoning content from the model output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_token = "" + self.think_end_token = "" + + self.reasoning_regex = re.compile( + rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + if (self.think_start_token_id is None + or self.think_end_token_id is None): + raise RuntimeError( + "DeepSeek R1 reasoning parser could not locate think start/end " + "tokens in the tokenizer!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.think_start_token_id, self.think_end_token_id + ]): + return None + + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.think_start_token_id in delta_token_ids: + logger.info(delta_text) + if self.think_end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[start_index + + len(self.think_start_token + ):end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # No in previous or delta, reasoning content continues. + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + + # Check if the model output contains the tokens. + if (self.think_start_token not in model_output + or self.think_end_token not in model_output): + return None, model_output + else: + # Use a regex to find the reasoning content + reasoning_content = self.reasoning_regex.findall(model_output)[0] + + # Remove the reasoning content from the model output + # Although deepseek's token is always at the + # beginning of the line, we cannot guarantee that the + # other models will follow this convention. + # Therefore, we need to add :start_index. + start_index = model_output.find(self.think_start_token) + if start_index != -1: + end_index = start_index + len( + f"{self.think_start_token}{reasoning_content}{self.think_end_token}" + ) + model_output = model_output[:start_index] + \ + model_output[end_index:] + + if len(model_output) == 0: + return reasoning_content, None + + return reasoning_content, model_output diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..002bf173883086f80bedcd61477ce9a0501e28fc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from json import JSONDecoder +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite-20b-fc") +class Granite20bFCToolParser(ToolParser): + """ + Tool call parser for the granite-20b-functioncalling model intended + for use with the examples/tool_chat_template_granite20b_fc.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.bot_token = "" + self.tool_start_token = self.bot_token + self.tool_call_regex = re.compile(r"\s*") + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + if self.tool_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + dec = JSONDecoder() + try: + matches = list(self.tool_call_regex.finditer(model_output)) + logger.debug("Found %d tool call matches", len(matches)) + + raw_function_calls = [] + + for i, match in enumerate(matches): + # position after the tag + start_of_json = match.end() + # end_index == the start of the next function call + # (if exists) + next_function_call_start = (matches[i + 1].start() if i + + 1 < len(matches) else None) + + raw_function_calls.append( + dec.raw_decode( + model_output[start_of_json:next_function_call_start]) + [0]) + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]), + ), + ) for function_call in raw_function_calls + ] + + content = model_output[:model_output.find(self.bot_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if len(current_text) < len( + self.bot_token) and self.bot_token.startswith(current_text): + return None + + if not current_text.startswith(self.bot_token): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + start_idx = len(self.bot_token) + start_idx = consume_space(start_idx, current_text) + + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + start_idx = consume_space(start_idx, current_text) + start_idx += len(self.bot_token) + start_idx = consume_space(start_idx, current_text) + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c948ed78f503b9bb9f760463846e5d459f11c21b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite") +class GraniteToolParser(ToolParser): + """ + Tool call parser for the granite 3.0 models. Intended + for use with the examples/tool_chat_template_granite.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + # for granite 3.0, the token `<|tool_call|>` + self.bot_token = "<|tool_call|>" + # for granite 3.1, the string `` + self.bot_string = "" + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + stripped = model_output.strip()\ + .removeprefix(self.bot_token)\ + .removeprefix(self.bot_string)\ + .lstrip() + if not stripped or stripped[0] != '[': + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + raw_function_calls = json.loads(stripped) + if not isinstance(raw_function_calls, list): + raise Exception( + f"Expected dict or list, got {type(raw_function_calls)}") + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]), + ), + ) for function_call in raw_function_calls + ] + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + start_idx = consume_space(0, current_text) + if current_text[start_idx:].startswith(self.bot_token): + start_idx = consume_space(start_idx + len(self.bot_token), + current_text) + if current_text[start_idx:].startswith(self.bot_string): + start_idx = consume_space(start_idx + len(self.bot_string), + current_text) + if not current_text or start_idx >= len(current_text)\ + or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = None + is_complete = None + try: + tool_calls, end_idx = partial_json_loads( + current_text[start_idx:], flags) + if type(tool_calls) is list: + tool_call_arr = tool_calls + else: + return DeltaMessage(content=delta_text) + + is_complete = [True] * len(tool_calls) + if not is_complete_json( + current_text[start_idx:start_idx + end_idx]): + is_complete[-1] = False + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if not tool_call_arr: + return None + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + delta = None + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + if len(tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4841b28703ee3beff672150f465577d57a17251b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("hermes") +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if (self.prev_tool_call_arr is None + or len(self.prev_tool_call_arr) == 0): + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = diff.encode('utf-8').decode( + 'unicode_escape') if diff is str else diff + if ('"}' not in delta_text): + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + except json.decoder.JSONDecodeError: + logger.debug("unable to parse JSON") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if (current_tool_call is None): + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + if (delta_text not in cur_arguments_json[:-2]): + return None + args_delta_start_loc = cur_arguments_json[:-2]. \ + rindex(delta_text) + \ + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if isinstance(delta_text, str) and len(delta_text.rstrip( + )) >= 1 and delta_text.rstrip()[-1] == '}': + delta_text = delta_text.rstrip()[:-1] + + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_text).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += delta_text + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b9215e7979bf534303ada53e5e7b8c9b54a89c08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_argments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + # if the tool call is sended, return a empty delta message + # to make sure the finish_reason will be send correctly. + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() + action = action.split('<|action_end|>'.strip())[0] + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls + try: + tool_call_arr: Dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.get_argments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_argments(tool_call_arr) + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = self.get_argments(tool_call_arr) + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {}))) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5c282b5c2605a6cc8a9c85a377cf7f31c8aab967 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ast +import json +import re +from typing import Any, Sequence, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("pythonic") +class PythonicToolParser(ToolParser): + """ + Tool call parser for models that produce tool calls in a pythonic style, + such as Llama 3.2 models. + + Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + if not (self.TOOL_CALL_REGEX.match(model_output)): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("["): + return DeltaMessage(content=delta_text) + + try: + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[Tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id="", index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..915fc6623398e2aa2ff67723aa3770d35b4aa1db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper + +__all__ = [ + "PunicaWrapperBase", + "get_punica_wrapper", +] diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fcac7ee56f284efe9c5f7b35c4c9780554e66b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49ef3b15f0064b8f6131b947e45659006426b81e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_cpu.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d37cbc3d89ee519acca8e7283a229c0d670745b6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_hpu.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35e257a075e948d1eb0a311ed593481a99a86c7f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_selector.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec82ad99ee725e3babf55d6bc4152029e2d406e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2282ae9accd5ca26691f64d4e4d7ca6dce0d49 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_base.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from .utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperABC(ABC): + """ + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs, + ) -> None: + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + @abstractmethod + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> None: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + offset = offset_start + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + offset_start (int): The starting position of y, defaults to 0 + add_inputs (bool): Defaults to True. + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + and this layer only requires the expand operation. + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..29428f4cfff3175e782618cbd16c727d56771798 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_cpu.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Optional, Tuple, Union + +import torch + +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperCPU(PunicaWrapperBase): + """ + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py new file mode 100644 index 0000000000000000000000000000000000000000..51e1bfab3f5136ab17732b2578d864fe3e0043d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_hpu.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, Union, final + +import torch +from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, + dispatch_bgmv_linear) + +from .punica_base import PunicaWrapperBase + + +@final +class PunicaWrapperHPU(PunicaWrapperBase): + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + # Increasing max_num_batched_tokens by 3x to handle increase in + # tensor size due to padding. + PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, + max_batches, device) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + for slice_idx in range(len(output_slices)): + dispatch_bgmv_linear( + y[:, offset_left:offset_left + output_slices[slice_idx]], x, + lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) + y = y.view_as(y_org) + + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + raise NotImplementedError + + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5d4b788ec435a970f35d1625e127e3d3812bcd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_selector.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname + +from .punica_base import PunicaWrapperBase + +logger = init_logger(__name__) + + +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + punica_wrapper_qualname = current_platform.get_punica_wrapper() + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) + punica_wrapper = punica_wrapper_cls(*args, **kwargs) + assert punica_wrapper is not None, \ + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + + ".") + return punica_wrapper diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc2d27c597f20c8a5aa79e87b96b8f76e9979bc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/utils.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f7fd323d020b42e7204c9401bbc4efad189d96d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b8e0fe2e27d36a6fa957f4c2756541c54b02a8a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/kv_cache_interface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8bf496dbe933abe40bb4ba3f9e820b133028d9a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/outputs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..158d653ed7fc8784d1e267d508f718f288dc7874 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/request.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cf3b59c57980942abb3548df723f10555e3d4c0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/serial_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74ebbd8250d645b616fa48a83b9a63761d2c2750 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83e8f94d4e26c0adda6039a49dd2ff818447c231 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23a658bc6705719367fddd5f70625a617675f258 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34e6b16b6700b685dbfaf5569f47f27a6385ee97 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/__pycache__/flash_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..837d7faf43708dbc2ece2eaa60c7283293f5f7c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/attention/backends/flash_attn.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import torch +import triton +import triton.language as tl + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cdiv +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported) + +logger = init_logger(__name__) + + +class FlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + + +@dataclass +class FlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + +class FlashAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + self.fa_version = 3 if is_fa_version_supported(3) else 2 + else: + self.fa_version = 2 + + if VLLM_FLASH_ATTN_VERSION is not None: + assert VLLM_FLASH_ATTN_VERSION in [2, 3] + self.fa_version = VLLM_FLASH_ATTN_VERSION + + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) + + assert is_fa_version_supported(self.fa_version) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + key_cache, value_cache = kv_cache.unbind(0) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # Compute attention and update output up to `num_actual_tokens`. + if not attn_metadata.use_cascade: + # Regular attention (common case). + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + seqused_k=attn_metadata.seq_lens, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table, + softcap=self.logits_soft_cap, + fa_version=self.fa_version, + ) + return output + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + fa_version=self.fa_version, + ) + return output + + +def use_cascade_attention( + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + num_sms: int, +) -> bool: + """Decide whether to use cascade attention. + + This function 1) checks whether cascade attention is supported with the + given configuration, and 2) heuristically decides whether using cascade + attention can improve performance. + """ + # Too short common prefix. Probably not worth using cascade attention. + # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. + # NOTE(woosuk): This is the common case. We should return False as soon as + # possible to avoid any unnecessary computation. + if common_prefix_len < 256: + return False + # Cascade attention is currently not supported with these variants. + if use_alibi or use_sliding_window: + return False + # Too few queries. Probably not worth using cascade attention. + # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. + num_reqs = len(query_lens) + if num_reqs < 8: + return False + + # Heuristics to decide whether using cascade attention is beneficial. + # 1. When FlashDecoding is not used for normal attention, cascade attention + # is likely to be faster since it saves memory bandwidth. + num_queries_per_kv = num_query_heads // num_kv_heads + # The criteria for using FlashDecoding can be found in the following link: + # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 + use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window + and not use_alibi and np.all(query_lens == 1)) + if not use_flash_decoding: + # Use cascade attention. + return True + + # 2. When FlashDecoding is used for normal attention, it is not clear + # whether cascade attention is beneficial, because FlashDecoding can + # launch more CTAs than cascade attention. + # We use a simple performance model to compare the two methods. + # NOTE(woosuk): The performance model is very rough and may not be + # accurate. + num_tokens = num_reqs + # NOTE(woosuk): These are default tile sizes. flash-attn might use + # different tile sizes (e.g., 64 or 256) depending on the configuration. + q_tile_size = 128 + kv_tile_size = 128 + num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) + + cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) + cascade_waves = cdiv(cascade_ctas, num_sms) + cascade_time = cascade_waves * num_prefix_tiles + + flash_decoding_ctas = (num_reqs * num_kv_heads * + cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas *= num_prefix_tiles + flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) + + # Use cascade attention if it is faster than FlashDecoding. + return cascade_time < flash_decoding_time + + +def cascade_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_query_lens: torch.Tensor, + max_query_len: int, + cu_prefix_query_lens: torch.Tensor, + prefix_kv_lens: torch.Tensor, + suffix_kv_lens: torch.Tensor, + max_kv_len: int, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Tuple[int, int], + logits_soft_cap: float, + block_table: torch.Tensor, + common_prefix_len: int, + fa_version: int, +) -> torch.Tensor: + assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + # TODO: Support sliding window. + assert sliding_window == (-1, -1), ( + "Cascade attention does not support sliding window.") + + num_tokens = query.shape[0] + block_size = key_cache.shape[-3] + assert common_prefix_len % block_size == 0 + num_common_kv_blocks = common_prefix_len // block_size + assert num_common_kv_blocks > 0 + + # Process shared prefix. + prefix_output, prefix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_prefix_query_lens, + seqused_k=prefix_kv_lens, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len, + softmax_scale=softmax_scale, + causal=False, + window_size=sliding_window, + block_table=block_table[:1], + softcap=logits_soft_cap, + return_softmax_lse=True, + fa_version=fa_version, + ) + + # Process suffix per query. + suffix_output, suffix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + seqused_k=suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, + softmax_scale=softmax_scale, + causal=True, + window_size=sliding_window, + block_table=block_table[:, num_common_kv_blocks:], + softcap=logits_soft_cap, + return_softmax_lse=True, + fa_version=fa_version, + ) + + # Merge prefix and suffix outputs, and store the result in output. + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, + suffix_lse) + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) + s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae72f748fce6398f894b583a84bce5aebe4b532a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c75b9c22950cfac9e9ec452c27717728366017 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/encoder_cache_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e49d1dfd6d5b0f63c89be37bfc352f6994178ff Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e4da6e587877383733293274bf661ebde4dfb82 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/kv_cache_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a12e04dd4c53298069c8cb84b724ee46b584620 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/core/__pycache__/scheduler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..651bc01aa5cf665c46bb0def62694499ae79d793 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/encoder_cache_manager.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Dict, List, Set, Tuple + +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.config import ModelConfig, SchedulerConfig + +logger = init_logger(__name__) + + +class EncoderCacheManager: + + def __init__(self, cache_size: int): + self.cache_size = cache_size + self.num_free_slots = cache_size + # req_id -> cached input ids + self.cached: Dict[str, Set[int]] = {} + # List of [req_id, input_id] + self.freed: List[Tuple[str, int]] = [] + + def has_cache(self, request: Request, input_id: int) -> bool: + req_id = request.request_id + return req_id in self.cached and input_id in self.cached[req_id] + + def can_allocate(self, request: Request, input_id: int) -> bool: + num_tokens = request.get_num_encoder_tokens(input_id) + return num_tokens <= self.num_free_slots + + def allocate(self, request: Request, input_id: int) -> None: + req_id = request.request_id + if req_id not in self.cached: + self.cached[req_id] = set() + self.cached[req_id].add(input_id) + self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + def get_cached_input_ids(self, request: Request) -> Set[int]: + return self.cached.get(request.request_id, set()) + + def free_encoder_input(self, request: Request, input_id: int) -> None: + """Free a single encoder input id for the request.""" + req_id = request.request_id + if req_id not in self.cached: + return + + self.cached[req_id].discard(input_id) + if len(self.cached[req_id]) == 0: + del self.cached[req_id] + self.num_free_slots += request.get_num_encoder_tokens(input_id) + self.freed.append((req_id, input_id)) + + def free(self, request: Request) -> None: + """Free all cached input ids for the request.""" + input_ids = self.get_cached_input_ids(request) + for input_id in input_ids: + self.free_encoder_input(request, input_id) + + def get_freed_ids(self) -> List[Tuple[str, int]]: + freed = self.freed + self.freed = [] + return freed + + +def compute_encoder_budget( + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", +) -> Tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations. + + Args: + model_config: Model configuration. + scheduler_config: Scheduler configuration. + + Returns: + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. + """ + + if not model_config.is_multimodal_model: + return 0, 0 + + # TODO: handle encoder-decoder models once we support them. + ( + encoder_compute_budget, + encoder_cache_size, + ) = _compute_encoder_budget_multimodal(model_config, scheduler_config) + + return encoder_compute_budget, encoder_cache_size + + +def _compute_encoder_budget_multimodal( + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", +) -> Tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a multimodal model. + + Args: + model_config: Model configuration. + scheduler_config: Scheduler configuration. + + Returns: + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. + """ + + max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 + model_config) + + if not max_tokens_by_modality_dict: + logger.warning( + "All non-text modalities supported by the model have been " + "explicitly disabled via limit_mm_per_prompt. Encoder cache will " + "not be initialized.") + return 0, 0 + + _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), + key=lambda item: item[1]) + + encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, + max_tokens_per_mm_item) + encoder_cache_size = max(scheduler_config.encoder_cache_size, + max_tokens_per_mm_item) + + return encoder_compute_budget, encoder_cache_size diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..de349ec12099931b81729d45274454e6b6f73c27 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_manager.py @@ -0,0 +1,500 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple + +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens, + hash_request_tokens) +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class KVCacheManager: + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + max_model_len: int, + sliding_window: Optional[int] = None, + enable_caching: bool = True, + num_preallocate_tokens: int = 64, + ) -> None: + self.block_size = block_size + self.num_gpu_blocks = num_gpu_blocks + self.max_model_len = max_model_len + self.max_num_blocks_per_req = cdiv(max_model_len, block_size) + self.sliding_window = sliding_window + self.enable_caching = enable_caching + # NOTE(woosuk): To avoid frequent block allocation, we preallocate some + # blocks for each request. For example, when a request reaches the end + # of its block table, we preallocate N blocks in advance. This way, we + # reduce the overhead of updating free_block_ids and ref_cnts for each + # request every step (at the cost of some memory waste). + # NOTE(woosuk): This is different from the "lookahead" slots since this + # does not guarantee that the request always has N empty blocks. After + # the request gets N empty blocks, it starts to use the blocks without + # further allocation. When it uses up all the N empty blocks, it gets + # N new empty blocks. + self.num_preallocate_tokens = num_preallocate_tokens + self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) + + # A Block pool of all kv-cache blocks. + self.block_pool: List[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ + int, KVCacheBlock]] = defaultdict(dict) + + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: DefaultDict[str, + List[KVCacheBlock]] = defaultdict(list) + + @property + def usage(self) -> float: + return 1.0 - (self.free_block_queue.num_free_blocks / + self.num_gpu_blocks) + + def get_computed_blocks( + self, request: Request) -> Tuple[List[KVCacheBlock], int]: + """Get the computed (cached) blocks for the request. + Note that the computed blocks must be full. + + Args: + request: The request to get the computed blocks. + + Returns: + A tuple containing: + - A list of blocks that are computed for the request. + - The number of computed tokens. + """ + if not self.enable_caching: + # Prefix caching is disabled. + return [], 0 + + computed_blocks = [] + + # The block hashes for the request may already be computed + # if the request was preempted and resumed. + if not request.kv_block_hashes: + request.set_kv_block_hashes( + hash_request_tokens(self.block_size, request)) + block_hashes = request.kv_block_hashes + + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self._get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + break + + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size + return computed_blocks, num_computed_tokens + + def allocate_slots( + self, + request: Request, + num_tokens: int, + new_computed_blocks: Optional[List[KVCacheBlock]] = None + ) -> Optional[List[KVCacheBlock]]: + """Add slots for a request with new tokens to append. + + Args: + request: The request to allocate slots. + num_tokens: The number of tokens to allocate. Note that this does + not include the tokens that have already been computed. + new_computed_blocks: A list of new computed blocks just hitting the + prefix caching. + + Blocks layout: + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + The following *_blocks are illustrated in this layout. + + Returns: + A list of new allocated blocks. + """ + if num_tokens == 0: + raise ValueError("num_tokens must be greater than 0") + + new_computed_blocks = new_computed_blocks or [] + + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + + len(new_computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_computed_tokens + num_tokens, + self.block_size) + req_blocks = self.req_to_blocks[request.request_id] + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(new_computed_blocks)) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + if blk.ref_cnt == 0) + if (num_new_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks): + # Cannot allocate new blocks + return None + + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self._touch(new_computed_blocks) + else: + assert not new_computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") + + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + req_blocks.extend(new_computed_blocks) + + # Start to handle new blocks + + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.free_block_queue.num_free_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(req_blocks), + ) + assert num_new_blocks > 0 + + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self._get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + + if not self.enable_caching: + return new_blocks + + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. + num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size + num_computed_full_blocks = num_computed_tokens // self.block_size + new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=num_computed_full_blocks, + # The new full blocks are the full blocks that are not computed. + full_blocks=new_full_blocks, + prev_block=(req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks > 0 else None)) + + return new_blocks + + def free(self, request: Request) -> None: + """Free the blocks allocated for the request. + When caching is enabled, we free the blocks in reverse order so that + the tail blocks are evicted first. + + Args: + request: The request to free the blocks. + """ + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request.request_id, []) + ordered_blocks: Iterable[KVCacheBlock] = blocks + if self.enable_caching: + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(blocks) + + for block in ordered_blocks: + block.decr_ref() + if block.ref_cnt == 0: + self.free_block_queue.append(block) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.num_gpu_blocks - + self.free_block_queue.num_free_blocks) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self.block_pool: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True + + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> int: + """Calculate the number of common prefix blocks shared by all requests + in the RUNNING state. + + The function determines this by selecting any request and iterating + through its blocks. A block is considered a common prefix block if its + `ref_cnt` equals the total number of requests in the RUNNING state. + + NOTE(woosuk): The number of requests in the RUNNING state is **greater + than or equal to** the number of requests scheduled in the current step. + This is because the RUNNING state only indicates that: + 1. The request has not yet finished, and + 2. The request holds its blocks unfreed. + + While all scheduled requests must be in the RUNNING state, the inverse + is not necessarily true. There may be RUNNING requests that are not + scheduled in the current step. As of 1/1/2025, the scheduler does not + allow this case, but it is possible in the future, as we allow more + flexible scheduling. + + This can result in an edge case where the number of common prefix blocks + is 0, even though all scheduled requests share a common prefix. This + occurs because there may be unscheduled RUNNING requests that do not + share the common prefix. Currently, this case cannot be easily detected, + so the function returns 0 in such cases. + + Args: + request: Any request in the RUNNING state, used to identify the + common prefix blocks. + num_running_requests: The total number of requests in the RUNNING + state. This can be different from the number of scheduled + requests in the current step. + + Returns: + int: The number of common prefix blocks. + """ + assert request.status == RequestStatus.RUNNING + blocks = self.req_to_blocks[request.request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks + + def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self.free_block_queue.num_free_blocks: + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: List[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self.free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] + + return True + return False + + def _get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self.cached_block_hash_to_block: + first_block_id = list( + self.cached_block_hash_to_block[block_hash].keys())[0] + return self.cached_block_hash_to_block[block_hash][first_block_id] + return None + + def _touch(self, blocks: List[KVCacheBlock]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0: + self.free_block_queue.remove(block) + block.incr_ref() + + def _cache_full_blocks( + self, + request: Request, + blk_start_idx: int, + full_blocks: List[KVCacheBlock], + prev_block: Optional[KVCacheBlock], + ) -> None: + """Cache a list of full blocks for prefix caching. + + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `blk_start_idx` to the end + of the request's full blocks, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blk_start_idx: The index of the first block in the request's blocks + to cache. + full_blocks: The list of blocks to update hash metadata. + prev_block: The previous block in the chain. + """ + num_cached_block_hashes = len(request.kv_block_hashes) + + # Update the new blocks with the block hashes through the chain. + prev_block_hash_value = None + if prev_block is not None: + # Previous block must have a block hash because it must be + # a full, cached block. + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + # Find the first uncached block. This case should only happen when + # speculative decoding is used. + offset = 0 + for blk in full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(full_blocks[offset:]): + blk_idx = blk_start_idx + offset + i + assert blk.block_hash is None + + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = request.kv_block_hashes[blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * self.block_size + end_token_idx = (blk_idx + 1) * self.block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, extra_keys) + request.append_kv_block_hashes(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0976ba8577b9ac4e02037c79d7e0914a6fc9675 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/kv_cache_utils.py @@ -0,0 +1,447 @@ +# SPDX-License-Identifier: Apache-2.0 +"""KV-Cache Utilities.""" +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, List, NamedTuple, Optional, Tuple + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, + KVCacheTensor) +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockHashType(NamedTuple): + """Hash value of a block (int), the token IDs in the block, and extra keys. + We keep a tuple of token IDs and extra keys to reduce the likelihood of + hash collisions when the hash value is the same. But please note that + hash collisions can still theoretically occur, albeit with an extremely + low probability. + """ + # Hash value of the block in an integer. + hash_value: int + # Token IDs in the block. + token_ids: Tuple[int, ...] + # Extra keys for the block. + extra_keys: Optional[Any] = None + + +@dataclass +class KVCacheBlock: + """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. + block_id: int + # Reference count. + ref_cnt: int = 0 + # The hash of the block composed of (block hash, tuple of token IDs). + # It is only available when the block is full. + _block_hash: Optional[BlockHashType] = None + + # Used to construct a doubly linked list for free blocks. + # These two attributes should only be manipulated by FreeKVCacheBlockQueue. + prev_free_block: Optional["KVCacheBlock"] = None + next_free_block: Optional["KVCacheBlock"] = None + + def incr_ref(self): + self.ref_cnt += 1 + + def decr_ref(self): + self.ref_cnt -= 1 + + @property + def block_hash(self) -> Optional[BlockHashType]: + return self._block_hash + + @block_hash.setter + def block_hash(self, block_hash: BlockHashType): + assert self.block_hash is None, ( + "The block already has a hash. This should not happen.") + self._block_hash = block_hash + + def reset_hash(self): + """Reset the block hash when the block is evicted.""" + self._block_hash = None + + +class FreeKVCacheBlockQueue: + """This class organizes a list of KVCacheBlock objects to a doubly linked + list of free blocks. We implement this class instead of using Python + builtin deque to support removing a block in the middle of the queue + in O(1) time. To close the performance gap to the builtin deque which is + implemented in C++, this class does not allocate any Python objects when + manipulating the linked list. Instead, this class manipulates the + prev_free_block and next_free_block attributes of the given blocks. + + The queue is ordered by block ID in the beginning. When a block is allocated + and then freed, it will be appended back with the eviction order: + 1. The least recent used block is at the front (LRU). + 2. If two blocks have the same last accessed time (allocated by the + same sequence), the one with more hash tokens (the tail of a block + chain) is at the front. + Note that we maintain this order by reversing the block order when free + blocks of a request. This operation is outside of this class. + + Args: + blocks: A list of KVCacheBlock objects. + """ + + def __init__(self, blocks: List[KVCacheBlock]) -> None: + self.num_free_blocks = len(blocks) + + # Initialize the doubly linked list of free blocks. + self.free_list_head: Optional[KVCacheBlock] = blocks[0] + self.free_list_tail: Optional[KVCacheBlock] = blocks[-1] + for i in range(self.num_free_blocks): + if i > 0: + blocks[i].prev_free_block = blocks[i - 1] + if i < self.num_free_blocks - 1: + blocks[i].next_free_block = blocks[i + 1] + + def popleft(self) -> KVCacheBlock: + """Pop the first free block and reduce num_free_blocks by 1. + + Returns: + The first free block. + """ + if not self.free_list_head: + raise ValueError("No free blocks available") + + block = self.free_list_head + self.remove(block) + return block + + def remove(self, block: KVCacheBlock) -> None: + """Remove a block in the free list and reduce num_free_blocks by 1. + + Args: + block: The block to remove. + """ + if block.prev_free_block is not None: + # Link the previous block to the next block. + block.prev_free_block.next_free_block = block.next_free_block + if block.next_free_block is not None: + # Link the next block to the previous block. + block.next_free_block.prev_free_block = block.prev_free_block + + if block == self.free_list_head: + # Update the head if the block is the head. + self.free_list_head = block.next_free_block + if block == self.free_list_tail: + # Update the tail if the block is the tail. + self.free_list_tail = block.prev_free_block + + # Remove the block from the linked list. + block.prev_free_block = block.next_free_block = None + self.num_free_blocks -= 1 + + def append(self, block: KVCacheBlock) -> None: + """Put a block back into the free list and increase + num_free_blocks by 1. + + Args: + block: The block to append. + """ + if self.free_list_tail is not None: + # Link the last block to the new block. + self.free_list_tail.next_free_block = block + block.prev_free_block = self.free_list_tail + self.free_list_tail = block + else: + # The free list is empty. + assert self.free_list_head is None + self.free_list_head = self.free_list_tail = block + + block.next_free_block = None + self.num_free_blocks += 1 + + def get_all_free_blocks(self) -> List[KVCacheBlock]: + """Get all free blocks in the free list. Mainly used for testing. + + Returns: + A list of free blocks. + """ + ret = [] + curr_block = self.free_list_head + while curr_block is not None: + ret.append(curr_block) + curr_block = curr_block.next_free_block + return ret + + +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + For multi-modal inputs, the extra keys are (mm_hash, start_offset) that + indicate a mm input contained in the block and its starting offset in + the block tokens. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if not mm_positions: + return None, start_mm_idx + + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match. This " + "is likely because you do not enable MM preprocessor hashing. " + "Please set disable_mm_preprocessor_cache=False.") + + # Note that we assume mm_positions is sorted by offset. + # We do not need to check all mm inputs if the start token index is out of + # range. This usually happens in the late prefill phase and decoding phase. + if mm_positions[-1]["offset"] + mm_positions[-1][ + "length"] < start_token_idx: + return None, start_mm_idx + + # Support start_mm_idx == -1 to indicate the last mm input. + if start_mm_idx < 0: + assert -start_mm_idx <= len(mm_positions) + start_mm_idx = len(mm_positions) + start_mm_idx + + extra_keys = [] + curr_mm_idx = start_mm_idx + while mm_positions and curr_mm_idx < len(mm_positions): + assert mm_hashes[curr_mm_idx] is not None + offset = mm_positions[curr_mm_idx]["offset"] + length = mm_positions[curr_mm_idx]["length"] + if end_token_idx > offset: + if start_token_idx > offset + length: + # This block has passed the current mm input. + curr_mm_idx += 1 + continue + + # The block contains the current mm input. + extra_keys.append(mm_hashes[curr_mm_idx]) + + if end_token_idx >= offset + length: + # If this block contains the end of the current mm input, + # move to the next mm input as this block may also contain + # the next mm input. + curr_mm_idx += 1 + else: + # Otherwise this block is done with mm inputs. + break + else: + # This block has not reached the current mm input. + break + return tuple(extra_keys), curr_mm_idx + + +def hash_block_tokens( + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. We use LRU cache for this function to avoid recomputing + hash values for the same block contents. + + TODO: Support arbitrary metadata so that we could support more + features such as LoRA adapter. + + Args: + parent_block_hash: The hash of the parent block. None + if this is the first block. + curr_block_token_ids: A list of token ids in the current + block. The current block is assumed to be full. + extra_keys: Extra keys for the block. + + Returns: + The hash value of the block and the token ids in the block. + The entire tuple is used as the hash key of the block. + """ + if not parent_block_hash: + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + parent_block_hash = hash('None') + + curr_block_token_ids_tuple = tuple(curr_block_token_ids) + return BlockHashType( + hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)), + curr_block_token_ids_tuple, extra_keys) + + +def hash_request_tokens(block_size: int, + request: Request) -> List[BlockHashType]: + """Computes hash values of a chain of blocks given a sequence of + token IDs. The hash value is used for prefix caching. + + Args: + block_size: The size of each block. + request: The request object. + + Returns: + The list of computed hash values. + """ + token_ids = request.all_token_ids + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match.") + + # TODO: Extend this to support other features such as LoRA. + need_extra_keys = bool(mm_positions) + extra_keys = None + curr_mm_idx = 0 + + ret = [] + parent_block_hash_value = None + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + # Add extra keys if the block is a multi-modal block. + if need_extra_keys: + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start, end, curr_mm_idx) + + block_hash = hash_block_tokens(parent_block_hash_value, + block_token_ids, extra_keys) + ret.append(block_hash) + parent_block_hash_value = block_hash.hash_value + return ret + + +def check_enough_kv_cache_memory(vllm_config: VllmConfig, + kv_cache_spec: KVCacheSpec, + available_memory: int): + """ + Checks whether `available_memory` is enough for the KV cache to hold at + least one request with the model's max_model_len. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory: Memory available for KV cache in bytes. + + Raises: + ValueError: If there is not enough memory available for the KV cache. + """ + + if available_memory <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_model_len = vllm_config.model_config.max_model_len + needed_memory = 0 + for layer_spec in kv_cache_spec.values(): + needed_memory += layer_spec.bytes_for_tokens(max_model_len) + + if needed_memory > available_memory: + raise ValueError( + f"To serve at least one request with the models's max seq len " + f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV " + f"cache is needed, which is larger than the available KV cache " + f"memory ({available_memory/1024/1024/1024:.2f} GB). Try " + f"increasing `gpu_memory_utilization` or decreasing " + f"`max_model_len` when initializing the engine.") + + +def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same type of KV cache. + + Args: + kv_cache_spec: The KVCacheSpec of the model + + Returns: + True if all layers have the same type, False otherwise. + """ + + layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) + return len(layer_keys) == 1 + + +def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, + kv_cache_spec: KVCacheSpec, + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + assert len(page_sizes) == 1 + page_size = page_sizes.pop() + + num_blocks = int(available_memory // page_size // len(kv_cache_spec)) + num_blocks = max(num_blocks, 0) + + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + num_blocks = num_gpu_blocks_override + + logger.info("# GPU blocks: %d", num_blocks) + max_concurrency = (num_blocks * vllm_config.cache_config.block_size / + vllm_config.model_config.max_model_len) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + vllm_config.model_config.max_model_len, max_concurrency) + + per_layer_size = page_size * num_blocks + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + tensors={ + layer_name: KVCacheTensor(size=per_layer_size) + for layer_name in kv_cache_spec + }, + groups=[[layer_name for layer_name in kv_cache_spec]], + kv_cache_spec=kv_cache_spec) + return kv_cache_config + + +def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model + TODO: support hybrid models with more than one type of KV cache. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for most models. + # Allocate the same amount of memory for each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + else: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py b/.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5e83fe062747ee85abf83243938c3763c1a2eb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/core/scheduler.py @@ -0,0 +1,631 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import deque +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) + +from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, + compute_encoder_budget) +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.base import PlaceholderRange + +logger = init_logger(__name__) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + model_config: ModelConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + self.lora_config = lora_config + # TODO: Support LoRA. + assert lora_config is None, "V1 does not support LoRA yet." + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + + num_gpu_blocks = cache_config.num_gpu_blocks + assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + max_model_len=self.max_model_len, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching) + self.block_size = self.cache_config.block_size + + # req_id -> Request + self.requests: Dict[str, Request] = {} + # Priority queues for requests. + self.waiting: Deque[Request] = deque() + self.running: List[Request] = [] + + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: Set[str] = set() + + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating + # them at each scheduling step. + # Request id -> CachedRequestData + self._cached_reqs_data: Dict[str, CachedRequestData] = {} + + # Encoder-related. + # Calculate encoder cache size if applicable + # NOTE: For now we use the same budget for both compute and space. + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + ) + + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and + # projector if needed). Currently, we assume that the encoder also + # has the Transformer architecture (e.g., ViT). + self.max_num_encoder_input_tokens = encoder_compute_budget + # NOTE: For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized because cache size is 0 + # for these models. + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) + + def schedule(self) -> "SchedulerOutput": + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and num_tokens, + # which is equal to len(prompt_token_ids) + len(output_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens. This is general enough to cover chunked prefills, + # prefix caching, and the "jump decoding" optimization in the future. + + scheduled_new_reqs: List[Request] = [] + scheduled_resumed_reqs: List[Request] = [] + scheduled_running_reqs: List[Request] = [] + preempted_reqs: List[Request] = [] + + req_to_new_block_ids: Dict[str, List[int]] = {} + num_scheduled_tokens: Dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: Dict[str, List[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + num_new_tokens = request.num_tokens - request.num_computed_tokens + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( + self._try_schedule_encoder_inputs(request, + request.num_computed_tokens, + num_new_tokens, + encoder_budget)) + if num_new_tokens == 0: + # The request cannot be scheduled because the encoder budget + # or the encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting[0] + # Get already-cached tokens. + computed_blocks, num_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks(request) + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed requests, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if num_new_tokens == 0: + # This happens when prompt length is divisible by the block + # size and all blocks are cached. Now we force to recompute + # the last block. Note that we have to re-compute an entire + # block because allocate_slots() assumes num_computed_tokens + # is always a multiple of the block size. This limitation + # can potentially be removed in the future to slightly + # improve the performance. + num_computed_tokens -= self.block_size + num_new_tokens = self.block_size + computed_blocks.pop() + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break + + self.waiting.popleft() + self.running.append(request) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + ) + + self.finished_req_ids = set() + return scheduler_output + + def _make_cached_request_data( + self, + request: Request, + new_block_ids: List[int], + num_computed_tokens: int, + resumed_from_preemption: bool, + ) -> "CachedRequestData": + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating + # them at each scheduling step. + if request.request_id in self._cached_reqs_data: + req_data = self._cached_reqs_data[request.request_id] + req_data.resumed_from_preemption = resumed_from_preemption + req_data.new_block_ids = new_block_ids + req_data.num_computed_tokens = num_computed_tokens + else: + req_data = CachedRequestData.from_request(request, + resumed_from_preemption, + new_block_ids, + num_computed_tokens) + self._cached_reqs_data[request.request_id] = req_data + return req_data + + def _try_schedule_encoder_inputs( + self, + request: Request, + num_computed_tokens: int, + num_new_tokens: int, + encoder_budget: int, + ) -> Tuple[List[int], int, int]: + """ + Determine which encoder inputs need to be scheduled in the current step, + and update `num_new_tokens` and encoder token budget accordingly. + + An encoder input will be scheduled if: + - Its output tokens overlap with the range of tokens being computed + in this step, i.e., + [num_computed_tokens, num_computed_tokens + num_new_tokens). + - It is not already computed and stored in the encoder cache. + - There is sufficient encoder token budget to process it. + - The encoder cache has space to store it. + + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. + """ + if not request.has_encoder_inputs(): + return [], num_new_tokens, encoder_budget + + encoder_inputs_to_schedule: List[int] = [] + mm_positions = request.mm_positions + assert mm_positions is not None + assert len(mm_positions) > 0 + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_new_tokens: + # The encoder input is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder input is already computed and stored + # in the decoder's KV cache. + continue + + if self.encoder_cache_manager.has_cache(request, i): + # The encoder input is already computed and cached. + continue + if (not self.encoder_cache_manager.can_allocate(request, i) + or num_encoder_tokens > encoder_budget): + # The encoder cache is full or the encoder budget is exhausted. + # NOTE(woosuk): We assume that the encoder input tokens should + # be processed altogether, as the encoder usually uses + # bidirectional attention. + if num_computed_tokens < start_pos: + # We only schedule the decoder tokens just before the + # encoder input. + num_new_tokens = start_pos - num_computed_tokens + else: + # Because of prefix caching, num_computed_tokens is greater + # than start_pos even though its encoder input is not + # available. In this case, we can't schedule any token for + # the request in this step. + num_new_tokens = 0 + break + + encoder_budget -= num_encoder_tokens + encoder_inputs_to_schedule.append(i) + return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> EngineCoreOutputs: + # NOTE(woosuk): This method doesn't consider speculative decoding. + sampled_token_ids = model_runner_output.sampled_token_ids + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + new_running: List[Request] = [] + outputs: List[EngineCoreOutput] = [] + + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + request.num_computed_tokens += num_tokens_scheduled + # When the request's num_computed_tokens catches up its num_tokens, + # the request generates output tokens. Otherwise, we ignore the + # sampler output for the request. + assert request.num_computed_tokens <= request.num_tokens + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + start_pos = request.mm_positions[input_id]["offset"] + num_tokens = request.mm_positions[input_id]["length"] + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + if request.num_computed_tokens == request.num_tokens: + req_index = model_runner_output.req_id_to_index[req_id] + # NOTE(woosuk): Currently, we assume that each request + # generates at most one token at each step. + token_id = sampled_token_ids[req_index] + request.append_output_token_ids(token_id) + num_new_tokens = 1 + # TODO: Update the KV cache manager for prefix caching. + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = self._check_stop(request) + if stopped: + self._free_request(request) + + # Add EngineCoreOutput for this Request. + output = EngineCoreOutput( + request_id=req_id, + new_token_ids=request.output_token_ids[-num_new_tokens:], + finished=request.is_finished(), + finish_reason=request.get_finished_reason(), + stop_reason=request.stop_reason) + outputs.append(output) + + # Breakout of the loop. + if stopped: + continue + + new_running.append(request) + self.running = new_running + return EngineCoreOutputs( + outputs=outputs, + scheduler_stats=self.make_stats(), + ) + + def _check_stop(self, request: Request) -> bool: + if (request.num_tokens >= self.max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + return True + + sampling_params = request.sampling_params + last_token_id = request.output_token_ids[-1] + if (not sampling_params.ignore_eos + and last_token_id == request.eos_token_id): + request.status = RequestStatus.FINISHED_STOPPED + return True + + if last_token_id in (sampling_params.stop_token_ids or ()): + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = last_token_id + return True + return False + + def add_request(self, request: Request) -> None: + self.waiting.append(request) + self.requests[request.request_id] = request + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + request_ids = set(request_ids) + + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + if request.status == RequestStatus.RUNNING: + self.running.remove(request) + else: + self.waiting.remove(request) + request.status = finished_status + self._free_request(request) + + def _free_request(self, request: Request) -> None: + assert request.is_finished() + self.kv_cache_manager.free(request) + self.encoder_cache_manager.free(request) + self._cached_reqs_data.pop(request.request_id, None) + del self.requests[request.request_id] + self.finished_req_ids.add(request.request_id) + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + def has_unfinished_requests(self) -> bool: + return self.get_num_unfinished_requests() > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() + + def make_stats(self) -> SchedulerStats: + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + gpu_cache_usage=self.kv_cache_manager.usage, + ) + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + block_ids: List[int], + num_computed_tokens: int, + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class CachedRequestData: + + req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. + resumed_from_preemption: bool + new_block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + resumed_from_preemption: bool, + new_block_ids: List[int], + num_computed_tokens: int, + ) -> "CachedRequestData": + return cls( + req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class SchedulerOutput: + + scheduled_new_reqs: List[NewRequestData] + scheduled_cached_reqs: List[CachedRequestData] + + num_scheduled_tokens: Dict[str, int] + total_num_scheduled_tokens: int + scheduled_encoder_inputs: Dict[str, List[int]] + num_common_prefix_blocks: int + + finished_req_ids: Set[str] + free_encoder_input_ids: List[Tuple[str, int]] diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48c0457236bee710090a550d1e9ee8b35edfaf00 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4052f3fbbd50f8599924f75cb20169790f43b35d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/abstract.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f65c16b53bb69b4c9acf5fd75fc0fac79e96eff Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/executor/__pycache__/multiproc_executor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py b/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py new file mode 100644 index 0000000000000000000000000000000000000000..ac10d43eb0d54d634af49964234cdcdc00c3c8d9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Type + +from vllm.config import VllmConfig +from vllm.executor.executor_base import ExecutorBase +from vllm.executor.ray_distributed_executor import ( # noqa + RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) +from vllm.executor.uniproc_executor import ( # noqa + UniProcExecutor as UniProcExecutorV0) +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.outputs import ModelRunnerOutput + + +class Executor(ExecutorBase): + """ + Abstract class for v1 executors, mainly define some methods for v1. + For methods shared by v0 and v1, define them in ExecutorBase""" + + @staticmethod + def get_class(vllm_config: VllmConfig) -> Type["Executor"]: + executor_class: Type[Executor] + parallel_config = vllm_config.parallel_config + distributed_executor_backend = ( + parallel_config.distributed_executor_backend) + if distributed_executor_backend is None: + # If the user does not specify the distributed executor backend, + # we will choose the backend based on the world size. + if parallel_config.world_size > 1: + distributed_executor_backend = "mp" + else: + distributed_executor_backend = "uni" + + if distributed_executor_backend == "ray": + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor + elif distributed_executor_backend == "uni": + executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # TODO: make v1 scheduling deterministic + # to support external launcher + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("Unknown distributed executor backend: " + f"{distributed_executor_backend}") + return executor_class + + def initialize(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the KV caches and begin the model execution loop of the + underlying workers. + """ + self.collective_rpc("initialize_cache", args=(kv_cache_config, )) + self.collective_rpc("compile_or_warm_up_model") + + def determine_available_memory(self) -> int: # in bytes + output = self.collective_rpc("determine_available_memory") + # Since we use a shared centralized controller, we take the minimum + # memory size across all workers to make sure all the memory + # operators can be applied to all workers. + return min(output) + + def get_kv_cache_spec(self) -> KVCacheSpec: + output = self.collective_rpc("get_kv_cache_spec") + for x in output: + assert x == output[0] + return output[0] + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + output = self.collective_rpc("execute_model", + args=(scheduler_output, )) + return output[0] + + def profile(self, is_start: bool = True): + self.collective_rpc("profile", args=(is_start, )) + + +class UniProcExecutor(UniProcExecutorV0, Executor): + pass + + +class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): + pass + + +class RayDistributedExecutor(RayDistributedExecutorV0, Executor): + pass diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py b/.venv/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f07172d8cd9bc280740c1d92caa2e6ca8f0607 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/executor/multiproc_executor.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import pickle +import signal +import sys +import time +import weakref +from dataclasses import dataclass +from enum import Enum, auto +from functools import partial +from multiprocessing.process import BaseProcess +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cloudpickle +import psutil +import zmq + +from vllm.config import VllmConfig +from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) +from vllm.executor.multiproc_worker_utils import ( + _add_prefix, set_multiprocessing_worker_envs) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_mp_context, + get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) +from vllm.v1.executor.abstract import Executor +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 5000 +POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 + + +class MultiprocExecutor(Executor): + + def _init_executor(self) -> None: + # Call self.shutdown at exit to clean up + # and ensure workers will be terminated. + self._finalizer = weakref.finalize(self, self.shutdown) + + # The child processes will send SIGUSR1 when unrecoverable + # errors happen. + def sigusr1_handler(signum, frame): + logger.fatal( + "MulitprocExecutor got fatal signal from worker processes, " + "shutting down. See stack trace above for root cause issue.") + # Propagate error up to parent process. + parent_process = psutil.Process().parent() + parent_process.send_signal(signal.SIGUSR1) + self.shutdown() + + signal.signal(signal.SIGUSR1, sigusr1_handler) + + self.world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + assert self.world_size == tensor_parallel_size, ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tensor_parallel_size}). " + f"Pipeline parallelism is not yet implemented in v1") + + # Set multiprocessing envs that are common to V0 and V1 + set_multiprocessing_worker_envs(self.parallel_config) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + # Initialize worker and set up message queues for SchedulerOutputs + # and ModelRunnerOutputs + self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + + # Create workers + self.workers: List[WorkerProcHandle] = [] + for rank in range(self.world_size): + worker = WorkerProc.make_worker_process(self.vllm_config, rank, + rank, + distributed_init_method, + scheduler_output_handle) + self.workers.append(worker) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + start_time = time.monotonic() + kwargs = kwargs or {} + + # NOTE: If the args are heterogeneous, then we pack them into a list, + # and unpack them in the method of every worker, because every worker + # knows their own rank. + try: + if isinstance(method, str): + send_method = method + else: + send_method = cloudpickle.dumps( + method, protocol=pickle.HIGHEST_PROTOCOL) + self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) + + responses = [None] * self.world_size + for w in self.workers: + dequeue_timeout = timeout - (time.monotonic() - start_time + ) if timeout is not None else None + status, result = w.worker_response_mq.dequeue( + timeout=dequeue_timeout) + + if status != WorkerProc.ResponseStatus.SUCCESS: + if isinstance(result, Exception): + raise result + else: + raise RuntimeError("Worker failed") + + responses[w.rank] = result + + return responses + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + except Exception as e: + # Re-raise any other exceptions + raise e + + def _ensure_worker_termination(self): + """Ensure that all worker processes are terminated. Assumes workers have + received termination requests. Waits for processing, then sends + termination and kill signals if needed.""" + + def wait_for_termination(procs, timeout): + if not time: + # If we are in late stage shutdown, the interpreter may replace + # `time` with `None`. + return all(not proc.is_alive() for proc in procs) + start_time = time.time() + while time.time() - start_time < timeout: + if all(not proc.is_alive() for proc in procs): + return True + time.sleep(0.1) + return False + + # Send SIGTERM if still running + active_procs = [w.proc for w in self.workers if w.proc.is_alive()] + for p in active_procs: + p.terminate() + if not wait_for_termination(active_procs, 4): + # Send SIGKILL if still running + active_procs = [p for p in active_procs if p.is_alive()] + for p in active_procs: + p.kill() + + self._cleanup_sockets() + + def _cleanup_sockets(self): + for w in self.workers: + # Remove the zmq ipc socket file + socket_path = w.ready_path.replace("ipc://", "") + if os and os.path.exists(socket_path): + os.remove(socket_path) + + def shutdown(self): + """Properly shut down the executor and its workers""" + if getattr(self, 'shutting_down', False): + self.shutting_down = True + for w in self.workers: + w.worker_response_mq = None + self._ensure_worker_termination() + + self.rpc_broadcast_mq = None + + def check_health(self) -> None: + self.collective_rpc("check_health", timeout=10) + return + + +@dataclass +class WorkerProcHandle: + proc: BaseProcess + rank: int + ready_path: str + worker_response_mq: MessageQueue # The worker process writes to this MQ + + +class WorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + READY_STR = "READY" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle: Handle, + ready_path: str, + ): + self.rank = rank + wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) + # TODO: move `init_worker` to executor level as a collective rpc call + all_kwargs: List[Dict] = [ + {} for _ in range(vllm_config.parallel_config.world_size) + ] + all_kwargs[rank] = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + } + wrapper.init_worker(all_kwargs) + self.worker = wrapper.worker + + pid = os.getpid() + _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) + _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid) + + # Initialize MessageQueue for receiving SchedulerOutput + self.rpc_broadcast_mq = MessageQueue.create_from_handle( + input_shm_handle, self.worker.rank) + + # Initializes a message queue for sending the model output + self.worker_response_mq = MessageQueue(1, 1) + worker_response_mq_handle = self.worker_response_mq.export_handle() + + # Send Readiness signal to EngineCore process. + with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: + payload = pickle.dumps(worker_response_mq_handle, + protocol=pickle.HIGHEST_PROTOCOL) + ready_socket.send_string(WorkerProc.READY_STR) + ready_socket.send(payload) + + self.worker.init_device() + self.worker.load_model() + + @staticmethod + def make_worker_process( + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + ) -> WorkerProcHandle: + context = get_mp_context() + + # ZMQ path for worker to send ready message and shm_broadcast handle + # back to core process. + ready_path = get_open_zmq_ipc_path() + + process_kwargs = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "input_shm_handle": input_shm_handle, + "ready_path": ready_path, + } + # Run EngineCore busy loop in background process. + proc = context.Process(target=WorkerProc.worker_main, + kwargs=process_kwargs, + daemon=True) + proc.start() + + # Wait for startup + worker_response_mq_handle = WorkerProc.wait_for_startup( + proc, ready_path) + + worker_response_mq = MessageQueue.create_from_handle( + worker_response_mq_handle, 0) + + return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) + + def shutdown(self): + self.rpc_broadcast_mq = None + self.worker_response_mq = None + destroy_model_parallel() + destroy_distributed_environment() + + @staticmethod + def worker_main(*args, **kwargs): + """ Worker initialization and execution loops. + This runs a background process """ + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the worker + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + worker = None + try: + worker = WorkerProc(*args, **kwargs) + + # Ensure message queues are ready. Will deadlock if re-ordered. + # Must be kept consistent with the Executor + worker.rpc_broadcast_mq.wait_until_ready() + worker.worker_response_mq.wait_until_ready() + + worker.worker_busy_loop() + + except SystemExit: + logger.debug("Worker interrupted.") + + except Exception: + # worker_busy_loop sends exceptions exceptons to Executor + # for shutdown, but if there is an error in startup or an + # error with IPC itself, we need to alert the parent. + psutil.Process().parent().send_signal(signal.SIGUSR1) + raise + + finally: + # Clean up once worker exits busy loop + if worker is not None: + worker.shutdown() + worker = None + + @staticmethod + def wait_for_startup( + proc: BaseProcess, + ready_path: str, + ) -> Optional[Handle]: + """Wait until the Worker is ready.""" + with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket: + + # Wait for Worker to send READY. + while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for WorkerProc to startup.") + + if not proc.is_alive(): + raise RuntimeError("WorkerProc failed to start.") + + message = socket.recv_string() + assert message == WorkerProc.READY_STR + handle_frame = socket.recv(copy=False) + handle = pickle.loads(handle_frame.buffer) + return handle + + class ResponseStatus(Enum): + SUCCESS = auto() + FAILURE = auto() + + def worker_busy_loop(self): + """Main busy loop for Multiprocessing Workers""" + while True: + method, args, kwargs = self.rpc_broadcast_mq.dequeue() + + try: + if isinstance(method, str): + func = getattr(self.worker, method) + elif isinstance(method, bytes): + func = partial(cloudpickle.loads(method), self.worker) + output = func(*args, **kwargs) + except Exception as e: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.FAILURE, e)) + logger.exception("WorkerProc hit an exception: %s", exc_info=e) + continue + + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/kv_cache_interface.py b/.venv/lib/python3.11/site-packages/vllm/v1/kv_cache_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..eddfb5949ebe65c3dd5f8ae72a8aad06ee818703 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/kv_cache_interface.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Dict, List + +import torch + +from vllm.logger import init_logger +from vllm.utils import cdiv, get_dtype_size + +logger = init_logger(__name__) + + +@dataclass +class KVCacheSpecBase: + """ + A base class for specifying the KV cache format of one layer. + """ + + # number of tokens in a block + block_size: int + + @property + def type_id(self) -> str: + """ + The type identifier of this KV cache. + Return different strings for layers with different KV cache type (e.g., + different number of tokens like full attention vs sliding window + attention, different KV cache size per token like layers with different + number of heads) + + Returns: + The type identifier of this KV cache. + """ + raise NotImplementedError + + @property + def page_size_bytes(self) -> int: + """ + The size of a page with `block_size` tokens in bytes. + + Returns: + The page size + """ + raise NotImplementedError + + def bytes_for_tokens(self, num_tokens: int) -> int: + """ + The KV cache size for `num_tokens` tokens in bytes. Returns the real + memory size after padding `num_tokens` to full blocks. + + Returns: + The KV cache size + """ + raise NotImplementedError + + +@dataclass +class FullAttentionSpec(KVCacheSpecBase): + num_kv_heads: int + head_size: int + dtype: torch.dtype + + @property + def type_id(self) -> str: + return f"full_attention_{self.block_size}_{self.page_size_bytes}" + + @property + def page_size_bytes(self) -> int: + return 2 * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + def bytes_for_tokens(self, num_tokens: int) -> int: + return cdiv(num_tokens, self.block_size) * self.page_size_bytes + + +KVCacheSpec = Dict[str, KVCacheSpecBase] + + +@dataclass +class KVCacheTensor: + """ + A dataclass for specifying how the workers should initialize the KV cache + for a layer. Only contains the size of KV cache for that layer for now. Will + be extended to support multiple layers sharing the same memory pool. + """ + size: int # The size of KV cache Tensor in bytes + + +@dataclass +class KVCacheConfig: + """ + The KV cache configuration of a model. + """ + """The number of KV cache blocks""" + num_blocks: int + """layer_name -> how to initialize KV cache for that layer""" + tensors: Dict[str, KVCacheTensor] + """ + A list of kv-cache groups. Each group includes a set of layers with + the same kv-cache spec, and the total page_size of layers inside a group + is same across all groups (as the KVCacheManager only supports allocating + pages of the same size). For example: + 1. A model only uses full attention: one group with all layers in the model. + 2. (not implemented yet) A model with the same number of full attention + layers and sliding window attention layers: two groups, one for full + attention layers and one for sliding window attention layers. + 3. (not implemented yet) A model with 2 full attention layers and 4 sliding + window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). + """ + groups: List[List[str]] + """the KVCacheSpec of the model""" + kv_cache_spec: KVCacheSpec diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/outputs.py b/.venv/lib/python3.11/site-packages/vllm/v1/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..6e82bffd7e5c9dfff0a077ecb9c34a3cad4c9c53 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/outputs.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + + +@dataclass +class SamplerOutput: + + # [num_reqs] + sampled_token_ids: torch.Tensor + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: Optional[torch.Tensor] + # [num_reqs, max_num_logprobs + 1] + logprobs: Optional[torch.Tensor] + + # TODO: Support prompt logprobs. + prompt_logprob_token_ids: Optional[torch.Tensor] + prompt_logprobs: Optional[torch.Tensor] + + +# ModelRunnerOutput is serialized and sent to the scheduler process. +# This is expensive for torch.Tensor so prefer to use List instead. +@dataclass +class ModelRunnerOutput: + + # [num_reqs] + req_ids: List[str] + # req_id -> index + req_id_to_index: Dict[str, int] + + # [num_reqs] + sampled_token_ids: List[int] + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids_cpu: Optional[torch.Tensor] + # [num_reqs, max_num_logprobs + 1] + logprobs_cpu: Optional[torch.Tensor] diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/request.py b/.venv/lib/python3.11/site-packages/vllm/v1/request.py new file mode 100644 index 0000000000000000000000000000000000000000..89b39ea615d20f6a1cfb3c3b91e2d211008a85ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/request.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from typing import TYPE_CHECKING, List, Optional, Union + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import RequestMetrics +from vllm.v1.engine import EngineCoreRequest, FinishReason +from vllm.v1.utils import ConstantList + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.core.kv_cache_utils import BlockHashType + + +class Request: + + def __init__( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: List[int], + multi_modal_inputs: Optional[List["MultiModalKwargs"]], + multi_modal_hashes: Optional[List[str]], + multi_modal_placeholders: Optional[List["PlaceholderRange"]], + sampling_params: SamplingParams, + eos_token_id: Optional[int], + arrival_time: float, + lora_request: Optional[LoRARequest] = None, + ) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + # Because of LoRA, the eos token id can be different for each request. + self.eos_token_id = eos_token_id + self.metrics = RequestMetrics(arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None) + self.lora_request = lora_request + + self.status = RequestStatus.WAITING + self.stop_reason: Union[int, str, None] = None + assert sampling_params.max_tokens is not None + self.max_tokens = sampling_params.max_tokens + + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.num_prompt_tokens = len(self.prompt_token_ids) + self._output_token_ids: List[int] = [] + self._all_token_ids: List[int] = self.prompt_token_ids.copy() + self.num_computed_tokens = 0 + + # Multi-modal related + self.mm_positions = multi_modal_placeholders or [] + self.mm_inputs = multi_modal_inputs or [] + self.mm_hashes: List[str] = multi_modal_hashes or [] + + # Sanity check + assert len(self.mm_inputs) == len(self.mm_positions) + if self.mm_hashes: + assert len(self.mm_inputs) == len(self.mm_hashes) + + # Cache the computed kv block hashes of the request to avoid + # recomputing. + self._kv_block_hashes: List[BlockHashType] = [] + self.kv_block_hashes = ConstantList(self._kv_block_hashes) + + # Read-only views + # Prevent directly appending to the these lists since + # they should also be updated simultaneously. + self.output_token_ids = ConstantList(self._output_token_ids) + self.all_token_ids = ConstantList(self._all_token_ids) + + @classmethod + def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + return cls( + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + multi_modal_inputs=request.mm_inputs, + multi_modal_hashes=request.mm_hashes, + multi_modal_placeholders=request.mm_placeholders, + sampling_params=request.sampling_params, + eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + ) + + def append_output_token_ids( + self, + token_ids: Union[int, List[int]], + ) -> None: + if isinstance(token_ids, int): + token_ids = [token_ids] + self._output_token_ids.extend(token_ids) + self._all_token_ids.extend(token_ids) + + @property + def num_tokens(self) -> int: + return len(self._all_token_ids) + + @property + def num_output_tokens(self) -> int: + return len(self._output_token_ids) + + def is_finished(self) -> bool: + return RequestStatus.is_finished(self.status) + + def get_finished_reason(self) -> Union[FinishReason, None]: + return RequestStatus.get_finished_reason(self.status) + + def has_encoder_inputs(self) -> bool: + return len(self.mm_inputs) > 0 + + @property + def num_encoder_inputs(self) -> int: + return len(self.mm_positions) + + def get_num_encoder_tokens(self, input_id: int) -> int: + assert input_id < len(self.mm_positions) + num_tokens = self.mm_positions[input_id]["length"] + return num_tokens + + def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: + self._kv_block_hashes = value + self.kv_block_hashes = ConstantList(self._kv_block_hashes) + + def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: + self._kv_block_hashes.append(block_hash) + + +class RequestStatus(enum.IntEnum): + """Status of a request.""" + WAITING = 0 + RUNNING = 1 + PREEMPTED = 2 + # Note: anything after PREEMPTED (2) will be considered + # as a finished status. + FINISHED_STOPPED = 3 + FINISHED_LENGTH_CAPPED = 4 + FINISHED_ABORTED = 5 + FINISHED_IGNORED = 6 + + @staticmethod + def is_finished(status: "RequestStatus") -> bool: + return status > RequestStatus.PREEMPTED + + @staticmethod + def get_finished_reason( + status: "RequestStatus") -> Union[FinishReason, None]: + return _FINISHED_REASON_MAP.get(status) + + +# Mapping of finished statuses to their finish reasons. +# NOTE: The ignored requests are the requests whose prompt lengths +# are longer than the model's length cap. Therefore, the stop +# reason should also be "length" as in OpenAI API. +_FINISHED_REASON_MAP = { + RequestStatus.FINISHED_STOPPED: FinishReason.STOP, + RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, + RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, + RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, +} diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e7683c151598a9f5317f12c93e204ecd21413b1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/metadata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/metadata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3658a6631dc04da622618de15c188352232940d4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/metadata.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60a763e4dff05ced5882b81c4f54e5407853afc0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/__pycache__/sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/metadata.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..8e54de34548ddfe8a631d78979b91c83dffc2e9f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/metadata.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Dict, List, Optional, Set + +import torch + + +@dataclass +class SamplingMetadata: + + temperature: torch.Tensor + all_greedy: bool + all_random: bool + + top_p: torch.Tensor + top_k: torch.Tensor + no_top_p: bool + no_top_k: bool + + generators: Dict[int, torch.Generator] + + max_num_logprobs: int + + no_penalties: bool + prompt_token_ids: Optional[torch.Tensor] + frequency_penalties: torch.Tensor + presence_penalties: torch.Tensor + repetition_penalties: torch.Tensor + + output_token_ids: List[List[int]] + min_tokens: List[int] + stop_token_ids: List[Set[int]] diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99b6c0884ce0c82ad1f707dc07edccddef4aa63f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/penalties.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/penalties.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1aa42895152c40804c6dc507e1fa412ccb5f4c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/penalties.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/topk_topp_sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/topk_topp_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c87dec7d51f69cb463730b59c4217506cbc7b08 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/__pycache__/topk_topp_sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/penalties.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/penalties.py new file mode 100644 index 0000000000000000000000000000000000000000..ba368b44ab9cc02c8cb281049f84bfd66705081a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/penalties.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Set, Tuple + +import torch + +from vllm.model_executor.layers.utils import apply_penalties +from vllm.utils import is_pin_memory_available, make_tensor_with_pad + + +def apply_min_token_penalties(logits: torch.Tensor, + output_token_ids: List[List[int]], + stop_token_ids: List[Set[int]], + min_tokens: List[int]) -> None: + """ + Applies minimum token penalty by setting the logits of the stop tokens + to -inf. + """ + min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] + for index, min_token in enumerate(min_tokens): + if len(output_token_ids[index]) < min_token: + for stop_token_id in stop_token_ids[index]: + min_tokens_logits_to_penalize.append((index, stop_token_id)) + if min_tokens_logits_to_penalize: + logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") + + +def apply_all_penalties( + logits: torch.Tensor, + prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: List[List[int]], +) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) + + +def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, + device: torch.device) -> torch.Tensor: + """ + Convert the different list data structures to tensors. + """ + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=is_pin_memory_available(), + ) + return output_tokens_tensor.to(device, non_blocking=True) diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/topk_topp_sampler.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/topk_topp_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..27431001e3e7a2f0c78baa9d6c20900bb107fe3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/ops/topk_topp_sampler.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict + +import torch +import torch.nn as nn + +from vllm import envs +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +try: + import flashinfer.sampling + is_flashinfer_available = True +except ImportError: + is_flashinfer_available = False + + +class TopKTopPSampler(nn.Module): + + def __init__(self): + super().__init__() + if current_platform.is_cuda: + if is_flashinfer_available: + if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + # default it is unused). For backward compatibility, we set + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and + # interpret it differently in V0 and V1 samplers: In V0, + # None means False, while in V1, None means True. This is + # why we use the condition + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. + logger.info("Using FlashInfer for top-p & top-k sampling.") + self.forward = self.forward_cuda + else: + logger.warning( + "FlashInfer is available, but it is not enabled. " + "Falling back to the PyTorch-native implementation of " + "top-p & top-k sampling. For the best performance, " + "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + self.forward = self.forward_native + else: + logger.warning( + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of top-p & top-k sampling. For the " + "best performance, please install FlashInfer.") + self.forward = self.forward_native + else: + self.forward = self.forward_native + + def forward_native( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation of top-k and top-p sampling.""" + logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + def forward_cuda( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + """More optimized implementation for top-k and top-p sampling.""" + probs = logits.softmax(dim=-1, dtype=torch.float32) + if no_top_k and no_top_p: + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + return random_sample(probs, generators) + return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) + + +def apply_top_k_top_p( + logits: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + This function sorts the logits tensor, which can be slow for large batches. + """ + if no_top_k and no_top_p: + return logits + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if not no_top_k: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if not no_top_p: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def random_sample( + probs: torch.Tensor, + generators: Dict[int, torch.Generator], +) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-GPU synchronization. + """ + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + if len(generators) != probs.shape[0]: + q.exponential_() + if generators: + # TODO(woosuk): This can be slow because we handle each request + # one by one. Optimize this. + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def flashinfer_sample( + probs: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + generators: Dict[int, torch.Generator], +) -> torch.Tensor: + """Sample from the probabilities using FlashInfer. + + Statistically, this function is equivalent to the `random_sample` function. + However, this function is faster because it avoids sorting the logits tensor + via rejection sampling. + + NOTE: The outputs of this function do not necessarily match the outputs of + the `random_sample` function. It only guarantees that the outputs are + statistically equivalent. + + NOTE: This function includes CPU-GPU synchronization, while `random_sample` + does not. Call this function at the end of the forward pass to minimize + the synchronization overhead. + """ + assert not (no_top_k and no_top_p) + max_top_k_round = 32 + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if len(generators) != batch_size: + uniform_samples.uniform_() + if generators: + for i, generator in generators.items(): + uniform_samples[:, i].uniform_(generator=generator) + + if no_top_k: + # Top-p only. + next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( + probs, uniform_samples, p, deterministic=True) + elif no_top_p: + # Top-k only. + next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( + probs, uniform_samples, k, deterministic=True) + else: + # Both top-k and top-p. + next_token_ids, success = ( + flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, uniform_samples, k, p, deterministic=True)) + + # NOTE: CPU-GPU synchronization happens here. + if not success.all(): + if not no_top_k: + probs = flashinfer.sampling.top_k_renorm_prob(probs, k) + if not no_top_p: + probs = flashinfer.sampling.top_p_renorm_prob(probs, p) + next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0], deterministic=True) + return next_token_ids.view(-1) diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/sample/sampler.py b/.venv/lib/python3.11/site-packages/vllm/v1/sample/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3da7498e0dae5d671e81b9fcace11ed992a8478d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/sample/sampler.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that samples the next tokens from the model's outputs.""" +from typing import Tuple + +import torch +import torch.nn as nn + +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.penalties import (apply_all_penalties, + apply_min_token_penalties) +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + needs_logprobs = sampling_metadata.max_num_logprobs > 0 + if needs_logprobs: + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # NOTE: We compute logprobs first because the below ops may + # modify the logits tensor in-place (and we don't want to clone + # the logits tensor for memory efficiency). + topk_logprobs, topk_indices = self.get_topk_logprobs( + logits, sampling_metadata) + else: + topk_logprobs = None + topk_indices = None + + # Use float32 for the logits. + logits = logits.to(torch.float32) + # Apply penalties (e.g., min_tokens, freq_penalties). + logits = self.apply_penalties(logits, sampling_metadata) + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) + + sampler_output = SamplerOutput( + sampled_token_ids=sampled, + logprob_token_ids=topk_indices, + logprobs=topk_logprobs, + prompt_logprob_token_ids=None, + prompt_logprobs=None, + ) + return sampler_output + + def apply_temperature( + self, + logits: torch.Tensor, + temp: torch.Tensor, + ) -> torch.Tensor: + # Avoid division by zero. + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) + # Use in-place division to avoid creating a new tensor. + logits.div_(temp.unsqueeze(dim=1)) + return logits + + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + assert not (sampling_metadata.all_greedy + and sampling_metadata.all_random) + if sampling_metadata.all_greedy: + return self.greedy_sample(logits) + + random_sampled = self.topk_topp_sampler( + logits, + sampling_metadata.generators, + sampling_metadata.no_top_k, + sampling_metadata.top_k, + sampling_metadata.no_top_p, + sampling_metadata.top_p, + ) + if sampling_metadata.all_random: + return random_sampled + + greedy_sampled = self.greedy_sample(logits) + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + ) + return sampled + + def get_topk_logprobs( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) + # FIXME: Mask the sampled token_id, get topk logprobs, + # and concatenate the topk with the sampled token_id. + topk_logprobs, topk_indices = torch.topk( + logprobs, sampling_metadata.max_num_logprobs, dim=-1) + # Use int32 to reduce the tensor size. + topk_indices = topk_indices.to(torch.int32) + return topk_logprobs, topk_indices + + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + apply_min_token_penalties(logits, sampling_metadata.output_token_ids, + sampling_metadata.stop_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + logits = apply_all_penalties( + logits, sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids) + return logits diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/serial_utils.py b/.venv/lib/python3.11/site-packages/vllm/v1/serial_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1791dfa2b6325f2f41c34cd68fa86152aa9a7c06 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/serial_utils.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pickle + + +class PickleEncoder: + + def encode(self, obj): + return pickle.dumps(obj) + + def decode(self, data): + return pickle.loads(data) diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a652930fd79c2041e96cfd05393790bea4d0992 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69377f715885e9b6b39d56cf04dd5286c7b8509b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/v1/stats/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/stats/common.py b/.venv/lib/python3.11/site-packages/vllm/v1/stats/common.py new file mode 100644 index 0000000000000000000000000000000000000000..09d382638bffd881c9dbe3ef5ec5a55c6fb17d7d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/stats/common.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from dataclasses import dataclass +from dataclasses import field as dataclass_field +from enum import IntEnum +from typing import ClassVar, Dict, List, Optional, Set + +import msgspec +from msgspec import field as msgspec_field + +from vllm.sampling_params import SamplingParams + + +class RequestStatsUpdate( + msgspec.Struct, # type: ignore + array_like=True, + omit_defaults=True, + gc=False): + """ + An update to the request stats. + + This represents a stats update at a specific timestamp with metadata + associated with the update. + + NOTE: since there might be multiple processes generating updates at + different parts of the engine (e.g. input processor, scheduler, engine core, + etc.), we use the monotonic timestamp to record the update to compute any + intervals, and explicit wall-clock timestamp should be used for timestamps. + + WARNING: This assumes stats are generated in a single machine. If there are + potentially multiple machines, one should always generate the stats updates + on one single machine or use something else. + """ + + class Type(IntEnum): + """See `RequestStats` for the lifecycle of a request.""" + + # Request arrived at the engine frontend. + ARRIVED = 0 + # Input processed by the input processor. + INPUT_PROCESSED = 1 + # Queued on the engine core. + QUEUED = 2 + # Scheduled running prefill by the scheduler. + # A request could be running a new prefill on the prompt tokens or + # a resumed prefill on the original prefill tokens + generated output + # tokens before preemption. + PREFILLING = 3 + # Preempted by the scheduler. + PREEMPTED = 4 + # Output token is generated by the engine core. + DECODING = 5 + # Token detokenized by the detokenizer. + # We will record the timestamp for each output token, as well as the + # finish reason. + DETOKENIZED = 6 + # Request finishes (or aborts). + FINISHED = 7 + + """ + Valid state updates: + ARRIVED + │ + ├──────► INPUT_PROCESSED ──────► QUEUED ──────► PREFILLING ◄────┐ + │ │ │ │ │ + │ │ │ ▼ │ + │ │ │ -──► DECODING │ + │ │ │ | │ │ + │ │ │ | ▼ │ + │ │ │ └─ DETOKENIZED │ + │ │ │ │ │ + │ │ │ ▼ │ + │ ▼ ▼ PREEMPTED ◄──────┘ + │ │ │ │ + └──────────────┴───────────────────┴──────────────┴ + │ + ▼ + FINISHED (All could go to FINISHED) + """ + _VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = { + Type.ARRIVED: { + Type.INPUT_PROCESSED, + Type.FINISHED, + }, + Type.INPUT_PROCESSED: { + Type.QUEUED, + Type.FINISHED, + }, + Type.QUEUED: { + Type.PREFILLING, + Type.FINISHED, + }, + Type.PREFILLING: { + Type.DECODING, + Type.PREEMPTED, + Type.FINISHED, + }, + Type.DECODING: { + Type.DETOKENIZED, + Type.FINISHED, + }, + Type.DETOKENIZED: { + Type.DECODING, + Type.PREEMPTED, + Type.FINISHED, + }, + Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED}, + Type.FINISHED: set(), + } + + request_id: str + + type: Type + + # Timestamp when the update is recorded. This is used to record time + # intervals between events rather than wall clock time. + monotonic_ts_s: float = msgspec_field( + default_factory=lambda: time.monotonic()) + + ############################################################ + # Metadata associated with the update. + ############################################################ + # For input_processed. Metadata needed for stats logging. + num_prompt_tokens: Optional[int] = None + sampling_params: Optional[SamplingParams] = None + + # For running. + # Number of tokens computed when scheduled to run. + num_computed_tokens: Optional[int] = None + # Number of cached tokens when scheduled to run. + num_cached_tokens: Optional[int] = None + + # For decoded. + # The number of new output tokens generated. + num_new_tokens: Optional[int] = None + + # For both detokenized and decoded. + # Finished reason. + finish_reason: Optional[str] = None + + # Non-optional fields for each update type. + _REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = { + Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"], + Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"], + Type.DETOKENIZED: ["num_new_tokens"], + Type.FINISHED: ["finish_reason"], + } + + def __post_init__(self): + required_fields = self._REQUIRED_FIELDS.get(self.type, []) + for field in required_fields: + if getattr(self, field) is None: + raise ValueError( + f"Field {field} is required for update type {self.type}.") + + @staticmethod + def check_valid_update( + update: "RequestStatsUpdate", + last_update_type: Optional[Type], + last_updated_ts_s: Optional[float], + ): + if last_update_type is None: + assert update.type == RequestStatsUpdate.Type.ARRIVED + else: + valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[ + last_update_type] + assert update.type in valid_cur_update_types, ( + f"Invalid update type: {update.type} for last_update_type: " + f"{last_update_type}.") + + if last_updated_ts_s is not None: + assert update.monotonic_ts_s >= last_updated_ts_s, ( + "Update timestamp must be monotonically increasing, but " + f"last_updated_ts_s={last_updated_ts_s} and " + f"update.monotonic_ts_s={update.monotonic_ts_s}.") + + +@dataclass +class RequestStats: + """Stats associated with a request (`Request`).""" + + ############################################################ + # Metadata + ############################################################ + request_id: str + sampling_params: Optional[SamplingParams] = None + num_prompt_tokens: Optional[int] = None + + ############################################################ + # Metrics and Stats + ############################################################ + # Timestamp when the request was last updated. + last_updated_ts_s: Optional[float] = None + + # Last update stats type. + last_update_type: Optional[RequestStatsUpdate.Type] = None + + # Timestamp when the request arrived at the llm engine. + arrival_ts_s: Optional[float] = None + + # Number of tokens cached. When part of the request prefix is cached, + # this will be set. + num_cached_tokens: int = 0 + + # Number of tokens computed. + num_computed_tokens: int = 0 + + # The timestamp when the request become waiting in the queue. + queued_ts_s: Optional[float] = None + + # When the input processor is completed. + input_processor_end_ts_s: Optional[float] = None + + # A sorted list of timestamps when the request was scheduled to prefill. + # This could be when: + # 1. the request is newly scheduled, so it's a new prefill. + # 2. the request was preempted and resumed. It is equivalent to running + # a prefill of the original prefill tokens + generated output tokens + # before preemption. + prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list) + + # A list of timestamps when a token is decoded by the engine core. + decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list) + + # A sorted list of timestamps for each output token. + output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list) + + # First token's timestamp. + first_token_ts_s: Optional[float] = None + + # TODO(rickyx): we need model runner to surface these. + model_forward_duration_s: float = 0.0 + # Includes model forward, block/sync across workers, cpu-gpu sync time + # and sampling time. + model_execute_duration_s: float = 0.0 + + # A sorted list of timestamps when the request was preempted at the + # scheduler. + # TODO(rickyx): right now, we don't actually have a good high-level + # metric to measure the impact of preemption other than observation of + # large P99 TPOT. Ideally we could quantify the impact of preemption by + # measuring the number of tokens re-computed due to preemption. + preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list) + + # Timestamp when the request was finished at the engine core. + finished_ts_s: Optional[float] = None + + # Finish reason. + finish_reason: Optional[str] = None + + ############################################################ + # Derived properties. + ############################################################ + @property + def prefill_ts_s(self) -> Optional[float]: + """The timestamp when the request started prefilling. + Since a request could be preempted in decoding and later resumed + to prefill the decoded tokens, we use the first prefill start timestamp. + """ + return (self.prefill_start_ts_s_lst[0] + if self.prefill_start_ts_s_lst else None) + + @property + def e2e_latency_s(self) -> Optional[float]: + if self.finished_ts_s is None or self.arrival_ts_s is None: + return None + assert self.finished_ts_s >= self.arrival_ts_s + return self.finished_ts_s - self.arrival_ts_s + + @property + def queue_duration_s(self) -> Optional[float]: + """How long the request was waiting to run.""" + if self.queued_ts_s is None or self.prefill_ts_s is None: + # Either not queued or not running yet. + return None + assert self.queued_ts_s <= self.prefill_ts_s + return self.prefill_ts_s - self.queued_ts_s + + @property + def inference_latency_s(self) -> Optional[float]: + """How long the request was running inference + (prefill and decode).""" + if self.finished_ts_s is None or self.prefill_ts_s is None: + return None + assert self.finished_ts_s >= self.prefill_ts_s + return self.finished_ts_s - self.prefill_ts_s + + @property + def first_token_latency_s(self) -> Optional[float]: + if self.first_token_ts_s is None or self.arrival_ts_s is None: + return None + assert self.first_token_ts_s >= self.arrival_ts_s + return self.first_token_ts_s - self.arrival_ts_s + + @property + def prefill_latency_s(self) -> Optional[float]: + if self.first_token_ts_s is None or self.prefill_ts_s is None: + return None + assert self.first_token_ts_s >= self.prefill_ts_s + return self.first_token_ts_s - self.prefill_ts_s + + @property + def decode_latency_s(self) -> Optional[float]: + if self.e2e_latency_s is None or self.first_token_latency_s is None: + return None + assert self.e2e_latency_s >= self.first_token_latency_s + return self.e2e_latency_s - self.first_token_latency_s + + @property + def output_token_latency_s_lst(self) -> List[float]: + if len(self.output_token_ts_s_lst) == 0: + return [] + latency_s_lst = [] + for i in range(1, len(self.output_token_ts_s_lst)): + assert (self.output_token_ts_s_lst[i] + >= self.output_token_ts_s_lst[i - 1]) + latency_s = (self.output_token_ts_s_lst[i] - + self.output_token_ts_s_lst[i - 1]) + latency_s_lst.append(latency_s) + return latency_s_lst + + @property + def num_output_tokens(self) -> int: + return len(self.output_token_ts_s_lst) + + @property + def is_finished(self) -> bool: + return self.finished_ts_s is not None + + def update_from(self, update: "RequestStatsUpdate"): + RequestStatsUpdate.check_valid_update(update, self.last_update_type, + self.last_updated_ts_s) + ts = update.monotonic_ts_s + self.last_updated_ts_s = ts + self.last_update_type = update.type + if update.type == RequestStatsUpdate.Type.ARRIVED: + self.arrival_ts_s = ts + elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED: + self.input_processor_end_ts_s = ts + self.sampling_params = update.sampling_params + self.num_prompt_tokens = update.num_prompt_tokens + elif update.type == RequestStatsUpdate.Type.QUEUED: + self.queued_ts_s = ts + elif update.type == RequestStatsUpdate.Type.PREFILLING: + self.prefill_start_ts_s_lst.append(ts) + self.num_cached_tokens = update.num_cached_tokens or 0 + self.num_computed_tokens = update.num_computed_tokens or 0 + elif update.type == RequestStatsUpdate.Type.PREEMPTED: + self._reset_for_preemption(ts) + elif update.type == RequestStatsUpdate.Type.DECODING: + self.decoding_ts_s_lst.append(ts) + elif update.type == RequestStatsUpdate.Type.DETOKENIZED: + self._record_detokenized_output( + ts, + update.num_new_tokens or 0, + ) + elif update.type == RequestStatsUpdate.Type.FINISHED: + self.finished_ts_s = ts + self.finish_reason = update.finish_reason + else: + raise ValueError(f"Unknown update type: {update.type}") + + def _record_detokenized_output( + self, + ts_s: float, + num_new_tokens: int, + ): + # Update if first output token is generated. + if len(self.output_token_ts_s_lst) == 0: + self.first_token_ts_s = ts_s + assert ( + self.prefill_ts_s is not None + ), "Request must be running before generating output tokens." + + # Some X new tokens were generated at the ts. + self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens) + + def _reset_for_preemption(self, ts_s: float): + self.preempted_ts_s_lst.append(ts_s) + # Reset the computed tokens since it might restart the prefill. + self.num_computed_tokens = 0 + # Cached token count might also change when resumed. + self.num_cached_tokens = 0 + # These stats don't change since they happen before request running. + # - arrival_ts_s + # - input_processor_end_ts_s + # - sampling_params + # - num_prompt_tokens + # - first_token_ts_s + # + # These stats are accumulated over preemptions: + # - output_token_ts_s_lst + # - prefill_start_ts_s_lst (after preemption, it will prefill the + # original prefill tokens and any output tokens generated before + # preemption.) + + +@dataclass +class KVCacheStats: + # KV Cache Usage in % + gpu_cache_usage_sys: float = 0.0 + gpu_prefix_cache_hit_rate: float = 0.0 + + +@dataclass +class SchedulerStats: + """Stats associated with the scheduler.""" + + # Number of requests currently running. + num_running_reqs: int = 0 + # Number of requests currently waiting. + num_waiting_reqs: int = 0 + + kv_cache_stats: KVCacheStats = dataclass_field( + default_factory=KVCacheStats) + + +@dataclass +class EngineCoreProcessStats: + """Stats associated with the engine core process.""" + + # Number of requests currently in the input queue. None if the engine core + # is not running in multiprocess mode. + input_queue_size: Optional[int] = None + # Number of outputs currently in the output queue. None if the engine core + # is not running in multiprocess mode. + output_queue_size: Optional[int] = None + + +class EngineCoreStatsSnapshot( + msgspec.Struct, # type: ignore + array_like=True, + omit_defaults=True, + gc=False): + """ + A snapshot of the EngineCore's current stats over a period of time. + """ + + # Snapshot of the scheduler stats. + scheduler_stats: SchedulerStats = msgspec_field( + default_factory=SchedulerStats) + + # Per request stats updates. + requests_stats_updates: List[RequestStatsUpdate] = msgspec_field( + default_factory=list) + + # Engine core's queue stats. + engine_core_process_stats: EngineCoreProcessStats = msgspec_field( + default_factory=EngineCoreProcessStats) + + # TODO(rickyx): Add other components' stats, + # e.g. model runner/worker and etc. diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/utils.py b/.venv/lib/python3.11/site-packages/vllm/v1/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5494542c181d7843db9cbdf9051a1ad55229ae9f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/utils.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import os +import weakref +from collections import defaultdict +from collections.abc import Sequence +from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List, + Optional, TypeVar, Union, overload) + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index +from vllm.utils import get_mp_context, kill_process_tree + +if TYPE_CHECKING: + from vllm.attention.layer import Attention + +logger = init_logger(__name__) + +T = TypeVar("T") + + +class ConstantList(Generic[T], Sequence): + + def __init__(self, x: List[T]) -> None: + self._x = x + + def append(self, item): + raise Exception("Cannot append to a constant list") + + def extend(self, item): + raise Exception("Cannot extend a constant list") + + def insert(self, item): + raise Exception("Cannot insert into a constant list") + + def pop(self, item): + raise Exception("Cannot pop from a constant list") + + def remove(self, item): + raise Exception("Cannot remove from a constant list") + + def clear(self): + raise Exception("Cannot clear a constant list") + + def index(self, + item: T, + start: int = 0, + stop: Optional[int] = None) -> int: + return self._x.index(item, start, + stop if stop is not None else len(self._x)) + + @overload + def __getitem__(self, item: int) -> T: + ... + + @overload + def __getitem__(self, s: slice, /) -> List[T]: + ... + + def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]: + return self._x[item] + + @overload + def __setitem__(self, item: int, value: T): + ... + + @overload + def __setitem__(self, s: slice, value: T, /): + ... + + def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]): + raise Exception("Cannot set item in a constant list") + + def __delitem__(self, item): + raise Exception("Cannot delete item from a constant list") + + def __iter__(self): + return iter(self._x) + + def __contains__(self, item): + return item in self._x + + def __len__(self): + return len(self._x) + + +class BackgroundProcHandle: + """ + Utility class to handle creation, readiness, and shutdown + of background processes used by the AsyncLLM and LLMEngine. + """ + + def __init__( + self, + input_path: str, + output_path: str, + process_name: str, + target_fn: Callable, + process_kwargs: Dict[Any, Any], + ): + context = get_mp_context() + reader, writer = context.Pipe(duplex=False) + + assert ("ready_pipe" not in process_kwargs + and "input_path" not in process_kwargs + and "output_path" not in process_kwargs) + process_kwargs["ready_pipe"] = writer + process_kwargs["input_path"] = input_path + process_kwargs["output_path"] = output_path + + # Run busy loop in background process. + self.proc = context.Process(target=target_fn, kwargs=process_kwargs) + self._finalizer = weakref.finalize(self, shutdown, self.proc, + input_path, output_path) + self.proc.start() + + # Wait for startup. + if reader.recv()["status"] != "READY": + raise RuntimeError(f"{process_name} initialization failed. " + "See root cause above.") + + def shutdown(self): + self._finalizer() + + +# Note(rob): shutdown function cannot be a bound method, +# else the gc cannot collect the object. +def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): + # Shutdown the process. + if proc.is_alive(): + proc.terminate() + proc.join(5) + + if proc.is_alive(): + kill_process_tree(proc.pid) + + # Remove zmq ipc socket files. + ipc_sockets = [output_path, input_path] + for ipc_socket in ipc_sockets: + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file) + + +def bind_kv_cache( + kv_caches: Dict[str, torch.Tensor], + forward_context: Dict[str, "Attention"], + runner_kv_caches: List[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/worker/__init__.py b/.venv/lib/python3.11/site-packages/vllm/v1/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/worker/block_table.py b/.venv/lib/python3.11/site-packages/vllm/v1/worker/block_table.py new file mode 100644 index 0000000000000000000000000000000000000000..f520ee9586c5c909a303d9470c8f16627480e68b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/worker/block_table.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import numpy as np +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BlockTable: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.pin_memory = pin_memory + self.device = device + + self.block_table = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + def append_row( + self, + row_idx: int, + start: int, + block_ids: List[int], + ) -> None: + if not block_ids: + return + num_blocks = len(block_ids) + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] = start + num_blocks + + def add_row(self, row_idx: int, block_ids: List[int]) -> None: + self.append_row(row_idx, 0, block_ids) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + def commit(self, num_reqs: int) -> None: + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + def get_device_tensor(self) -> torch.Tensor: + """Ruturns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np diff --git a/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_input_batch.py b/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..39708f833fd58340a160eef20150c977ce17506f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,440 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set + +import numpy as np +import torch + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + self.vocab_size = vocab_size + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + # TODO(woosuk): This buffer could be too large if max_model_len is big. + # Find a way to reduce the CPU memory usage. + # This buffer is not directly transferred to the GPU, so it does not + # need to be pinned. + self.token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + pin_memory=False, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Block table. + self.block_table = BlockTable( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_blocks_per_req=max_num_blocks_per_req, + pin_memory=pin_memory, + device=device, + ) + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # Frequency penalty related data structures + self.frequency_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: Set[str] = set() + + # Presence penalty related data structures + self.presence_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = \ + self.presence_penalties_cpu_tensor.numpy() + self.presence_penalties_reqs: Set[str] = set() + + # Repetition penalty related data structures + self.repetition_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu = \ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: Set[str] = set() + + self.min_tokens: List[int] = [0] * max_num_reqs + self.stop_token_ids: List[Set[int]] = [ + set() for _ in range(max_num_reqs) + ] + self.prompt_token_ids: Optional[torch.Tensor] = None + + # req_index -> generator + # NOTE(woosuk): The indices of the requests that do not have their own + # generator should not be included in the dictionary. + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.num_prompt_tokens[req_index] = num_prompt_tokens + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + self.num_tokens[req_index] = request.num_tokens + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + self.block_table.add_row(req_index, request.block_ids) + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + self.frequency_penalties_cpu[req_index] = \ + sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[req_index] = \ + sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[req_index] = \ + sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + self.min_tokens[req_index] = sampling_params.min_tokens + self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.frequency_penalties_reqs.clear() + self.presence_penalties_reqs.clear() + self.repetition_penalties_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + assert req_id is not None + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + num_tokens = self.num_tokens[last_req_index] + self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ + last_req_index, :num_tokens] + self.num_tokens[empty_index] = num_tokens + self.num_prompt_tokens[empty_index] = \ + self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table.move_row(last_req_index, empty_index) + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = \ + self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[empty_index] = \ + self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[empty_index] = \ + self.repetition_penalties_cpu[last_req_index] + self.min_tokens[empty_index] = self.min_tokens[last_req_index] + self.stop_token_ids[empty_index] = \ + self.stop_token_ids[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + req_id_output_token_ids: Dict[str, List[int]], + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + self.frequency_penalties[:self.num_reqs].copy_( + self.frequency_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.presence_penalties[:self.num_reqs].copy_( + self.presence_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.repetition_penalties[:self.num_reqs].copy_( + self.repetition_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + self.prompt_token_ids = self._make_prompt_token_ids_tensor() + + output_token_ids: List[List[int]] = [] + + for req_id in self.req_ids[:self.num_reqs]: + assert req_id is not None + # Currently we create a tensor for output_token_ids from scratch + # at each step. However, for the penalties computation what we + # need is stats about the token ids present in the output. This + # stats can be maintained incrementally instead of computing it + # from scratch at each step. + # TODO - Replace this with incremental update to output token + # statistics. + output_token_ids.append(req_id_output_token_ids[req_id]) + + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=self.prompt_token_ids, + frequency_penalties=self.frequency_penalties[:self.num_reqs], + presence_penalties=self.presence_penalties[:self.num_reqs], + repetition_penalties=self.repetition_penalties[:self.num_reqs], + output_token_ids=output_token_ids, + min_tokens=self.min_tokens[:self.num_reqs], + stop_token_ids=self.stop_token_ids[:self.num_reqs], + no_penalties=self.no_penalties, + ) + + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: + max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + prompt_token_ids_cpu_tensor = torch.empty( + (self.num_reqs, max_prompt_len), + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory) + prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() + prompt_token_ids[:] = ( + self.token_ids_cpu[:self.num_reqs, :max_prompt_len]) + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + for i in range(self.num_reqs): + prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, + non_blocking=True) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def no_penalties(self) -> bool: + return (len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0) + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0