Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py +18 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +162 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +302 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +260 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +324 -0
- .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py +123 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py +335 -0
- .venv/lib/python3.11/site-packages/vllm/lora/layers.py +1206 -0
- .venv/lib/python3.11/site-packages/vllm/lora/lora.py +198 -0
- .venv/lib/python3.11/site-packages/vllm/lora/models.py +763 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py +15 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py +115 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py +15 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -108,3 +108,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 108 |
.venv/lib/python3.11/site-packages/pillow.libs/liblcms2-e69eef39.so.2.0.16 filter=lfs diff=lfs merge=lfs -text
|
| 109 |
.venv/lib/python3.11/site-packages/pillow.libs/libopenjp2-05423b53.so filter=lfs diff=lfs merge=lfs -text
|
| 110 |
.venv/lib/python3.11/site-packages/pillow.libs/libxcb-b8a56d01.so.1.1.0 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 108 |
.venv/lib/python3.11/site-packages/pillow.libs/liblcms2-e69eef39.so.2.0.16 filter=lfs diff=lfs merge=lfs -text
|
| 109 |
.venv/lib/python3.11/site-packages/pillow.libs/libopenjp2-05423b53.so filter=lfs diff=lfs merge=lfs -text
|
| 110 |
.venv/lib/python3.11/site-packages/pillow.libs/libxcb-b8a56d01.so.1.1.0 filter=lfs diff=lfs merge=lfs -text
|
| 111 |
+
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_model_runner.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc
ADDED
|
Binary file (46.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc
ADDED
|
Binary file (6.12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc
ADDED
|
Binary file (65.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from .abstract_tool_parser import ToolParser, ToolParserManager
|
| 4 |
+
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
| 5 |
+
from .granite_tool_parser import GraniteToolParser
|
| 6 |
+
from .hermes_tool_parser import Hermes2ProToolParser
|
| 7 |
+
from .internlm2_tool_parser import Internlm2ToolParser
|
| 8 |
+
from .jamba_tool_parser import JambaToolParser
|
| 9 |
+
from .llama_tool_parser import Llama3JsonToolParser
|
| 10 |
+
from .mistral_tool_parser import MistralToolParser
|
| 11 |
+
from .pythonic_tool_parser import PythonicToolParser
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
|
| 15 |
+
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
| 16 |
+
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
| 17 |
+
"PythonicToolParser"
|
| 18 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc
ADDED
|
Binary file (8.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc
ADDED
|
Binary file (9.44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (5.99 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from functools import cached_property
|
| 5 |
+
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
| 6 |
+
|
| 7 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 8 |
+
DeltaMessage,
|
| 9 |
+
ExtractedToolCallInformation)
|
| 10 |
+
from vllm.logger import init_logger
|
| 11 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 12 |
+
from vllm.utils import import_from_path, is_list_of
|
| 13 |
+
|
| 14 |
+
logger = init_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ToolParser:
|
| 18 |
+
"""
|
| 19 |
+
Abstract ToolParser class that should not be used directly. Provided
|
| 20 |
+
properties and methods should be used in
|
| 21 |
+
derived classes.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 25 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 26 |
+
# the index of the tool call that is currently being parsed
|
| 27 |
+
self.current_tool_id: int = -1
|
| 28 |
+
self.current_tool_name_sent: bool = False
|
| 29 |
+
self.streamed_args_for_tool: List[str] = []
|
| 30 |
+
|
| 31 |
+
self.model_tokenizer = tokenizer
|
| 32 |
+
|
| 33 |
+
@cached_property
|
| 34 |
+
def vocab(self) -> Dict[str, int]:
|
| 35 |
+
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
| 36 |
+
# whereas all tokenizers have .get_vocab()
|
| 37 |
+
return self.model_tokenizer.get_vocab()
|
| 38 |
+
|
| 39 |
+
def adjust_request(
|
| 40 |
+
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
| 41 |
+
"""
|
| 42 |
+
Static method that used to adjust the request parameters.
|
| 43 |
+
"""
|
| 44 |
+
return request
|
| 45 |
+
|
| 46 |
+
def extract_tool_calls(
|
| 47 |
+
self, model_output: str,
|
| 48 |
+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
| 49 |
+
"""
|
| 50 |
+
Static method that should be implemented for extracting tool calls from
|
| 51 |
+
a complete model-generated string.
|
| 52 |
+
Used for non-streaming responses where we have the entire model response
|
| 53 |
+
available before sending to the client.
|
| 54 |
+
Static because it's stateless.
|
| 55 |
+
"""
|
| 56 |
+
raise NotImplementedError(
|
| 57 |
+
"AbstractToolParser.extract_tool_calls has not been implemented!")
|
| 58 |
+
|
| 59 |
+
def extract_tool_calls_streaming(
|
| 60 |
+
self,
|
| 61 |
+
previous_text: str,
|
| 62 |
+
current_text: str,
|
| 63 |
+
delta_text: str,
|
| 64 |
+
previous_token_ids: Sequence[int],
|
| 65 |
+
current_token_ids: Sequence[int],
|
| 66 |
+
delta_token_ids: Sequence[int],
|
| 67 |
+
request: ChatCompletionRequest,
|
| 68 |
+
) -> Union[DeltaMessage, None]:
|
| 69 |
+
"""
|
| 70 |
+
Instance method that should be implemented for extracting tool calls
|
| 71 |
+
from an incomplete response; for use when handling tool calls and
|
| 72 |
+
streaming. Has to be an instance method because it requires state -
|
| 73 |
+
the current tokens/diffs, but also the information about what has
|
| 74 |
+
previously been parsed and extracted (see constructor)
|
| 75 |
+
"""
|
| 76 |
+
raise NotImplementedError(
|
| 77 |
+
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
| 78 |
+
"implemented!")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ToolParserManager:
|
| 82 |
+
tool_parsers: Dict[str, Type] = {}
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def get_tool_parser(cls, name) -> Type:
|
| 86 |
+
"""
|
| 87 |
+
Get tool parser by name which is registered by `register_module`.
|
| 88 |
+
|
| 89 |
+
Raise a KeyError exception if the name is not registered.
|
| 90 |
+
"""
|
| 91 |
+
if name in cls.tool_parsers:
|
| 92 |
+
return cls.tool_parsers[name]
|
| 93 |
+
|
| 94 |
+
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def _register_module(cls,
|
| 98 |
+
module: Type,
|
| 99 |
+
module_name: Optional[Union[str, List[str]]] = None,
|
| 100 |
+
force: bool = True) -> None:
|
| 101 |
+
if not issubclass(module, ToolParser):
|
| 102 |
+
raise TypeError(
|
| 103 |
+
f'module must be subclass of ToolParser, but got {type(module)}'
|
| 104 |
+
)
|
| 105 |
+
if module_name is None:
|
| 106 |
+
module_name = module.__name__
|
| 107 |
+
if isinstance(module_name, str):
|
| 108 |
+
module_name = [module_name]
|
| 109 |
+
for name in module_name:
|
| 110 |
+
if not force and name in cls.tool_parsers:
|
| 111 |
+
existed_module = cls.tool_parsers[name]
|
| 112 |
+
raise KeyError(f'{name} is already registered '
|
| 113 |
+
f'at {existed_module.__module__}')
|
| 114 |
+
cls.tool_parsers[name] = module
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def register_module(
|
| 118 |
+
cls,
|
| 119 |
+
name: Optional[Union[str, List[str]]] = None,
|
| 120 |
+
force: bool = True,
|
| 121 |
+
module: Union[Type, None] = None) -> Union[type, Callable]:
|
| 122 |
+
"""
|
| 123 |
+
Register module with the given name or name list. it can be used as a
|
| 124 |
+
decoder(with module as None) or normal function(with module as not
|
| 125 |
+
None).
|
| 126 |
+
"""
|
| 127 |
+
if not isinstance(force, bool):
|
| 128 |
+
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
| 129 |
+
|
| 130 |
+
# raise the error ahead of time
|
| 131 |
+
if not (name is None or isinstance(name, str)
|
| 132 |
+
or is_list_of(name, str)):
|
| 133 |
+
raise TypeError(
|
| 134 |
+
'name must be None, an instance of str, or a sequence of str, '
|
| 135 |
+
f'but got {type(name)}')
|
| 136 |
+
|
| 137 |
+
# use it as a normal method: x.register_module(module=SomeClass)
|
| 138 |
+
if module is not None:
|
| 139 |
+
cls._register_module(module=module, module_name=name, force=force)
|
| 140 |
+
return module
|
| 141 |
+
|
| 142 |
+
# use it as a decorator: @x.register_module()
|
| 143 |
+
def _register(module):
|
| 144 |
+
cls._register_module(module=module, module_name=name, force=force)
|
| 145 |
+
return module
|
| 146 |
+
|
| 147 |
+
return _register
|
| 148 |
+
|
| 149 |
+
@classmethod
|
| 150 |
+
def import_tool_parser(cls, plugin_path: str) -> None:
|
| 151 |
+
"""
|
| 152 |
+
Import a user-defined tool parser by the path of the tool parser define
|
| 153 |
+
file.
|
| 154 |
+
"""
|
| 155 |
+
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
import_from_path(module_name, plugin_path)
|
| 159 |
+
except Exception:
|
| 160 |
+
logger.exception("Failed to load module '%s' from %s.",
|
| 161 |
+
module_name, plugin_path)
|
| 162 |
+
return
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, List, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import partial_json_parser
|
| 8 |
+
from partial_json_parser.core.options import Allow
|
| 9 |
+
|
| 10 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 11 |
+
DeltaFunctionCall, DeltaMessage,
|
| 12 |
+
DeltaToolCall,
|
| 13 |
+
ExtractedToolCallInformation,
|
| 14 |
+
FunctionCall, ToolCall)
|
| 15 |
+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
| 16 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (
|
| 17 |
+
extract_intermediate_diff)
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 20 |
+
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
| 21 |
+
from vllm.utils import random_uuid
|
| 22 |
+
|
| 23 |
+
logger = init_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@ToolParserManager.register_module("jamba")
|
| 27 |
+
class JambaToolParser(ToolParser):
|
| 28 |
+
|
| 29 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 30 |
+
super().__init__(tokenizer)
|
| 31 |
+
|
| 32 |
+
if isinstance(self.model_tokenizer, MistralTokenizer):
|
| 33 |
+
raise ValueError(
|
| 34 |
+
"Detected a MistralTokenizer tokenizer when using a Jamba model"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.current_tool_name_sent: bool = False
|
| 38 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 39 |
+
self.current_tool_id: int = -1
|
| 40 |
+
self.streamed_args_for_tool: List[str] = [
|
| 41 |
+
] # map what has been streamed for each tool so far to a list
|
| 42 |
+
|
| 43 |
+
self.tool_calls_start_token: str = "<tool_calls>"
|
| 44 |
+
self.tool_calls_end_token: str = "</tool_calls>"
|
| 45 |
+
|
| 46 |
+
self.tool_calls_regex = re.compile(
|
| 47 |
+
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
|
| 48 |
+
re.DOTALL)
|
| 49 |
+
|
| 50 |
+
if not self.model_tokenizer:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"The model tokenizer must be passed to the ToolParser "
|
| 53 |
+
"constructor during construction.")
|
| 54 |
+
self.tool_calls_start_token_id = self.vocab.get(
|
| 55 |
+
self.tool_calls_start_token)
|
| 56 |
+
self.tool_calls_end_token_id = self.vocab.get(
|
| 57 |
+
self.tool_calls_end_token)
|
| 58 |
+
if (self.tool_calls_start_token_id is None
|
| 59 |
+
or self.tool_calls_end_token_id is None):
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
"Jamba Tool parser could not locate tool calls start/end "
|
| 62 |
+
"tokens in the tokenizer!")
|
| 63 |
+
|
| 64 |
+
def adjust_request(
|
| 65 |
+
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
| 66 |
+
if request.tools and request.tool_choice != 'none':
|
| 67 |
+
# do not skip special tokens because jamba use the special
|
| 68 |
+
# tokens to indicate the start and end of the tool calls
|
| 69 |
+
# information.
|
| 70 |
+
request.skip_special_tokens = False
|
| 71 |
+
return request
|
| 72 |
+
|
| 73 |
+
def extract_tool_calls(
|
| 74 |
+
self, model_output: str,
|
| 75 |
+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
| 76 |
+
|
| 77 |
+
# sanity check; avoid unnecessary processing
|
| 78 |
+
if self.tool_calls_start_token not in model_output:
|
| 79 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 80 |
+
tool_calls=[],
|
| 81 |
+
content=model_output)
|
| 82 |
+
|
| 83 |
+
else:
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# use a regex to find the tool call between the tags
|
| 87 |
+
function_calls = self.tool_calls_regex.findall(model_output)[0]
|
| 88 |
+
|
| 89 |
+
# load the JSON, and then use it to build the Function and
|
| 90 |
+
# Tool Call
|
| 91 |
+
raw_function_calls = json.loads(function_calls)
|
| 92 |
+
tool_calls = [
|
| 93 |
+
ToolCall(
|
| 94 |
+
type="function",
|
| 95 |
+
function=FunctionCall(
|
| 96 |
+
name=function_call["name"],
|
| 97 |
+
# function call args are JSON but as a string
|
| 98 |
+
arguments=json.dumps(function_call["arguments"])))
|
| 99 |
+
for function_call in raw_function_calls
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
content = model_output[:model_output.
|
| 103 |
+
find(self.tool_calls_start_token)]
|
| 104 |
+
return ExtractedToolCallInformation(
|
| 105 |
+
tools_called=True,
|
| 106 |
+
tool_calls=tool_calls,
|
| 107 |
+
content=content if
|
| 108 |
+
(len(content) > 0 and content != " ") else None)
|
| 109 |
+
|
| 110 |
+
except Exception:
|
| 111 |
+
logger.exception(
|
| 112 |
+
"Error in extracting tool call from response.")
|
| 113 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 114 |
+
tool_calls=[],
|
| 115 |
+
content=model_output)
|
| 116 |
+
|
| 117 |
+
def extract_tool_calls_streaming(
|
| 118 |
+
self,
|
| 119 |
+
previous_text: str,
|
| 120 |
+
current_text: str,
|
| 121 |
+
delta_text: str,
|
| 122 |
+
previous_token_ids: Sequence[int],
|
| 123 |
+
current_token_ids: Sequence[int],
|
| 124 |
+
delta_token_ids: Sequence[int],
|
| 125 |
+
request: ChatCompletionRequest,
|
| 126 |
+
) -> Union[DeltaMessage, None]:
|
| 127 |
+
|
| 128 |
+
# if the tool call token is not in the tokens generated so far, append
|
| 129 |
+
# output to contents since it's not a tool
|
| 130 |
+
if self.tool_calls_start_token not in current_text:
|
| 131 |
+
return DeltaMessage(content=delta_text)
|
| 132 |
+
|
| 133 |
+
# if the tool call token ID IS in the tokens generated so far, that
|
| 134 |
+
# means we're parsing as tool calls now
|
| 135 |
+
|
| 136 |
+
# handle if we detected the start of tool calls token which means
|
| 137 |
+
# the start of tool calling
|
| 138 |
+
if (self.tool_calls_start_token_id in delta_token_ids
|
| 139 |
+
and len(delta_token_ids) == 1):
|
| 140 |
+
# if it's the only token, return None, so we don't send a chat
|
| 141 |
+
# completion and don't send a control token
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 145 |
+
# sent yet, don't allow sending
|
| 146 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 147 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 148 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 149 |
+
else Allow.ALL & ~Allow.STR
|
| 150 |
+
try:
|
| 151 |
+
|
| 152 |
+
# Extract the tool calls between the special tool call tokens
|
| 153 |
+
parsable_arr = current_text.split(
|
| 154 |
+
self.tool_calls_start_token)[-1].split(
|
| 155 |
+
self.tool_calls_end_token)[0]
|
| 156 |
+
|
| 157 |
+
# tool calls are generated in an array, so do partial JSON
|
| 158 |
+
# parsing on the entire array
|
| 159 |
+
try:
|
| 160 |
+
tool_call_arr: List[Dict] = partial_json_parser.loads(
|
| 161 |
+
parsable_arr, flags)
|
| 162 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 163 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
# select as the current tool call the one we're on the state at
|
| 167 |
+
|
| 168 |
+
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
| 169 |
+
if len(tool_call_arr) > 0 else {}
|
| 170 |
+
|
| 171 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 172 |
+
# only the array brackets, stream nothing
|
| 173 |
+
if len(tool_call_arr) == 0:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
# case: we are starting a new tool in the array
|
| 177 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 178 |
+
elif (len(tool_call_arr) > 0
|
| 179 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 180 |
+
|
| 181 |
+
# if we're moving on to a new call, first make sure we
|
| 182 |
+
# haven't missed anything in the previous one that was
|
| 183 |
+
# auto-generated due to JSON completions, but wasn't
|
| 184 |
+
# streamed to the client yet.
|
| 185 |
+
if self.current_tool_id >= 0:
|
| 186 |
+
diff: Union[str, None] = current_tool_call.get("arguments")
|
| 187 |
+
|
| 188 |
+
if diff:
|
| 189 |
+
diff = json.dumps(diff).replace(
|
| 190 |
+
self.streamed_args_for_tool[self.current_tool_id],
|
| 191 |
+
"")
|
| 192 |
+
delta = DeltaMessage(tool_calls=[
|
| 193 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 194 |
+
function=DeltaFunctionCall(
|
| 195 |
+
arguments=diff).model_dump(
|
| 196 |
+
exclude_none=True))
|
| 197 |
+
])
|
| 198 |
+
self.streamed_args_for_tool[
|
| 199 |
+
self.current_tool_id] += diff
|
| 200 |
+
else:
|
| 201 |
+
delta = None
|
| 202 |
+
else:
|
| 203 |
+
delta = None
|
| 204 |
+
# re-set stuff pertaining to progress in the current tool
|
| 205 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 206 |
+
self.current_tool_name_sent = False
|
| 207 |
+
self.streamed_args_for_tool.append("")
|
| 208 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 209 |
+
return delta
|
| 210 |
+
|
| 211 |
+
# case: update an existing tool - this is handled below
|
| 212 |
+
|
| 213 |
+
# if the current tool name hasn't been sent, send if available
|
| 214 |
+
# - otherwise send nothing
|
| 215 |
+
if not self.current_tool_name_sent:
|
| 216 |
+
function_name = current_tool_call.get("name")
|
| 217 |
+
if function_name:
|
| 218 |
+
|
| 219 |
+
delta = DeltaMessage(tool_calls=[
|
| 220 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 221 |
+
type="function",
|
| 222 |
+
id=f"chatcmpl-tool-{random_uuid()}",
|
| 223 |
+
function=DeltaFunctionCall(
|
| 224 |
+
name=function_name).model_dump(
|
| 225 |
+
exclude_none=True))
|
| 226 |
+
])
|
| 227 |
+
self.current_tool_name_sent = True
|
| 228 |
+
else:
|
| 229 |
+
delta = None
|
| 230 |
+
|
| 231 |
+
# now we know we're on the same tool call and we're streaming
|
| 232 |
+
# arguments
|
| 233 |
+
else:
|
| 234 |
+
|
| 235 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 236 |
+
self.current_tool_id].get("arguments")
|
| 237 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 238 |
+
|
| 239 |
+
new_text = delta_text.replace("\'", "\"")
|
| 240 |
+
|
| 241 |
+
if not cur_arguments and not prev_arguments:
|
| 242 |
+
|
| 243 |
+
delta = None
|
| 244 |
+
elif not cur_arguments and prev_arguments:
|
| 245 |
+
logger.error(
|
| 246 |
+
"INVARIANT - impossible to have arguments reset "
|
| 247 |
+
"mid-arguments")
|
| 248 |
+
delta = None
|
| 249 |
+
elif cur_arguments and not prev_arguments:
|
| 250 |
+
cur_arguments_json = json.dumps(cur_arguments)
|
| 251 |
+
logger.debug("finding %s in %s", new_text,
|
| 252 |
+
cur_arguments_json)
|
| 253 |
+
|
| 254 |
+
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
| 255 |
+
index(new_text) +
|
| 256 |
+
len(new_text)]
|
| 257 |
+
logger.debug("First tokens in arguments received: %s",
|
| 258 |
+
arguments_delta)
|
| 259 |
+
delta = DeltaMessage(tool_calls=[
|
| 260 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 261 |
+
function=DeltaFunctionCall(
|
| 262 |
+
arguments=arguments_delta).
|
| 263 |
+
model_dump(exclude_none=True))
|
| 264 |
+
])
|
| 265 |
+
self.streamed_args_for_tool[
|
| 266 |
+
self.current_tool_id] += arguments_delta
|
| 267 |
+
|
| 268 |
+
elif cur_arguments and prev_arguments:
|
| 269 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 270 |
+
prev_args_json = json.dumps(prev_arguments)
|
| 271 |
+
logger.debug("Searching for diff between \n%s\n%s",
|
| 272 |
+
cur_args_json, prev_args_json)
|
| 273 |
+
|
| 274 |
+
argument_diff = extract_intermediate_diff(
|
| 275 |
+
cur_args_json, prev_args_json)
|
| 276 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 277 |
+
delta = DeltaMessage(tool_calls=[
|
| 278 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 279 |
+
function=DeltaFunctionCall(
|
| 280 |
+
arguments=argument_diff).model_dump(
|
| 281 |
+
exclude_none=True))
|
| 282 |
+
])
|
| 283 |
+
self.streamed_args_for_tool[
|
| 284 |
+
self.current_tool_id] += argument_diff
|
| 285 |
+
else:
|
| 286 |
+
# try parsing it with regular JSON - if it works we're
|
| 287 |
+
# at the end, and we need to send the difference between
|
| 288 |
+
# tokens streamed so far and the valid JSON
|
| 289 |
+
delta = None
|
| 290 |
+
|
| 291 |
+
# check to see if the name is defined and has been sent. if so,
|
| 292 |
+
# stream the name - otherwise keep waiting
|
| 293 |
+
# finish by setting old and returning None as base case
|
| 294 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 295 |
+
return delta
|
| 296 |
+
|
| 297 |
+
except Exception:
|
| 298 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 299 |
+
logger.debug(
|
| 300 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 301 |
+
"error")
|
| 302 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from json import JSONDecoder
|
| 6 |
+
from typing import Dict, List, Sequence, Union
|
| 7 |
+
|
| 8 |
+
import partial_json_parser
|
| 9 |
+
from partial_json_parser.core.options import Allow
|
| 10 |
+
from transformers import PreTrainedTokenizerBase
|
| 11 |
+
|
| 12 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 13 |
+
DeltaFunctionCall, DeltaMessage,
|
| 14 |
+
DeltaToolCall,
|
| 15 |
+
ExtractedToolCallInformation,
|
| 16 |
+
FunctionCall, ToolCall)
|
| 17 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 18 |
+
ToolParser, ToolParserManager)
|
| 19 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
| 20 |
+
is_complete_json,
|
| 21 |
+
partial_json_loads)
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
from vllm.utils import random_uuid
|
| 24 |
+
|
| 25 |
+
logger = init_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@ToolParserManager.register_module("llama3_json")
|
| 29 |
+
class Llama3JsonToolParser(ToolParser):
|
| 30 |
+
"""
|
| 31 |
+
Tool call parser for Llama 3.1 models intended for use with the
|
| 32 |
+
examples/tool_chat_template_llama.jinja template.
|
| 33 |
+
|
| 34 |
+
Used when --enable-auto-tool-choice --tool-call-parser llama3_json
|
| 35 |
+
are all set
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 39 |
+
super().__init__(tokenizer)
|
| 40 |
+
|
| 41 |
+
# initialize properties used for state when parsing tool calls in
|
| 42 |
+
# streaming mode
|
| 43 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 44 |
+
self.current_tool_id: int = -1
|
| 45 |
+
self.current_tool_name_sent: bool = False
|
| 46 |
+
self.streamed_args_for_tool: List[str] = [
|
| 47 |
+
] # map what has been streamed for each tool so far to a list
|
| 48 |
+
self.bot_token = "<|python_tag|>"
|
| 49 |
+
self.bot_token_id = tokenizer.encode(self.bot_token,
|
| 50 |
+
add_special_tokens=False)[0]
|
| 51 |
+
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
| 52 |
+
|
| 53 |
+
def extract_tool_calls(
|
| 54 |
+
self, model_output: str,
|
| 55 |
+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
| 56 |
+
"""
|
| 57 |
+
Extract the tool calls from a complete model response.
|
| 58 |
+
"""
|
| 59 |
+
# case -- if a tool call token is not present, return a text response
|
| 60 |
+
if not (model_output.startswith(self.bot_token)
|
| 61 |
+
or model_output.startswith('{')):
|
| 62 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 63 |
+
tool_calls=[],
|
| 64 |
+
content=model_output)
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
# load the JSON, and then use it to build the Function and
|
| 68 |
+
# Tool Call
|
| 69 |
+
dec = JSONDecoder()
|
| 70 |
+
function_call_arr = []
|
| 71 |
+
|
| 72 |
+
# depending on the prompt format the Llama model may or may not
|
| 73 |
+
# prefix the output with the <|python_tag|> token
|
| 74 |
+
start_idx = len(self.bot_token) if model_output.startswith(
|
| 75 |
+
self.bot_token) else 0
|
| 76 |
+
while start_idx < len(model_output):
|
| 77 |
+
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
|
| 78 |
+
start_idx += end_idx + len('; ')
|
| 79 |
+
function_call_arr.append(obj)
|
| 80 |
+
|
| 81 |
+
tool_calls: List[ToolCall] = [
|
| 82 |
+
ToolCall(
|
| 83 |
+
type="function",
|
| 84 |
+
function=FunctionCall(
|
| 85 |
+
name=raw_function_call["name"],
|
| 86 |
+
# function call args are JSON but as a string
|
| 87 |
+
arguments=json.dumps(raw_function_call["arguments"] \
|
| 88 |
+
if "arguments" in raw_function_call \
|
| 89 |
+
else raw_function_call["parameters"])))
|
| 90 |
+
for raw_function_call in function_call_arr
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
# get any content before the tool call
|
| 94 |
+
ret = ExtractedToolCallInformation(tools_called=True,
|
| 95 |
+
tool_calls=tool_calls,
|
| 96 |
+
content=None)
|
| 97 |
+
return ret
|
| 98 |
+
|
| 99 |
+
except Exception:
|
| 100 |
+
logger.exception("Error in extracting tool call from response.")
|
| 101 |
+
# return information to just treat the tool call as regular JSON
|
| 102 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 103 |
+
tool_calls=[],
|
| 104 |
+
content=model_output)
|
| 105 |
+
|
| 106 |
+
def extract_tool_calls_streaming(
|
| 107 |
+
self,
|
| 108 |
+
previous_text: str,
|
| 109 |
+
current_text: str,
|
| 110 |
+
delta_text: str,
|
| 111 |
+
previous_token_ids: Sequence[int],
|
| 112 |
+
current_token_ids: Sequence[int],
|
| 113 |
+
delta_token_ids: Sequence[int],
|
| 114 |
+
request: ChatCompletionRequest,
|
| 115 |
+
) -> Union[DeltaMessage, None]:
|
| 116 |
+
|
| 117 |
+
if not (current_text.startswith(self.bot_token)
|
| 118 |
+
or current_text.startswith('{')):
|
| 119 |
+
return DeltaMessage(content=delta_text)
|
| 120 |
+
|
| 121 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 122 |
+
# sent yet, don't allow sending
|
| 123 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 124 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 125 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 126 |
+
else Allow.ALL & ~Allow.STR
|
| 127 |
+
try:
|
| 128 |
+
tool_call_arr = []
|
| 129 |
+
is_complete = []
|
| 130 |
+
try:
|
| 131 |
+
# depending on the prompt format the Llama model may or may not
|
| 132 |
+
# prefix the output with the <|python_tag|> token
|
| 133 |
+
start_idx = len(self.bot_token) if current_text.startswith(
|
| 134 |
+
self.bot_token) else 0
|
| 135 |
+
while start_idx < len(current_text):
|
| 136 |
+
(obj,
|
| 137 |
+
end_idx) = partial_json_loads(current_text[start_idx:],
|
| 138 |
+
flags)
|
| 139 |
+
is_complete.append(
|
| 140 |
+
is_complete_json(current_text[start_idx:start_idx +
|
| 141 |
+
end_idx]))
|
| 142 |
+
start_idx += end_idx + len('; ')
|
| 143 |
+
# depending on the prompt Llama can use
|
| 144 |
+
# either arguments or parameters
|
| 145 |
+
if "parameters" in obj:
|
| 146 |
+
assert "arguments" not in obj, \
|
| 147 |
+
"model generated both parameters and arguments"
|
| 148 |
+
obj["arguments"] = obj["parameters"]
|
| 149 |
+
tool_call_arr.append(obj)
|
| 150 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 151 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# select as the current tool call the one we're on the state at
|
| 155 |
+
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
| 156 |
+
if len(tool_call_arr) > 0 else {}
|
| 157 |
+
|
| 158 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 159 |
+
# only the array brackets, stream nothing
|
| 160 |
+
if len(tool_call_arr) == 0:
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
# case: we are starting a new tool in the array
|
| 164 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 165 |
+
elif (len(tool_call_arr) > 0
|
| 166 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 167 |
+
|
| 168 |
+
# if we're moving on to a new call, first make sure we
|
| 169 |
+
# haven't missed anything in the previous one that was
|
| 170 |
+
# auto-generated due to JSON completions, but wasn't
|
| 171 |
+
# streamed to the client yet.
|
| 172 |
+
if self.current_tool_id >= 0:
|
| 173 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 174 |
+
if cur_arguments:
|
| 175 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 176 |
+
sent = len(
|
| 177 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 178 |
+
argument_diff = cur_args_json[sent:]
|
| 179 |
+
|
| 180 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 181 |
+
delta = DeltaMessage(tool_calls=[
|
| 182 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 183 |
+
function=DeltaFunctionCall(
|
| 184 |
+
arguments=argument_diff).
|
| 185 |
+
model_dump(exclude_none=True))
|
| 186 |
+
])
|
| 187 |
+
self.streamed_args_for_tool[
|
| 188 |
+
self.current_tool_id] += argument_diff
|
| 189 |
+
else:
|
| 190 |
+
delta = None
|
| 191 |
+
else:
|
| 192 |
+
delta = None
|
| 193 |
+
# re-set stuff pertaining to progress in the current tool
|
| 194 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 195 |
+
self.current_tool_name_sent = False
|
| 196 |
+
self.streamed_args_for_tool.append("")
|
| 197 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 198 |
+
return delta
|
| 199 |
+
|
| 200 |
+
# if the current tool name hasn't been sent, send if available
|
| 201 |
+
# - otherwise send nothing
|
| 202 |
+
elif not self.current_tool_name_sent:
|
| 203 |
+
function_name = current_tool_call.get("name")
|
| 204 |
+
if function_name:
|
| 205 |
+
|
| 206 |
+
delta = DeltaMessage(tool_calls=[
|
| 207 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 208 |
+
type="function",
|
| 209 |
+
id=f"chatcmpl-tool-{random_uuid()}",
|
| 210 |
+
function=DeltaFunctionCall(
|
| 211 |
+
name=function_name).model_dump(
|
| 212 |
+
exclude_none=True))
|
| 213 |
+
])
|
| 214 |
+
self.current_tool_name_sent = True
|
| 215 |
+
else:
|
| 216 |
+
delta = None
|
| 217 |
+
|
| 218 |
+
# now we know we're on the same tool call and we're streaming
|
| 219 |
+
# arguments
|
| 220 |
+
else:
|
| 221 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 222 |
+
delta = None
|
| 223 |
+
|
| 224 |
+
if cur_arguments:
|
| 225 |
+
sent = len(
|
| 226 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 227 |
+
cur_args_json = json.dumps(cur_arguments)
|
| 228 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 229 |
+
self.current_tool_id].get("arguments")
|
| 230 |
+
|
| 231 |
+
argument_diff = None
|
| 232 |
+
if is_complete[self.current_tool_id]:
|
| 233 |
+
argument_diff = cur_args_json[sent:]
|
| 234 |
+
elif prev_arguments:
|
| 235 |
+
prev_args_json = json.dumps(prev_arguments)
|
| 236 |
+
if cur_args_json != prev_args_json:
|
| 237 |
+
|
| 238 |
+
prefix = find_common_prefix(
|
| 239 |
+
prev_args_json, cur_args_json)
|
| 240 |
+
argument_diff = prefix[sent:]
|
| 241 |
+
|
| 242 |
+
if argument_diff is not None:
|
| 243 |
+
delta = DeltaMessage(tool_calls=[
|
| 244 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 245 |
+
function=DeltaFunctionCall(
|
| 246 |
+
arguments=argument_diff).
|
| 247 |
+
model_dump(exclude_none=True))
|
| 248 |
+
])
|
| 249 |
+
self.streamed_args_for_tool[
|
| 250 |
+
self.current_tool_id] += argument_diff
|
| 251 |
+
|
| 252 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 253 |
+
return delta
|
| 254 |
+
|
| 255 |
+
except Exception:
|
| 256 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 257 |
+
logger.debug(
|
| 258 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 259 |
+
"error")
|
| 260 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from random import choices
|
| 6 |
+
from string import ascii_letters, digits
|
| 7 |
+
from typing import Dict, List, Sequence, Union
|
| 8 |
+
|
| 9 |
+
import partial_json_parser
|
| 10 |
+
from partial_json_parser.core.options import Allow
|
| 11 |
+
from pydantic import Field
|
| 12 |
+
|
| 13 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 14 |
+
DeltaFunctionCall, DeltaMessage,
|
| 15 |
+
DeltaToolCall,
|
| 16 |
+
ExtractedToolCallInformation,
|
| 17 |
+
FunctionCall, ToolCall)
|
| 18 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 19 |
+
ToolParser, ToolParserManager)
|
| 20 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (
|
| 21 |
+
extract_intermediate_diff)
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
| 24 |
+
|
| 25 |
+
logger = init_logger(__name__)
|
| 26 |
+
|
| 27 |
+
ALPHANUMERIC = ascii_letters + digits
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MistralToolCall(ToolCall):
|
| 31 |
+
id: str = Field(
|
| 32 |
+
default_factory=lambda: MistralToolCall.generate_random_id())
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def generate_random_id():
|
| 36 |
+
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
| 37 |
+
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
| 38 |
+
return "".join(choices(ALPHANUMERIC, k=9))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@ToolParserManager.register_module("mistral")
|
| 42 |
+
class MistralToolParser(ToolParser):
|
| 43 |
+
"""
|
| 44 |
+
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
| 45 |
+
examples/tool_chat_template_mistral.jinja template.
|
| 46 |
+
|
| 47 |
+
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 51 |
+
super().__init__(tokenizer)
|
| 52 |
+
|
| 53 |
+
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
| 54 |
+
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
| 55 |
+
"model...")
|
| 56 |
+
|
| 57 |
+
# initialize properties used for state when parsing tool calls in
|
| 58 |
+
# streaming mode
|
| 59 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 60 |
+
self.current_tool_id: int = -1
|
| 61 |
+
self.current_tool_name_sent: bool = False
|
| 62 |
+
self.streamed_args_for_tool: List[str] = [
|
| 63 |
+
] # map what has been streamed for each tool so far to a list
|
| 64 |
+
self.bot_token = "[TOOL_CALLS]"
|
| 65 |
+
self.bot_token_id = self.vocab.get(self.bot_token)
|
| 66 |
+
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
| 67 |
+
if self.bot_token_id is None:
|
| 68 |
+
raise RuntimeError(
|
| 69 |
+
"Mistral Tool Parser could not locate the tool call token in "
|
| 70 |
+
"the tokenizer!")
|
| 71 |
+
|
| 72 |
+
def extract_tool_calls(
|
| 73 |
+
self,
|
| 74 |
+
model_output: str,
|
| 75 |
+
request: ChatCompletionRequest,
|
| 76 |
+
) -> ExtractedToolCallInformation:
|
| 77 |
+
"""
|
| 78 |
+
Extract the tool calls from a complete model response. Requires
|
| 79 |
+
find-and-replacing single quotes with double quotes for JSON parsing,
|
| 80 |
+
make sure your tool call arguments don't ever include quotes!
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# case -- if a tool call token is not present, return a text response
|
| 84 |
+
if self.bot_token not in model_output:
|
| 85 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 86 |
+
tool_calls=[],
|
| 87 |
+
content=model_output)
|
| 88 |
+
|
| 89 |
+
# first remove the BOT token
|
| 90 |
+
tool_content = model_output.replace(self.bot_token, "").strip()
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
|
| 94 |
+
# we first try to directly load the json as parsing very nested
|
| 95 |
+
# jsons is difficult
|
| 96 |
+
try:
|
| 97 |
+
function_call_arr = json.loads(tool_content)
|
| 98 |
+
except json.JSONDecodeError:
|
| 99 |
+
# use a regex to find the part corresponding to the tool call.
|
| 100 |
+
# NOTE: This use case should not happen if the model is trained
|
| 101 |
+
# correctly. It's a easy possible fix so it's included, but
|
| 102 |
+
# can be brittle for very complex / highly nested tool calls
|
| 103 |
+
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
| 104 |
+
function_call_arr = json.loads(raw_tool_call)
|
| 105 |
+
|
| 106 |
+
# Tool Call
|
| 107 |
+
tool_calls: List[MistralToolCall] = [
|
| 108 |
+
MistralToolCall(
|
| 109 |
+
type="function",
|
| 110 |
+
function=FunctionCall(
|
| 111 |
+
name=raw_function_call["name"],
|
| 112 |
+
# function call args are JSON but as a string
|
| 113 |
+
arguments=json.dumps(raw_function_call["arguments"],
|
| 114 |
+
ensure_ascii=False)))
|
| 115 |
+
for raw_function_call in function_call_arr
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
# get any content before the tool call
|
| 119 |
+
content = model_output.split(self.bot_token)[0]
|
| 120 |
+
return ExtractedToolCallInformation(
|
| 121 |
+
tools_called=True,
|
| 122 |
+
tool_calls=tool_calls,
|
| 123 |
+
content=content if len(content) > 0 else None)
|
| 124 |
+
|
| 125 |
+
except Exception:
|
| 126 |
+
logger.exception("Error in extracting tool call from response.")
|
| 127 |
+
# return information to just treat the tool call as regular JSON
|
| 128 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 129 |
+
tool_calls=[],
|
| 130 |
+
content=tool_content)
|
| 131 |
+
|
| 132 |
+
def extract_tool_calls_streaming(
|
| 133 |
+
self,
|
| 134 |
+
previous_text: str,
|
| 135 |
+
current_text: str,
|
| 136 |
+
delta_text: str,
|
| 137 |
+
previous_token_ids: Sequence[int],
|
| 138 |
+
current_token_ids: Sequence[int],
|
| 139 |
+
delta_token_ids: Sequence[int],
|
| 140 |
+
request: ChatCompletionRequest,
|
| 141 |
+
) -> Union[DeltaMessage, None]:
|
| 142 |
+
|
| 143 |
+
# if the tool call token is not in the tokens generated so far, append
|
| 144 |
+
# output to contents since it's not a tool
|
| 145 |
+
if self.bot_token not in current_text:
|
| 146 |
+
return DeltaMessage(content=delta_text)
|
| 147 |
+
|
| 148 |
+
# if the tool call token ID IS in the tokens generated so far, that
|
| 149 |
+
# means we're parsing as tool calls now
|
| 150 |
+
|
| 151 |
+
# handle if we detected the BOT token which means the start of tool
|
| 152 |
+
# calling
|
| 153 |
+
if (self.bot_token_id in delta_token_ids
|
| 154 |
+
and len(delta_token_ids) == 1):
|
| 155 |
+
# if it's the only token, return None, so we don't send a chat
|
| 156 |
+
# completion any don't send a control token
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 160 |
+
# sent yet, don't allow sending
|
| 161 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 162 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 163 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 164 |
+
else Allow.ALL & ~Allow.STR
|
| 165 |
+
try:
|
| 166 |
+
|
| 167 |
+
# replace BOT token with empty string, and convert single quotes
|
| 168 |
+
# to double to allow parsing as JSON since mistral uses single
|
| 169 |
+
# quotes instead of double for tool calls
|
| 170 |
+
parsable_arr = current_text.split(self.bot_token)[-1]
|
| 171 |
+
|
| 172 |
+
# tool calls are generated in an array, so do partial JSON
|
| 173 |
+
# parsing on the entire array
|
| 174 |
+
try:
|
| 175 |
+
tool_call_arr: List[Dict] = partial_json_parser.loads(
|
| 176 |
+
parsable_arr, flags)
|
| 177 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 178 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
# select as the current tool call the one we're on the state at
|
| 182 |
+
|
| 183 |
+
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
| 184 |
+
if len(tool_call_arr) > 0 else {}
|
| 185 |
+
|
| 186 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 187 |
+
# only the array brackets, stream nothing
|
| 188 |
+
if len(tool_call_arr) == 0:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
# case: we are starting a new tool in the array
|
| 192 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 193 |
+
elif (len(tool_call_arr) > 0
|
| 194 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 195 |
+
|
| 196 |
+
# if we're moving on to a new call, first make sure we
|
| 197 |
+
# haven't missed anything in the previous one that was
|
| 198 |
+
# auto-generated due to JSON completions, but wasn't
|
| 199 |
+
# streamed to the client yet.
|
| 200 |
+
if self.current_tool_id >= 0:
|
| 201 |
+
diff: Union[str, None] = current_tool_call.get("arguments")
|
| 202 |
+
|
| 203 |
+
if diff:
|
| 204 |
+
diff = json.dumps(diff, ensure_ascii=False).replace(
|
| 205 |
+
self.streamed_args_for_tool[self.current_tool_id],
|
| 206 |
+
"")
|
| 207 |
+
delta = DeltaMessage(tool_calls=[
|
| 208 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 209 |
+
function=DeltaFunctionCall(
|
| 210 |
+
arguments=diff).model_dump(
|
| 211 |
+
exclude_none=True))
|
| 212 |
+
])
|
| 213 |
+
self.streamed_args_for_tool[
|
| 214 |
+
self.current_tool_id] += diff
|
| 215 |
+
else:
|
| 216 |
+
delta = None
|
| 217 |
+
else:
|
| 218 |
+
delta = None
|
| 219 |
+
# re-set stuff pertaining to progress in the current tool
|
| 220 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 221 |
+
self.current_tool_name_sent = False
|
| 222 |
+
self.streamed_args_for_tool.append("")
|
| 223 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 224 |
+
return delta
|
| 225 |
+
|
| 226 |
+
# case: update an existing tool - this is handled below
|
| 227 |
+
|
| 228 |
+
# if the current tool name hasn't been sent, send if available
|
| 229 |
+
# - otherwise send nothing
|
| 230 |
+
if not self.current_tool_name_sent:
|
| 231 |
+
function_name = current_tool_call.get("name")
|
| 232 |
+
if function_name:
|
| 233 |
+
|
| 234 |
+
delta = DeltaMessage(tool_calls=[
|
| 235 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 236 |
+
type="function",
|
| 237 |
+
id=MistralToolCall.generate_random_id(),
|
| 238 |
+
function=DeltaFunctionCall(
|
| 239 |
+
name=function_name).model_dump(
|
| 240 |
+
exclude_none=True))
|
| 241 |
+
])
|
| 242 |
+
self.current_tool_name_sent = True
|
| 243 |
+
else:
|
| 244 |
+
delta = None
|
| 245 |
+
|
| 246 |
+
# now we know we're on the same tool call and we're streaming
|
| 247 |
+
# arguments
|
| 248 |
+
else:
|
| 249 |
+
|
| 250 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 251 |
+
self.current_tool_id].get("arguments")
|
| 252 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 253 |
+
|
| 254 |
+
new_text = delta_text.replace("\'", "\"")
|
| 255 |
+
if ('"}' in new_text):
|
| 256 |
+
new_text = new_text[:new_text.rindex('"}')]
|
| 257 |
+
|
| 258 |
+
if not cur_arguments and not prev_arguments:
|
| 259 |
+
|
| 260 |
+
delta = None
|
| 261 |
+
elif not cur_arguments and prev_arguments:
|
| 262 |
+
logger.error(
|
| 263 |
+
"INVARIANT - impossible to have arguments reset "
|
| 264 |
+
"mid-arguments")
|
| 265 |
+
delta = None
|
| 266 |
+
elif cur_arguments and not prev_arguments:
|
| 267 |
+
cur_arguments_json = json.dumps(cur_arguments,
|
| 268 |
+
ensure_ascii=False)[:-2]
|
| 269 |
+
logger.debug("finding %s in %s", new_text,
|
| 270 |
+
cur_arguments_json)
|
| 271 |
+
|
| 272 |
+
if (new_text not in cur_arguments_json):
|
| 273 |
+
return None
|
| 274 |
+
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
| 275 |
+
rindex(new_text) +
|
| 276 |
+
len(new_text)]
|
| 277 |
+
logger.debug("First tokens in arguments received: %s",
|
| 278 |
+
arguments_delta)
|
| 279 |
+
delta = DeltaMessage(tool_calls=[
|
| 280 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 281 |
+
function=DeltaFunctionCall(
|
| 282 |
+
arguments=arguments_delta).
|
| 283 |
+
model_dump(exclude_none=True))
|
| 284 |
+
])
|
| 285 |
+
self.streamed_args_for_tool[
|
| 286 |
+
self.current_tool_id] += arguments_delta
|
| 287 |
+
|
| 288 |
+
elif cur_arguments and prev_arguments:
|
| 289 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 290 |
+
ensure_ascii=False)
|
| 291 |
+
prev_args_json = json.dumps(prev_arguments,
|
| 292 |
+
ensure_ascii=False)
|
| 293 |
+
logger.debug("Searching for diff between \n%s\n%s",
|
| 294 |
+
cur_args_json, prev_args_json)
|
| 295 |
+
|
| 296 |
+
argument_diff = extract_intermediate_diff(
|
| 297 |
+
cur_args_json, prev_args_json)
|
| 298 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 299 |
+
delta = DeltaMessage(tool_calls=[
|
| 300 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 301 |
+
function=DeltaFunctionCall(
|
| 302 |
+
arguments=argument_diff).model_dump(
|
| 303 |
+
exclude_none=True))
|
| 304 |
+
])
|
| 305 |
+
self.streamed_args_for_tool[
|
| 306 |
+
self.current_tool_id] += argument_diff
|
| 307 |
+
else:
|
| 308 |
+
# try parsing it with regular JSON - if it works we're
|
| 309 |
+
# at the end, and we need to send the difference between
|
| 310 |
+
# tokens streamed so far and the valid JSON
|
| 311 |
+
delta = None
|
| 312 |
+
|
| 313 |
+
# check to see if the name is defined and has been sent. if so,
|
| 314 |
+
# stream the name - otherwise keep waiting
|
| 315 |
+
# finish by setting old and returning None as base case
|
| 316 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 317 |
+
return delta
|
| 318 |
+
|
| 319 |
+
except Exception:
|
| 320 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 321 |
+
logger.debug(
|
| 322 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 323 |
+
"error")
|
| 324 |
+
return None
|
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from json import JSONDecodeError, JSONDecoder
|
| 5 |
+
from typing import Any, List, Tuple
|
| 6 |
+
|
| 7 |
+
import partial_json_parser
|
| 8 |
+
from partial_json_parser.core.options import Allow
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def find_common_prefix(s1: str, s2: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Finds a common prefix that is shared between two strings, if there is one.
|
| 14 |
+
Order of arguments is NOT important.
|
| 15 |
+
|
| 16 |
+
This function is provided as a UTILITY for extracting information from JSON
|
| 17 |
+
generated by partial_json_parser, to help in ensuring that the right tokens
|
| 18 |
+
are returned in streaming, so that close-quotes, close-brackets and
|
| 19 |
+
close-braces are not returned prematurely.
|
| 20 |
+
|
| 21 |
+
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
| 22 |
+
'{"fruit": "ap'
|
| 23 |
+
"""
|
| 24 |
+
prefix = ''
|
| 25 |
+
min_length = min(len(s1), len(s2))
|
| 26 |
+
for i in range(0, min_length):
|
| 27 |
+
if s1[i] == s2[i]:
|
| 28 |
+
prefix += s1[i]
|
| 29 |
+
else:
|
| 30 |
+
break
|
| 31 |
+
return prefix
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def find_common_suffix(s1: str, s2: str) -> str:
|
| 35 |
+
"""
|
| 36 |
+
Finds a common suffix shared between two strings, if there is one. Order of
|
| 37 |
+
arguments is NOT important.
|
| 38 |
+
Stops when the suffix ends OR it hits an alphanumeric character
|
| 39 |
+
|
| 40 |
+
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
| 41 |
+
"""
|
| 42 |
+
suffix = ''
|
| 43 |
+
min_length = min(len(s1), len(s2))
|
| 44 |
+
for i in range(1, min_length + 1):
|
| 45 |
+
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
| 46 |
+
suffix = s1[-i] + suffix
|
| 47 |
+
else:
|
| 48 |
+
break
|
| 49 |
+
return suffix
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def extract_intermediate_diff(curr: str, old: str) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Given two strings, extract the difference in the middle between two strings
|
| 55 |
+
that are known to have a common prefix and/or suffix.
|
| 56 |
+
|
| 57 |
+
This function is provided as a UTILITY for extracting information from JSON
|
| 58 |
+
generated by partial_json_parser, to help in ensuring that the right tokens
|
| 59 |
+
are returned in streaming, so that close-quotes, close-brackets and
|
| 60 |
+
close-braces are not returned prematurely. The order of arguments IS
|
| 61 |
+
important - the new version of the partially-parsed JSON must be the first
|
| 62 |
+
argument, and the secnod argument must be from the previous generation.
|
| 63 |
+
|
| 64 |
+
What it returns, is tokens that should be streamed to the client.
|
| 65 |
+
|
| 66 |
+
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
| 67 |
+
-> 'ple'
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
suffix = find_common_suffix(curr, old)
|
| 71 |
+
|
| 72 |
+
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
|
| 73 |
+
prefix = find_common_prefix(curr, old)
|
| 74 |
+
diff = curr
|
| 75 |
+
if len(suffix):
|
| 76 |
+
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
|
| 77 |
+
|
| 78 |
+
if len(prefix):
|
| 79 |
+
# replace the prefix only once in case it's mirrored
|
| 80 |
+
diff = diff.replace(prefix, '', 1)
|
| 81 |
+
|
| 82 |
+
return diff
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def find_all_indices(string: str, substring: str) -> List[int]:
|
| 86 |
+
"""
|
| 87 |
+
Find all (starting) indices of a substring in a given string. Useful for
|
| 88 |
+
tool call extraction
|
| 89 |
+
"""
|
| 90 |
+
indices = []
|
| 91 |
+
index = -1
|
| 92 |
+
while True:
|
| 93 |
+
index = string.find(substring, index + 1)
|
| 94 |
+
if index == -1:
|
| 95 |
+
break
|
| 96 |
+
indices.append(index)
|
| 97 |
+
return indices
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# partial_json_parser doesn't support extra data and
|
| 101 |
+
# JSONDecorder.raw_decode doesn't support partial JSON
|
| 102 |
+
def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
| 103 |
+
try:
|
| 104 |
+
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
| 105 |
+
except JSONDecodeError as e:
|
| 106 |
+
if "Extra data" in e.msg:
|
| 107 |
+
dec = JSONDecoder()
|
| 108 |
+
return dec.raw_decode(input_str)
|
| 109 |
+
raise
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def is_complete_json(input_str: str) -> bool:
|
| 113 |
+
try:
|
| 114 |
+
json.loads(input_str)
|
| 115 |
+
return True
|
| 116 |
+
except JSONDecodeError:
|
| 117 |
+
return False
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def consume_space(i: int, s: str) -> int:
|
| 121 |
+
while i < len(s) and s[i].isspace():
|
| 122 |
+
i += 1
|
| 123 |
+
return i
|
.venv/lib/python3.11/site-packages/vllm/lora/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (182 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc
ADDED
|
Binary file (59.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc
ADDED
|
Binary file (9.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (39.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc
ADDED
|
Binary file (6.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (9.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# pylint: disable=unused-argument
|
| 4 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import PretrainedConfig
|
| 9 |
+
|
| 10 |
+
from vllm.config import LoRAConfig
|
| 11 |
+
from vllm.distributed.communication_op import (
|
| 12 |
+
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
|
| 13 |
+
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
| 14 |
+
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
| 15 |
+
MergedColumnParallelLinearWithLoRA,
|
| 16 |
+
MergedQKVParallelLinearWithLora,
|
| 17 |
+
QKVParallelLinearWithLora,
|
| 18 |
+
RowParallelLinearWithLoRA)
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _fully_sharded_can_replace(can_replace):
|
| 25 |
+
"""
|
| 26 |
+
decorator which adds the condition of fully sharded loras
|
| 27 |
+
intended to wrap can_replace_layer()
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def dec(*args, **kwargs):
|
| 31 |
+
return (can_replace(*args, **kwargs)
|
| 32 |
+
and kwargs["lora_config"].fully_sharded_loras)
|
| 33 |
+
|
| 34 |
+
return dec
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
|
| 38 |
+
"""
|
| 39 |
+
For `ColumnParallelLinearWithLoRA` or classes that inherit from
|
| 40 |
+
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
|
| 41 |
+
"""
|
| 42 |
+
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
|
| 43 |
+
layer.lora_b_stacked) == len(layer.output_slices))
|
| 44 |
+
if layer.lora_bias_stacked is not None:
|
| 45 |
+
assert layer.n_slices == len(layer.lora_bias_stacked)
|
| 46 |
+
|
| 47 |
+
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
| 48 |
+
|
| 49 |
+
x = x.view(-1, x.shape[-1])
|
| 50 |
+
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
| 51 |
+
|
| 52 |
+
# Since communication is needed, the buffer is directly initialized as a
|
| 53 |
+
# tensor rather than a tuple of tensor.
|
| 54 |
+
buffers = torch.zeros(
|
| 55 |
+
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
| 56 |
+
dtype=torch.float32,
|
| 57 |
+
device=x.device,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
|
| 61 |
+
buffers = tensor_model_parallel_all_gather(buffers)
|
| 62 |
+
layer.punica_wrapper.add_expand(output,
|
| 63 |
+
buffers,
|
| 64 |
+
layer.lora_b_stacked,
|
| 65 |
+
layer.lora_bias_stacked,
|
| 66 |
+
layer.output_slices,
|
| 67 |
+
offset_start=0,
|
| 68 |
+
add_input=True)
|
| 69 |
+
|
| 70 |
+
output = output.view(*out_orig_shape)
|
| 71 |
+
# now have column partitioned and packed output
|
| 72 |
+
return output
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# these layers are based on the tensor parallelism strategy given in
|
| 76 |
+
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
| 77 |
+
# https://arxiv.org/abs/2311.03285.
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
| 81 |
+
"""
|
| 82 |
+
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
|
| 83 |
+
|
| 84 |
+
Based on S-LoRA, slicing happens along the rank dim.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
|
| 88 |
+
# their `lora_a` and `lora_b` have different sharding patterns. After
|
| 89 |
+
# completing the `lora_a` GEMM , a gather operation is performed.
|
| 90 |
+
# Therefore, the sharding of `lora_a` only needs to correspond with the
|
| 91 |
+
# gather operation.
|
| 92 |
+
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 94 |
+
shard_size = self.lora_a_stacked[0].shape[2]
|
| 95 |
+
start_idx = tp_rank * shard_size
|
| 96 |
+
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
| 97 |
+
return lora_a
|
| 98 |
+
|
| 99 |
+
def apply(self,
|
| 100 |
+
x: torch.Tensor,
|
| 101 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 102 |
+
return _mcp_apply(x, bias, self)
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
@_fully_sharded_can_replace
|
| 106 |
+
def can_replace_layer(
|
| 107 |
+
cls,
|
| 108 |
+
source_layer: nn.Module,
|
| 109 |
+
lora_config: LoRAConfig,
|
| 110 |
+
packed_modules_list: List,
|
| 111 |
+
model_config: Optional[PretrainedConfig],
|
| 112 |
+
) -> bool:
|
| 113 |
+
# specifying kwargs so they can be easily accessed in decorator
|
| 114 |
+
return super().can_replace_layer(
|
| 115 |
+
source_layer=source_layer,
|
| 116 |
+
lora_config=lora_config,
|
| 117 |
+
packed_modules_list=packed_modules_list,
|
| 118 |
+
model_config=model_config,
|
| 119 |
+
decorate=False,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MergedColumnParallelLinearWithShardedLoRA(
|
| 124 |
+
MergedColumnParallelLinearWithLoRA):
|
| 125 |
+
"""
|
| 126 |
+
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
| 127 |
+
LoRA A's also.
|
| 128 |
+
|
| 129 |
+
Based on S-LoRA, slicing happens along the rank dim.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def slice_lora_a(
|
| 133 |
+
self, lora_a: List[Union[torch.Tensor, None]]
|
| 134 |
+
) -> List[Union[torch.Tensor, None]]:
|
| 135 |
+
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
|
| 136 |
+
output_shard_size = self.lora_a_stacked[0].shape[2]
|
| 137 |
+
output_start_idx = self.tp_rank * output_shard_size
|
| 138 |
+
lora_a = [
|
| 139 |
+
lora_a[0][:, output_start_idx:output_start_idx +
|
| 140 |
+
output_shard_size] if lora_a[0] is not None else None,
|
| 141 |
+
lora_a[1][:, output_start_idx:output_start_idx +
|
| 142 |
+
output_shard_size] if lora_a[1] is not None else None,
|
| 143 |
+
]
|
| 144 |
+
return lora_a
|
| 145 |
+
|
| 146 |
+
def apply(self,
|
| 147 |
+
x: torch.Tensor,
|
| 148 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 149 |
+
return _mcp_apply(x, bias, self)
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
@_fully_sharded_can_replace
|
| 153 |
+
def can_replace_layer(
|
| 154 |
+
cls,
|
| 155 |
+
source_layer: nn.Module,
|
| 156 |
+
lora_config: LoRAConfig,
|
| 157 |
+
packed_modules_list: List,
|
| 158 |
+
model_config: Optional[PretrainedConfig],
|
| 159 |
+
) -> bool:
|
| 160 |
+
# specifying kwargs so they can be easily accessed in decorator
|
| 161 |
+
return super().can_replace_layer(
|
| 162 |
+
source_layer=source_layer,
|
| 163 |
+
lora_config=lora_config,
|
| 164 |
+
packed_modules_list=packed_modules_list,
|
| 165 |
+
model_config=model_config,
|
| 166 |
+
decorate=False,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
| 171 |
+
"""
|
| 172 |
+
Differs from QKVParallelLinearWithLora by slicing the
|
| 173 |
+
LoRA A's also.
|
| 174 |
+
|
| 175 |
+
Based on S-LoRA, slicing happens along the rank dim.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
| 179 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 180 |
+
shard_size = self.lora_a_stacked[0].shape[2]
|
| 181 |
+
start_idx = tp_rank * shard_size
|
| 182 |
+
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
| 183 |
+
return lora_a
|
| 184 |
+
|
| 185 |
+
def apply(self,
|
| 186 |
+
x: torch.Tensor,
|
| 187 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 188 |
+
return _mcp_apply(x, bias, self)
|
| 189 |
+
|
| 190 |
+
@classmethod
|
| 191 |
+
@_fully_sharded_can_replace
|
| 192 |
+
def can_replace_layer(cls, source_layer: nn.Module,
|
| 193 |
+
lora_config: LoRAConfig, packed_modules_list: List,
|
| 194 |
+
model_config: Optional[PretrainedConfig]) -> bool:
|
| 195 |
+
# specifying kwargs so they can be easily accessed in decorator
|
| 196 |
+
return super().can_replace_layer(
|
| 197 |
+
source_layer=source_layer,
|
| 198 |
+
lora_config=lora_config,
|
| 199 |
+
packed_modules_list=packed_modules_list,
|
| 200 |
+
model_config=model_config,
|
| 201 |
+
decorate=False,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
| 206 |
+
"""
|
| 207 |
+
Differs from MergedQKVParallelLinearWithLora by slicing the
|
| 208 |
+
LoRA A's also.
|
| 209 |
+
|
| 210 |
+
Based on S-LoRA, slicing happens along the rank dim.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def slice_lora_a(
|
| 214 |
+
self, lora_a: List[Union[torch.Tensor, None]]
|
| 215 |
+
) -> List[Union[torch.Tensor, None]]:
|
| 216 |
+
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
|
| 217 |
+
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
| 218 |
+
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
| 219 |
+
lora_a = [
|
| 220 |
+
lora_a[0][:, start_idx[0]:start_idx[0] +
|
| 221 |
+
shard_size[0]] if lora_a[0] is not None else None,
|
| 222 |
+
lora_a[1][:, start_idx[1]:start_idx[1] +
|
| 223 |
+
shard_size[1]] if lora_a[1] is not None else None,
|
| 224 |
+
lora_a[2][:, start_idx[2]:start_idx[2] +
|
| 225 |
+
shard_size[2]] if lora_a[2] is not None else None,
|
| 226 |
+
]
|
| 227 |
+
return lora_a
|
| 228 |
+
|
| 229 |
+
def apply(self,
|
| 230 |
+
x: torch.Tensor,
|
| 231 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 232 |
+
return _mcp_apply(x, bias, self)
|
| 233 |
+
|
| 234 |
+
@classmethod
|
| 235 |
+
@_fully_sharded_can_replace
|
| 236 |
+
def can_replace_layer(
|
| 237 |
+
cls,
|
| 238 |
+
source_layer: nn.Module,
|
| 239 |
+
lora_config: LoRAConfig,
|
| 240 |
+
packed_modules_list: List,
|
| 241 |
+
model_config: Optional[PretrainedConfig],
|
| 242 |
+
) -> bool:
|
| 243 |
+
# specifying kwargs so they can be easily accessed in decorator
|
| 244 |
+
return super().can_replace_layer(
|
| 245 |
+
source_layer=source_layer,
|
| 246 |
+
lora_config=lora_config,
|
| 247 |
+
packed_modules_list=packed_modules_list,
|
| 248 |
+
model_config=model_config,
|
| 249 |
+
decorate=False,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
| 254 |
+
"""
|
| 255 |
+
Differs from RowParallelLinearWithLoRA by slicing the
|
| 256 |
+
LoRA B's also.
|
| 257 |
+
|
| 258 |
+
Based on S-LoRA, slicing happens along the output dim.
|
| 259 |
+
This yields a combined partial sum from the row parallel base
|
| 260 |
+
layer and column partitioned output from the LoRA.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
| 264 |
+
shard_size = self.lora_b_stacked[0].shape[2]
|
| 265 |
+
start_idx = self.tp_rank * shard_size
|
| 266 |
+
end_idx = (self.tp_rank + 1) * shard_size
|
| 267 |
+
lora_b = lora_b[:, start_idx:end_idx]
|
| 268 |
+
return lora_b
|
| 269 |
+
|
| 270 |
+
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
| 271 |
+
if bias is None:
|
| 272 |
+
return bias
|
| 273 |
+
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
| 274 |
+
self.lora_bias_stacked)
|
| 275 |
+
shard_size = self.lora_bias_stacked[0].shape[2]
|
| 276 |
+
start_idx = self.tp_rank * shard_size
|
| 277 |
+
end_idx = (self.tp_rank + 1) * shard_size
|
| 278 |
+
bias = bias[start_idx:end_idx]
|
| 279 |
+
return bias
|
| 280 |
+
|
| 281 |
+
def apply(self,
|
| 282 |
+
x: torch.Tensor,
|
| 283 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 284 |
+
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
| 285 |
+
|
| 286 |
+
x = x.view(-1, x.shape[-1])
|
| 287 |
+
output, out_orig_shape = output.view(-1,
|
| 288 |
+
output.shape[-1]), output.shape
|
| 289 |
+
buffer = torch.zeros(
|
| 290 |
+
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
|
| 291 |
+
dtype=torch.float32,
|
| 292 |
+
device=x.device,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
| 296 |
+
buffer = tensor_model_parallel_all_reduce(buffer)
|
| 297 |
+
|
| 298 |
+
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
| 299 |
+
# by adding the column partitioned lora output to a slice of output
|
| 300 |
+
# tensor, which is a partial sum due to row parallel. All that
|
| 301 |
+
# remains is a standard all_reduce. User should be aware though that
|
| 302 |
+
# the output is not the same as a normal row_parallel, it should be
|
| 303 |
+
# reduced before being used
|
| 304 |
+
# NOTE offset are based on the rank.
|
| 305 |
+
shard_size = self.lora_b_stacked[0].shape[2]
|
| 306 |
+
offset_start = self.tp_rank * shard_size
|
| 307 |
+
self.punica_wrapper.add_expand(
|
| 308 |
+
output,
|
| 309 |
+
buffer,
|
| 310 |
+
self.lora_b_stacked,
|
| 311 |
+
self.lora_bias_stacked,
|
| 312 |
+
self.output_slices,
|
| 313 |
+
offset_start=offset_start,
|
| 314 |
+
add_input=True,
|
| 315 |
+
)
|
| 316 |
+
output = output.view(*out_orig_shape)
|
| 317 |
+
return output
|
| 318 |
+
|
| 319 |
+
@classmethod
|
| 320 |
+
@_fully_sharded_can_replace
|
| 321 |
+
def can_replace_layer(
|
| 322 |
+
cls,
|
| 323 |
+
source_layer: nn.Module,
|
| 324 |
+
lora_config: LoRAConfig,
|
| 325 |
+
packed_modules_list: List,
|
| 326 |
+
model_config: Optional[PretrainedConfig],
|
| 327 |
+
) -> bool:
|
| 328 |
+
# specifying kwargs so they can be easily accessed in decorator
|
| 329 |
+
return super().can_replace_layer(
|
| 330 |
+
source_layer=source_layer,
|
| 331 |
+
lora_config=lora_config,
|
| 332 |
+
packed_modules_list=packed_modules_list,
|
| 333 |
+
model_config=model_config,
|
| 334 |
+
decorate=False,
|
| 335 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/lora/layers.py
ADDED
|
@@ -0,0 +1,1206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# pylint: disable=unused-argument
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import PretrainedConfig
|
| 12 |
+
|
| 13 |
+
from vllm.adapter_commons.layers import AdapterMapping
|
| 14 |
+
from vllm.config import LoRAConfig
|
| 15 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 16 |
+
get_tensor_model_parallel_world_size,
|
| 17 |
+
split_tensor_along_last_dim,
|
| 18 |
+
tensor_model_parallel_all_gather,
|
| 19 |
+
tensor_model_parallel_all_reduce,
|
| 20 |
+
tensor_model_parallel_gather)
|
| 21 |
+
from vllm.distributed.utils import divide
|
| 22 |
+
# yapf: disable
|
| 23 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 24 |
+
LinearBase,
|
| 25 |
+
MergedColumnParallelLinear,
|
| 26 |
+
QKVParallelLinear,
|
| 27 |
+
ReplicatedLinear,
|
| 28 |
+
RowParallelLinear)
|
| 29 |
+
# yapf: enable
|
| 30 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 31 |
+
from vllm.model_executor.layers.rotary_embedding import (
|
| 32 |
+
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
| 33 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 34 |
+
VocabParallelEmbedding)
|
| 35 |
+
from vllm.platforms import current_platform
|
| 36 |
+
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
| 42 |
+
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
| 43 |
+
"""Returns the device for where to place the LoRA tensors."""
|
| 44 |
+
# unquantizedLinear
|
| 45 |
+
if hasattr(base_layer, "weight"):
|
| 46 |
+
return base_layer.weight.device
|
| 47 |
+
# Compressed Tensor
|
| 48 |
+
elif hasattr(base_layer, "weight_packed"):
|
| 49 |
+
return base_layer.weight_packed.device
|
| 50 |
+
# GPTQ/AWQ
|
| 51 |
+
elif hasattr(base_layer, "qweight"):
|
| 52 |
+
return base_layer.qweight.device
|
| 53 |
+
# marlin
|
| 54 |
+
elif hasattr(base_layer, "B"):
|
| 55 |
+
return base_layer.B.device
|
| 56 |
+
# HQQ marlin
|
| 57 |
+
elif hasattr(base_layer, "W_q"):
|
| 58 |
+
return base_layer.W_q.device
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unsupported base layer: {base_layer}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _not_fully_sharded_can_replace(can_replace):
|
| 64 |
+
"""
|
| 65 |
+
decorator which adds the condition of not using fully sharded loras
|
| 66 |
+
intended to wrap can_replace_layer()
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def dec(*args, **kwargs):
|
| 70 |
+
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
| 71 |
+
condition = (not kwargs["lora_config"].fully_sharded_loras
|
| 72 |
+
if decorate else True)
|
| 73 |
+
return can_replace(*args, **kwargs) and condition
|
| 74 |
+
|
| 75 |
+
return dec
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class LoRAMapping(AdapterMapping):
|
| 80 |
+
is_prefill: bool = False
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class BaseLayerWithLoRA(nn.Module):
|
| 84 |
+
|
| 85 |
+
def slice_lora_a(
|
| 86 |
+
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
|
| 87 |
+
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
|
| 88 |
+
"""Slice lora a if splitting for tensor parallelism."""
|
| 89 |
+
...
|
| 90 |
+
|
| 91 |
+
def slice_lora_b(
|
| 92 |
+
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
|
| 93 |
+
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
|
| 94 |
+
"""Slice lora b if splitting with tensor parallelism."""
|
| 95 |
+
...
|
| 96 |
+
|
| 97 |
+
def create_lora_weights(
|
| 98 |
+
self,
|
| 99 |
+
max_loras: int,
|
| 100 |
+
lora_config: LoRAConfig,
|
| 101 |
+
model_config: Optional[PretrainedConfig] = None,
|
| 102 |
+
) -> None:
|
| 103 |
+
"""Initializes lora matrices."""
|
| 104 |
+
...
|
| 105 |
+
|
| 106 |
+
def reset_lora(self, index: int):
|
| 107 |
+
"""Resets the lora weights at index back to 0."""
|
| 108 |
+
...
|
| 109 |
+
|
| 110 |
+
def set_lora(
|
| 111 |
+
self,
|
| 112 |
+
index: int,
|
| 113 |
+
lora_a: torch.Tensor,
|
| 114 |
+
lora_b: torch.Tensor,
|
| 115 |
+
embeddings_tensor: Optional[torch.Tensor],
|
| 116 |
+
bias: Optional[torch.Tensor] = None,
|
| 117 |
+
):
|
| 118 |
+
"""Overwrites lora tensors at index."""
|
| 119 |
+
...
|
| 120 |
+
|
| 121 |
+
def set_mapping(
|
| 122 |
+
self,
|
| 123 |
+
punica_wrapper,
|
| 124 |
+
):
|
| 125 |
+
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def can_replace_layer(
|
| 129 |
+
cls,
|
| 130 |
+
source_layer: nn.Module,
|
| 131 |
+
lora_config: LoRAConfig,
|
| 132 |
+
packed_modules_list: List,
|
| 133 |
+
model_config: Optional[PretrainedConfig],
|
| 134 |
+
) -> bool:
|
| 135 |
+
"""Returns True if the layer can be replaced by this LoRA layer."""
|
| 136 |
+
raise NotImplementedError
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
| 140 |
+
|
| 141 |
+
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.base_layer = base_layer
|
| 144 |
+
self.embeddings_slice: Optional[Tuple[int, int]]
|
| 145 |
+
self.embeddings_weights: Optional[torch.Tensor]
|
| 146 |
+
|
| 147 |
+
def create_lora_weights(
|
| 148 |
+
self,
|
| 149 |
+
max_loras: int,
|
| 150 |
+
lora_config: LoRAConfig,
|
| 151 |
+
model_config: Optional[PretrainedConfig] = None) -> None:
|
| 152 |
+
|
| 153 |
+
if self.base_layer.num_added_embeddings_per_partition > 0:
|
| 154 |
+
# We can start adding lora weights
|
| 155 |
+
self.embeddings_weights = self.base_layer.weight.data[
|
| 156 |
+
self.base_layer.num_org_embeddings_per_partition:self.
|
| 157 |
+
base_layer.num_org_embeddings_per_partition +
|
| 158 |
+
self.base_layer.num_added_embeddings_per_partition]
|
| 159 |
+
self.embeddings_slice = (
|
| 160 |
+
self.base_layer.shard_indices.added_vocab_start_index -
|
| 161 |
+
self.base_layer.org_vocab_size,
|
| 162 |
+
self.base_layer.shard_indices.added_vocab_end_index -
|
| 163 |
+
self.base_layer.org_vocab_size)
|
| 164 |
+
self.base_layer.weight.data[
|
| 165 |
+
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
|
| 166 |
+
else:
|
| 167 |
+
self.embeddings_slice = None
|
| 168 |
+
self.embeddings_weights = None
|
| 169 |
+
|
| 170 |
+
self.embeddings_tensors = torch.zeros(
|
| 171 |
+
(
|
| 172 |
+
max_loras,
|
| 173 |
+
lora_config.lora_extra_vocab_size,
|
| 174 |
+
self.base_layer.embedding_dim,
|
| 175 |
+
),
|
| 176 |
+
dtype=self.base_layer.weight.dtype,
|
| 177 |
+
device=self.base_layer.weight.device,
|
| 178 |
+
)
|
| 179 |
+
self.lora_a_stacked = torch.zeros(
|
| 180 |
+
(
|
| 181 |
+
max_loras,
|
| 182 |
+
self.base_layer.org_vocab_size +
|
| 183 |
+
lora_config.lora_extra_vocab_size,
|
| 184 |
+
lora_config.max_lora_rank,
|
| 185 |
+
),
|
| 186 |
+
dtype=lora_config.lora_dtype,
|
| 187 |
+
device=self.base_layer.weight.device,
|
| 188 |
+
)
|
| 189 |
+
self.lora_b_stacked = torch.zeros(
|
| 190 |
+
(
|
| 191 |
+
max_loras,
|
| 192 |
+
1,
|
| 193 |
+
self.base_layer.embedding_dim,
|
| 194 |
+
lora_config.max_lora_rank,
|
| 195 |
+
),
|
| 196 |
+
dtype=lora_config.lora_dtype,
|
| 197 |
+
device=self.base_layer.weight.device,
|
| 198 |
+
)
|
| 199 |
+
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
| 200 |
+
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
| 201 |
+
self.lora_a_stacked.shape[2],
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def reset_lora(self, index: int):
|
| 205 |
+
self.lora_a_stacked[index] = 0
|
| 206 |
+
self.lora_b_stacked[index] = 0
|
| 207 |
+
self.embeddings_tensors[index] = 0
|
| 208 |
+
|
| 209 |
+
def set_lora(
|
| 210 |
+
self,
|
| 211 |
+
index: int,
|
| 212 |
+
lora_a: torch.Tensor,
|
| 213 |
+
lora_b: torch.Tensor,
|
| 214 |
+
embeddings_tensor: Optional[torch.Tensor],
|
| 215 |
+
bias: Optional[torch.Tensor] = None,
|
| 216 |
+
):
|
| 217 |
+
self.reset_lora(index)
|
| 218 |
+
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
| 219 |
+
lora_a, non_blocking=True)
|
| 220 |
+
self.lora_b_stacked[index,
|
| 221 |
+
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
| 222 |
+
lora_b.T, non_blocking=True)
|
| 223 |
+
if embeddings_tensor is not None:
|
| 224 |
+
self.embeddings_tensors[
|
| 225 |
+
index,
|
| 226 |
+
:embeddings_tensor.shape[0],
|
| 227 |
+
:embeddings_tensor.shape[1],
|
| 228 |
+
].copy_(embeddings_tensor, non_blocking=True)
|
| 229 |
+
if self.embeddings_slice is not None:
|
| 230 |
+
# TODO(yard1): Optimize this copy, we don't need to copy
|
| 231 |
+
# everything, just the modified part
|
| 232 |
+
embeddings = self.embeddings_tensors.view(
|
| 233 |
+
self.embeddings_tensors.shape[0] *
|
| 234 |
+
self.embeddings_tensors.shape[1],
|
| 235 |
+
self.embeddings_tensors.shape[2],
|
| 236 |
+
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
| 237 |
+
assert self.embeddings_weights is not None
|
| 238 |
+
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
| 239 |
+
|
| 240 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 241 |
+
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
| 242 |
+
embeddings_indices = self.punica_wrapper.embeddings_indices
|
| 243 |
+
indices = embeddings_indices[1].view_as(x)
|
| 244 |
+
full_lora_a_embeddings = F.embedding(
|
| 245 |
+
x + indices,
|
| 246 |
+
self.lora_a_stacked_2d,
|
| 247 |
+
)
|
| 248 |
+
indices = embeddings_indices[0].view_as(x)
|
| 249 |
+
full_output = self.base_layer.forward(
|
| 250 |
+
x.add_(indices * added_tokens_mask))
|
| 251 |
+
|
| 252 |
+
full_output_org = full_output
|
| 253 |
+
if full_output.ndim == 3:
|
| 254 |
+
full_output = full_output.view(
|
| 255 |
+
full_output.shape[0] * full_output.shape[1], -1)
|
| 256 |
+
if full_lora_a_embeddings.ndim == 3:
|
| 257 |
+
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
| 258 |
+
full_lora_a_embeddings.shape[0] *
|
| 259 |
+
full_lora_a_embeddings.shape[1],
|
| 260 |
+
-1,
|
| 261 |
+
)
|
| 262 |
+
self.punica_wrapper.add_lora_embedding(full_output,
|
| 263 |
+
full_lora_a_embeddings,
|
| 264 |
+
self.lora_b_stacked,
|
| 265 |
+
add_input=True)
|
| 266 |
+
return full_output.view_as(full_output_org)
|
| 267 |
+
|
| 268 |
+
@classmethod
|
| 269 |
+
def can_replace_layer(
|
| 270 |
+
cls,
|
| 271 |
+
source_layer: nn.Module,
|
| 272 |
+
lora_config: LoRAConfig,
|
| 273 |
+
packed_modules_list: List,
|
| 274 |
+
model_config: Optional[PretrainedConfig],
|
| 275 |
+
) -> bool:
|
| 276 |
+
return type(source_layer) is VocabParallelEmbedding
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
| 280 |
+
|
| 281 |
+
def __init__(self, base_layer: LinearBase):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.base_layer = base_layer
|
| 284 |
+
self.input_size = self.base_layer.input_size
|
| 285 |
+
self.device = _get_lora_device(self.base_layer)
|
| 286 |
+
self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None
|
| 287 |
+
|
| 288 |
+
self.output_slices: Tuple[int, ...]
|
| 289 |
+
self.tp_size: int
|
| 290 |
+
self.output_size: int
|
| 291 |
+
self.n_slices: int
|
| 292 |
+
|
| 293 |
+
def create_lora_weights(
|
| 294 |
+
self,
|
| 295 |
+
max_loras: int,
|
| 296 |
+
lora_config: LoRAConfig,
|
| 297 |
+
model_config: Optional[PretrainedConfig] = None,
|
| 298 |
+
) -> None:
|
| 299 |
+
self.lora_config = lora_config
|
| 300 |
+
#
|
| 301 |
+
if isinstance(self.base_layer, ReplicatedLinear):
|
| 302 |
+
lora_a_out_size = lora_config.max_lora_rank
|
| 303 |
+
lora_b_out_size = self.output_size
|
| 304 |
+
|
| 305 |
+
elif isinstance(self.base_layer, ColumnParallelLinear):
|
| 306 |
+
lora_a_out_size = (lora_config.max_lora_rank if
|
| 307 |
+
not lora_config.fully_sharded_loras else divide(
|
| 308 |
+
lora_config.max_lora_rank, self.tp_size))
|
| 309 |
+
lora_b_out_size = self.output_size
|
| 310 |
+
|
| 311 |
+
elif isinstance(self.base_layer, RowParallelLinear):
|
| 312 |
+
lora_a_out_size = lora_config.max_lora_rank
|
| 313 |
+
lora_b_out_size = (self.output_size if
|
| 314 |
+
not lora_config.fully_sharded_loras else divide(
|
| 315 |
+
self.output_size, self.tp_size))
|
| 316 |
+
else:
|
| 317 |
+
raise NotImplementedError
|
| 318 |
+
|
| 319 |
+
self.lora_a_stacked = tuple(
|
| 320 |
+
torch.zeros(
|
| 321 |
+
max_loras,
|
| 322 |
+
1,
|
| 323 |
+
lora_a_out_size,
|
| 324 |
+
self.input_size,
|
| 325 |
+
dtype=lora_config.lora_dtype,
|
| 326 |
+
device=self.device,
|
| 327 |
+
) for _ in range(self.n_slices))
|
| 328 |
+
self.lora_b_stacked = tuple(
|
| 329 |
+
torch.zeros(
|
| 330 |
+
max_loras,
|
| 331 |
+
1,
|
| 332 |
+
lora_b_out_size,
|
| 333 |
+
lora_config.max_lora_rank,
|
| 334 |
+
dtype=lora_config.lora_dtype,
|
| 335 |
+
device=self.device,
|
| 336 |
+
) for _ in range(self.n_slices))
|
| 337 |
+
if lora_config.bias_enabled:
|
| 338 |
+
lora_bias_out_size = lora_b_out_size
|
| 339 |
+
self.lora_bias_stacked = tuple(
|
| 340 |
+
torch.zeros(
|
| 341 |
+
max_loras,
|
| 342 |
+
1,
|
| 343 |
+
lora_bias_out_size,
|
| 344 |
+
dtype=lora_config.lora_dtype,
|
| 345 |
+
device=self.device,
|
| 346 |
+
) for _ in range(self.n_slices))
|
| 347 |
+
self.output_slices = (self.lora_b_stacked[0].shape[2], )
|
| 348 |
+
|
| 349 |
+
def reset_lora(self, index: int):
|
| 350 |
+
for s_index in range(self.n_slices):
|
| 351 |
+
self.lora_a_stacked[s_index][index] = 0
|
| 352 |
+
self.lora_b_stacked[s_index][index] = 0
|
| 353 |
+
if self.lora_config.bias_enabled:
|
| 354 |
+
# Make mypy happy
|
| 355 |
+
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
| 356 |
+
self.lora_bias_stacked)
|
| 357 |
+
self.lora_bias_stacked[s_index][index] = 0
|
| 358 |
+
|
| 359 |
+
def set_lora(
|
| 360 |
+
self,
|
| 361 |
+
index: int,
|
| 362 |
+
lora_a: torch.Tensor,
|
| 363 |
+
lora_b: torch.Tensor,
|
| 364 |
+
embeddings_tensor: Optional[torch.Tensor],
|
| 365 |
+
lora_bias: Optional[torch.Tensor] = None,
|
| 366 |
+
):
|
| 367 |
+
# Except for QKVParallelLinearWithLora and
|
| 368 |
+
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
| 369 |
+
# store weights in a tuple of size 1. These two layers will
|
| 370 |
+
# override this function.
|
| 371 |
+
assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
|
| 372 |
+
self.n_slices == 1)
|
| 373 |
+
|
| 374 |
+
self.reset_lora(index)
|
| 375 |
+
if self.tp_size > 1:
|
| 376 |
+
lora_a = self.slice_lora_a(lora_a)
|
| 377 |
+
lora_b = self.slice_lora_b(lora_b)
|
| 378 |
+
if lora_bias is not None:
|
| 379 |
+
lora_bias = self.slice_bias(lora_bias)
|
| 380 |
+
|
| 381 |
+
self.lora_a_stacked[0][index,
|
| 382 |
+
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
| 383 |
+
lora_a.T, non_blocking=True)
|
| 384 |
+
self.lora_b_stacked[0][index,
|
| 385 |
+
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
| 386 |
+
lora_b.T, non_blocking=True)
|
| 387 |
+
if lora_bias is not None:
|
| 388 |
+
|
| 389 |
+
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
| 390 |
+
self.lora_bias_stacked)
|
| 391 |
+
assert len(self.lora_bias_stacked)
|
| 392 |
+
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
|
| 393 |
+
lora_bias.T, non_blocking=True)
|
| 394 |
+
|
| 395 |
+
def apply(self,
|
| 396 |
+
x: torch.Tensor,
|
| 397 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 398 |
+
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
| 399 |
+
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
|
| 400 |
+
self.lora_b_stacked,
|
| 401 |
+
self.lora_bias_stacked, 1.0,
|
| 402 |
+
self.output_slices)
|
| 403 |
+
return output
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
| 407 |
+
|
| 408 |
+
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
| 409 |
+
super().__init__(base_layer, )
|
| 410 |
+
# To ensure interface compatibility, set to 1 always.
|
| 411 |
+
self.tp_size = 1
|
| 412 |
+
self.output_size = self.base_layer.output_size
|
| 413 |
+
self.n_slices = 1
|
| 414 |
+
|
| 415 |
+
def forward(
|
| 416 |
+
self, input_: torch.Tensor
|
| 417 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 418 |
+
"""Forward of ReplicatedLinearWithLoRA
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
input_: Tensor whose last dimension is `input_size`.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
- output
|
| 425 |
+
- bias
|
| 426 |
+
"""
|
| 427 |
+
bias = (self.base_layer.bias
|
| 428 |
+
if not self.base_layer.skip_bias_add else None)
|
| 429 |
+
|
| 430 |
+
# Matrix multiply.
|
| 431 |
+
output = self.apply(input_, bias)
|
| 432 |
+
|
| 433 |
+
output_bias = (self.base_layer.bias
|
| 434 |
+
if self.base_layer.skip_bias_add else None)
|
| 435 |
+
return output, output_bias
|
| 436 |
+
|
| 437 |
+
# ReplicatedLinear should always be replaced, regardless of the fully
|
| 438 |
+
# sharded LoRAs setting, because it is, by definition, copied per GPU.
|
| 439 |
+
@classmethod
|
| 440 |
+
def can_replace_layer(
|
| 441 |
+
cls,
|
| 442 |
+
source_layer: nn.Module,
|
| 443 |
+
lora_config: LoRAConfig,
|
| 444 |
+
packed_modules_list: List,
|
| 445 |
+
model_config: Optional[PretrainedConfig],
|
| 446 |
+
) -> bool:
|
| 447 |
+
return type(source_layer) is ReplicatedLinear
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
| 451 |
+
"""
|
| 452 |
+
LoRA on top of ColumnParallelLinear layer.
|
| 453 |
+
LoRA B is sliced for tensor parallelism.
|
| 454 |
+
There are two types for the `base_layer`:
|
| 455 |
+
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
|
| 456 |
+
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
| 460 |
+
super().__init__(base_layer)
|
| 461 |
+
# The base_layer type is ColumnParallelLinear or
|
| 462 |
+
# MergedColumnParallelLinear, their weight sharding logic is
|
| 463 |
+
# inconsistent when TP is greater than 1.
|
| 464 |
+
self.is_merged_col_linear = type(
|
| 465 |
+
base_layer) is MergedColumnParallelLinear
|
| 466 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 467 |
+
self.output_size = self.base_layer.output_size_per_partition
|
| 468 |
+
# There is only one LoRA layer
|
| 469 |
+
self.n_slices = 1
|
| 470 |
+
|
| 471 |
+
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
| 472 |
+
return lora_a
|
| 473 |
+
|
| 474 |
+
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
| 475 |
+
# Applicable to cases where the base_layer is
|
| 476 |
+
# MergedColumnParallelLinear.
|
| 477 |
+
if self.is_merged_col_linear:
|
| 478 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 479 |
+
shard_size = self.output_size // 2
|
| 480 |
+
offset = lora_b.shape[-1] // 2
|
| 481 |
+
|
| 482 |
+
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
|
| 483 |
+
shard_size]
|
| 484 |
+
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
|
| 485 |
+
(tp_rank + 1) * shard_size]
|
| 486 |
+
lora_b = torch.cat([left_weight, right_weight], dim=1)
|
| 487 |
+
# Applicable to cases where the base_layer is
|
| 488 |
+
# ColumnParallelLinear.
|
| 489 |
+
else:
|
| 490 |
+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
| 491 |
+
shard_size = self.output_size
|
| 492 |
+
start_idx = tensor_model_parallel_rank * shard_size
|
| 493 |
+
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
| 494 |
+
lora_b = lora_b[:, start_idx:end_idx]
|
| 495 |
+
return lora_b
|
| 496 |
+
|
| 497 |
+
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
| 498 |
+
# TODO: Fix the slicing logic of bias.
|
| 499 |
+
if bias is None:
|
| 500 |
+
return bias
|
| 501 |
+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
| 502 |
+
shard_size = self.output_size
|
| 503 |
+
start_idx = tensor_model_parallel_rank * shard_size
|
| 504 |
+
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
| 505 |
+
bias = bias[start_idx:end_idx]
|
| 506 |
+
return bias
|
| 507 |
+
|
| 508 |
+
def forward(
|
| 509 |
+
self, input_: torch.Tensor
|
| 510 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 511 |
+
"""Forward of ColumnParallelLinear
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
input_: Tensor whose last dimension is `input_size`.
|
| 515 |
+
|
| 516 |
+
Returns:
|
| 517 |
+
- output
|
| 518 |
+
- bias
|
| 519 |
+
"""
|
| 520 |
+
bias = (self.base_layer.bias
|
| 521 |
+
if not self.base_layer.skip_bias_add else None)
|
| 522 |
+
|
| 523 |
+
# Matrix multiply.
|
| 524 |
+
output_parallel = self.apply(input_, bias)
|
| 525 |
+
if self.base_layer.gather_output:
|
| 526 |
+
# All-gather across the partitions.
|
| 527 |
+
output = tensor_model_parallel_all_gather(output_parallel)
|
| 528 |
+
else:
|
| 529 |
+
output = output_parallel
|
| 530 |
+
output_bias = (self.base_layer.bias
|
| 531 |
+
if self.base_layer.skip_bias_add else None)
|
| 532 |
+
return output, output_bias
|
| 533 |
+
|
| 534 |
+
@classmethod
|
| 535 |
+
@_not_fully_sharded_can_replace
|
| 536 |
+
def can_replace_layer(
|
| 537 |
+
cls,
|
| 538 |
+
source_layer: nn.Module,
|
| 539 |
+
lora_config: LoRAConfig,
|
| 540 |
+
packed_modules_list: List,
|
| 541 |
+
model_config: Optional[PretrainedConfig],
|
| 542 |
+
) -> bool:
|
| 543 |
+
return type(source_layer) is ColumnParallelLinear or (
|
| 544 |
+
type(source_layer) is MergedColumnParallelLinear
|
| 545 |
+
and len(packed_modules_list) == 1)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
| 549 |
+
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
| 550 |
+
packed together (eg. gate_proj + up_proj -> gate_up_proj).
|
| 551 |
+
|
| 552 |
+
This means we have 2 LoRAs, each applied to one half of the layer.
|
| 553 |
+
|
| 554 |
+
Both slices must have the same size.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(
|
| 558 |
+
self, base_layer: Union[MergedColumnParallelLinear,
|
| 559 |
+
QKVParallelLinear]) -> None:
|
| 560 |
+
super().__init__(base_layer)
|
| 561 |
+
# There are two LoRA layers
|
| 562 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 563 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 564 |
+
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
|
| 565 |
+
# we need to divide it by the tp_size to get correct slices size
|
| 566 |
+
output_sizes = self.base_layer.output_sizes
|
| 567 |
+
self.output_slices = tuple(
|
| 568 |
+
divide(output_size, self.tp_size) for output_size in output_sizes)
|
| 569 |
+
self.n_slices = len(self.output_slices)
|
| 570 |
+
self.output_ids = (self.tp_rank, ) * self.n_slices
|
| 571 |
+
|
| 572 |
+
def create_lora_weights(
|
| 573 |
+
self,
|
| 574 |
+
max_loras: int,
|
| 575 |
+
lora_config: LoRAConfig,
|
| 576 |
+
model_config: Optional[PretrainedConfig] = None,
|
| 577 |
+
) -> None:
|
| 578 |
+
"""
|
| 579 |
+
The main reason for overriding this function is to enhance code
|
| 580 |
+
maintainability.
|
| 581 |
+
"""
|
| 582 |
+
self.lora_config = lora_config
|
| 583 |
+
|
| 584 |
+
lora_a_output_size_per_partition = (
|
| 585 |
+
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
| 586 |
+
else divide(lora_config.max_lora_rank, self.tp_size))
|
| 587 |
+
|
| 588 |
+
self.lora_a_stacked = tuple(
|
| 589 |
+
torch.zeros(
|
| 590 |
+
max_loras,
|
| 591 |
+
1,
|
| 592 |
+
lora_a_output_size_per_partition,
|
| 593 |
+
self.input_size,
|
| 594 |
+
dtype=lora_config.lora_dtype,
|
| 595 |
+
device=self.device,
|
| 596 |
+
) for _ in range(self.n_slices))
|
| 597 |
+
self.lora_b_stacked = tuple(
|
| 598 |
+
torch.zeros(
|
| 599 |
+
max_loras,
|
| 600 |
+
1,
|
| 601 |
+
output_size,
|
| 602 |
+
lora_config.max_lora_rank,
|
| 603 |
+
dtype=lora_config.lora_dtype,
|
| 604 |
+
device=self.device,
|
| 605 |
+
) for output_size in self.output_slices)
|
| 606 |
+
if lora_config.bias_enabled:
|
| 607 |
+
self.lora_bias_stacked = tuple(
|
| 608 |
+
torch.zeros(
|
| 609 |
+
max_loras,
|
| 610 |
+
1,
|
| 611 |
+
output_size,
|
| 612 |
+
dtype=lora_config.lora_dtype,
|
| 613 |
+
device=self.device,
|
| 614 |
+
) for output_size in self.output_slices)
|
| 615 |
+
|
| 616 |
+
def slice_lora_a(
|
| 617 |
+
self, lora_a: List[Union[torch.Tensor, None]]
|
| 618 |
+
) -> List[Union[torch.Tensor, None]]:
|
| 619 |
+
return lora_a
|
| 620 |
+
|
| 621 |
+
def slice_lora_b(
|
| 622 |
+
self, lora_b: List[Union[torch.Tensor, None]]
|
| 623 |
+
) -> List[Union[torch.Tensor, None]]:
|
| 624 |
+
for i, (shard_id, shard_size) in enumerate(
|
| 625 |
+
zip(self.output_ids, self.output_slices)):
|
| 626 |
+
if (lora_b_i := lora_b[i]) is not None:
|
| 627 |
+
lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
|
| 628 |
+
(shard_id + 1)]
|
| 629 |
+
return lora_b
|
| 630 |
+
|
| 631 |
+
def slice_bias(
|
| 632 |
+
self, bias: List[Union[torch.Tensor,
|
| 633 |
+
None]]) -> List[Union[torch.Tensor, None]]:
|
| 634 |
+
for i, (shard_id, shard_size) in enumerate(
|
| 635 |
+
zip(self.output_ids, self.output_slices)):
|
| 636 |
+
if (bias_i := bias[i]) is not None:
|
| 637 |
+
bias[i] = bias_i[shard_size * shard_id:shard_size *
|
| 638 |
+
(shard_id + 1)]
|
| 639 |
+
return bias
|
| 640 |
+
|
| 641 |
+
def set_lora(
|
| 642 |
+
self,
|
| 643 |
+
index: int,
|
| 644 |
+
lora_a: torch.Tensor,
|
| 645 |
+
lora_b: torch.Tensor,
|
| 646 |
+
embeddings_tensor: Optional[torch.Tensor],
|
| 647 |
+
lora_bias: Optional[torch.Tensor] = None,
|
| 648 |
+
):
|
| 649 |
+
self.reset_lora(index)
|
| 650 |
+
|
| 651 |
+
if self.tp_size > 1:
|
| 652 |
+
lora_a = self.slice_lora_a(lora_a)
|
| 653 |
+
lora_b = self.slice_lora_b(lora_b)
|
| 654 |
+
if lora_bias is not None:
|
| 655 |
+
lora_bias = self.slice_bias(lora_bias)
|
| 656 |
+
|
| 657 |
+
for i in range(self.n_slices):
|
| 658 |
+
if (lora_a_i := lora_a[i]) is not None:
|
| 659 |
+
self.lora_a_stacked[i][
|
| 660 |
+
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
|
| 661 |
+
lora_a_i.T, non_blocking=True)
|
| 662 |
+
if (lora_b_i := lora_b[i]) is not None:
|
| 663 |
+
self.lora_b_stacked[i][
|
| 664 |
+
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
|
| 665 |
+
lora_b_i.T, non_blocking=True)
|
| 666 |
+
|
| 667 |
+
if lora_bias is not None:
|
| 668 |
+
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
| 669 |
+
self.lora_bias_stacked)
|
| 670 |
+
for i in range(self.n_slices):
|
| 671 |
+
if (lora_bias_i := lora_bias[i]) is not None:
|
| 672 |
+
self.lora_bias_stacked[i][index,
|
| 673 |
+
0, :lora_bias_i.shape[0]].copy_(
|
| 674 |
+
lora_bias_i.T,
|
| 675 |
+
non_blocking=True)
|
| 676 |
+
|
| 677 |
+
@classmethod
|
| 678 |
+
@_not_fully_sharded_can_replace
|
| 679 |
+
def can_replace_layer(
|
| 680 |
+
cls,
|
| 681 |
+
source_layer: nn.Module,
|
| 682 |
+
lora_config: LoRAConfig,
|
| 683 |
+
packed_modules_list: List,
|
| 684 |
+
model_config: Optional[PretrainedConfig],
|
| 685 |
+
) -> bool:
|
| 686 |
+
return (type(source_layer) is MergedColumnParallelLinear
|
| 687 |
+
and len(packed_modules_list) == 2)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
| 691 |
+
"""
|
| 692 |
+
ColumnParallelLinear layer that is specifically designed for
|
| 693 |
+
qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
|
| 694 |
+
only contains a single LoRA within their qkv_proj layer.
|
| 695 |
+
|
| 696 |
+
During inference with Tensor Parallel, the weights of lora_b
|
| 697 |
+
must be accurately partitioned according to the respective ranks.
|
| 698 |
+
|
| 699 |
+
Q slice may have different shape than K and V slices (which both have
|
| 700 |
+
the same shape).
|
| 701 |
+
"""
|
| 702 |
+
|
| 703 |
+
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
| 704 |
+
super().__init__(base_layer)
|
| 705 |
+
self.q_proj_total_size = (self.base_layer.total_num_heads *
|
| 706 |
+
self.base_layer.head_size)
|
| 707 |
+
self.q_proj_shard_size = (self.base_layer.num_heads *
|
| 708 |
+
self.base_layer.head_size)
|
| 709 |
+
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
| 710 |
+
self.base_layer.head_size)
|
| 711 |
+
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
|
| 712 |
+
self.base_layer.head_size)
|
| 713 |
+
# There is only one LoRA layer
|
| 714 |
+
self.n_slices = 1
|
| 715 |
+
|
| 716 |
+
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
| 717 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 718 |
+
self.q_shard_id = tp_rank
|
| 719 |
+
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
| 720 |
+
lora_b_q = lora_b[:, self.q_proj_shard_size *
|
| 721 |
+
self.q_shard_id:self.q_proj_shard_size *
|
| 722 |
+
(self.q_shard_id + 1)]
|
| 723 |
+
k_offset = self.q_proj_total_size
|
| 724 |
+
lora_b_k = lora_b[:, k_offset +
|
| 725 |
+
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
| 726 |
+
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
| 727 |
+
v_offset = k_offset + self.kv_proj_total_size
|
| 728 |
+
lora_b_v = lora_b[:, v_offset +
|
| 729 |
+
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
| 730 |
+
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
| 731 |
+
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
| 732 |
+
return lora_b
|
| 733 |
+
|
| 734 |
+
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
| 735 |
+
bias_q = bias[self.q_proj_shard_size *
|
| 736 |
+
self.q_shard_id:self.q_proj_shard_size *
|
| 737 |
+
(self.q_shard_id + 1)]
|
| 738 |
+
k_offset = self.q_proj_total_size
|
| 739 |
+
bias_k = bias[k_offset +
|
| 740 |
+
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
| 741 |
+
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
| 742 |
+
v_offset = k_offset + self.kv_proj_total_size
|
| 743 |
+
bias_v = bias[v_offset +
|
| 744 |
+
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
| 745 |
+
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
| 746 |
+
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
|
| 747 |
+
return bias
|
| 748 |
+
|
| 749 |
+
@classmethod
|
| 750 |
+
@_not_fully_sharded_can_replace
|
| 751 |
+
def can_replace_layer(cls, source_layer: nn.Module,
|
| 752 |
+
lora_config: LoRAConfig, packed_modules_list: List,
|
| 753 |
+
model_config: Optional[PretrainedConfig]) -> bool:
|
| 754 |
+
return type(source_layer) is QKVParallelLinear and len(
|
| 755 |
+
packed_modules_list) == 1
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
|
| 759 |
+
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
| 760 |
+
packed together in qkv proj fashion
|
| 761 |
+
(q_proj + k_proj + v_proj -> qkv_proj).
|
| 762 |
+
|
| 763 |
+
This means we have 3 LoRAs, each applied to one slice of the layer.
|
| 764 |
+
|
| 765 |
+
Q slice may have different shape than K and V slices (which both have
|
| 766 |
+
the same shape).
|
| 767 |
+
"""
|
| 768 |
+
|
| 769 |
+
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
| 770 |
+
super().__init__(base_layer)
|
| 771 |
+
# There are three LoRA layer.
|
| 772 |
+
self.n_slices = len(self.base_layer.output_sizes)
|
| 773 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 774 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 775 |
+
|
| 776 |
+
self.q_proj_shard_size = (self.base_layer.num_heads *
|
| 777 |
+
self.base_layer.head_size)
|
| 778 |
+
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
| 779 |
+
self.base_layer.head_size)
|
| 780 |
+
self.q_shard_id = self.tp_rank
|
| 781 |
+
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
| 782 |
+
|
| 783 |
+
self.output_slices = (
|
| 784 |
+
self.q_proj_shard_size,
|
| 785 |
+
self.kv_proj_shard_size,
|
| 786 |
+
self.kv_proj_shard_size,
|
| 787 |
+
)
|
| 788 |
+
self.output_ids = (
|
| 789 |
+
self.q_shard_id,
|
| 790 |
+
self.kv_shard_id,
|
| 791 |
+
self.kv_shard_id,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
def create_lora_weights(
|
| 795 |
+
self,
|
| 796 |
+
max_loras: int,
|
| 797 |
+
lora_config: LoRAConfig,
|
| 798 |
+
model_config: Optional[PretrainedConfig] = None,
|
| 799 |
+
) -> None:
|
| 800 |
+
"""
|
| 801 |
+
The main reason for overloading this function is to handle inconsistent
|
| 802 |
+
weight dimensions in qkv lora.
|
| 803 |
+
"""
|
| 804 |
+
super().create_lora_weights(max_loras, lora_config, model_config)
|
| 805 |
+
|
| 806 |
+
@classmethod
|
| 807 |
+
@_not_fully_sharded_can_replace
|
| 808 |
+
def can_replace_layer(
|
| 809 |
+
cls,
|
| 810 |
+
source_layer: nn.Module,
|
| 811 |
+
lora_config: LoRAConfig,
|
| 812 |
+
packed_modules_list: List,
|
| 813 |
+
model_config: Optional[PretrainedConfig],
|
| 814 |
+
) -> bool:
|
| 815 |
+
return (type(source_layer) is QKVParallelLinear
|
| 816 |
+
and len(packed_modules_list) == 3)
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
| 820 |
+
|
| 821 |
+
def __init__(self, base_layer: RowParallelLinear) -> None:
|
| 822 |
+
super().__init__(base_layer)
|
| 823 |
+
|
| 824 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 825 |
+
# reset input_size
|
| 826 |
+
self.input_size = self.base_layer.input_size_per_partition
|
| 827 |
+
self.output_size = self.base_layer.output_size
|
| 828 |
+
|
| 829 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 830 |
+
# There is only one LoRA layer.
|
| 831 |
+
self.n_slices = 1
|
| 832 |
+
|
| 833 |
+
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
| 834 |
+
|
| 835 |
+
shard_size = self.input_size
|
| 836 |
+
start_idx = self.tp_rank * shard_size
|
| 837 |
+
end_idx = (self.tp_rank + 1) * shard_size
|
| 838 |
+
lora_a = lora_a[start_idx:end_idx, :]
|
| 839 |
+
return lora_a
|
| 840 |
+
|
| 841 |
+
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
| 842 |
+
return lora_b
|
| 843 |
+
|
| 844 |
+
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
| 845 |
+
return bias
|
| 846 |
+
|
| 847 |
+
def forward(
|
| 848 |
+
self, input_: torch.Tensor
|
| 849 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 850 |
+
"""Forward of RowParallelLinear
|
| 851 |
+
|
| 852 |
+
Args:
|
| 853 |
+
input_: tensor whose last dimension is `input_size`. If
|
| 854 |
+
`input_is_parallel` is set, then the last dimension
|
| 855 |
+
is `input_size // tp_size`.
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
- output
|
| 859 |
+
- bias
|
| 860 |
+
"""
|
| 861 |
+
# Set up backprop all-reduce.
|
| 862 |
+
if self.base_layer.input_is_parallel:
|
| 863 |
+
input_parallel = input_
|
| 864 |
+
else:
|
| 865 |
+
# TODO: simplify code below
|
| 866 |
+
splitted_input = split_tensor_along_last_dim(
|
| 867 |
+
input_, num_partitions=self.base_layer.tp_size)
|
| 868 |
+
input_parallel = splitted_input[self.tp_rank].contiguous()
|
| 869 |
+
|
| 870 |
+
# Matrix multiply.
|
| 871 |
+
output_parallel = self.apply(input_parallel)
|
| 872 |
+
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
| 873 |
+
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
| 874 |
+
else:
|
| 875 |
+
output_ = output_parallel
|
| 876 |
+
|
| 877 |
+
if not self.base_layer.skip_bias_add:
|
| 878 |
+
output = (output_ + self.base_layer.bias
|
| 879 |
+
if self.base_layer.bias is not None else output_)
|
| 880 |
+
output_bias = None
|
| 881 |
+
else:
|
| 882 |
+
output = output_
|
| 883 |
+
output_bias = self.base_layer.bias
|
| 884 |
+
return output, output_bias
|
| 885 |
+
|
| 886 |
+
@property
|
| 887 |
+
def weight(self):
|
| 888 |
+
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
|
| 889 |
+
else self.base_layer.qweight)
|
| 890 |
+
|
| 891 |
+
@classmethod
|
| 892 |
+
@_not_fully_sharded_can_replace
|
| 893 |
+
def can_replace_layer(
|
| 894 |
+
cls,
|
| 895 |
+
source_layer: nn.Module,
|
| 896 |
+
lora_config: LoRAConfig,
|
| 897 |
+
packed_modules_list: List,
|
| 898 |
+
model_config: Optional[PretrainedConfig],
|
| 899 |
+
) -> bool:
|
| 900 |
+
return type(source_layer) is RowParallelLinear
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
| 904 |
+
"""
|
| 905 |
+
LoRA wrapper for LogitsProcessor, with extra logic to handle the
|
| 906 |
+
application of the LoRA adapter and added LoRA vocabulary.
|
| 907 |
+
|
| 908 |
+
Args:
|
| 909 |
+
base_layer: LogitsProcessor layer
|
| 910 |
+
hidden_size: hidden size of the model
|
| 911 |
+
dtype: data type of the model
|
| 912 |
+
device: device of the model
|
| 913 |
+
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
|
| 914 |
+
received from base_layer.get_sharded_to_full_mapping(). If None,
|
| 915 |
+
no reindexing will be done.
|
| 916 |
+
"""
|
| 917 |
+
|
| 918 |
+
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
|
| 919 |
+
dtype: torch.dtype, device: torch.device,
|
| 920 |
+
sharded_to_full_mapping: Optional[List[int]]) -> None:
|
| 921 |
+
super().__init__()
|
| 922 |
+
self.base_layer = base_layer
|
| 923 |
+
self.hidden_size = hidden_size
|
| 924 |
+
self.dtype = dtype
|
| 925 |
+
self.device = device
|
| 926 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 927 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 928 |
+
self.sharded_to_full_mapping = sharded_to_full_mapping
|
| 929 |
+
|
| 930 |
+
@property
|
| 931 |
+
def logits_as_input(self):
|
| 932 |
+
return self.base_layer.logits_as_input
|
| 933 |
+
|
| 934 |
+
@property
|
| 935 |
+
def vocab_size(self):
|
| 936 |
+
return self.base_layer.vocab_size
|
| 937 |
+
|
| 938 |
+
@property
|
| 939 |
+
def scale(self):
|
| 940 |
+
return self.base_layer.scale
|
| 941 |
+
|
| 942 |
+
@property
|
| 943 |
+
def soft_cap(self):
|
| 944 |
+
return self.base_layer.soft_cap
|
| 945 |
+
|
| 946 |
+
@property
|
| 947 |
+
def use_all_gather(self):
|
| 948 |
+
return self.base_layer.use_all_gather
|
| 949 |
+
|
| 950 |
+
@property
|
| 951 |
+
def org_vocab_size(self):
|
| 952 |
+
return self.base_layer.org_vocab_size
|
| 953 |
+
|
| 954 |
+
@property
|
| 955 |
+
def include_gpu_probs_tensor(self):
|
| 956 |
+
return self.base_layer.include_gpu_probs_tensor
|
| 957 |
+
|
| 958 |
+
@property
|
| 959 |
+
def should_modify_greedy_probs_inplace(self):
|
| 960 |
+
return self.base_layer.should_modify_greedy_probs_inplace
|
| 961 |
+
|
| 962 |
+
def create_lora_weights(
|
| 963 |
+
self,
|
| 964 |
+
max_loras: int,
|
| 965 |
+
lora_config: LoRAConfig,
|
| 966 |
+
model_config: Optional[PretrainedConfig] = None,
|
| 967 |
+
) -> None:
|
| 968 |
+
# TODO: Verify if this condition can be further relaxed
|
| 969 |
+
if 32000 < self.base_layer.vocab_size > 257024:
|
| 970 |
+
raise ValueError("When using LoRA, vocab size must be "
|
| 971 |
+
"32000 >= vocab_size <= 257024")
|
| 972 |
+
self.lora_a_stacked = torch.zeros(
|
| 973 |
+
(
|
| 974 |
+
max_loras,
|
| 975 |
+
1,
|
| 976 |
+
lora_config.max_lora_rank,
|
| 977 |
+
self.hidden_size,
|
| 978 |
+
),
|
| 979 |
+
dtype=lora_config.lora_dtype,
|
| 980 |
+
device=self.device,
|
| 981 |
+
)
|
| 982 |
+
self.lora_b_stacked = torch.zeros(
|
| 983 |
+
(
|
| 984 |
+
max_loras,
|
| 985 |
+
1,
|
| 986 |
+
# Pad for kernel compatibility
|
| 987 |
+
math.ceil(self.base_layer.vocab_size /
|
| 988 |
+
lora_config.lora_vocab_padding_size) *
|
| 989 |
+
lora_config.lora_vocab_padding_size,
|
| 990 |
+
lora_config.max_lora_rank,
|
| 991 |
+
),
|
| 992 |
+
dtype=lora_config.lora_dtype,
|
| 993 |
+
device=self.device,
|
| 994 |
+
)
|
| 995 |
+
self.embeddings_tensors = torch.full(
|
| 996 |
+
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
| 997 |
+
fill_value=float("-inf"),
|
| 998 |
+
dtype=self.dtype,
|
| 999 |
+
device=self.device,
|
| 1000 |
+
)
|
| 1001 |
+
if self.sharded_to_full_mapping is not None:
|
| 1002 |
+
self.sharded_to_full_mapping_gpu = torch.tensor(
|
| 1003 |
+
self.sharded_to_full_mapping,
|
| 1004 |
+
device=self.device,
|
| 1005 |
+
dtype=torch.long)
|
| 1006 |
+
else:
|
| 1007 |
+
self.sharded_to_full_mapping_gpu = None
|
| 1008 |
+
|
| 1009 |
+
def reset_lora(self, index: int):
|
| 1010 |
+
self.lora_a_stacked[index] = 0
|
| 1011 |
+
self.lora_b_stacked[index] = 0
|
| 1012 |
+
self.embeddings_tensors[index] = float("-inf")
|
| 1013 |
+
|
| 1014 |
+
def set_lora(
|
| 1015 |
+
self,
|
| 1016 |
+
index: int,
|
| 1017 |
+
lora_a: torch.Tensor,
|
| 1018 |
+
lora_b: torch.Tensor,
|
| 1019 |
+
embeddings_tensor: Optional[torch.Tensor],
|
| 1020 |
+
bias: Optional[torch.Tensor] = None,
|
| 1021 |
+
):
|
| 1022 |
+
self.reset_lora(index)
|
| 1023 |
+
self.lora_a_stacked[index,
|
| 1024 |
+
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
| 1025 |
+
lora_a.T, non_blocking=True)
|
| 1026 |
+
self.lora_b_stacked[index,
|
| 1027 |
+
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
| 1028 |
+
lora_b.T, non_blocking=True)
|
| 1029 |
+
if embeddings_tensor is not None:
|
| 1030 |
+
self.embeddings_tensors[
|
| 1031 |
+
index,
|
| 1032 |
+
:embeddings_tensor.shape[0],
|
| 1033 |
+
:embeddings_tensor.shape[1],
|
| 1034 |
+
] = embeddings_tensor
|
| 1035 |
+
|
| 1036 |
+
def _get_logits(
|
| 1037 |
+
self,
|
| 1038 |
+
hidden_states: torch.Tensor,
|
| 1039 |
+
lm_head: VocabParallelEmbedding,
|
| 1040 |
+
embedding_bias: Optional[torch.Tensor] = None,
|
| 1041 |
+
) -> Optional[torch.Tensor]:
|
| 1042 |
+
# Get the logits for the next tokens.
|
| 1043 |
+
logits = lm_head.linear_method.apply(lm_head, hidden_states)
|
| 1044 |
+
if embedding_bias is not None:
|
| 1045 |
+
logits += embedding_bias
|
| 1046 |
+
logits = tensor_model_parallel_gather(logits)
|
| 1047 |
+
if logits is None:
|
| 1048 |
+
return None
|
| 1049 |
+
|
| 1050 |
+
if self.sharded_to_full_mapping_gpu is not None:
|
| 1051 |
+
# Reindex full logits tensor to ensure 1:1 mapping between
|
| 1052 |
+
# index and token_id
|
| 1053 |
+
# Example for:
|
| 1054 |
+
# org_vocab_size = 4
|
| 1055 |
+
# added_vocab_size = 2
|
| 1056 |
+
# pad_to_size = 8
|
| 1057 |
+
# tp_size = 2
|
| 1058 |
+
|
| 1059 |
+
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
| 1060 |
+
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
|
| 1061 |
+
|
| 1062 |
+
# Therefore, the mapping is expected to be:
|
| 1063 |
+
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
|
| 1064 |
+
# we get:
|
| 1065 |
+
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
| 1066 |
+
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
| 1067 |
+
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
| 1068 |
+
|
| 1069 |
+
lora_logits = torch.empty(
|
| 1070 |
+
self.embeddings_tensors.shape[0] + 1,
|
| 1071 |
+
self.embeddings_tensors.shape[1],
|
| 1072 |
+
hidden_states.shape[0],
|
| 1073 |
+
dtype=self.embeddings_tensors.dtype,
|
| 1074 |
+
device=self.embeddings_tensors.device,
|
| 1075 |
+
)
|
| 1076 |
+
torch.matmul(self.embeddings_tensors,
|
| 1077 |
+
hidden_states.T,
|
| 1078 |
+
out=lora_logits[:-1])
|
| 1079 |
+
lora_logits[-1] = float("-inf")
|
| 1080 |
+
lora_logits = lora_logits.mT
|
| 1081 |
+
indices_padded = self.punica_wrapper.sampler_indices_padded
|
| 1082 |
+
lora_logits = (lora_logits.reshape(
|
| 1083 |
+
lora_logits.shape[0] * lora_logits.shape[1],
|
| 1084 |
+
lora_logits.shape[2],
|
| 1085 |
+
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
|
| 1086 |
+
posinf=float("inf"),
|
| 1087 |
+
neginf=float("-inf")))
|
| 1088 |
+
|
| 1089 |
+
# HPU needs special handling to prune out dummy samples.
|
| 1090 |
+
if current_platform.is_hpu():
|
| 1091 |
+
lora_logits = lora_logits[:logits.shape[0], :]
|
| 1092 |
+
|
| 1093 |
+
logits[:,
|
| 1094 |
+
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
| 1095 |
+
lora_logits.shape[1]] = lora_logits
|
| 1096 |
+
|
| 1097 |
+
# LogitsProcessorWithLoRA always using bgmv
|
| 1098 |
+
self.punica_wrapper.add_lora_logits(logits, hidden_states,
|
| 1099 |
+
self.lora_a_stacked,
|
| 1100 |
+
self.lora_b_stacked, 1.0)
|
| 1101 |
+
|
| 1102 |
+
# Remove paddings in vocab (if any).
|
| 1103 |
+
logits = logits[:, :self.base_layer.vocab_size]
|
| 1104 |
+
return logits
|
| 1105 |
+
|
| 1106 |
+
def forward(self, *args, **kwargs):
|
| 1107 |
+
return type(self.base_layer).forward(self, *args, **kwargs)
|
| 1108 |
+
|
| 1109 |
+
@classmethod
|
| 1110 |
+
def can_replace_layer(
|
| 1111 |
+
cls,
|
| 1112 |
+
source_layer: nn.Module,
|
| 1113 |
+
lora_config: LoRAConfig,
|
| 1114 |
+
packed_modules_list: List,
|
| 1115 |
+
model_config: Optional[PretrainedConfig],
|
| 1116 |
+
) -> bool:
|
| 1117 |
+
# Special handling for the LogitsProcessor.
|
| 1118 |
+
return False
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
| 1122 |
+
"""Implements RoPE-scaled embeddings with linear scaling for
|
| 1123 |
+
multiple LoRA adapters with a specialized kernel.
|
| 1124 |
+
|
| 1125 |
+
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
|
| 1126 |
+
which can handle multi lora adapters in a specialied kernel.
|
| 1127 |
+
"""
|
| 1128 |
+
|
| 1129 |
+
def __init__(self, base_layer: RotaryEmbedding) -> None:
|
| 1130 |
+
super().__init__()
|
| 1131 |
+
self.base_layer = base_layer
|
| 1132 |
+
|
| 1133 |
+
@property
|
| 1134 |
+
def scaling_factors(self):
|
| 1135 |
+
return self.base_layer.scaling_factors
|
| 1136 |
+
|
| 1137 |
+
@property
|
| 1138 |
+
def rotary_dim(self):
|
| 1139 |
+
return self.base_layer.rotary_dim
|
| 1140 |
+
|
| 1141 |
+
def create_lora_weights(
|
| 1142 |
+
self,
|
| 1143 |
+
max_loras: int,
|
| 1144 |
+
lora_config: LoRAConfig,
|
| 1145 |
+
model_config: Optional[PretrainedConfig] = None,
|
| 1146 |
+
) -> None:
|
| 1147 |
+
scaling_factors = (list(lora_config.long_lora_scaling_factors)
|
| 1148 |
+
if lora_config.long_lora_scaling_factors else [])
|
| 1149 |
+
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
| 1150 |
+
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
|
| 1151 |
+
scaling_factors = sorted(
|
| 1152 |
+
list(set([base_scaling_factor] + scaling_factors)))
|
| 1153 |
+
self.base_layer = LinearScalingRotaryEmbedding(
|
| 1154 |
+
self.base_layer.head_size,
|
| 1155 |
+
self.base_layer.rotary_dim,
|
| 1156 |
+
self.base_layer.max_position_embeddings,
|
| 1157 |
+
self.base_layer.base,
|
| 1158 |
+
self.base_layer.is_neox_style,
|
| 1159 |
+
scaling_factors,
|
| 1160 |
+
self.base_layer.dtype,
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
def reset_lora(self, index: int):
|
| 1164 |
+
...
|
| 1165 |
+
|
| 1166 |
+
def set_lora(
|
| 1167 |
+
self,
|
| 1168 |
+
index: int,
|
| 1169 |
+
lora_a: torch.Tensor,
|
| 1170 |
+
lora_b: torch.Tensor,
|
| 1171 |
+
embeddings_tensor: Optional[torch.Tensor],
|
| 1172 |
+
bias: Optional[torch.Tensor] = None,
|
| 1173 |
+
):
|
| 1174 |
+
...
|
| 1175 |
+
|
| 1176 |
+
def forward(
|
| 1177 |
+
self,
|
| 1178 |
+
positions: torch.Tensor,
|
| 1179 |
+
query: torch.Tensor,
|
| 1180 |
+
key: torch.Tensor,
|
| 1181 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1182 |
+
return self.base_layer(
|
| 1183 |
+
positions,
|
| 1184 |
+
query,
|
| 1185 |
+
key,
|
| 1186 |
+
offsets=self.punica_wrapper.long_lora_indices,
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
@property
|
| 1190 |
+
def scaling_factor_to_offset(self) -> Dict[float, int]:
|
| 1191 |
+
return self.base_layer.scaling_factor_to_offset
|
| 1192 |
+
|
| 1193 |
+
@classmethod
|
| 1194 |
+
def can_replace_layer(
|
| 1195 |
+
cls,
|
| 1196 |
+
source_layer: nn.Module,
|
| 1197 |
+
lora_config: LoRAConfig,
|
| 1198 |
+
packed_modules_list: List,
|
| 1199 |
+
model_config: Optional[PretrainedConfig],
|
| 1200 |
+
) -> bool:
|
| 1201 |
+
"""Returns True if the layer can be replaced by this LoRA layer."""
|
| 1202 |
+
return (type(source_layer) is LinearScalingRotaryEmbedding
|
| 1203 |
+
or type(source_layer) is RotaryEmbedding)
|
| 1204 |
+
|
| 1205 |
+
def extra_repr(self) -> str:
|
| 1206 |
+
return self.base_layer.extra_repr()
|
.venv/lib/python3.11/site-packages/vllm/lora/lora.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from typing import Sequence as GenericSequence
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.types
|
| 8 |
+
|
| 9 |
+
from vllm.lora.peft_helper import PEFTHelper
|
| 10 |
+
from vllm.utils import is_pin_memory_available
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LoRALayerWeights:
|
| 14 |
+
"""LoRA weights for a layer composed of two low rank matrixes."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
module_name: str,
|
| 19 |
+
rank: int,
|
| 20 |
+
lora_alpha: int,
|
| 21 |
+
lora_a: torch.Tensor,
|
| 22 |
+
lora_b: torch.Tensor,
|
| 23 |
+
bias: Optional[torch.Tensor] = None,
|
| 24 |
+
embeddings_tensor: Optional[torch.Tensor] = None,
|
| 25 |
+
scaling: Optional[float] = None,
|
| 26 |
+
) -> None:
|
| 27 |
+
self.module_name = module_name
|
| 28 |
+
self.rank = rank
|
| 29 |
+
self.lora_alpha = lora_alpha
|
| 30 |
+
self.lora_a = lora_a
|
| 31 |
+
self.lora_b = lora_b
|
| 32 |
+
self.bias = bias
|
| 33 |
+
self.embeddings_tensor = embeddings_tensor
|
| 34 |
+
|
| 35 |
+
if scaling is None:
|
| 36 |
+
self.scaling = self.lora_alpha / self.rank
|
| 37 |
+
else:
|
| 38 |
+
self.scaling = scaling
|
| 39 |
+
|
| 40 |
+
def optimize(self) -> "LoRALayerWeights":
|
| 41 |
+
"""Optimize the LoRA by merging the scaling into lora_b."""
|
| 42 |
+
if self.scaling == 1:
|
| 43 |
+
return self
|
| 44 |
+
self.lora_b *= self.scaling
|
| 45 |
+
self.scaling = 1
|
| 46 |
+
return self
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def input_dim(self) -> int:
|
| 50 |
+
return self.lora_a.shape[0]
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def output_dim(self) -> int:
|
| 54 |
+
return self.lora_b.shape[1]
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def is_packed(self) -> bool:
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def extra_vocab_size(self) -> int:
|
| 62 |
+
return self.embeddings_tensor.shape[
|
| 63 |
+
0] if self.embeddings_tensor is not None else 0
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_config(
|
| 67 |
+
cls,
|
| 68 |
+
module_name: str,
|
| 69 |
+
peft_helper: PEFTHelper,
|
| 70 |
+
embeddings_tensor: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> "LoRALayerWeights":
|
| 72 |
+
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
|
| 73 |
+
None, None, embeddings_tensor,
|
| 74 |
+
peft_helper.vllm_lora_scaling_factor)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def create_dummy_lora_weights(
|
| 78 |
+
cls,
|
| 79 |
+
module_name: str,
|
| 80 |
+
input_dim: int,
|
| 81 |
+
output_dim: int,
|
| 82 |
+
rank: int,
|
| 83 |
+
dtype: torch.dtype,
|
| 84 |
+
device: torch.types.Device,
|
| 85 |
+
embeddings_tensor_dim: Optional[int] = None,
|
| 86 |
+
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
| 87 |
+
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
| 88 |
+
lora_a = torch.zeros([input_dim, rank],
|
| 89 |
+
dtype=dtype,
|
| 90 |
+
device=device,
|
| 91 |
+
pin_memory=pin_memory)
|
| 92 |
+
lora_b = torch.zeros([rank, output_dim],
|
| 93 |
+
dtype=dtype,
|
| 94 |
+
device=device,
|
| 95 |
+
pin_memory=pin_memory)
|
| 96 |
+
if bias_enabled:
|
| 97 |
+
bias = torch.zeros([output_dim],
|
| 98 |
+
dtype=dtype,
|
| 99 |
+
device=device,
|
| 100 |
+
pin_memory=pin_memory)
|
| 101 |
+
else:
|
| 102 |
+
bias = None
|
| 103 |
+
|
| 104 |
+
embeddings_tensor = torch.rand(
|
| 105 |
+
10,
|
| 106 |
+
embeddings_tensor_dim,
|
| 107 |
+
dtype=dtype,
|
| 108 |
+
device=device,
|
| 109 |
+
pin_memory=pin_memory) if embeddings_tensor_dim else None
|
| 110 |
+
return cls(
|
| 111 |
+
module_name,
|
| 112 |
+
rank=rank,
|
| 113 |
+
lora_alpha=1,
|
| 114 |
+
lora_a=lora_a,
|
| 115 |
+
lora_b=lora_b,
|
| 116 |
+
bias=bias,
|
| 117 |
+
embeddings_tensor=embeddings_tensor,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class PackedLoRALayerWeights(LoRALayerWeights):
|
| 122 |
+
"""LoRA used for packed layers (eg. qkv_proj)."""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
module_name: str,
|
| 127 |
+
rank: int,
|
| 128 |
+
lora_alphas: List[Optional[int]],
|
| 129 |
+
lora_a: List[Optional[torch.Tensor]],
|
| 130 |
+
lora_b: List[Optional[torch.Tensor]],
|
| 131 |
+
bias: Optional[List[Optional[torch.Tensor]]] = None,
|
| 132 |
+
scaling: Optional[List[float]] = None,
|
| 133 |
+
) -> None:
|
| 134 |
+
super().__init__(
|
| 135 |
+
module_name=module_name,
|
| 136 |
+
rank=rank,
|
| 137 |
+
lora_alpha=0,
|
| 138 |
+
lora_a=lora_a,
|
| 139 |
+
lora_b=lora_b,
|
| 140 |
+
bias=bias,
|
| 141 |
+
scaling=scaling, # type: ignore
|
| 142 |
+
embeddings_tensor=None,
|
| 143 |
+
)
|
| 144 |
+
self.lora_alphas = lora_alphas
|
| 145 |
+
if scaling is None:
|
| 146 |
+
self.scaling = [ # type: ignore
|
| 147 |
+
lora_alpha / self.rank # type: ignore # noqa
|
| 148 |
+
for lora_alpha in self.lora_alphas
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def pack(
|
| 153 |
+
cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
|
| 154 |
+
) -> "PackedLoRALayerWeights":
|
| 155 |
+
"""Pack a list of LoRAs into a single LoRA.
|
| 156 |
+
|
| 157 |
+
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
| 158 |
+
"""
|
| 159 |
+
first_lora = next(lora for lora in loras if lora is not None)
|
| 160 |
+
for lora in loras:
|
| 161 |
+
if lora is None:
|
| 162 |
+
continue
|
| 163 |
+
lora.optimize()
|
| 164 |
+
rank = first_lora.rank
|
| 165 |
+
module_name = first_lora.module_name
|
| 166 |
+
obj = cls(
|
| 167 |
+
module_name,
|
| 168 |
+
rank,
|
| 169 |
+
[lora.lora_alpha if lora is not None else None for lora in loras],
|
| 170 |
+
[lora.lora_a if lora is not None else None for lora in loras],
|
| 171 |
+
[lora.lora_b if lora is not None else None for lora in loras],
|
| 172 |
+
[lora.bias if lora is not None else None for lora in loras],
|
| 173 |
+
scaling=[
|
| 174 |
+
1 if lora is not None else None # type: ignore
|
| 175 |
+
for lora in loras
|
| 176 |
+
])
|
| 177 |
+
return obj
|
| 178 |
+
|
| 179 |
+
def optimize(self) -> "PackedLoRALayerWeights":
|
| 180 |
+
"""Optimize the LoRA by merging the scaling into lora_b."""
|
| 181 |
+
for i in range(len(self.lora_b)):
|
| 182 |
+
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
|
| 183 |
+
continue
|
| 184 |
+
self.lora_b[i] *= self.scaling[i] # type: ignore
|
| 185 |
+
self.scaling[i] = 1 # type: ignore
|
| 186 |
+
return self
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def input_dim(self) -> int:
|
| 190 |
+
raise NotImplementedError()
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def output_dim(self) -> int:
|
| 194 |
+
raise NotImplementedError()
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def is_packed(self) -> bool:
|
| 198 |
+
return True
|
.venv/lib/python3.11/site-packages/vllm/lora/models.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
| 9 |
+
|
| 10 |
+
import safetensors.torch
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
| 15 |
+
AdapterModelManager)
|
| 16 |
+
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
| 17 |
+
get_adapter, list_adapters,
|
| 18 |
+
remove_adapter, set_adapter_mapping)
|
| 19 |
+
from vllm.config import LoRAConfig
|
| 20 |
+
from vllm.logger import init_logger
|
| 21 |
+
from vllm.lora.layers import (BaseLayerWithLoRA,
|
| 22 |
+
LinearScalingRotaryEmbeddingWithLora,
|
| 23 |
+
LoRAMapping)
|
| 24 |
+
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
| 25 |
+
from vllm.lora.peft_helper import PEFTHelper
|
| 26 |
+
from vllm.lora.punica_wrapper import get_punica_wrapper
|
| 27 |
+
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
| 28 |
+
is_regex_target_modules,
|
| 29 |
+
parse_fine_tuned_lora_name, replace_submodule)
|
| 30 |
+
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
| 31 |
+
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
| 32 |
+
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
| 33 |
+
from vllm.utils import is_pin_memory_available
|
| 34 |
+
|
| 35 |
+
logger = init_logger(__name__)
|
| 36 |
+
|
| 37 |
+
_GLOBAL_LORA_ID = 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class LongContextLoRAContext:
|
| 42 |
+
"""Context for lora adapters that support long context."""
|
| 43 |
+
# The scaling factors to support long context lora fine tuned models.
|
| 44 |
+
scaling_factors: List[float]
|
| 45 |
+
# dimension to apply rotary embedding.
|
| 46 |
+
rot_dim: int
|
| 47 |
+
# offsets to the sin_cos_cache for each lora_id loaded.
|
| 48 |
+
# This value is dynamically modified.
|
| 49 |
+
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_lora_id():
|
| 53 |
+
global _GLOBAL_LORA_ID
|
| 54 |
+
_GLOBAL_LORA_ID += 1
|
| 55 |
+
return _GLOBAL_LORA_ID
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class LoRAModel(AdapterModel):
|
| 59 |
+
"""A LoRA fine-tuned model."""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
lora_model_id: int,
|
| 64 |
+
rank: int,
|
| 65 |
+
loras: Dict[str, LoRALayerWeights],
|
| 66 |
+
scaling_factor: Optional[float] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
lora_model_id: The integer id for the lora model.
|
| 71 |
+
rank: lora rank.
|
| 72 |
+
loras: module name -> weights for lora-replaced layers.
|
| 73 |
+
scaling_factor: Scaling factor to support long context lora model.
|
| 74 |
+
None if the lora is not tuned for long context support.
|
| 75 |
+
"""
|
| 76 |
+
self.id = lora_model_id
|
| 77 |
+
# Scaling factor for long context lora model. None if it is not
|
| 78 |
+
# fine tuned for the long context.
|
| 79 |
+
self.scaling_factor = scaling_factor
|
| 80 |
+
assert (
|
| 81 |
+
lora_model_id
|
| 82 |
+
> 0), f"a valid lora id should be greater than 0, got {self.id}"
|
| 83 |
+
self.rank = rank
|
| 84 |
+
self.loras: Dict[str, LoRALayerWeights] = loras
|
| 85 |
+
|
| 86 |
+
def clone(self, lora_model_id: int) -> "LoRAModel":
|
| 87 |
+
"""Return a copy of the object with different ids.
|
| 88 |
+
|
| 89 |
+
Will share the underlying tensors."""
|
| 90 |
+
return self.__class__(
|
| 91 |
+
lora_model_id,
|
| 92 |
+
rank=self.rank,
|
| 93 |
+
loras=self.loras.copy(),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def extra_vocab_size(self) -> int:
|
| 98 |
+
return max(lora.extra_vocab_size
|
| 99 |
+
for lora in self.loras.values()) if self.loras else 0
|
| 100 |
+
|
| 101 |
+
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
| 102 |
+
"""Get LoRA for a given module by name"""
|
| 103 |
+
return self.loras.get(module_name, None)
|
| 104 |
+
|
| 105 |
+
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
| 106 |
+
@classmethod
|
| 107 |
+
def from_lora_tensors(
|
| 108 |
+
cls,
|
| 109 |
+
lora_model_id: int,
|
| 110 |
+
tensors: Dict[str, torch.Tensor],
|
| 111 |
+
peft_helper: PEFTHelper,
|
| 112 |
+
device: str = "cuda",
|
| 113 |
+
dtype: Optional[torch.dtype] = None,
|
| 114 |
+
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
| 115 |
+
target_embedding_padding: Optional[int] = None,
|
| 116 |
+
embedding_modules: Optional[Dict[str, str]] = None,
|
| 117 |
+
embedding_padding_modules: Optional[List[str]] = None,
|
| 118 |
+
weights_mapper: Optional[WeightsMapper] = None,
|
| 119 |
+
) -> "LoRAModel":
|
| 120 |
+
"""Create a LoRAModel from a dictionary of tensors."""
|
| 121 |
+
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
| 122 |
+
loras: Dict[str, LoRALayerWeights] = {}
|
| 123 |
+
for tensor_name, tensor in tensors.items():
|
| 124 |
+
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
|
| 125 |
+
tensor_name, weights_mapper)
|
| 126 |
+
if module_name not in loras:
|
| 127 |
+
lora_embeddings_tensor = None
|
| 128 |
+
if embeddings:
|
| 129 |
+
assert embedding_modules is not None
|
| 130 |
+
embeddings_module = next(
|
| 131 |
+
(k for k in embedding_modules if k in module_name),
|
| 132 |
+
None)
|
| 133 |
+
if embeddings_module:
|
| 134 |
+
lora_embeddings_tensor = embeddings[
|
| 135 |
+
embedding_modules[embeddings_module]].to(
|
| 136 |
+
device=device, dtype=dtype)
|
| 137 |
+
if pin_memory:
|
| 138 |
+
lora_embeddings_tensor = (
|
| 139 |
+
lora_embeddings_tensor.pin_memory())
|
| 140 |
+
loras[module_name] = LoRALayerWeights.from_config(
|
| 141 |
+
module_name, peft_helper, lora_embeddings_tensor)
|
| 142 |
+
|
| 143 |
+
if is_bias:
|
| 144 |
+
loras[module_name].bias = tensor.to(device=device,
|
| 145 |
+
dtype=dtype).t()
|
| 146 |
+
bias = tensor.to(device=device, dtype=dtype).t()
|
| 147 |
+
if pin_memory:
|
| 148 |
+
bias = bias.pin_memory()
|
| 149 |
+
loras[module_name].bias = bias
|
| 150 |
+
elif is_lora_a:
|
| 151 |
+
loras[module_name].lora_a = tensor.to(device=device,
|
| 152 |
+
dtype=dtype).t()
|
| 153 |
+
if pin_memory:
|
| 154 |
+
loras[module_name].lora_a = loras[
|
| 155 |
+
module_name].lora_a.pin_memory()
|
| 156 |
+
else:
|
| 157 |
+
loras[module_name].lora_b = tensor.to(device=device,
|
| 158 |
+
dtype=dtype).t()
|
| 159 |
+
assert embedding_padding_modules is not None
|
| 160 |
+
if any(name in module_name
|
| 161 |
+
for name in embedding_padding_modules
|
| 162 |
+
) and target_embedding_padding is not None:
|
| 163 |
+
lora_b = loras[module_name].lora_b
|
| 164 |
+
assert target_embedding_padding >= lora_b.shape[1]
|
| 165 |
+
addition = target_embedding_padding - lora_b.shape[1]
|
| 166 |
+
loras[module_name].lora_b = torch.nn.functional.pad(
|
| 167 |
+
lora_b, (0, addition))
|
| 168 |
+
if pin_memory:
|
| 169 |
+
loras[module_name].lora_b = loras[
|
| 170 |
+
module_name].lora_b.pin_memory()
|
| 171 |
+
|
| 172 |
+
for lora in loras.values():
|
| 173 |
+
lora.optimize()
|
| 174 |
+
|
| 175 |
+
return cls(lora_model_id,
|
| 176 |
+
peft_helper.r,
|
| 177 |
+
loras,
|
| 178 |
+
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def from_local_checkpoint(
|
| 182 |
+
cls,
|
| 183 |
+
lora_dir: str,
|
| 184 |
+
expected_lora_modules: List[str],
|
| 185 |
+
peft_helper: PEFTHelper,
|
| 186 |
+
*,
|
| 187 |
+
lora_model_id: Optional[int] = None,
|
| 188 |
+
device: str = "cuda",
|
| 189 |
+
dtype: Optional[torch.dtype] = None,
|
| 190 |
+
target_embedding_padding: Optional[int] = None,
|
| 191 |
+
embedding_modules: Optional[Dict[str, str]] = None,
|
| 192 |
+
embedding_padding_modules: Optional[List[str]] = None,
|
| 193 |
+
weights_mapper: Optional[WeightsMapper] = None,
|
| 194 |
+
) -> "LoRAModel":
|
| 195 |
+
"""Create a LoRAModel from a local checkpoint.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
lora_dir: The local path that has lora data.
|
| 199 |
+
expected_lora_modules: Name of modules that are expected to be
|
| 200 |
+
replaced by lora.
|
| 201 |
+
peft_helper: Loaded lora configuration information.
|
| 202 |
+
lora_model_id: Lora model id. If not given, automatically set by
|
| 203 |
+
a global counter.
|
| 204 |
+
device: Device where the lora model is loaded.
|
| 205 |
+
dtype: dtype of the lora model weights.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Loaded LoRA Model.
|
| 209 |
+
"""
|
| 210 |
+
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
| 211 |
+
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
| 212 |
+
new_embeddings_tensor_path = os.path.join(
|
| 213 |
+
lora_dir, "new_embeddings.safetensors")
|
| 214 |
+
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
| 215 |
+
"new_embeddings.bin")
|
| 216 |
+
|
| 217 |
+
unexpected_modules: List[Union[list[str], str]]
|
| 218 |
+
if os.path.isfile(lora_tensor_path):
|
| 219 |
+
tensors: Dict[str, torch.Tensor] = {}
|
| 220 |
+
# Find unexpected modules.
|
| 221 |
+
# Use safetensor key as a source of truth to find expected modules.
|
| 222 |
+
# in peft if you have target_modules A, B, C and C does not exist
|
| 223 |
+
# in the model it won’t error and model will be trained with A, B
|
| 224 |
+
# loraified. C won’t exist in the safetensor but it will exist in
|
| 225 |
+
# the target_modules of the adapter_config.json.
|
| 226 |
+
unexpected_modules = []
|
| 227 |
+
with safetensors.safe_open(lora_tensor_path,
|
| 228 |
+
framework="pt") as f: # type: ignore
|
| 229 |
+
for lora_module in f.keys(): # noqa
|
| 230 |
+
module_name, _, _ = parse_fine_tuned_lora_name(
|
| 231 |
+
lora_module, weights_mapper)
|
| 232 |
+
part_name = module_name.split(".")[-1]
|
| 233 |
+
if part_name not in expected_lora_modules:
|
| 234 |
+
unexpected_modules.append(module_name)
|
| 235 |
+
if unexpected_modules:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"While loading {lora_dir}, expected"
|
| 238 |
+
f" target modules in {expected_lora_modules}"
|
| 239 |
+
f" but received {unexpected_modules}."
|
| 240 |
+
f" Please verify that the loaded LoRA module is correct"
|
| 241 |
+
)
|
| 242 |
+
# Load tensors if there are only expected modules.
|
| 243 |
+
for module in f.keys(): # noqa
|
| 244 |
+
tensors[module] = f.get_tensor(module)
|
| 245 |
+
elif os.path.isfile(lora_bin_file_path):
|
| 246 |
+
# When a bin file is provided, we rely on config to find unexpected
|
| 247 |
+
# modules.
|
| 248 |
+
unexpected_modules = []
|
| 249 |
+
target_modules = peft_helper.target_modules
|
| 250 |
+
if not isinstance(target_modules, list):
|
| 251 |
+
target_modules = [target_modules]
|
| 252 |
+
for module in target_modules:
|
| 253 |
+
# Compatible with more modules,
|
| 254 |
+
# such as:layers.11.self_attn.k_proj
|
| 255 |
+
part_name = module.split(".")[-1]
|
| 256 |
+
if part_name not in expected_lora_modules:
|
| 257 |
+
unexpected_modules.append(module)
|
| 258 |
+
# loaded lora's target modules must be a subset of
|
| 259 |
+
# expected_lora_modules. It is not reliable. See
|
| 260 |
+
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
| 261 |
+
# other better mechanism.
|
| 262 |
+
if unexpected_modules and not is_regex_target_modules(
|
| 263 |
+
peft_helper.target_modules, expected_lora_modules):
|
| 264 |
+
raise ValueError(
|
| 265 |
+
f"While loading {lora_dir}, expected"
|
| 266 |
+
f" target modules in {expected_lora_modules}"
|
| 267 |
+
f" but received {unexpected_modules}."
|
| 268 |
+
f" Please verify that the loaded LoRA module is correct")
|
| 269 |
+
tensors = torch.load(lora_bin_file_path, map_location=device)
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
| 272 |
+
|
| 273 |
+
embeddings = None
|
| 274 |
+
if os.path.isfile(new_embeddings_tensor_path):
|
| 275 |
+
embeddings = safetensors.torch.load_file(
|
| 276 |
+
new_embeddings_tensor_path)
|
| 277 |
+
elif os.path.isfile(new_embeddings_bin_file_path):
|
| 278 |
+
embeddings = torch.load(new_embeddings_bin_file_path,
|
| 279 |
+
map_location=device,
|
| 280 |
+
weights_only=True)
|
| 281 |
+
|
| 282 |
+
return cls.from_lora_tensors(
|
| 283 |
+
lora_model_id=get_lora_id()
|
| 284 |
+
if lora_model_id is None else lora_model_id,
|
| 285 |
+
tensors=tensors,
|
| 286 |
+
peft_helper=peft_helper,
|
| 287 |
+
device=device,
|
| 288 |
+
dtype=dtype,
|
| 289 |
+
embeddings=embeddings,
|
| 290 |
+
target_embedding_padding=target_embedding_padding,
|
| 291 |
+
embedding_modules=embedding_modules,
|
| 292 |
+
embedding_padding_modules=embedding_padding_modules,
|
| 293 |
+
weights_mapper=weights_mapper)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class LoRAModelManager(AdapterModelManager):
|
| 297 |
+
"""A manager that manages multiple LoRA-fine-tuned models."""
|
| 298 |
+
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
model: SupportsLoRA,
|
| 302 |
+
max_num_seqs: int,
|
| 303 |
+
max_num_batched_tokens: int,
|
| 304 |
+
vocab_size: int,
|
| 305 |
+
lora_config: LoRAConfig,
|
| 306 |
+
device: torch.device,
|
| 307 |
+
):
|
| 308 |
+
"""Create a LoRAModelManager and adapter for a given model.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
model: the model to be adapted.
|
| 312 |
+
max_num_seqs: the maximum number of sequences model can run in a
|
| 313 |
+
single batch.
|
| 314 |
+
max_num_batched_tokens: the maximum number of tokens model can run
|
| 315 |
+
in a single batch.
|
| 316 |
+
vocab_size: the vocab size of the model.
|
| 317 |
+
lora_config: the LoRA configuration.
|
| 318 |
+
"""
|
| 319 |
+
self.lora_config = lora_config
|
| 320 |
+
self.device = device
|
| 321 |
+
self.max_num_seqs = max_num_seqs
|
| 322 |
+
assert self.capacity >= self.lora_slots
|
| 323 |
+
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
| 324 |
+
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
| 325 |
+
self.vocab_size = vocab_size
|
| 326 |
+
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
| 327 |
+
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
|
| 328 |
+
max_batches=self.max_num_seqs,
|
| 329 |
+
device=self.device)
|
| 330 |
+
# Scaling factor -> offset to the sin_cos_cache to it.
|
| 331 |
+
# Used for long context lora.
|
| 332 |
+
self.scaling_factor_to_offset: Dict[float, int] = {}
|
| 333 |
+
super().__init__(model)
|
| 334 |
+
if hasattr(self.model, "supported_lora_modules"):
|
| 335 |
+
self.supported_lora_modules = copy.deepcopy(
|
| 336 |
+
self.model.supported_lora_modules)
|
| 337 |
+
if lora_config.long_lora_scaling_factors:
|
| 338 |
+
# We need to replace rotary emb layer to do batch computation
|
| 339 |
+
# for long lora.
|
| 340 |
+
self.supported_lora_modules.append("rotary_emb")
|
| 341 |
+
self.packed_modules_mapping = copy.deepcopy(
|
| 342 |
+
self.model.packed_modules_mapping)
|
| 343 |
+
# Used to indicate whether the model is a multimodal model
|
| 344 |
+
self.supports_mm: bool = (
|
| 345 |
+
supports_multimodal(self.model)
|
| 346 |
+
# In case the model only supports LoRA for
|
| 347 |
+
# text modules (e.g. ChatGLM)
|
| 348 |
+
and hasattr(self.model, "get_mm_mapping"))
|
| 349 |
+
self.packed_modules: Dict[str, List[str]] = {}
|
| 350 |
+
self.modules: Dict[str, BaseLayerWithLoRA] = {}
|
| 351 |
+
# Dict instead of a Set for compatibility with LRUCache.
|
| 352 |
+
self._last_mapping: Optional[LoRAMapping] = None
|
| 353 |
+
self._create_lora_modules()
|
| 354 |
+
self.model.lora_manager = self
|
| 355 |
+
self.adapter_type = 'LoRa'
|
| 356 |
+
|
| 357 |
+
@property
|
| 358 |
+
def capacity(self) -> int:
|
| 359 |
+
return self.lora_config.max_cpu_loras
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def lora_slots(self) -> int:
|
| 363 |
+
return self.lora_config.max_loras
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def adapter_slots(self) -> int:
|
| 367 |
+
return self.lora_slots
|
| 368 |
+
|
| 369 |
+
def activate_adapter(
|
| 370 |
+
self,
|
| 371 |
+
lora_id: int,
|
| 372 |
+
) -> bool:
|
| 373 |
+
"""Move LoRA into a GPU buffer to be used in the forward pass."""
|
| 374 |
+
if lora_id in self._active_adapters:
|
| 375 |
+
return False
|
| 376 |
+
first_free_slot = next(
|
| 377 |
+
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
|
| 378 |
+
if lora_id is None), None)
|
| 379 |
+
if first_free_slot is None:
|
| 380 |
+
raise ValueError("No free lora slots")
|
| 381 |
+
index, _ = first_free_slot
|
| 382 |
+
self._active_adapters[lora_id] = None
|
| 383 |
+
lora_model = self._registered_adapters[lora_id]
|
| 384 |
+
logger.debug("Activating LoRA. int id: %d, slot index: %d",
|
| 385 |
+
lora_model.id, index)
|
| 386 |
+
self.lora_index_to_id[index] = lora_model.id
|
| 387 |
+
for module_name, module in self.modules.items():
|
| 388 |
+
module_lora = lora_model.get_lora(module_name)
|
| 389 |
+
if module_lora:
|
| 390 |
+
module_lora.optimize()
|
| 391 |
+
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
| 392 |
+
bias = module_lora.bias
|
| 393 |
+
if ((torch.is_tensor(bias) or
|
| 394 |
+
(isinstance(bias, Sequence) and any(b is not None
|
| 395 |
+
for b in bias)))
|
| 396 |
+
and not self.lora_config.bias_enabled):
|
| 397 |
+
module_lora.bias = None
|
| 398 |
+
raise ValueError(
|
| 399 |
+
f"Adapter bias cannot be used for {module_name}"
|
| 400 |
+
" without --enable-lora-bias.")
|
| 401 |
+
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
| 402 |
+
module_lora.embeddings_tensor,
|
| 403 |
+
module_lora.bias)
|
| 404 |
+
else:
|
| 405 |
+
module.reset_lora(index)
|
| 406 |
+
return True
|
| 407 |
+
|
| 408 |
+
def _deactivate_adapter(self, lora_id: int):
|
| 409 |
+
try:
|
| 410 |
+
index = self.lora_index_to_id.index(lora_id)
|
| 411 |
+
self.lora_index_to_id[index] = None
|
| 412 |
+
except ValueError:
|
| 413 |
+
pass
|
| 414 |
+
|
| 415 |
+
def _set_long_lora_context(self, lora: LoRAModel):
|
| 416 |
+
if self.long_lora_context is None:
|
| 417 |
+
return
|
| 418 |
+
|
| 419 |
+
if lora.scaling_factor is None:
|
| 420 |
+
return
|
| 421 |
+
|
| 422 |
+
if (lora.scaling_factor not in self.scaling_factor_to_offset):
|
| 423 |
+
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
|
| 424 |
+
" has not been initialized.")
|
| 425 |
+
|
| 426 |
+
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
|
| 427 |
+
if offsets:
|
| 428 |
+
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
|
| 429 |
+
|
| 430 |
+
def _add_adapter(self, lora: LoRAModel):
|
| 431 |
+
self._create_merged_loras_inplace(lora)
|
| 432 |
+
self._registered_adapters[lora.id] = lora
|
| 433 |
+
self._set_long_lora_context(lora)
|
| 434 |
+
|
| 435 |
+
def pin_adapter(self, lora_id: int) -> bool:
|
| 436 |
+
"""Pin a LoRAModel in the manager cache."""
|
| 437 |
+
raise NotImplementedError(
|
| 438 |
+
"Pinning is not supported in LoRAModelManager."
|
| 439 |
+
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
|
| 440 |
+
|
| 441 |
+
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
| 442 |
+
# update lora states
|
| 443 |
+
self.punica_wrapper.update_metadata(
|
| 444 |
+
mapping,
|
| 445 |
+
self.lora_index_to_id,
|
| 446 |
+
self.lora_slots + 1,
|
| 447 |
+
self.vocab_size,
|
| 448 |
+
self.lora_config.lora_extra_vocab_size,
|
| 449 |
+
self.long_lora_context,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def remove_all_adapters(self):
|
| 453 |
+
"""Remove all LoRAModels from the manager."""
|
| 454 |
+
self._registered_adapters.clear()
|
| 455 |
+
self.lora_index_to_id = [None] * self.lora_slots
|
| 456 |
+
self._active_adapters.clear()
|
| 457 |
+
|
| 458 |
+
def _create_lora_modules(self):
|
| 459 |
+
for module_name, module in self.model.named_modules(
|
| 460 |
+
remove_duplicate=False):
|
| 461 |
+
if isinstance(module, PPMissingLayer):
|
| 462 |
+
continue
|
| 463 |
+
if not self._match_target_modules(module_name):
|
| 464 |
+
continue
|
| 465 |
+
# A temporary approach for multimodal models to support LoRA
|
| 466 |
+
# TODO: Remove this restriction
|
| 467 |
+
if self._filter_unsupported_mm_module(module_name):
|
| 468 |
+
logger.warning(
|
| 469 |
+
"Regarding multimodal models, vLLM currently only supports "
|
| 470 |
+
"adding LoRA to language model, %s will be ignored.",
|
| 471 |
+
module_name,
|
| 472 |
+
)
|
| 473 |
+
continue
|
| 474 |
+
parts = module_name.split(".")[-1]
|
| 475 |
+
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
| 476 |
+
new_module = replace_submodule(
|
| 477 |
+
self.model, module_name,
|
| 478 |
+
from_layer(module, self.lora_slots, self.lora_config,
|
| 479 |
+
packed_moduled_lst, self.model.config))
|
| 480 |
+
|
| 481 |
+
# LinearScalingRotaryEmbeddingWithLora is used to handle
|
| 482 |
+
# long context lora. Register relevant metadata.
|
| 483 |
+
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
|
| 484 |
+
self.long_lora_context = LongContextLoRAContext(
|
| 485 |
+
new_module.scaling_factors, new_module.rotary_dim)
|
| 486 |
+
self.scaling_factor_to_offset = \
|
| 487 |
+
new_module.scaling_factor_to_offset
|
| 488 |
+
# (yard1): TODO make this more robust
|
| 489 |
+
if "lm_head" in module_name:
|
| 490 |
+
logits_processor_module = self.model.get_submodule(
|
| 491 |
+
"logits_processor")
|
| 492 |
+
new_module = replace_submodule(
|
| 493 |
+
self.model, "logits_processor",
|
| 494 |
+
from_layer_logits_processor(logits_processor_module,
|
| 495 |
+
module, self.lora_slots,
|
| 496 |
+
self.lora_config,
|
| 497 |
+
self.model.config))
|
| 498 |
+
|
| 499 |
+
# In some models, especially multimodal ones, layers with the same
|
| 500 |
+
# name may have different types, such as nn.Linear and
|
| 501 |
+
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
|
| 502 |
+
# LoRA layers, leading to assertion error. The following check
|
| 503 |
+
# aims to prevent this error
|
| 504 |
+
if self.supports_mm and not isinstance(new_module,
|
| 505 |
+
BaseLayerWithLoRA):
|
| 506 |
+
continue
|
| 507 |
+
self.register_module(module_name, new_module)
|
| 508 |
+
self._register_packed_modules(module_name)
|
| 509 |
+
# All lora layers share the same punica_wrapper based on reference.
|
| 510 |
+
new_module.set_mapping(self.punica_wrapper)
|
| 511 |
+
|
| 512 |
+
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
| 513 |
+
assert isinstance(module, BaseLayerWithLoRA)
|
| 514 |
+
self.modules[module_name] = module
|
| 515 |
+
|
| 516 |
+
def create_dummy_lora(
|
| 517 |
+
self,
|
| 518 |
+
lora_id: int,
|
| 519 |
+
rank: int,
|
| 520 |
+
scaling_factor: Optional[float],
|
| 521 |
+
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
| 522 |
+
"""Create zero-initialized LoRAModel for warmup."""
|
| 523 |
+
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
| 524 |
+
for module_name, module in self.model.named_modules():
|
| 525 |
+
bias_enabled = self.lora_config.bias_enabled
|
| 526 |
+
if (not self._match_target_modules(module_name)
|
| 527 |
+
or not isinstance(module, BaseLayerWithLoRA)
|
| 528 |
+
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
| 529 |
+
or self._filter_unsupported_mm_module(module_name)):
|
| 530 |
+
continue
|
| 531 |
+
parts = module_name.split(".")
|
| 532 |
+
if module_name not in self.packed_modules:
|
| 533 |
+
assert embedding_modules is not None
|
| 534 |
+
if parts[-1] in embedding_modules:
|
| 535 |
+
input_dim = (module.base_layer.org_vocab_size +
|
| 536 |
+
self.lora_config.lora_extra_vocab_size if
|
| 537 |
+
hasattr(module.base_layer, "org_vocab_size")
|
| 538 |
+
else module.base_layer.weight.shape[1])
|
| 539 |
+
output_dim = module.base_layer.embedding_dim if hasattr(
|
| 540 |
+
module.base_layer,
|
| 541 |
+
"embedding_dim") else module.base_layer.weight.shape[0]
|
| 542 |
+
embeddings_tensor_dim = (module.base_layer.embedding_dim if
|
| 543 |
+
hasattr(module.base_layer,
|
| 544 |
+
"embedding_dim") else
|
| 545 |
+
module.base_layer.weight.shape[1])
|
| 546 |
+
lora = LoRALayerWeights.create_dummy_lora_weights(
|
| 547 |
+
module_name,
|
| 548 |
+
input_dim,
|
| 549 |
+
output_dim,
|
| 550 |
+
rank,
|
| 551 |
+
module.lora_a_stacked[0].dtype,
|
| 552 |
+
"cpu",
|
| 553 |
+
embeddings_tensor_dim=embeddings_tensor_dim,
|
| 554 |
+
bias_enabled=bias_enabled)
|
| 555 |
+
else:
|
| 556 |
+
lora = LoRALayerWeights.create_dummy_lora_weights(
|
| 557 |
+
module_name,
|
| 558 |
+
module.lora_a_stacked[0].shape[-1],
|
| 559 |
+
module.lora_b_stacked[0].shape[-2],
|
| 560 |
+
rank,
|
| 561 |
+
module.lora_a_stacked[0].dtype,
|
| 562 |
+
"cpu",
|
| 563 |
+
bias_enabled=bias_enabled,
|
| 564 |
+
)
|
| 565 |
+
lora.optimize()
|
| 566 |
+
else:
|
| 567 |
+
parts = module_name.split(".")
|
| 568 |
+
replacements = self.packed_modules_mapping[parts[-1]]
|
| 569 |
+
subloras: List[Optional[LoRALayerWeights]] = []
|
| 570 |
+
for i, r in enumerate(replacements):
|
| 571 |
+
lora = LoRALayerWeights.create_dummy_lora_weights(
|
| 572 |
+
module_name + "." + r,
|
| 573 |
+
module.lora_a_stacked[i].shape[-1],
|
| 574 |
+
module.lora_b_stacked[i].shape[-2],
|
| 575 |
+
rank,
|
| 576 |
+
module.lora_a_stacked[i].dtype,
|
| 577 |
+
"cpu",
|
| 578 |
+
bias_enabled=bias_enabled,
|
| 579 |
+
)
|
| 580 |
+
lora.optimize()
|
| 581 |
+
subloras.append(lora)
|
| 582 |
+
lora = PackedLoRALayerWeights.pack(subloras)
|
| 583 |
+
model.loras[module_name] = lora
|
| 584 |
+
return model
|
| 585 |
+
|
| 586 |
+
def _match_target_modules(self, module_name: str):
|
| 587 |
+
return any(
|
| 588 |
+
re.match(
|
| 589 |
+
r".*\.{target_module}$".format(target_module=target_module),
|
| 590 |
+
module_name) or target_module == module_name
|
| 591 |
+
for target_module in self.supported_lora_modules)
|
| 592 |
+
|
| 593 |
+
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
| 594 |
+
"""
|
| 595 |
+
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
| 596 |
+
language model. LoRA for other modules, such as the vision tower, will
|
| 597 |
+
be filtered out.
|
| 598 |
+
"""
|
| 599 |
+
if self.supports_mm:
|
| 600 |
+
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
| 601 |
+
prefix_lst = module_mapping.connector + module_mapping.tower_model
|
| 602 |
+
return any(
|
| 603 |
+
[module_name.startswith(prefix) for prefix in prefix_lst])
|
| 604 |
+
return False
|
| 605 |
+
|
| 606 |
+
def _register_packed_modules(self, module_full_name: str) -> None:
|
| 607 |
+
parts = module_full_name.split(".")
|
| 608 |
+
module_name = parts[-1]
|
| 609 |
+
replacements = self.packed_modules_mapping.get(module_name, [])
|
| 610 |
+
# When replacements is less than or equal to 1, it indicates that this
|
| 611 |
+
# module is not a packed module.
|
| 612 |
+
if len(replacements) <= 1:
|
| 613 |
+
return
|
| 614 |
+
prefix = ".".join(parts[:-1])
|
| 615 |
+
self.packed_modules[module_full_name] = [
|
| 616 |
+
prefix + "." + r if prefix else r for r in replacements
|
| 617 |
+
]
|
| 618 |
+
|
| 619 |
+
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
| 620 |
+
for module_name, new_module_names in self.packed_modules.items():
|
| 621 |
+
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
| 622 |
+
has_replacement = False
|
| 623 |
+
for r in new_module_names:
|
| 624 |
+
lora = lora_model.get_lora(r)
|
| 625 |
+
replacement_loras.append(lora)
|
| 626 |
+
if lora:
|
| 627 |
+
has_replacement = True
|
| 628 |
+
if not has_replacement:
|
| 629 |
+
continue
|
| 630 |
+
for i in range(len(replacement_loras)):
|
| 631 |
+
if replacement_loras[i]:
|
| 632 |
+
continue
|
| 633 |
+
replacement_loras[i] = None
|
| 634 |
+
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
| 635 |
+
replacement_loras)
|
| 636 |
+
|
| 637 |
+
def deactivate_adapter(self, adapter_id: int) -> bool:
|
| 638 |
+
return deactivate_adapter(adapter_id, self._active_adapters,
|
| 639 |
+
self._deactivate_adapter)
|
| 640 |
+
|
| 641 |
+
def add_adapter(self, adapter: LoRAModel) -> bool:
|
| 642 |
+
logger.debug(
|
| 643 |
+
"Adding lora. Model id: %d, "
|
| 644 |
+
"int id: %d, "
|
| 645 |
+
"scaling factor: %s", adapter.id, adapter.id,
|
| 646 |
+
adapter.scaling_factor)
|
| 647 |
+
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
| 648 |
+
self._add_adapter)
|
| 649 |
+
|
| 650 |
+
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
| 651 |
+
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
| 652 |
+
self._set_adapter_mapping)
|
| 653 |
+
|
| 654 |
+
def remove_adapter(self, adapter_id: int) -> bool:
|
| 655 |
+
return remove_adapter(adapter_id, self._registered_adapters,
|
| 656 |
+
self.deactivate_adapter)
|
| 657 |
+
|
| 658 |
+
def list_adapters(self) -> Dict[int, Any]:
|
| 659 |
+
return list_adapters(self._registered_adapters)
|
| 660 |
+
|
| 661 |
+
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
| 662 |
+
return get_adapter(adapter_id, self._registered_adapters)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
|
| 666 |
+
|
| 667 |
+
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
|
| 668 |
+
bool]):
|
| 669 |
+
super().__init__(capacity, deactivate_lora_fn)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
class LRUCacheLoRAModelManager(LoRAModelManager):
|
| 673 |
+
"""A model manager that manages multiple LoRAs with LRU cache."""
|
| 674 |
+
|
| 675 |
+
def __init__(self, model: nn.Module, max_num_seqs: int,
|
| 676 |
+
max_num_batched_tokens: int, vocab_size: int,
|
| 677 |
+
lora_config: LoRAConfig, device: torch.device):
|
| 678 |
+
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
| 679 |
+
vocab_size, lora_config, device)
|
| 680 |
+
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
| 681 |
+
self.capacity, self.deactivate_adapter)
|
| 682 |
+
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
| 683 |
+
self.lora_slots, self._deactivate_adapter)
|
| 684 |
+
|
| 685 |
+
def list_adapters(self) -> Dict[int, LoRAModel]:
|
| 686 |
+
"""List all registered LoRAModels."""
|
| 687 |
+
return dict(self._registered_adapters.cache)
|
| 688 |
+
|
| 689 |
+
def add_adapter(self, lora: LoRAModel) -> bool:
|
| 690 |
+
"""Add a LoRAModel to the manager."""
|
| 691 |
+
logger.debug(
|
| 692 |
+
"Adding lora. Model id: %d, "
|
| 693 |
+
"int id: %d, "
|
| 694 |
+
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
| 695 |
+
if lora.id not in self._registered_adapters:
|
| 696 |
+
self._add_adapter(lora)
|
| 697 |
+
was_added = True
|
| 698 |
+
else:
|
| 699 |
+
# We always touch to update the LRU cache order
|
| 700 |
+
self._registered_adapters.touch(lora.id)
|
| 701 |
+
was_added = False
|
| 702 |
+
return was_added
|
| 703 |
+
|
| 704 |
+
def activate_adapter(
|
| 705 |
+
self,
|
| 706 |
+
lora_id: int,
|
| 707 |
+
) -> bool:
|
| 708 |
+
if lora_id not in self._active_adapters and len(
|
| 709 |
+
self._active_adapters) >= self.lora_slots:
|
| 710 |
+
self._active_adapters.remove_oldest()
|
| 711 |
+
result = super().activate_adapter(lora_id)
|
| 712 |
+
# We always touch to update the LRU cache order
|
| 713 |
+
self._active_adapters.touch(lora_id)
|
| 714 |
+
return result
|
| 715 |
+
|
| 716 |
+
def remove_oldest_adapter(self) -> bool:
|
| 717 |
+
if len(self._registered_adapters) > 0:
|
| 718 |
+
self._registered_adapters.remove_oldest()
|
| 719 |
+
return True
|
| 720 |
+
return False
|
| 721 |
+
|
| 722 |
+
def pin_adapter(self, lora_id: int) -> bool:
|
| 723 |
+
"""Pin a LoRAModel in the manager cache."""
|
| 724 |
+
self._pin_lora_in_cpu_cache(lora_id)
|
| 725 |
+
self._pin_lora_in_gpu_cache(lora_id)
|
| 726 |
+
return True
|
| 727 |
+
|
| 728 |
+
def _pin_lora_in_cpu_cache(self, lora_id: int):
|
| 729 |
+
try:
|
| 730 |
+
self._registered_adapters.pin(lora_id)
|
| 731 |
+
except ValueError as err:
|
| 732 |
+
raise ValueError("Pinning failed. "
|
| 733 |
+
f"LoRA {lora_id} is not registered.") from err
|
| 734 |
+
|
| 735 |
+
def _pin_lora_in_gpu_cache(self, lora_id: int):
|
| 736 |
+
if lora_id not in self._active_adapters:
|
| 737 |
+
# move lora to gpu if not already active
|
| 738 |
+
self.activate_adapter(lora_id)
|
| 739 |
+
|
| 740 |
+
self._active_adapters.pin(lora_id)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def create_lora_manager(
|
| 744 |
+
model: nn.Module,
|
| 745 |
+
max_num_seqs: int,
|
| 746 |
+
max_num_batched_tokens: int,
|
| 747 |
+
vocab_size: int,
|
| 748 |
+
lora_config: LoRAConfig,
|
| 749 |
+
device: torch.device,
|
| 750 |
+
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
| 751 |
+
**kwargs) -> LoRAModelManager:
|
| 752 |
+
"""Create a LoRA adapter for a given model."""
|
| 753 |
+
if not hasattr(model, "supported_lora_modules"):
|
| 754 |
+
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
| 755 |
+
lora_manager = lora_manager_cls(
|
| 756 |
+
model=model,
|
| 757 |
+
max_num_seqs=max_num_seqs,
|
| 758 |
+
max_num_batched_tokens=max_num_batched_tokens,
|
| 759 |
+
vocab_size=vocab_size,
|
| 760 |
+
lora_config=lora_config,
|
| 761 |
+
device=device,
|
| 762 |
+
**kwargs)
|
| 763 |
+
return lora_manager
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (186 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
|
| 4 |
+
from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
|
| 5 |
+
sgmv_expand, sgmv_expand_slice,
|
| 6 |
+
sgmv_shrink)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"bgmv_expand",
|
| 10 |
+
"bgmv_expand_slice",
|
| 11 |
+
"bgmv_shrink",
|
| 12 |
+
"sgmv_expand",
|
| 13 |
+
"sgmv_expand_slice",
|
| 14 |
+
"sgmv_shrink",
|
| 15 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (546 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def sgmv_expand(inputs: torch.Tensor,
|
| 7 |
+
lora_b_weights: torch.Tensor,
|
| 8 |
+
output_tensor: torch.Tensor,
|
| 9 |
+
b_seq_start_loc: torch.Tensor,
|
| 10 |
+
seq_len_tensor: torch.Tensor,
|
| 11 |
+
lora_indices_tensor: torch.Tensor,
|
| 12 |
+
batches: int,
|
| 13 |
+
max_seq_length: int,
|
| 14 |
+
token_nums: int,
|
| 15 |
+
add_inputs: bool = False):
|
| 16 |
+
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
| 17 |
+
seq_len_tensor)
|
| 18 |
+
|
| 19 |
+
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
| 20 |
+
add_inputs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def bgmv_expand(inputs: torch.Tensor,
|
| 24 |
+
lora_b_weights: torch.Tensor,
|
| 25 |
+
output_tensor: torch.Tensor,
|
| 26 |
+
lora_indices_tensor: torch.Tensor,
|
| 27 |
+
add_inputs: bool = True):
|
| 28 |
+
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
| 29 |
+
dtype=output_tensor.dtype)
|
| 30 |
+
if len(selected_loras.shape) == 4:
|
| 31 |
+
selected_loras = selected_loras.squeeze(dim=1)
|
| 32 |
+
inputs = inputs.to(dtype=output_tensor.dtype)
|
| 33 |
+
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
| 34 |
+
|
| 35 |
+
limit = output_tensor.shape[0]
|
| 36 |
+
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
| 37 |
+
limit = 1
|
| 38 |
+
|
| 39 |
+
if add_inputs:
|
| 40 |
+
output_tensor[:, :outputs.shape[1]] += outputs[:limit, :]
|
| 41 |
+
else:
|
| 42 |
+
output_tensor[:, :outputs.shape[1]] = outputs[:limit, :]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def sgmv_shrink(
|
| 46 |
+
inputs: torch.Tensor,
|
| 47 |
+
lora_a_weights: torch.Tensor,
|
| 48 |
+
output_tensor: torch.Tensor,
|
| 49 |
+
b_seq_start_loc: torch.Tensor,
|
| 50 |
+
seq_len_tensor: torch.Tensor,
|
| 51 |
+
lora_indices_tensor: torch.Tensor,
|
| 52 |
+
batches: int,
|
| 53 |
+
max_seq_length: int,
|
| 54 |
+
token_nums: int,
|
| 55 |
+
scaling: float,
|
| 56 |
+
):
|
| 57 |
+
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
| 58 |
+
seq_len_tensor)
|
| 59 |
+
|
| 60 |
+
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
| 61 |
+
scaling)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def bgmv_shrink(inputs: torch.Tensor,
|
| 65 |
+
lora_b_weights: torch.Tensor,
|
| 66 |
+
output_tensor: torch.Tensor,
|
| 67 |
+
lora_indices_tensor: torch.Tensor,
|
| 68 |
+
scaling: float = 1.0):
|
| 69 |
+
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
| 70 |
+
dtype=output_tensor.dtype)
|
| 71 |
+
if len(selected_loras.shape) == 4:
|
| 72 |
+
selected_loras = selected_loras.squeeze(dim=1)
|
| 73 |
+
inputs = inputs.to(dtype=output_tensor.dtype)
|
| 74 |
+
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
| 75 |
+
|
| 76 |
+
output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def sgmv_expand_slice(inputs: torch.Tensor,
|
| 80 |
+
lora_b_weights: torch.Tensor,
|
| 81 |
+
output_tensor: torch.Tensor,
|
| 82 |
+
b_seq_start_loc: torch.Tensor,
|
| 83 |
+
seq_len_tensor: torch.Tensor,
|
| 84 |
+
lora_indices_tensor: torch.Tensor,
|
| 85 |
+
batches: int,
|
| 86 |
+
max_seq_length: int,
|
| 87 |
+
token_nums: int,
|
| 88 |
+
slice_offset: int,
|
| 89 |
+
slice_size: int,
|
| 90 |
+
add_inputs: bool = False):
|
| 91 |
+
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
| 92 |
+
seq_len_tensor)
|
| 93 |
+
|
| 94 |
+
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
| 95 |
+
slice_offset, slice_size, add_inputs)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def bgmv_expand_slice(inputs: torch.Tensor,
|
| 99 |
+
lora_b_weights: torch.Tensor,
|
| 100 |
+
output_tensor: torch.Tensor,
|
| 101 |
+
lora_indices_tensor: torch.Tensor,
|
| 102 |
+
slice_offset: int,
|
| 103 |
+
slice_size: int,
|
| 104 |
+
add_inputs: bool = True):
|
| 105 |
+
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
| 106 |
+
dtype=output_tensor.dtype)
|
| 107 |
+
inputs = inputs.to(dtype=output_tensor.dtype)
|
| 108 |
+
if len(selected_loras.shape) == 4:
|
| 109 |
+
selected_loras = selected_loras.squeeze(dim=1)
|
| 110 |
+
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
| 111 |
+
|
| 112 |
+
if add_inputs:
|
| 113 |
+
output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
|
| 114 |
+
else:
|
| 115 |
+
output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
| 4 |
+
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
| 5 |
+
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
| 6 |
+
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
| 7 |
+
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"bgmv_expand",
|
| 11 |
+
"bgmv_expand_slice",
|
| 12 |
+
"bgmv_shrink",
|
| 13 |
+
"sgmv_expand",
|
| 14 |
+
"sgmv_shrink",
|
| 15 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (708 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc
ADDED
|
Binary file (7.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|