diff --git a/.gitattributes b/.gitattributes index 9c58da6a148b620692548959adc7c0b2e901d6c0..923f17182cdf6ec245890aee0aef60c32958bf7b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -108,3 +108,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/pillow.libs/liblcms2-e69eef39.so.2.0.16 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pillow.libs/libopenjp2-05423b53.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pillow.libs/libxcb-b8a56d01.so.1.1.0 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_model_runner.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b979b731e9d6f71187d5fb3226db1ccec4e92837 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd468bdfc07c7557d22c23e06c689556a0973fbc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42d2ce3a2f7be780dd93e62f5558bc01ca1c32f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e5325942610da35ddce14175d46e9458a81ed2e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e99a634db908ad3776ff4a982c94ca7608d8565 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f909164fcb55f74cf8190b9b66d2993468e18e5e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2ea1f73e319ba09a55ae6bffba50235bb5cd3ed Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c3afa64b96cc7216ecd9c8eed5f41336006a39 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .abstract_tool_parser import ToolParser, ToolParserManager +from .granite_20b_fc_tool_parser import Granite20bFCToolParser +from .granite_tool_parser import GraniteToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser +from .jamba_tool_parser import JambaToolParser +from .llama_tool_parser import Llama3JsonToolParser +from .mistral_tool_parser import MistralToolParser +from .pythonic_tool_parser import PythonicToolParser + +__all__ = [ + "ToolParser", "ToolParserManager", "Granite20bFCToolParser", + "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", + "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", + "PythonicToolParser" +] diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f0b9791fc2f97a1ce6f2dfeb54c111c4a4f0a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfcac8774f1b263654f7660517d2c0e82d148887 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dcc6cd79a1a43bc09dd441c7264f2b4f5a0f8e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f511323e4cbbcfd918de9d83814b38b4ebbcba7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8851377443fd3563dd59887fda4a0903be27a086 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e3f715d8ebdad0e539306ad0b0e0cfd01c74ae5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d5860ecdebd7fa89a594f1ffd6e3442dcfc8cbe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c5e2579d3e6e91bf2e97f12e2f96a719dc51489 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c38ad3f201165870a6adc6f0fe9d59aa33eeefbf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c66c2448736d2f2c2fbfc33bd605169678b9f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc6a7b9b9d1ff968428046e950112d2f9635aca7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdd6d4c4f2ba69d1616caa4b1e4cbd174a08f39 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from functools import cached_property +from typing import Callable, Dict, List, Optional, Sequence, Type, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import import_from_path, is_list_of + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [] + + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> Dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") + + +class ToolParserManager: + tool_parsers: Dict[str, Type] = {} + + @classmethod + def get_tool_parser(cls, name) -> Type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, plugin_path) + return diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4d63e18865376339d023ca5d91abb0d648b8ce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers import MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("jamba") +class JambaToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + raise ValueError( + "Detected a MistralTokenizer tokenizer when using a Jamba model" + ) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "" + self.tool_calls_end_token: str = "" + + self.tool_calls_regex = re.compile( + rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", + re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "Jamba Tool parser could not locate tool calls start/end " + "tokens in the tokenizer!") + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because jamba use the special + # tokens to indicate the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # use a regex to find the tool call between the tags + function_calls = self.tool_calls_regex.findall(model_output)[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = json.loads(function_calls) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if + (len(content) > 0 and content != " ") else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.tool_calls_start_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the start of tool calls token which means + # the start of tool calling + if (self.tool_calls_start_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion and don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # Extract the tool calls between the special tool call tokens + parsable_arr = current_text.split( + self.tool_calls_start_token)[-1].split( + self.tool_calls_end_token)[0] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7b113623e6515d3a31307f1ec42c3451a5fec3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from json import JSONDecoder +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("llama3_json") +class Llama3JsonToolParser(ToolParser): + """ + Tool call parser for Llama 3.1 models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser llama3_json + are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "<|python_tag|>" + self.bot_token_id = tokenizer.encode(self.bot_token, + add_special_tokens=False)[0] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + # case -- if a tool call token is not present, return a text response + if not (model_output.startswith(self.bot_token) + or model_output.startswith('{')): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # load the JSON, and then use it to build the Function and + # Tool Call + dec = JSONDecoder() + function_call_arr = [] + + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if model_output.startswith( + self.bot_token) else 0 + while start_idx < len(model_output): + (obj, end_idx) = dec.raw_decode(model_output[start_idx:]) + start_idx += end_idx + len('; ') + function_call_arr.append(obj) + + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"] \ + if "arguments" in raw_function_call \ + else raw_function_call["parameters"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not (current_text.startswith(self.bot_token) + or current_text.startswith('{')): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if current_text.startswith( + self.bot_token) else 0 + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + len('; ') + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert "arguments" not in obj, \ + "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..51354f7c9562355c9791cee5bc695097646f9984 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from random import choices +from string import ascii_letters, digits +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow +from pydantic import Field + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + +ALPHANUMERIC = ascii_letters + digits + + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + + +@ToolParserManager.register_module("mistral") +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not isinstance(self.model_tokenizer, MistralTokenizer): + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.vocab.get(self.bot_token) + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + if self.bot_token_id is None: + raise RuntimeError( + "Mistral Tool Parser could not locate the tool call token in " + "the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # first remove the BOT token + tool_content = model_output.replace(self.bot_token, "").strip() + + try: + + # we first try to directly load the json as parsing very nested + # jsons is difficult + try: + function_call_arr = json.loads(tool_content) + except json.JSONDecodeError: + # use a regex to find the part corresponding to the tool call. + # NOTE: This use case should not happen if the model is trained + # correctly. It's a easy possible fix so it's included, but + # can be brittle for very complex / highly nested tool calls + raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + function_call_arr = json.loads(raw_tool_call) + + # Tool Call + tool_calls: List[MistralToolCall] = [ + MistralToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"], + ensure_ascii=False))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=tool_content) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[-1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff, ensure_ascii=False).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + if ('"}' in new_text): + new_text = new_text[:new_text.rindex('"}')] + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False)[:-2] + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + if (new_text not in cur_arguments_json): + return None + arguments_delta = cur_arguments_json[:cur_arguments_json. + rindex(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..945cbd6835028b3d73297b20cdd32b69f9584112 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from json import JSONDecodeError, JSONDecoder +from typing import Any, List, Tuple + +import partial_json_parser +from partial_json_parser.core.options import Allow + + +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string: str, substring: str) -> List[int]: + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices + + +# partial_json_parser doesn't support extra data and +# JSONDecorder.raw_decode doesn't support partial JSON +def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +def consume_space(i: int, s: str) -> int: + while i < len(s) and s[i].isspace(): + i += 1 + return i diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__init__.py b/.venv/lib/python3.11/site-packages/vllm/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d44ebcdaedd32e6a4902595b0899c65f1d81d31 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..122509e0e8f7e03b37fc9f67ad9d6cc55ce99a5e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e0824518e32e09f58ea1e63b7e731ba5a2651e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..990b50ad819cc0118dd78b8a4b2c904427d598a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87b346778b2cc39295cafdf39db07ce6017ddc2a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5be6c795c3b08bd7ce5269d12950eaf3091cc18 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cede30d12a4b9d21dbf4863e7cc8f0b7f3c95282 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a6b941bcf90d4a4b2c25e273a24ab7f312ac354 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2480db478494c50173aa78b244e86890c7a83b6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py b/.venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6620817b4bb4f78d46d91e64aef231522fda4f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA) + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec + + +def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + buffers = tensor_model_parallel_all_gather(buffers) + layer.punica_wrapper.add_expand(output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from MergedQKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/layers.py b/.venv/lib/python3.11/site-packages/vllm/lora/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0297596ccbf236266f6e9daedfe1666dcf0d90 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/layers.py @@ -0,0 +1,1206 @@ +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather) +from vllm.distributed.utils import divide +# yapf: disable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + embeddings_indices = self.punica_wrapper.embeddings_indices + indices = embeddings_indices[1].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = embeddings_indices[0].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + self.punica_wrapper.add_lora_embedding(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + + self.output_slices: Tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLora and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + return output + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def forward( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] + return lora_b + + def slice_bias( + self, bias: List[Union[torch.Tensor, + None]]) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def forward( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return (self.base_layer.weight if hasattr(self.base_layer, "weight") + else self.base_layer.qweight) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[List[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.linear_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + + # HPU needs special handling to prune out dummy samples. + if current_platform.is_hpu(): + lora_logits = lora_logits[:logits.shape[0], :] + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False + + +class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + ... + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.punica_wrapper.long_lora_indices, + ) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/lora.py b/.venv/lib/python3.11/site-packages/vllm/lora/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..00299bf6c2a81446e5faa1ddf1ce2868f6b95046 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/lora.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional +from typing import Sequence as GenericSequence + +import torch +import torch.types + +from vllm.lora.peft_helper import PEFTHelper +from vllm.utils import is_pin_memory_available + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.bias = bias + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return self + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def from_config( + cls, + module_name: str, + peft_helper: PEFTHelper, + embeddings_tensor: Optional[torch.Tensor] = None, + ) -> "LoRALayerWeights": + return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, + None, None, embeddings_tensor, + peft_helper.vllm_lora_scaling_factor) + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.types.Device, + embeddings_tensor_dim: Optional[int] = None, + bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and is_pin_memory_available() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + if bias_enabled: + bias = torch.zeros([output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + else: + bias = None + + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + bias=bias, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[Optional[int]], + lora_a: List[Optional[torch.Tensor]], + lora_b: List[Optional[torch.Tensor]], + bias: Optional[List[Optional[torch.Tensor]]] = None, + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + bias=bias, + scaling=scaling, # type: ignore + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack( + cls, loras: GenericSequence[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + [lora.bias if lora is not None else None for lora in loras], + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore + continue + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/models.py b/.venv/lib/python3.11/site-packages/vllm/lora/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ef77fd4b74cecf943590a5d58d633e48a90420cf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/models.py @@ -0,0 +1,763 @@ +# SPDX-License-Identifier: Apache-2.0 + +import copy +import math +import os +import re +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union + +import safetensors.torch +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LoRAMapping) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + is_regex_target_modules, + parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models import SupportsLoRA, supports_multimodal +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.utils import is_pin_memory_available + +logger = init_logger(__name__) + +_GLOBAL_LORA_ID = 0 + + +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: List[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel(AdapterModel): + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, + ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ + self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor + assert ( + lora_model_id + > 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRALayerWeights] = loras + + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + tensors: Dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + tensor_name, weights_mapper) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + assert embedding_modules is not None + embeddings_module = next( + (k for k in embedding_modules if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper, lora_embeddings_tensor) + + if is_bias: + loras[module_name].bias = tensor.to(device=device, + dtype=dtype).t() + bias = tensor.to(device=device, dtype=dtype).t() + if pin_memory: + bias = bias.pin_memory() + loras[module_name].bias = bias + elif is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + assert embedding_padding_modules is not None + if any(name in module_name + for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + + return cls(lora_model_id, + peft_helper.r, + loras, + scaling_factor=peft_helper.vllm_long_context_scaling_factor) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: List[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + + unexpected_modules: List[Union[list[str], str]] + if os.path.isfile(lora_tensor_path): + tensors: Dict[str, torch.Tensor] = {} + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, + framework="pt") as f: # type: ignore + for lora_module in f.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + # Load tensors if there are only expected modules. + for module in f.keys(): # noqa + tensors[module] = f.get_tensor(module) + elif os.path.isfile(lora_bin_file_path): + # When a bin file is provided, we rely on config to find unexpected + # modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + tensors = torch.load(lora_bin_file_path, map_location=device) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device, + weights_only=True) + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) + + +class LoRAModelManager(AdapterModelManager): + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: SupportsLoRA, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + """ + self.lora_config = lora_config + self.device = device + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None + self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device) + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: Dict[float, int] = {} + super().__init__(model) + if hasattr(self.model, "supported_lora_modules"): + self.supported_lora_modules = copy.deepcopy( + self.model.supported_lora_modules) + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = ( + supports_multimodal(self.model) + # In case the model only supports LoRA for + # text modules (e.g. ChatGLM) + and hasattr(self.model, "get_mm_mapping")) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, BaseLayerWithLoRA] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._last_mapping: Optional[LoRAMapping] = None + self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRa' + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + @property + def adapter_slots(self) -> int: + return self.lora_slots + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_adapters: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] + logger.debug("Activating LoRA. int id: %d, slot index: %d", + lora_model.id, index) + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + # Bias is not explicitly enabled with the flag enable_lora_bias. + bias = module_lora.bias + if ((torch.is_tensor(bias) or + (isinstance(bias, Sequence) and any(b is not None + for b in bias))) + and not self.lora_config.bias_enabled): + module_lora.bias = None + raise ValueError( + f"Adapter bias cannot be used for {module_name}" + " without --enable-lora-bias.") + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor, + module_lora.bias) + else: + module.reset_lora(index) + return True + + def _deactivate_adapter(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + + def _add_adapter(self, lora: LoRAModel): + self._create_merged_loras_inplace(lora) + self._registered_adapters[lora.id] = lora + self._set_long_lora_context(lora) + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in LoRAModelManager." + "Use LRUCacheLoRAModelManager for pinning") # type: ignore + + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: + # update lora states + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) + + def remove_all_adapters(self): + """Remove all LoRAModels from the manager.""" + self._registered_adapters.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_adapters.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if isinstance(module, PPMissingLayer): + continue + if not self._match_target_modules(module_name): + continue + # A temporary approach for multimodal models to support LoRA + # TODO: Remove this restriction + if self._filter_unsupported_mm_module(module_name): + logger.warning( + "Regarding multimodal models, vLLM currently only supports " + "adding LoRA to language model, %s will be ignored.", + module_name, + ) + continue + parts = module_name.split(".")[-1] + packed_moduled_lst = self.packed_modules_mapping.get(parts, []) + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + packed_moduled_lst, self.model.config)) + + # LinearScalingRotaryEmbeddingWithLora is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset + # (yard1): TODO make this more robust + if "lm_head" in module_name: + logits_processor_module = self.model.get_submodule( + "logits_processor") + new_module = replace_submodule( + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) + + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error + if self.supports_mm and not isinstance(new_module, + BaseLayerWithLoRA): + continue + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + # All lora layers share the same punica_wrapper based on reference. + new_module.set_mapping(self.punica_wrapper) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora( + self, + lora_id: int, + rank: int, + scaling_factor: Optional[float], + embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}, scaling_factor) + for module_name, module in self.model.named_modules(): + bias_enabled = self.lora_config.bias_enabled + if (not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or isinstance(module, LinearScalingRotaryEmbeddingWithLora) + or self._filter_unsupported_mm_module(module_name)): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + assert embedding_modules is not None + if parts[-1] in embedding_modules: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked[0].dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim, + bias_enabled=bias_enabled) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked[0].shape[-1], + module.lora_b_stacked[0].shape[-2], + rank, + module.lora_a_stacked[0].dtype, + "cpu", + bias_enabled=bias_enabled, + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras: List[Optional[LoRALayerWeights]] = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + bias_enabled=bias_enabled, + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.supported_lora_modules) + + def _filter_unsupported_mm_module(self, module_name: str) -> bool: + """ + Regarding multimodal models, vLLM currently only supports adding LoRA to + language model. LoRA for other modules, such as the vision tower, will + be filtered out. + """ + if self.supports_mm: + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + prefix_lst = module_mapping.connector + module_mapping.tower_model + return any( + [module_name.startswith(prefix) for prefix in prefix_lst]) + return False + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name, []) + # When replacements is less than or equal to 1, it indicates that this + # module is not a packed module. + if len(replacements) <= 1: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras: List[Optional[LoRALayerWeights]] = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): + super().__init__(capacity, deactivate_lora_fn) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__(self, model: nn.Module, max_num_seqs: int, + max_num_batched_tokens: int, vocab_size: int, + lora_config: LoRAConfig, device: torch.device): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, device) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(lora.id) + was_added = False + return was_added + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(lora_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + self._pin_lora_in_cpu_cache(lora_id) + self._pin_lora_in_gpu_cache(lora_id) + return True + + def _pin_lora_in_cpu_cache(self, lora_id: int): + try: + self._registered_adapters.pin(lora_id) + except ValueError as err: + raise ValueError("Pinning failed. " + f"LoRA {lora_id} is not registered.") from err + + def _pin_lora_in_gpu_cache(self, lora_id: int): + if lora_id not in self._active_adapters: + # move lora to gpu if not already active + self.activate_adapter(lora_id) + + self._active_adapters.pin(lora_id) + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not hasattr(model, "supported_lora_modules"): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + device=device, + **kwargs) + return lora_manager diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/__init__.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbb125612e91f88083a5394051bdfbd84ff49afd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85601d58c9d73d4680806d13d0414d238454da10 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..329b393f4dd045255754a188d62e9f4a7f270518 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd40a32f75f84191e409d857cb6a20eb7a34127 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..af79f98415cbc1d83bd0e3446f4596562f776315 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc440f7327fa4316cf5bccd494c57c8c79602786 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401 + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_shrink", +] diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a2dac3352ea1eb3eebe9f0d60c87b3fa4b6b97f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1c008b18af68965ee03aa708883d3d31972a0e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d07a3c9992e5bd0673536b9ee1ef1f7b6b8c69b1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35da0a97f173c99e4f789358f65157080af82e1a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/sgmv_expand.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/sgmv_expand.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef4976a5730cb4a02b2fa9c8fc901c81562497e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/sgmv_expand.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/sgmv_shrink.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/sgmv_shrink.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89e07371c8ea11ef1553fbc813e66e4dada35fcf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/sgmv_shrink.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e4c5cd7b118a34136eef8cebf0995ae118f01e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_expand.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_expand.py new file mode 100644 index 0000000000000000000000000000000000000000..98510b39661a60c9777935fc39b56a248bc25ad1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_expand.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + current_n_c = tl.max_contiguous(current_n, BLOCK_N) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n_c[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=0.0) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + batches (int): batch size + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + batches = lora_indices_tensor.size(0) + config = get_lora_op_configs("expand", batches, N) + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return + + +def bgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="bgmv_expand", + op_func=_bgmv_expand, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_fake, + ) + bgmv_expand = torch.ops.vllm.bgmv_expand + +except AttributeError: + bgmv_expand = _bgmv_expand diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_expand_slice.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_expand_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..48804123c1eae1d3a601bd644025759ebfc76c9f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_expand_slice.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + + slice_offset * cn_stride) + + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + if ADD_INPUTS: + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store + # operation uses the same mask, eliminating the risk of garbage + # values propagating + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=None) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size + batches (int): batch size + add_inputs (bool, optional): Defaults to False. + """ + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + batches = lora_indices_tensor.size(0) + + config = get_lora_op_configs("expand", batches, N) + + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return + + +def bgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="bgmv_expand_slice", + op_func=_bgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_slice_fake, + ) + bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + +except AttributeError: + bgmv_expand_slice = _bgmv_expand_slice diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_shrink.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_shrink.py new file mode 100644 index 0000000000000000000000000000000000000000..227a5765e56be6450ffbe0d4665634e3e30c322d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/bgmv_shrink.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's + performance + """ + pid_sk = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + offset_n = tl.arange(0, BLOCK_N) + offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K + a_ptr = input_ptr + cur_batch * xm_stride + b_ptr = lora_ptr + l0_stride * lora_index + accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) + for k in range(0, K, BLOCK_K * SPLIT_K): + current_k = k + offset_k + current_k_c = tl.max_contiguous(current_k, BLOCK_K) + tiled_a = tl.load( + a_ptr + current_k_c, + mask=current_k < K, + other=0.0, + ) # [BLOCK_K] + b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) + + tiled_b = tl.load( + b_ptr + offset_n[:, None] * lora_k_stride + + current_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + accumulator += tl.sum(tiled_a * tiled_b, 1) + accumulator *= scaling + offset_cn = tl.arange(0, BLOCK_N) + c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride + c_mask = offset_cn < N + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + batches = lora_indices_tensor.size(0) + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_N = triton.next_power_of_2(N) + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) + + grid = lambda META: ( + META["SPLIT_K"], + batches, + ) + _bgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_N=BLOCK_N, + **config, + ) + return + + +def bgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="bgmv_shrink", + op_func=_bgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=bgmv_shrink_fake, + ) + bgmv_shrink = torch.ops.vllm.bgmv_shrink + +except AttributeError: + bgmv_shrink = _bgmv_shrink diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/sgmv_expand.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/sgmv_expand.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e71cacfe5a2e76b9ab1577b4082a977ed37d50 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/sgmv_expand.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + +from .utils import _get_lora_b_ptr + + +@triton.jit +def _sgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): + """ + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + slice_id = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + if pid_n * BLOCK_N > curr_N: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N), + BLOCK_N) + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr + + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * input_d2_stride + b_ptr += BLOCK_K * cur_lora_d2_stride + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & ( + offset_cn[None, :] < (cur_slice_start + curr_N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (List[torch.Tensor]): lora'b weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + """ + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(1) == token_nums + assert inputs.size(0) == len(lora_b_weights) + + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert output_tensor.is_contiguous() + (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, + b_seq_start_loc.device) + + # TODO tuning this config + K = lora_b_weights[0].shape[-1] # K= rank + + BLOCK_M = 64 + BLOCK_N = 128 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), + batches, + len(lora_b_weights), + ) + _sgmv_expand_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + MAX_N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + hidden_sizes_tensor, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + len(lora_b_weights), + same_stride, + ) + return + + +def _sgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="sgmv_expand", + op_func=_sgmv_expand, + mutates_args=["output_tensor"], + fake_impl=_sgmv_expand_fake, + ) + sgmv_expand = torch.ops.vllm.sgmv_expand + +except AttributeError: + sgmv_expand = _sgmv_expand diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/sgmv_shrink.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/sgmv_shrink.py new file mode 100644 index 0000000000000000000000000000000000000000..8b26583c11c14eb4c649fc7706ffe4319a96df2f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/sgmv_shrink.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + +from .utils import _get_lora_a_ptr + + +@triton.jit +def _sgmv_shrink_kernel( + input_ptr, + lora_ptr, #1-3 + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + input_d0_stride, + input_d1_stride, # 1 + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, # 1 + output_d0_stride, + output_d1_stride, + output_d2_stride, # 1 + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_mix = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + if SLICE_NUM == 1: + slice_id: tl.constexpr = 0 + pid_sk = tl.program_id(axis=1) + else: + pid_mix = tl.program_id(axis=1) + slice_id = pid_mix // SPLIT_K + pid_sk = pid_mix % SPLIT_K + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + # input ptr + a_ptr = (input_ptr + cur_seq_start * input_d0_stride + + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < k_remaining, + other=0.0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < k_remaining, + other=0.0) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * input_d1_stride + b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] + < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (List[torch.Tensor]): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights[0].dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_a_weights[0].size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) + # TODO tuning this config + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K * len(lora_a_weights), + batches, + ) + _sgmv_shrink_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + len(lora_a_weights), + ) + return + + +def sgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="sgmv_shrink", + op_func=_sgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=sgmv_shrink_fake, + ) + sgmv_shrink = torch.ops.vllm.sgmv_shrink + +except AttributeError: + sgmv_shrink = _sgmv_shrink diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/utils.py b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78409b91a14e80177d710f0ab8d7d02a950b16fa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/utils.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +from typing import Dict, List, Tuple + +import torch + + +@functools.lru_cache +def _get_op_configs(op_type: str, batch: int, hidden_size: int): + # TODO: add optimal configurations + return None + + +def _check_divisibility(hidden_size: int): + # The bgmv_expand kernel requires that the hidden_size be divisible by + # the number below. + divisibility = [2, 4, 8, 16, 32, 64] + divisibility.sort(reverse=True) + for div in divisibility: + if hidden_size % div == 0: + return div + # hidden_size is an odd number + return 1 + + +def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "expand": + return { + "BLOCK_N": 256, + "SPLIT_N": _check_divisibility(hidden_size), + "num_warps": 8 + } + else: + return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} + + +def get_lora_op_configs(op_type: str, batch: int, + hidden_size: int) -> Dict[str, int]: + """Inspired by `fused_moe_kernel` + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using + scripts similar to `benchmark_moe.py`. + """ + config = _get_op_configs(op_type, batch, hidden_size) + if not config: + config = _get_default_config(op_type, batch, hidden_size) + return config + + +_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): + """ + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) + + if values := _LORA_A_PTR_DICT.get(key): + return values + + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + if len(lora_a_weights) > 1: + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + else: + lora_ptr_tensor = lora_a_weights[0] + + if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1): + raise ValueError("All LoRA weights must have the same stride.") + + _LORA_A_PTR_DICT[key] = ( + lora_ptr_tensor, + lora_strides_d0[0], + lora_strides_d1[0], + lora_strides_d2[0], + ) + return _LORA_A_PTR_DICT.get(key) + + +def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, + device: str): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + + """ + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if values := _LORA_B_PTR_DICT.get(key): + return values + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + hidden_sizes = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + hidden_sizes.append(lora_b_weight.size(1)) + + if len(lora_weights) > 1: + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + else: + slice_start_tensor = slice_offset_lst[0] + lora_ptr_tensor = lora_b_weight[0] + + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and + len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + hidden_sizes_tensor = hidden_sizes[0] + same_stride = True + + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device) + same_stride = False + # MAX_N is the maximum hidden size among all the lora_b weights + MAX_N = max(hidden_sizes) + _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, + lora_strides_d0_tensor, lora_strides_d1_tensor, + lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) + return _LORA_B_PTR_DICT.get(key) diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/peft_helper.py b/.venv/lib/python3.11/site-packages/vllm/lora/peft_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..9496ab5a75c0710b2a0957ea68cb0b76684b37a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/peft_helper.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py + +import json +import math +import os +from dataclasses import MISSING, dataclass, field, fields +from typing import List, Literal, Optional, Union + +from vllm.config import LoRAConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class PEFTHelper: + """ + A helper class for PEFT configurations, specifically designed for LoRA. + This class handles configuration validation, compatibility checks for + various LoRA implementations. + """ + + # Required fields + r: int + lora_alpha: int + target_modules: Union[list[str], str] + + bias: Literal["none", "all", "lora_only"] = field(default="none") + modules_to_save: Optional[list[str]] = field(default=None) + # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) + use_rslora: bool = field(default=False) + # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) + use_dora: bool = field(default=False) + # long context lora field + context_length: int = field(default=0) + # Extra vllm field, start with 'vllm_' to avoid conflict + vllm_lora_scaling_factor: float = field(default=1.0) + vllm_max_position_embeddings: Optional[int] = field(default=False) + vllm_long_context_scaling_factor: Optional[float] = field(default=None) + + def _validate_features(self) -> List[str]: + """ + Check if there are any unsupported Lora features. + """ + error_msg = [] + if self.modules_to_save: + error_msg.append("vLLM only supports modules_to_save being None.") + if self.use_dora: + error_msg.append("vLLM does not yet support DoRA.") + return error_msg + + def __post_init__(self): + if self.use_rslora: + logger.info_once("Loading LoRA weights trained with rsLoRA.") + self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) + else: + self.vllm_lora_scaling_factor = self.lora_alpha / self.r + if self.context_length: + if self.vllm_max_position_embeddings is None: + self.vllm_max_position_embeddings = self.context_length + self.vllm_long_context_scaling_factor = float( + math.ceil(self.context_length / + self.vllm_max_position_embeddings)) + + @classmethod + def from_dict(cls, config_dict: dict) -> "PEFTHelper": + # Get all field information from the class + class_fields = {f.name: f for f in fields(cls)} + # Check for required fields + required_fields = { + name + for name, f in class_fields.items() + if f.default is MISSING and f.default_factory is MISSING + } + + # Identify any missing required fields + missing_fields = required_fields - set(config_dict.keys()) + if missing_fields: + raise ValueError( + f"Missing required configuration fields: {missing_fields}") + + # Filter out fields that aren't defined in the class + filtered_dict = { + k: v + for k, v in config_dict.items() if k in class_fields + } + return cls(**filtered_dict) + + @classmethod + def from_local_dir(cls, lora_path: str, + max_position_embeddings: Optional[int]) -> "PEFTHelper": + lora_config_path = os.path.join(lora_path, "adapter_config.json") + + with open(lora_config_path) as f: + config = json.load(f) + config["vllm_max_position_embeddings"] = max_position_embeddings + return cls.from_dict(config) + + def validate_legal(self, lora_config: LoRAConfig) -> None: + """ + Validates the LoRA configuration settings against application + constraints and requirements. + """ + error_msg = self._validate_features() + if self.r > lora_config.max_lora_rank: + error_msg.append( + f"LoRA rank {self.r} is greater than max_lora_rank" + f" {lora_config.max_lora_rank}.") + if self.bias != "none" and not lora_config.bias_enabled: + error_msg.append( + "Adapter bias cannot be used without bias_enabled.") + if error_msg: + raise ValueError(f"{' '.join(error_msg)}") diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..765d00c7f09ae25d17afd57948338450279284db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_gpu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_gpu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c334c3bf6de40f6b891fea96112b11c9c17611 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/__pycache__/punica_gpu.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccd9c36a073ecd5b0ff1f3f815a797977d0e445 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Optional, Tuple, Union, final + +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import bgmv_expand + from vllm.lora.ops.triton_ops import bgmv_expand_slice + from vllm.lora.ops.triton_ops import bgmv_shrink + from vllm.lora.ops.triton_ops import sgmv_expand + from vllm.lora.ops.triton_ops import sgmv_shrink + +from .punica_base import PunicaWrapperBase + + +@final +class PunicaWrapperGPU(PunicaWrapperBase): + """ + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica triton kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _apply_shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _apply_shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _apply_expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + offset_start: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + offset_start=offset_start, + add_inputs=add_inputs, + ) + + def _apply_expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + + if self.is_prefill: + # NOTE fused kernel + self._apply_shrink_prefill(y, x, lora_a_stacked, scale) + else: + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink_decode(y[slice_idx], x, + lora_a_stacked[slice_idx], scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + if self.is_prefill: + # NOTE fused kernel + self._apply_expand_prefill(y, + x, + lora_b_stacked, + offset_start, + add_inputs=True) + else: + # TODO fuse these kernels + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand_decode( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_start, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_start += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + if self.is_prefill: + sgmv_expand( + x.unsqueeze(dim=0), + [lora_b_stacked], + y, + *self.prefill_metadata, + offset_start=0, + add_inputs=add_inputs, + ) + else: + bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices, + add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros( + (len(output_slices), x.size(0), r), + dtype=torch.float32, + device=x.device, + ) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/request.py b/.venv/lib/python3.11/site-packages/vllm/lora/request.py new file mode 100644 index 0000000000000000000000000000000000000000..badfaa41937741bb1fdaa618c12a322710dfcbe1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/request.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from typing import Optional + +import msgspec + +from vllm.adapter_commons.request import AdapterRequest + + +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """ + Request for a LoRA adapter. + + Note that this class should be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + __metaclass__ = AdapterRequest + + lora_name: str + lora_int_id: int + lora_path: str = "" + lora_local_path: Optional[str] = msgspec.field(default=None) + long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) + + def __post_init__(self): + if self.lora_local_path: + warnings.warn( + "The 'lora_local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'lora_path' instead.", + DeprecationWarning, + stacklevel=2) + if not self.lora_path: + self.lora_path = self.lora_local_path or "" + + # Ensure lora_path is not empty + assert self.lora_path, "lora_path cannot be empty" + + @property + def adapter_id(self): + return self.lora_int_id + + @property + def name(self): + return self.lora_name + + @property + def path(self): + return self.lora_path + + @property + def local_path(self): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + return self.lora_path + + @local_path.setter + def local_path(self, value): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/utils.py b/.venv/lib/python3.11/site-packages/vllm/lora/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f47b0af1552262c3007daa042caff49cf30c4cbc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/utils.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +from typing import List, Optional, Set, Tuple, Type, Union + +import huggingface_hub +from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, + HFValidationError, RepositoryNotFoundError) +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLora, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLora, +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device, + lm_head.get_sharded_to_full_mapping()) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name( + name: str, + weights_mapper: Optional[WeightsMapper] = None +) -> Tuple[str, bool, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + weights_mapper: maps the name of weight, e.g. + `model.` -> `language_model.model.`, + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + is_bias whether the tensor is lora bias. + """ + + # LoRA weight qualified name always starts with `base_model.model.`, + # so we remove the prefix `base_model.model.` to make the following + # mapping correctly. + if "base_model.model." in name: + name = name.replace("base_model.model.", "") + name = weights_mapper._map_name(name) if weights_mapper else name + # recover the prefix `base_model.model.` + name = "base_model.model." + name + + parts = name.split(".") + if parts[-1] == "weight" and (parts[-2] == "lora_A" + or parts[-2] == "lora_B"): + new_name = ".".join(parts[2:-2]) + return new_name, parts[-2] == "lora_A", False + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + new_name = ".".join(parts[2:-1]) + return new_name, parts[-1] == "lora_embedding_A", False + + if parts[-1] == "bias": + new_name = ".".join(parts[2:-2]) + return new_name, False, True + + raise ValueError(f"{name} is unsupported LoRA weight") + + +def is_regex_target_modules(load_modules: Union[str, List[str]], + expected_lora_modules: List[str]) -> bool: + """ + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the + `expected_lora_modules`. + """ + + def is_valid_regex(pattern): + try: + re.compile(pattern) + return True + except re.error: + return False + + def is_subset(sub_list, full_list): + return set(sub_list).issubset(set(full_list)) + + # Similar to PEFT's processing logic, regex-related operations are only + # executed when the load_modules is a `str`. + if not isinstance(load_modules, str): + return False + + if is_valid_regex(load_modules): + match = re.search(r"\((.*?)\)\$?$", load_modules) + if match: + suffix = match.group(1).split("|") + return is_subset(suffix, expected_lora_modules) + return False + + +def get_adapter_absolute_path(lora_path: str) -> str: + """ + Resolves the given lora_path to an absolute local path. + + If the lora_path is identified as a Hugging Face model identifier, + it will download the model and return the local snapshot path. + Otherwise, it treats the lora_path as a local file path and + converts it to an absolute path. + + Parameters: + lora_path (str): The path to the lora model, which can be an absolute path, + a relative path, or a Hugging Face model identifier. + + Returns: + str: The resolved absolute local path to the lora model. + """ + + # Check if the path is an absolute path. Return it no matter exists or not. + if os.path.isabs(lora_path): + return lora_path + + # If the path starts with ~, expand the user home directory. + if lora_path.startswith('~'): + return os.path.expanduser(lora_path) + + # Check if the expanded relative path exists locally. + if os.path.exists(lora_path): + return os.path.abspath(lora_path) + + # If the path does not exist locally, assume it's a Hugging Face repo. + try: + local_snapshot_path = huggingface_hub.snapshot_download( + repo_id=lora_path) + except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, + HFValidationError): + # Handle errors that may occur during the download + # Return original path instead instead of throwing error here + logger.exception("Error downloading the HuggingFace model") + return lora_path + + return local_snapshot_path diff --git a/.venv/lib/python3.11/site-packages/vllm/lora/worker_manager.py b/.venv/lib/python3.11/site-packages/vllm/lora/worker_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f33a7b88cc35ffc53eab9eebb6072cdc674194d2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/lora/worker_manager.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.models import (LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path + +logger = init_logger(__name__) + + +class WorkerLoRAManager(AbstractWorkerManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + embedding_modules: Dict[str, str], + embedding_padding_modules: List[str], + lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, + ): + self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) + # Lazily initialized by create_lora_manager. + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + lora_manager_cls=self._manager_cls, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: + try: + model = self._adapter_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + expected_lora_modules = list(set(expected_lora_modules)) + lora_path = get_adapter_absolute_path(lora_request.lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, self.max_position_embeddings) + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + hf_to_vllm_mapper = None + if (hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = model.hf_to_vllm_mapper + + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper) + + except FileNotFoundError as e: + # FileNotFoundError should be raised if both + # - No adapter found to download from huggingface (or in + # offline mode) + # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError + raise ValueError( + f"Loading lora {lora_request.lora_name} failed: No adapter " + f"found for {lora_path}") from e + except Exception as e: + # For BadRequestError + raise e + + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_adapters(): + return False + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._adapter_manager.create_dummy_lora( + lora_request.lora_int_id, rank, 1, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._adapter_manager.add_adapter(dummy_lora) + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + lora_manager_cls=self._manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._adapter_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._adapter_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_adapter(lora) + + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): + # Load the new adapter first to ensure it is actually valid, before + # evicting any existing adapters. + # This may cause the # of loaded lora adapters to very temporarily + # exceed `--max-cpu-loras`. + lora = self._load_adapter(lora_request) + + # Loading succeeded, now check if we will exceed cache capacity and + # evict if the oldest adapter if so + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + # Then add the new adapter to the cache + loaded = self._adapter_manager.add_adapter(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + lora_request.lora_int_id) is not None + self._adapter_manager.activate_adapter(lora_request.lora_int_id) + return loaded diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eadc09b6fdd23f146508a43466a4b47bdf1834ee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/image.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/image.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc3ce5a2e4e36c7bf17e8dd825c40b45277e063 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/image.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/inputs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/inputs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eae3a618c520188744a8b2f04f346d8e4155ee2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/inputs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/profiling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/profiling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f20b783659d1cbbfdfcfcc197d77708487784c4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/profiling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c785b317278e352094b9053548d6a98701f0ba9b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/multimodal/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/plugins/__init__.py b/.venv/lib/python3.11/site-packages/vllm/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..389cb8728103189b501f8a936fdeb02b194a12a1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/plugins/__init__.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from typing import Callable, Dict + +import torch + +import vllm.envs as envs + +logger = logging.getLogger(__name__) + +# make sure one process only loads plugins once +plugins_loaded = False + + +def load_plugins_by_group(group: str) -> Dict[str, Callable]: + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group=group) + if len(discovered_plugins) == 0: + logger.debug("No plugins for group %s found.", group) + return {} + logger.info("Available plugins for group %s:", group) + for plugin in discovered_plugins: + logger.info("name=%s, value=%s", plugin.name, plugin.value) + if allowed_plugins is None: + logger.info("all available plugins for group %s will be loaded.", + group) + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + plugins = {} + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + plugins[plugin.name] = func + logger.info("plugin %s loaded.", plugin.name) + except Exception: + logger.exception("Failed to load plugin %s", plugin.name) + return plugins + + +def load_general_plugins(): + """WARNING: plugins can be loaded for multiple times in different + processes. They should be designed in a way that they can be loaded + multiple times without causing issues. + """ + global plugins_loaded + if plugins_loaded: + return + plugins_loaded = True + + # some platform-specific configurations + from vllm.platforms import current_platform + + if current_platform.is_xpu(): + # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 + torch._dynamo.config.disable = True + elif current_platform.is_hpu(): + # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) + # does not support torch.compile + # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for + # torch.compile support + is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' + if is_lazy: + torch._dynamo.config.disable = True + # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) + # requires enabling lazy collectives + # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 + os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' + + plugins = load_plugins_by_group(group='vllm.general_plugins') + # general plugins, we only need to execute the loaded functions + for func in plugins.values(): + func() diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/draft_model_runner.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/draft_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..3948298db40c210a6c390ed0ca35f341d80995dc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/draft_model_runner.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import torch + +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.sampler import SamplerOutput + +try: + try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + except (ModuleNotFoundError, ImportError): + # vllm_flash_attn is not installed, try the ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) +except (ModuleNotFoundError, ImportError) as err: + raise RuntimeError( + "Draft model speculative decoding currently only supports" + "CUDA and ROCm flash attention backend.") from err + +from vllm.logger import init_logger +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) + +logger = init_logger(__name__) + +# A flag to enable debug prints for the updated input tensors +# before each step. +debug_advance_input = False +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True + + +class TP1DraftModelRunner(ModelRunnerWrapperBase): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__(self, model_runner: ModelRunnerBase): + if hasattr( + model_runner, + "return_hidden_states") and model_runner.return_hidden_states: + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + super().__init__(model_runner) + + self.indices_of_seq_with_bonus_tokens = None + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + + def _gpu_advance_step(self, model_input: ModelRunnerInputBase, + last_output: SamplerOutput) -> ModelRunnerInputBase: + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + + attn_metadata.advance_step(model_input, sampled_token_ids, + self.block_size, num_seqs, num_queries) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config + """ + if not allow_gpu_advance_step: + return False + + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False + + # TODO: Add support for other attn backends + if self.attn_backend.get_name() != "FLASH_ATTN": + return False + + # TODO: Add support for LORA + if self.lora_config: + return False + + # TODO: Add soft-tuning prompt adapter support + return not self.prompt_adapter_config + + def set_indices_of_seq_with_bonus_tokens(self, + indices_of_seq_with_bonus_tokens): + self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelRunnerInputBase, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. + + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ + + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + self.attn_state.begin_forward(model_input) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + + if previous_hidden_states is not None: + hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + hidden_states = None + else: + model_executable = self.model + hidden_states = previous_hidden_states + + outputs: List[SamplerOutput] = [] + for step in range(num_steps): + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + + # Run model + with set_forward_context(model_input.attn_metadata, + self.vllm_config): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + outputs.append(output) + + if model_input.attn_metadata.num_prefills == 0 \ + and self.indices_of_seq_with_bonus_tokens is not None: + assert output.sampled_token_ids is not None + # output.sampled_token_ids should be of shape (num_seqs, 1) + nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape + assert num_tokens_per_seq == 1 + count = 0 + for i in range(nums_seqs): + bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ + count] + if i != bonus_seq_idx: + # The following might cause a cpu->gpu sync + # However, the performance impact is negligible as we + # benchmarked on H100. + output.sampled_token_ids[ + i, :] = model_input.input_tokens[bonus_seq_idx] + else: + count += 1 + + # Prepare inputs for the next step + if step != num_steps - 1: + model_input = self._gpu_advance_step(model_input, outputs[-1]) + + return outputs diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/metrics.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0e0a121cd55363d8bc902747b5e0f4f3f77f56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/metrics.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from typing import Callable, Optional, Union + +import msgspec +import torch + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) +from vllm.utils import is_pin_memory_available + + +class SpecDecodeWorkerMetrics( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """Dataclass holding metrics emitted from the spec decode worker. + """ + + # The empirical acceptance rate of the proposal method on a per-token basis. + # This is useful for evaluating how well the proposal method aligns with the + # scoring method. + draft_acceptance_rate: float + + # The empirical efficiency, measured as the number of tokens emitted by the + # system divided by the number of tokens that could be emitted by the system + # if the proposal method were perfect. + system_efficiency: float + + # The number of speculative tokens produced by the proposal method. + draft_tokens: int + + # The number of tokens emitted by the entire system. + emitted_tokens: int + + # The number of tokens accepted by the scoring model and verification + # routine, e.g. Llama2-70B and lossless rejection sampling. + # + # NOTE: Any token accepted by the verification routine is considered + # accepted (regardless of if the speculative prefix is also accepted). The + # user will usually see less accepted tokens. This metric is helpful when + # evaluating alignment of the proposal method with the scoring model. + accepted_tokens: int + + # The number of speculative tokens per sequence. + num_spec_tokens: int + + +Timer = Callable[[], float] + + +class AsyncMetricsCollector: + """Class which copies rejection/typical-acceptance sampler metrics + from the device to CPU on a non-default Torch stream. + """ + + def __init__(self, + spec_decode_sampler: SpecDecodeBaseSampler, + timer: Optional[Timer] = None, + collect_interval_s: float = 5.0): + self.spec_decode_sampler = spec_decode_sampler + self._timer = time.time if timer is None else timer + + self._rank: Optional[int] = None + + # We don't have a device set yet. + self._copy_stream: Optional[torch.cuda.Stream] = None + + self._in_flight_copy: Optional[torch.cuda.Event] = None + + pin_memory = is_pin_memory_available() + self._aggregate_num_accepted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_emitted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_draft_tokens = 0 + + self._rejsample_metrics_collect_interval_s = collect_interval_s + self._last_metrics_collect_time = self._timer() + + def init_gpu_tensors(self, rank: int) -> None: + self._rank = rank + self._copy_stream = torch.cuda.Stream() + + def init_tensors(self, + rank: int, + device_type: Union[torch.device, str] = 'cuda') -> None: + self._rank = rank + if isinstance(device_type, torch.device): + device_type = device_type.type + if device_type == 'cuda': + self._copy_stream = torch.cuda.Stream() + + def maybe_collect_rejsample_metrics( + self, k: int) -> Optional[SpecDecodeWorkerMetrics]: + # currently using cuda.Event, skip for any non_cuda_alike platform + from vllm.platforms import current_platform + if not current_platform.is_cuda_alike(): + return None + + # If a copy was initiated in the previous call, collect and return. + if self._in_flight_copy is not None: + ready_event = self._in_flight_copy + self._in_flight_copy = None + return self._collect_rejsample_metrics(k, ready_event) + + # Otherwise, check if we should start a new copy. + if self._should_collect_rejsample_metrics(self._timer()): + assert self._in_flight_copy is None + self._in_flight_copy = self._copy_rejsample_metrics_async() + + return None + + def _should_collect_rejsample_metrics(self, now: float) -> bool: + """Return whether or not this iteration should print sampling + metrics. + """ + if self._rank != 0: + return False + + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 + + def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: + """Copy rejection/typical-acceptance sampling metrics + (number of accepted tokens, etc) to CPU asynchronously. + + Returns a CUDA event recording when the copy is complete. + """ + assert self._copy_stream is not None + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self._copy_stream): + self._aggregate_num_accepted_tokens.copy_( + self.spec_decode_sampler.num_accepted_tokens, + non_blocking=True) + self._aggregate_num_emitted_tokens.copy_( + self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) + # Number of draft tokens is calculated on CPU, so no copy is + # required. + self._aggregate_num_draft_tokens = ( + self.spec_decode_sampler.num_draft_tokens) + + aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready.record(self._copy_stream) + + return aggregate_metrics_ready + + def _collect_rejsample_metrics( + self, k: int, + ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics: + """Create metrics object from statistics copied asynchronously. + + Args: + k: int. The number of speculative tokens; used to determine system + efficiency. + ready_event: torch.cuda.Event. The CUDA event recording when the + async GPU->CPU copy is complete. + """ + + ready_event.synchronize() + + # update time of last collection + self._last_metrics_collect_time = self._timer() + + accepted_tokens = self._aggregate_num_accepted_tokens.item() + emitted_tokens = self._aggregate_num_emitted_tokens.item() + draft_tokens = self._aggregate_num_draft_tokens + + max_num_emitted_tokens = self.get_max_num_emitted_tokens( + draft_tokens, k) + + if draft_tokens > 0: + draft_acceptance_rate = accepted_tokens / draft_tokens + else: + draft_acceptance_rate = float("nan") + + if max_num_emitted_tokens > 0: + system_efficiency = emitted_tokens / max_num_emitted_tokens + else: + system_efficiency = float("nan") + + return SpecDecodeWorkerMetrics( + num_spec_tokens=k, + draft_acceptance_rate=draft_acceptance_rate, + system_efficiency=system_efficiency, + accepted_tokens=accepted_tokens, + draft_tokens=draft_tokens, + emitted_tokens=emitted_tokens, + ) + + @staticmethod + def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: + """Calculate the number of emitted tokens, assuming all tokens are + accepted. + + This is equal to the number of sequences that have been speculated on, + times (speculation len + 1). The +1 comes from the bonus token. + """ + # Determine the number of sequences that have been speculated on. Since + # the batch size can be variable, we divide by k. + assert draft_tokens % k == 0 + total_num_spec_seqs = draft_tokens // k + + # A single sequence may emit k accepted tokens and one bonus token in + # the best case. + num_emitted_per_seq_if_all_accepted = k + 1 + + # The max num of emitted tokens is the number of speculated sequences + # times the max emitted per seq. + return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/target_model_runner.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/target_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..08e773c562bf83f7f9fd3928b0aa30881f25e526 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/target_model_runner.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) + + +class TargetModelRunner(ModelRunnerWrapperBase): + """Specialized model runner for speculative decoding target model. + In speculative decoding, the log probabilities selected finally may not + be the same ones as selected by the target model sampling. This means + that the time spent in the log probability calculation of the target model + is time wasted, since we calculate log probabilities after deciding which + tokens are accepted. For this reason disabling log probabilities in the + target model will make decode faster. The model runner sets the + SamplingMetadata parameters according to whether log probabilities are + requested or not. + """ + + def __init__(self, model_runner: ModelRunnerBase): + # An internal boolean member variable to indicate if token log + # probabilities are needed or not. + super().__init__(model_runner) + self.disable_logprobs = True + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelRunnerInputBase: + model_input: ModelRunnerInputBase =\ + self.model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + # If token log probabilities is disabled then skip generating sampler + # CPU output. We directly serialize the GPU sampled_token_id tensors + # as needed. If log probabilities is enabled then synchronize all the + # sampling related tensors which includes the logprobs tensors. + model_input.sampling_metadata.skip_sampler_cpu_output = ( + self.disable_logprobs) + return model_input diff --git a/.venv/lib/python3.11/site-packages/vllm/spec_decode/top1_proposer.py b/.venv/lib/python3.11/site-packages/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000000000000000000000000000000000000..b538923c03e74a95e307302bfb9f2d2fe976d838 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.util import sampler_output_to_torch + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: ProposerWorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + hidden_states = execute_model_req.previous_hidden_states + if hidden_states is not None: + hidden_states.prune(nonzero_proposal_len_seqs) + nonzero_execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=nonzero_proposal_len_seqs, + num_lookahead_slots=proposal_len, + previous_hidden_states=hidden_states, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, + sample_len=proposal_len, + seq_ids_with_bonus_token_in_last_step=\ + seq_ids_with_bonus_token_in_last_step, + ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + no_proposals=maybe_sampler_output + is None) + return proposals + + def _split_by_proposal_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has either been disabled + # (e.g. due to high traffic) or this is a prompt request. + if (seq_group_metadata.is_prompt + or seq_group_metadata.num_speculative_tokens == 0): + proposal_lens.append(0) + continue + + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall not exceed this + # quota for nonzero_proposal + new_k = 0 + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + new_k = proposal_len + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + @staticmethod + def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[List[SamplerOutput]], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.tensor(-1, + dtype=torch.long, + device=self._device).expand( + batch_size, proposal_len) + proposal_probs = torch.tensor(0, + dtype=torch.float32, + device=self._device).expand( + batch_size, proposal_len, + self._vocab_size) + proposal_lens_tensor = torch.tensor(0, + dtype=torch.long, + device=self._device).expand( + len(proposal_lens)) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = proposal_tokens.new_full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = proposal_probs.new_zeros( + batch_size, + *proposal_probs.shape[1:], + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_model_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_model_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..885d0c09a1e01983e730726a1ef41a3e32b105f6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_model_runner.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:481ac49179845f4e110495059bdad0c28e7c453a6511f2bc920716c63888c198 +size 101296