koichi12 commited on
Commit
7695dda
·
verified ·
1 Parent(s): 00eb59b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py +18 -0
  10. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +162 -0
  22. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +302 -0
  23. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +260 -0
  24. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +324 -0
  25. .venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py +123 -0
  26. .venv/lib/python3.11/site-packages/vllm/lora/__init__.py +0 -0
  27. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py +335 -0
  37. .venv/lib/python3.11/site-packages/vllm/lora/layers.py +1206 -0
  38. .venv/lib/python3.11/site-packages/vllm/lora/lora.py +198 -0
  39. .venv/lib/python3.11/site-packages/vllm/lora/models.py +763 -0
  40. .venv/lib/python3.11/site-packages/vllm/lora/ops/__init__.py +0 -0
  41. .venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py +15 -0
  43. .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py +115 -0
  46. .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py +15 -0
  47. .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -108,3 +108,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
108
  .venv/lib/python3.11/site-packages/pillow.libs/liblcms2-e69eef39.so.2.0.16 filter=lfs diff=lfs merge=lfs -text
109
  .venv/lib/python3.11/site-packages/pillow.libs/libopenjp2-05423b53.so filter=lfs diff=lfs merge=lfs -text
110
  .venv/lib/python3.11/site-packages/pillow.libs/libxcb-b8a56d01.so.1.1.0 filter=lfs diff=lfs merge=lfs -text
 
 
108
  .venv/lib/python3.11/site-packages/pillow.libs/liblcms2-e69eef39.so.2.0.16 filter=lfs diff=lfs merge=lfs -text
109
  .venv/lib/python3.11/site-packages/pillow.libs/libopenjp2-05423b53.so filter=lfs diff=lfs merge=lfs -text
110
  .venv/lib/python3.11/site-packages/pillow.libs/libxcb-b8a56d01.so.1.1.0 filter=lfs diff=lfs merge=lfs -text
111
+ .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_model_runner.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/api_server.cpython-311.pyc ADDED
Binary file (8.59 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/chat_utils.cpython-311.pyc ADDED
Binary file (46.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/launcher.cpython-311.pyc ADDED
Binary file (6.12 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/llm.cpython-311.pyc ADDED
Binary file (65.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/logger.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.95 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from .abstract_tool_parser import ToolParser, ToolParserManager
4
+ from .granite_20b_fc_tool_parser import Granite20bFCToolParser
5
+ from .granite_tool_parser import GraniteToolParser
6
+ from .hermes_tool_parser import Hermes2ProToolParser
7
+ from .internlm2_tool_parser import Internlm2ToolParser
8
+ from .jamba_tool_parser import JambaToolParser
9
+ from .llama_tool_parser import Llama3JsonToolParser
10
+ from .mistral_tool_parser import MistralToolParser
11
+ from .pythonic_tool_parser import PythonicToolParser
12
+
13
+ __all__ = [
14
+ "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
15
+ "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
16
+ "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
17
+ "PythonicToolParser"
18
+ ]
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.03 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-311.pyc ADDED
Binary file (8.45 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_20b_fc_tool_parser.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/granite_tool_parser.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-311.pyc ADDED
Binary file (9.44 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/jamba_tool_parser.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/pythonic_tool_parser.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.99 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from functools import cached_property
5
+ from typing import Callable, Dict, List, Optional, Sequence, Type, Union
6
+
7
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
8
+ DeltaMessage,
9
+ ExtractedToolCallInformation)
10
+ from vllm.logger import init_logger
11
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
12
+ from vllm.utils import import_from_path, is_list_of
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class ToolParser:
18
+ """
19
+ Abstract ToolParser class that should not be used directly. Provided
20
+ properties and methods should be used in
21
+ derived classes.
22
+ """
23
+
24
+ def __init__(self, tokenizer: AnyTokenizer):
25
+ self.prev_tool_call_arr: List[Dict] = []
26
+ # the index of the tool call that is currently being parsed
27
+ self.current_tool_id: int = -1
28
+ self.current_tool_name_sent: bool = False
29
+ self.streamed_args_for_tool: List[str] = []
30
+
31
+ self.model_tokenizer = tokenizer
32
+
33
+ @cached_property
34
+ def vocab(self) -> Dict[str, int]:
35
+ # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
36
+ # whereas all tokenizers have .get_vocab()
37
+ return self.model_tokenizer.get_vocab()
38
+
39
+ def adjust_request(
40
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
41
+ """
42
+ Static method that used to adjust the request parameters.
43
+ """
44
+ return request
45
+
46
+ def extract_tool_calls(
47
+ self, model_output: str,
48
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
49
+ """
50
+ Static method that should be implemented for extracting tool calls from
51
+ a complete model-generated string.
52
+ Used for non-streaming responses where we have the entire model response
53
+ available before sending to the client.
54
+ Static because it's stateless.
55
+ """
56
+ raise NotImplementedError(
57
+ "AbstractToolParser.extract_tool_calls has not been implemented!")
58
+
59
+ def extract_tool_calls_streaming(
60
+ self,
61
+ previous_text: str,
62
+ current_text: str,
63
+ delta_text: str,
64
+ previous_token_ids: Sequence[int],
65
+ current_token_ids: Sequence[int],
66
+ delta_token_ids: Sequence[int],
67
+ request: ChatCompletionRequest,
68
+ ) -> Union[DeltaMessage, None]:
69
+ """
70
+ Instance method that should be implemented for extracting tool calls
71
+ from an incomplete response; for use when handling tool calls and
72
+ streaming. Has to be an instance method because it requires state -
73
+ the current tokens/diffs, but also the information about what has
74
+ previously been parsed and extracted (see constructor)
75
+ """
76
+ raise NotImplementedError(
77
+ "AbstractToolParser.extract_tool_calls_streaming has not been "
78
+ "implemented!")
79
+
80
+
81
+ class ToolParserManager:
82
+ tool_parsers: Dict[str, Type] = {}
83
+
84
+ @classmethod
85
+ def get_tool_parser(cls, name) -> Type:
86
+ """
87
+ Get tool parser by name which is registered by `register_module`.
88
+
89
+ Raise a KeyError exception if the name is not registered.
90
+ """
91
+ if name in cls.tool_parsers:
92
+ return cls.tool_parsers[name]
93
+
94
+ raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
95
+
96
+ @classmethod
97
+ def _register_module(cls,
98
+ module: Type,
99
+ module_name: Optional[Union[str, List[str]]] = None,
100
+ force: bool = True) -> None:
101
+ if not issubclass(module, ToolParser):
102
+ raise TypeError(
103
+ f'module must be subclass of ToolParser, but got {type(module)}'
104
+ )
105
+ if module_name is None:
106
+ module_name = module.__name__
107
+ if isinstance(module_name, str):
108
+ module_name = [module_name]
109
+ for name in module_name:
110
+ if not force and name in cls.tool_parsers:
111
+ existed_module = cls.tool_parsers[name]
112
+ raise KeyError(f'{name} is already registered '
113
+ f'at {existed_module.__module__}')
114
+ cls.tool_parsers[name] = module
115
+
116
+ @classmethod
117
+ def register_module(
118
+ cls,
119
+ name: Optional[Union[str, List[str]]] = None,
120
+ force: bool = True,
121
+ module: Union[Type, None] = None) -> Union[type, Callable]:
122
+ """
123
+ Register module with the given name or name list. it can be used as a
124
+ decoder(with module as None) or normal function(with module as not
125
+ None).
126
+ """
127
+ if not isinstance(force, bool):
128
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
129
+
130
+ # raise the error ahead of time
131
+ if not (name is None or isinstance(name, str)
132
+ or is_list_of(name, str)):
133
+ raise TypeError(
134
+ 'name must be None, an instance of str, or a sequence of str, '
135
+ f'but got {type(name)}')
136
+
137
+ # use it as a normal method: x.register_module(module=SomeClass)
138
+ if module is not None:
139
+ cls._register_module(module=module, module_name=name, force=force)
140
+ return module
141
+
142
+ # use it as a decorator: @x.register_module()
143
+ def _register(module):
144
+ cls._register_module(module=module, module_name=name, force=force)
145
+ return module
146
+
147
+ return _register
148
+
149
+ @classmethod
150
+ def import_tool_parser(cls, plugin_path: str) -> None:
151
+ """
152
+ Import a user-defined tool parser by the path of the tool parser define
153
+ file.
154
+ """
155
+ module_name = os.path.splitext(os.path.basename(plugin_path))[0]
156
+
157
+ try:
158
+ import_from_path(module_name, plugin_path)
159
+ except Exception:
160
+ logger.exception("Failed to load module '%s' from %s.",
161
+ module_name, plugin_path)
162
+ return
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from typing import Dict, List, Sequence, Union
6
+
7
+ import partial_json_parser
8
+ from partial_json_parser.core.options import Allow
9
+
10
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
11
+ DeltaFunctionCall, DeltaMessage,
12
+ DeltaToolCall,
13
+ ExtractedToolCallInformation,
14
+ FunctionCall, ToolCall)
15
+ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
16
+ from vllm.entrypoints.openai.tool_parsers.utils import (
17
+ extract_intermediate_diff)
18
+ from vllm.logger import init_logger
19
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
20
+ from vllm.transformers_utils.tokenizers import MistralTokenizer
21
+ from vllm.utils import random_uuid
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ @ToolParserManager.register_module("jamba")
27
+ class JambaToolParser(ToolParser):
28
+
29
+ def __init__(self, tokenizer: AnyTokenizer):
30
+ super().__init__(tokenizer)
31
+
32
+ if isinstance(self.model_tokenizer, MistralTokenizer):
33
+ raise ValueError(
34
+ "Detected a MistralTokenizer tokenizer when using a Jamba model"
35
+ )
36
+
37
+ self.current_tool_name_sent: bool = False
38
+ self.prev_tool_call_arr: List[Dict] = []
39
+ self.current_tool_id: int = -1
40
+ self.streamed_args_for_tool: List[str] = [
41
+ ] # map what has been streamed for each tool so far to a list
42
+
43
+ self.tool_calls_start_token: str = "<tool_calls>"
44
+ self.tool_calls_end_token: str = "</tool_calls>"
45
+
46
+ self.tool_calls_regex = re.compile(
47
+ rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
48
+ re.DOTALL)
49
+
50
+ if not self.model_tokenizer:
51
+ raise ValueError(
52
+ "The model tokenizer must be passed to the ToolParser "
53
+ "constructor during construction.")
54
+ self.tool_calls_start_token_id = self.vocab.get(
55
+ self.tool_calls_start_token)
56
+ self.tool_calls_end_token_id = self.vocab.get(
57
+ self.tool_calls_end_token)
58
+ if (self.tool_calls_start_token_id is None
59
+ or self.tool_calls_end_token_id is None):
60
+ raise RuntimeError(
61
+ "Jamba Tool parser could not locate tool calls start/end "
62
+ "tokens in the tokenizer!")
63
+
64
+ def adjust_request(
65
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
66
+ if request.tools and request.tool_choice != 'none':
67
+ # do not skip special tokens because jamba use the special
68
+ # tokens to indicate the start and end of the tool calls
69
+ # information.
70
+ request.skip_special_tokens = False
71
+ return request
72
+
73
+ def extract_tool_calls(
74
+ self, model_output: str,
75
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
76
+
77
+ # sanity check; avoid unnecessary processing
78
+ if self.tool_calls_start_token not in model_output:
79
+ return ExtractedToolCallInformation(tools_called=False,
80
+ tool_calls=[],
81
+ content=model_output)
82
+
83
+ else:
84
+
85
+ try:
86
+ # use a regex to find the tool call between the tags
87
+ function_calls = self.tool_calls_regex.findall(model_output)[0]
88
+
89
+ # load the JSON, and then use it to build the Function and
90
+ # Tool Call
91
+ raw_function_calls = json.loads(function_calls)
92
+ tool_calls = [
93
+ ToolCall(
94
+ type="function",
95
+ function=FunctionCall(
96
+ name=function_call["name"],
97
+ # function call args are JSON but as a string
98
+ arguments=json.dumps(function_call["arguments"])))
99
+ for function_call in raw_function_calls
100
+ ]
101
+
102
+ content = model_output[:model_output.
103
+ find(self.tool_calls_start_token)]
104
+ return ExtractedToolCallInformation(
105
+ tools_called=True,
106
+ tool_calls=tool_calls,
107
+ content=content if
108
+ (len(content) > 0 and content != " ") else None)
109
+
110
+ except Exception:
111
+ logger.exception(
112
+ "Error in extracting tool call from response.")
113
+ return ExtractedToolCallInformation(tools_called=False,
114
+ tool_calls=[],
115
+ content=model_output)
116
+
117
+ def extract_tool_calls_streaming(
118
+ self,
119
+ previous_text: str,
120
+ current_text: str,
121
+ delta_text: str,
122
+ previous_token_ids: Sequence[int],
123
+ current_token_ids: Sequence[int],
124
+ delta_token_ids: Sequence[int],
125
+ request: ChatCompletionRequest,
126
+ ) -> Union[DeltaMessage, None]:
127
+
128
+ # if the tool call token is not in the tokens generated so far, append
129
+ # output to contents since it's not a tool
130
+ if self.tool_calls_start_token not in current_text:
131
+ return DeltaMessage(content=delta_text)
132
+
133
+ # if the tool call token ID IS in the tokens generated so far, that
134
+ # means we're parsing as tool calls now
135
+
136
+ # handle if we detected the start of tool calls token which means
137
+ # the start of tool calling
138
+ if (self.tool_calls_start_token_id in delta_token_ids
139
+ and len(delta_token_ids) == 1):
140
+ # if it's the only token, return None, so we don't send a chat
141
+ # completion and don't send a control token
142
+ return None
143
+
144
+ # bit mask flags for partial JSON parsing. If the name hasn't been
145
+ # sent yet, don't allow sending
146
+ # an incomplete string since OpenAI only ever (as far as I have
147
+ # seen) allows sending the entire tool/ function name at once.
148
+ flags = Allow.ALL if self.current_tool_name_sent \
149
+ else Allow.ALL & ~Allow.STR
150
+ try:
151
+
152
+ # Extract the tool calls between the special tool call tokens
153
+ parsable_arr = current_text.split(
154
+ self.tool_calls_start_token)[-1].split(
155
+ self.tool_calls_end_token)[0]
156
+
157
+ # tool calls are generated in an array, so do partial JSON
158
+ # parsing on the entire array
159
+ try:
160
+ tool_call_arr: List[Dict] = partial_json_parser.loads(
161
+ parsable_arr, flags)
162
+ except partial_json_parser.core.exceptions.MalformedJSON:
163
+ logger.debug('not enough tokens to parse into JSON yet')
164
+ return None
165
+
166
+ # select as the current tool call the one we're on the state at
167
+
168
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
169
+ if len(tool_call_arr) > 0 else {}
170
+
171
+ # case -- if no tokens have been streamed for the tool, e.g.
172
+ # only the array brackets, stream nothing
173
+ if len(tool_call_arr) == 0:
174
+ return None
175
+
176
+ # case: we are starting a new tool in the array
177
+ # -> array has > 0 length AND length has moved past cursor
178
+ elif (len(tool_call_arr) > 0
179
+ and len(tool_call_arr) > self.current_tool_id + 1):
180
+
181
+ # if we're moving on to a new call, first make sure we
182
+ # haven't missed anything in the previous one that was
183
+ # auto-generated due to JSON completions, but wasn't
184
+ # streamed to the client yet.
185
+ if self.current_tool_id >= 0:
186
+ diff: Union[str, None] = current_tool_call.get("arguments")
187
+
188
+ if diff:
189
+ diff = json.dumps(diff).replace(
190
+ self.streamed_args_for_tool[self.current_tool_id],
191
+ "")
192
+ delta = DeltaMessage(tool_calls=[
193
+ DeltaToolCall(index=self.current_tool_id,
194
+ function=DeltaFunctionCall(
195
+ arguments=diff).model_dump(
196
+ exclude_none=True))
197
+ ])
198
+ self.streamed_args_for_tool[
199
+ self.current_tool_id] += diff
200
+ else:
201
+ delta = None
202
+ else:
203
+ delta = None
204
+ # re-set stuff pertaining to progress in the current tool
205
+ self.current_tool_id = len(tool_call_arr) - 1
206
+ self.current_tool_name_sent = False
207
+ self.streamed_args_for_tool.append("")
208
+ logger.debug("starting on new tool %d", self.current_tool_id)
209
+ return delta
210
+
211
+ # case: update an existing tool - this is handled below
212
+
213
+ # if the current tool name hasn't been sent, send if available
214
+ # - otherwise send nothing
215
+ if not self.current_tool_name_sent:
216
+ function_name = current_tool_call.get("name")
217
+ if function_name:
218
+
219
+ delta = DeltaMessage(tool_calls=[
220
+ DeltaToolCall(index=self.current_tool_id,
221
+ type="function",
222
+ id=f"chatcmpl-tool-{random_uuid()}",
223
+ function=DeltaFunctionCall(
224
+ name=function_name).model_dump(
225
+ exclude_none=True))
226
+ ])
227
+ self.current_tool_name_sent = True
228
+ else:
229
+ delta = None
230
+
231
+ # now we know we're on the same tool call and we're streaming
232
+ # arguments
233
+ else:
234
+
235
+ prev_arguments = self.prev_tool_call_arr[
236
+ self.current_tool_id].get("arguments")
237
+ cur_arguments = current_tool_call.get("arguments")
238
+
239
+ new_text = delta_text.replace("\'", "\"")
240
+
241
+ if not cur_arguments and not prev_arguments:
242
+
243
+ delta = None
244
+ elif not cur_arguments and prev_arguments:
245
+ logger.error(
246
+ "INVARIANT - impossible to have arguments reset "
247
+ "mid-arguments")
248
+ delta = None
249
+ elif cur_arguments and not prev_arguments:
250
+ cur_arguments_json = json.dumps(cur_arguments)
251
+ logger.debug("finding %s in %s", new_text,
252
+ cur_arguments_json)
253
+
254
+ arguments_delta = cur_arguments_json[:cur_arguments_json.
255
+ index(new_text) +
256
+ len(new_text)]
257
+ logger.debug("First tokens in arguments received: %s",
258
+ arguments_delta)
259
+ delta = DeltaMessage(tool_calls=[
260
+ DeltaToolCall(index=self.current_tool_id,
261
+ function=DeltaFunctionCall(
262
+ arguments=arguments_delta).
263
+ model_dump(exclude_none=True))
264
+ ])
265
+ self.streamed_args_for_tool[
266
+ self.current_tool_id] += arguments_delta
267
+
268
+ elif cur_arguments and prev_arguments:
269
+ cur_args_json = json.dumps(cur_arguments)
270
+ prev_args_json = json.dumps(prev_arguments)
271
+ logger.debug("Searching for diff between \n%s\n%s",
272
+ cur_args_json, prev_args_json)
273
+
274
+ argument_diff = extract_intermediate_diff(
275
+ cur_args_json, prev_args_json)
276
+ logger.debug("got arguments diff: %s", argument_diff)
277
+ delta = DeltaMessage(tool_calls=[
278
+ DeltaToolCall(index=self.current_tool_id,
279
+ function=DeltaFunctionCall(
280
+ arguments=argument_diff).model_dump(
281
+ exclude_none=True))
282
+ ])
283
+ self.streamed_args_for_tool[
284
+ self.current_tool_id] += argument_diff
285
+ else:
286
+ # try parsing it with regular JSON - if it works we're
287
+ # at the end, and we need to send the difference between
288
+ # tokens streamed so far and the valid JSON
289
+ delta = None
290
+
291
+ # check to see if the name is defined and has been sent. if so,
292
+ # stream the name - otherwise keep waiting
293
+ # finish by setting old and returning None as base case
294
+ self.prev_tool_call_arr = tool_call_arr
295
+ return delta
296
+
297
+ except Exception:
298
+ logger.exception("Error trying to handle streaming tool call.")
299
+ logger.debug(
300
+ "Skipping chunk as a result of tool streaming extraction "
301
+ "error")
302
+ return None
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from json import JSONDecoder
6
+ from typing import Dict, List, Sequence, Union
7
+
8
+ import partial_json_parser
9
+ from partial_json_parser.core.options import Allow
10
+ from transformers import PreTrainedTokenizerBase
11
+
12
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
13
+ DeltaFunctionCall, DeltaMessage,
14
+ DeltaToolCall,
15
+ ExtractedToolCallInformation,
16
+ FunctionCall, ToolCall)
17
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
+ ToolParser, ToolParserManager)
19
+ from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
20
+ is_complete_json,
21
+ partial_json_loads)
22
+ from vllm.logger import init_logger
23
+ from vllm.utils import random_uuid
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ @ToolParserManager.register_module("llama3_json")
29
+ class Llama3JsonToolParser(ToolParser):
30
+ """
31
+ Tool call parser for Llama 3.1 models intended for use with the
32
+ examples/tool_chat_template_llama.jinja template.
33
+
34
+ Used when --enable-auto-tool-choice --tool-call-parser llama3_json
35
+ are all set
36
+ """
37
+
38
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
39
+ super().__init__(tokenizer)
40
+
41
+ # initialize properties used for state when parsing tool calls in
42
+ # streaming mode
43
+ self.prev_tool_call_arr: List[Dict] = []
44
+ self.current_tool_id: int = -1
45
+ self.current_tool_name_sent: bool = False
46
+ self.streamed_args_for_tool: List[str] = [
47
+ ] # map what has been streamed for each tool so far to a list
48
+ self.bot_token = "<|python_tag|>"
49
+ self.bot_token_id = tokenizer.encode(self.bot_token,
50
+ add_special_tokens=False)[0]
51
+ self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
52
+
53
+ def extract_tool_calls(
54
+ self, model_output: str,
55
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
56
+ """
57
+ Extract the tool calls from a complete model response.
58
+ """
59
+ # case -- if a tool call token is not present, return a text response
60
+ if not (model_output.startswith(self.bot_token)
61
+ or model_output.startswith('{')):
62
+ return ExtractedToolCallInformation(tools_called=False,
63
+ tool_calls=[],
64
+ content=model_output)
65
+
66
+ try:
67
+ # load the JSON, and then use it to build the Function and
68
+ # Tool Call
69
+ dec = JSONDecoder()
70
+ function_call_arr = []
71
+
72
+ # depending on the prompt format the Llama model may or may not
73
+ # prefix the output with the <|python_tag|> token
74
+ start_idx = len(self.bot_token) if model_output.startswith(
75
+ self.bot_token) else 0
76
+ while start_idx < len(model_output):
77
+ (obj, end_idx) = dec.raw_decode(model_output[start_idx:])
78
+ start_idx += end_idx + len('; ')
79
+ function_call_arr.append(obj)
80
+
81
+ tool_calls: List[ToolCall] = [
82
+ ToolCall(
83
+ type="function",
84
+ function=FunctionCall(
85
+ name=raw_function_call["name"],
86
+ # function call args are JSON but as a string
87
+ arguments=json.dumps(raw_function_call["arguments"] \
88
+ if "arguments" in raw_function_call \
89
+ else raw_function_call["parameters"])))
90
+ for raw_function_call in function_call_arr
91
+ ]
92
+
93
+ # get any content before the tool call
94
+ ret = ExtractedToolCallInformation(tools_called=True,
95
+ tool_calls=tool_calls,
96
+ content=None)
97
+ return ret
98
+
99
+ except Exception:
100
+ logger.exception("Error in extracting tool call from response.")
101
+ # return information to just treat the tool call as regular JSON
102
+ return ExtractedToolCallInformation(tools_called=False,
103
+ tool_calls=[],
104
+ content=model_output)
105
+
106
+ def extract_tool_calls_streaming(
107
+ self,
108
+ previous_text: str,
109
+ current_text: str,
110
+ delta_text: str,
111
+ previous_token_ids: Sequence[int],
112
+ current_token_ids: Sequence[int],
113
+ delta_token_ids: Sequence[int],
114
+ request: ChatCompletionRequest,
115
+ ) -> Union[DeltaMessage, None]:
116
+
117
+ if not (current_text.startswith(self.bot_token)
118
+ or current_text.startswith('{')):
119
+ return DeltaMessage(content=delta_text)
120
+
121
+ # bit mask flags for partial JSON parsing. If the name hasn't been
122
+ # sent yet, don't allow sending
123
+ # an incomplete string since OpenAI only ever (as far as I have
124
+ # seen) allows sending the entire tool/ function name at once.
125
+ flags = Allow.ALL if self.current_tool_name_sent \
126
+ else Allow.ALL & ~Allow.STR
127
+ try:
128
+ tool_call_arr = []
129
+ is_complete = []
130
+ try:
131
+ # depending on the prompt format the Llama model may or may not
132
+ # prefix the output with the <|python_tag|> token
133
+ start_idx = len(self.bot_token) if current_text.startswith(
134
+ self.bot_token) else 0
135
+ while start_idx < len(current_text):
136
+ (obj,
137
+ end_idx) = partial_json_loads(current_text[start_idx:],
138
+ flags)
139
+ is_complete.append(
140
+ is_complete_json(current_text[start_idx:start_idx +
141
+ end_idx]))
142
+ start_idx += end_idx + len('; ')
143
+ # depending on the prompt Llama can use
144
+ # either arguments or parameters
145
+ if "parameters" in obj:
146
+ assert "arguments" not in obj, \
147
+ "model generated both parameters and arguments"
148
+ obj["arguments"] = obj["parameters"]
149
+ tool_call_arr.append(obj)
150
+ except partial_json_parser.core.exceptions.MalformedJSON:
151
+ logger.debug('not enough tokens to parse into JSON yet')
152
+ return None
153
+
154
+ # select as the current tool call the one we're on the state at
155
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
156
+ if len(tool_call_arr) > 0 else {}
157
+
158
+ # case -- if no tokens have been streamed for the tool, e.g.
159
+ # only the array brackets, stream nothing
160
+ if len(tool_call_arr) == 0:
161
+ return None
162
+
163
+ # case: we are starting a new tool in the array
164
+ # -> array has > 0 length AND length has moved past cursor
165
+ elif (len(tool_call_arr) > 0
166
+ and len(tool_call_arr) > self.current_tool_id + 1):
167
+
168
+ # if we're moving on to a new call, first make sure we
169
+ # haven't missed anything in the previous one that was
170
+ # auto-generated due to JSON completions, but wasn't
171
+ # streamed to the client yet.
172
+ if self.current_tool_id >= 0:
173
+ cur_arguments = current_tool_call.get("arguments")
174
+ if cur_arguments:
175
+ cur_args_json = json.dumps(cur_arguments)
176
+ sent = len(
177
+ self.streamed_args_for_tool[self.current_tool_id])
178
+ argument_diff = cur_args_json[sent:]
179
+
180
+ logger.debug("got arguments diff: %s", argument_diff)
181
+ delta = DeltaMessage(tool_calls=[
182
+ DeltaToolCall(index=self.current_tool_id,
183
+ function=DeltaFunctionCall(
184
+ arguments=argument_diff).
185
+ model_dump(exclude_none=True))
186
+ ])
187
+ self.streamed_args_for_tool[
188
+ self.current_tool_id] += argument_diff
189
+ else:
190
+ delta = None
191
+ else:
192
+ delta = None
193
+ # re-set stuff pertaining to progress in the current tool
194
+ self.current_tool_id = len(tool_call_arr) - 1
195
+ self.current_tool_name_sent = False
196
+ self.streamed_args_for_tool.append("")
197
+ logger.debug("starting on new tool %d", self.current_tool_id)
198
+ return delta
199
+
200
+ # if the current tool name hasn't been sent, send if available
201
+ # - otherwise send nothing
202
+ elif not self.current_tool_name_sent:
203
+ function_name = current_tool_call.get("name")
204
+ if function_name:
205
+
206
+ delta = DeltaMessage(tool_calls=[
207
+ DeltaToolCall(index=self.current_tool_id,
208
+ type="function",
209
+ id=f"chatcmpl-tool-{random_uuid()}",
210
+ function=DeltaFunctionCall(
211
+ name=function_name).model_dump(
212
+ exclude_none=True))
213
+ ])
214
+ self.current_tool_name_sent = True
215
+ else:
216
+ delta = None
217
+
218
+ # now we know we're on the same tool call and we're streaming
219
+ # arguments
220
+ else:
221
+ cur_arguments = current_tool_call.get("arguments")
222
+ delta = None
223
+
224
+ if cur_arguments:
225
+ sent = len(
226
+ self.streamed_args_for_tool[self.current_tool_id])
227
+ cur_args_json = json.dumps(cur_arguments)
228
+ prev_arguments = self.prev_tool_call_arr[
229
+ self.current_tool_id].get("arguments")
230
+
231
+ argument_diff = None
232
+ if is_complete[self.current_tool_id]:
233
+ argument_diff = cur_args_json[sent:]
234
+ elif prev_arguments:
235
+ prev_args_json = json.dumps(prev_arguments)
236
+ if cur_args_json != prev_args_json:
237
+
238
+ prefix = find_common_prefix(
239
+ prev_args_json, cur_args_json)
240
+ argument_diff = prefix[sent:]
241
+
242
+ if argument_diff is not None:
243
+ delta = DeltaMessage(tool_calls=[
244
+ DeltaToolCall(index=self.current_tool_id,
245
+ function=DeltaFunctionCall(
246
+ arguments=argument_diff).
247
+ model_dump(exclude_none=True))
248
+ ])
249
+ self.streamed_args_for_tool[
250
+ self.current_tool_id] += argument_diff
251
+
252
+ self.prev_tool_call_arr = tool_call_arr
253
+ return delta
254
+
255
+ except Exception:
256
+ logger.exception("Error trying to handle streaming tool call.")
257
+ logger.debug(
258
+ "Skipping chunk as a result of tool streaming extraction "
259
+ "error")
260
+ return None
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from random import choices
6
+ from string import ascii_letters, digits
7
+ from typing import Dict, List, Sequence, Union
8
+
9
+ import partial_json_parser
10
+ from partial_json_parser.core.options import Allow
11
+ from pydantic import Field
12
+
13
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
14
+ DeltaFunctionCall, DeltaMessage,
15
+ DeltaToolCall,
16
+ ExtractedToolCallInformation,
17
+ FunctionCall, ToolCall)
18
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
19
+ ToolParser, ToolParserManager)
20
+ from vllm.entrypoints.openai.tool_parsers.utils import (
21
+ extract_intermediate_diff)
22
+ from vllm.logger import init_logger
23
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
24
+
25
+ logger = init_logger(__name__)
26
+
27
+ ALPHANUMERIC = ascii_letters + digits
28
+
29
+
30
+ class MistralToolCall(ToolCall):
31
+ id: str = Field(
32
+ default_factory=lambda: MistralToolCall.generate_random_id())
33
+
34
+ @staticmethod
35
+ def generate_random_id():
36
+ # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
37
+ # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
38
+ return "".join(choices(ALPHANUMERIC, k=9))
39
+
40
+
41
+ @ToolParserManager.register_module("mistral")
42
+ class MistralToolParser(ToolParser):
43
+ """
44
+ Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
45
+ examples/tool_chat_template_mistral.jinja template.
46
+
47
+ Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
48
+ """
49
+
50
+ def __init__(self, tokenizer: AnyTokenizer):
51
+ super().__init__(tokenizer)
52
+
53
+ if not isinstance(self.model_tokenizer, MistralTokenizer):
54
+ logger.info("Non-Mistral tokenizer detected when using a Mistral "
55
+ "model...")
56
+
57
+ # initialize properties used for state when parsing tool calls in
58
+ # streaming mode
59
+ self.prev_tool_call_arr: List[Dict] = []
60
+ self.current_tool_id: int = -1
61
+ self.current_tool_name_sent: bool = False
62
+ self.streamed_args_for_tool: List[str] = [
63
+ ] # map what has been streamed for each tool so far to a list
64
+ self.bot_token = "[TOOL_CALLS]"
65
+ self.bot_token_id = self.vocab.get(self.bot_token)
66
+ self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
67
+ if self.bot_token_id is None:
68
+ raise RuntimeError(
69
+ "Mistral Tool Parser could not locate the tool call token in "
70
+ "the tokenizer!")
71
+
72
+ def extract_tool_calls(
73
+ self,
74
+ model_output: str,
75
+ request: ChatCompletionRequest,
76
+ ) -> ExtractedToolCallInformation:
77
+ """
78
+ Extract the tool calls from a complete model response. Requires
79
+ find-and-replacing single quotes with double quotes for JSON parsing,
80
+ make sure your tool call arguments don't ever include quotes!
81
+ """
82
+
83
+ # case -- if a tool call token is not present, return a text response
84
+ if self.bot_token not in model_output:
85
+ return ExtractedToolCallInformation(tools_called=False,
86
+ tool_calls=[],
87
+ content=model_output)
88
+
89
+ # first remove the BOT token
90
+ tool_content = model_output.replace(self.bot_token, "").strip()
91
+
92
+ try:
93
+
94
+ # we first try to directly load the json as parsing very nested
95
+ # jsons is difficult
96
+ try:
97
+ function_call_arr = json.loads(tool_content)
98
+ except json.JSONDecodeError:
99
+ # use a regex to find the part corresponding to the tool call.
100
+ # NOTE: This use case should not happen if the model is trained
101
+ # correctly. It's a easy possible fix so it's included, but
102
+ # can be brittle for very complex / highly nested tool calls
103
+ raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
104
+ function_call_arr = json.loads(raw_tool_call)
105
+
106
+ # Tool Call
107
+ tool_calls: List[MistralToolCall] = [
108
+ MistralToolCall(
109
+ type="function",
110
+ function=FunctionCall(
111
+ name=raw_function_call["name"],
112
+ # function call args are JSON but as a string
113
+ arguments=json.dumps(raw_function_call["arguments"],
114
+ ensure_ascii=False)))
115
+ for raw_function_call in function_call_arr
116
+ ]
117
+
118
+ # get any content before the tool call
119
+ content = model_output.split(self.bot_token)[0]
120
+ return ExtractedToolCallInformation(
121
+ tools_called=True,
122
+ tool_calls=tool_calls,
123
+ content=content if len(content) > 0 else None)
124
+
125
+ except Exception:
126
+ logger.exception("Error in extracting tool call from response.")
127
+ # return information to just treat the tool call as regular JSON
128
+ return ExtractedToolCallInformation(tools_called=False,
129
+ tool_calls=[],
130
+ content=tool_content)
131
+
132
+ def extract_tool_calls_streaming(
133
+ self,
134
+ previous_text: str,
135
+ current_text: str,
136
+ delta_text: str,
137
+ previous_token_ids: Sequence[int],
138
+ current_token_ids: Sequence[int],
139
+ delta_token_ids: Sequence[int],
140
+ request: ChatCompletionRequest,
141
+ ) -> Union[DeltaMessage, None]:
142
+
143
+ # if the tool call token is not in the tokens generated so far, append
144
+ # output to contents since it's not a tool
145
+ if self.bot_token not in current_text:
146
+ return DeltaMessage(content=delta_text)
147
+
148
+ # if the tool call token ID IS in the tokens generated so far, that
149
+ # means we're parsing as tool calls now
150
+
151
+ # handle if we detected the BOT token which means the start of tool
152
+ # calling
153
+ if (self.bot_token_id in delta_token_ids
154
+ and len(delta_token_ids) == 1):
155
+ # if it's the only token, return None, so we don't send a chat
156
+ # completion any don't send a control token
157
+ return None
158
+
159
+ # bit mask flags for partial JSON parsing. If the name hasn't been
160
+ # sent yet, don't allow sending
161
+ # an incomplete string since OpenAI only ever (as far as I have
162
+ # seen) allows sending the entire tool/ function name at once.
163
+ flags = Allow.ALL if self.current_tool_name_sent \
164
+ else Allow.ALL & ~Allow.STR
165
+ try:
166
+
167
+ # replace BOT token with empty string, and convert single quotes
168
+ # to double to allow parsing as JSON since mistral uses single
169
+ # quotes instead of double for tool calls
170
+ parsable_arr = current_text.split(self.bot_token)[-1]
171
+
172
+ # tool calls are generated in an array, so do partial JSON
173
+ # parsing on the entire array
174
+ try:
175
+ tool_call_arr: List[Dict] = partial_json_parser.loads(
176
+ parsable_arr, flags)
177
+ except partial_json_parser.core.exceptions.MalformedJSON:
178
+ logger.debug('not enough tokens to parse into JSON yet')
179
+ return None
180
+
181
+ # select as the current tool call the one we're on the state at
182
+
183
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
184
+ if len(tool_call_arr) > 0 else {}
185
+
186
+ # case -- if no tokens have been streamed for the tool, e.g.
187
+ # only the array brackets, stream nothing
188
+ if len(tool_call_arr) == 0:
189
+ return None
190
+
191
+ # case: we are starting a new tool in the array
192
+ # -> array has > 0 length AND length has moved past cursor
193
+ elif (len(tool_call_arr) > 0
194
+ and len(tool_call_arr) > self.current_tool_id + 1):
195
+
196
+ # if we're moving on to a new call, first make sure we
197
+ # haven't missed anything in the previous one that was
198
+ # auto-generated due to JSON completions, but wasn't
199
+ # streamed to the client yet.
200
+ if self.current_tool_id >= 0:
201
+ diff: Union[str, None] = current_tool_call.get("arguments")
202
+
203
+ if diff:
204
+ diff = json.dumps(diff, ensure_ascii=False).replace(
205
+ self.streamed_args_for_tool[self.current_tool_id],
206
+ "")
207
+ delta = DeltaMessage(tool_calls=[
208
+ DeltaToolCall(index=self.current_tool_id,
209
+ function=DeltaFunctionCall(
210
+ arguments=diff).model_dump(
211
+ exclude_none=True))
212
+ ])
213
+ self.streamed_args_for_tool[
214
+ self.current_tool_id] += diff
215
+ else:
216
+ delta = None
217
+ else:
218
+ delta = None
219
+ # re-set stuff pertaining to progress in the current tool
220
+ self.current_tool_id = len(tool_call_arr) - 1
221
+ self.current_tool_name_sent = False
222
+ self.streamed_args_for_tool.append("")
223
+ logger.debug("starting on new tool %d", self.current_tool_id)
224
+ return delta
225
+
226
+ # case: update an existing tool - this is handled below
227
+
228
+ # if the current tool name hasn't been sent, send if available
229
+ # - otherwise send nothing
230
+ if not self.current_tool_name_sent:
231
+ function_name = current_tool_call.get("name")
232
+ if function_name:
233
+
234
+ delta = DeltaMessage(tool_calls=[
235
+ DeltaToolCall(index=self.current_tool_id,
236
+ type="function",
237
+ id=MistralToolCall.generate_random_id(),
238
+ function=DeltaFunctionCall(
239
+ name=function_name).model_dump(
240
+ exclude_none=True))
241
+ ])
242
+ self.current_tool_name_sent = True
243
+ else:
244
+ delta = None
245
+
246
+ # now we know we're on the same tool call and we're streaming
247
+ # arguments
248
+ else:
249
+
250
+ prev_arguments = self.prev_tool_call_arr[
251
+ self.current_tool_id].get("arguments")
252
+ cur_arguments = current_tool_call.get("arguments")
253
+
254
+ new_text = delta_text.replace("\'", "\"")
255
+ if ('"}' in new_text):
256
+ new_text = new_text[:new_text.rindex('"}')]
257
+
258
+ if not cur_arguments and not prev_arguments:
259
+
260
+ delta = None
261
+ elif not cur_arguments and prev_arguments:
262
+ logger.error(
263
+ "INVARIANT - impossible to have arguments reset "
264
+ "mid-arguments")
265
+ delta = None
266
+ elif cur_arguments and not prev_arguments:
267
+ cur_arguments_json = json.dumps(cur_arguments,
268
+ ensure_ascii=False)[:-2]
269
+ logger.debug("finding %s in %s", new_text,
270
+ cur_arguments_json)
271
+
272
+ if (new_text not in cur_arguments_json):
273
+ return None
274
+ arguments_delta = cur_arguments_json[:cur_arguments_json.
275
+ rindex(new_text) +
276
+ len(new_text)]
277
+ logger.debug("First tokens in arguments received: %s",
278
+ arguments_delta)
279
+ delta = DeltaMessage(tool_calls=[
280
+ DeltaToolCall(index=self.current_tool_id,
281
+ function=DeltaFunctionCall(
282
+ arguments=arguments_delta).
283
+ model_dump(exclude_none=True))
284
+ ])
285
+ self.streamed_args_for_tool[
286
+ self.current_tool_id] += arguments_delta
287
+
288
+ elif cur_arguments and prev_arguments:
289
+ cur_args_json = json.dumps(cur_arguments,
290
+ ensure_ascii=False)
291
+ prev_args_json = json.dumps(prev_arguments,
292
+ ensure_ascii=False)
293
+ logger.debug("Searching for diff between \n%s\n%s",
294
+ cur_args_json, prev_args_json)
295
+
296
+ argument_diff = extract_intermediate_diff(
297
+ cur_args_json, prev_args_json)
298
+ logger.debug("got arguments diff: %s", argument_diff)
299
+ delta = DeltaMessage(tool_calls=[
300
+ DeltaToolCall(index=self.current_tool_id,
301
+ function=DeltaFunctionCall(
302
+ arguments=argument_diff).model_dump(
303
+ exclude_none=True))
304
+ ])
305
+ self.streamed_args_for_tool[
306
+ self.current_tool_id] += argument_diff
307
+ else:
308
+ # try parsing it with regular JSON - if it works we're
309
+ # at the end, and we need to send the difference between
310
+ # tokens streamed so far and the valid JSON
311
+ delta = None
312
+
313
+ # check to see if the name is defined and has been sent. if so,
314
+ # stream the name - otherwise keep waiting
315
+ # finish by setting old and returning None as base case
316
+ self.prev_tool_call_arr = tool_call_arr
317
+ return delta
318
+
319
+ except Exception:
320
+ logger.exception("Error trying to handle streaming tool call.")
321
+ logger.debug(
322
+ "Skipping chunk as a result of tool streaming extraction "
323
+ "error")
324
+ return None
.venv/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ from json import JSONDecodeError, JSONDecoder
5
+ from typing import Any, List, Tuple
6
+
7
+ import partial_json_parser
8
+ from partial_json_parser.core.options import Allow
9
+
10
+
11
+ def find_common_prefix(s1: str, s2: str) -> str:
12
+ """
13
+ Finds a common prefix that is shared between two strings, if there is one.
14
+ Order of arguments is NOT important.
15
+
16
+ This function is provided as a UTILITY for extracting information from JSON
17
+ generated by partial_json_parser, to help in ensuring that the right tokens
18
+ are returned in streaming, so that close-quotes, close-brackets and
19
+ close-braces are not returned prematurely.
20
+
21
+ e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
22
+ '{"fruit": "ap'
23
+ """
24
+ prefix = ''
25
+ min_length = min(len(s1), len(s2))
26
+ for i in range(0, min_length):
27
+ if s1[i] == s2[i]:
28
+ prefix += s1[i]
29
+ else:
30
+ break
31
+ return prefix
32
+
33
+
34
+ def find_common_suffix(s1: str, s2: str) -> str:
35
+ """
36
+ Finds a common suffix shared between two strings, if there is one. Order of
37
+ arguments is NOT important.
38
+ Stops when the suffix ends OR it hits an alphanumeric character
39
+
40
+ e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
41
+ """
42
+ suffix = ''
43
+ min_length = min(len(s1), len(s2))
44
+ for i in range(1, min_length + 1):
45
+ if s1[-i] == s2[-i] and not s1[-i].isalnum():
46
+ suffix = s1[-i] + suffix
47
+ else:
48
+ break
49
+ return suffix
50
+
51
+
52
+ def extract_intermediate_diff(curr: str, old: str) -> str:
53
+ """
54
+ Given two strings, extract the difference in the middle between two strings
55
+ that are known to have a common prefix and/or suffix.
56
+
57
+ This function is provided as a UTILITY for extracting information from JSON
58
+ generated by partial_json_parser, to help in ensuring that the right tokens
59
+ are returned in streaming, so that close-quotes, close-brackets and
60
+ close-braces are not returned prematurely. The order of arguments IS
61
+ important - the new version of the partially-parsed JSON must be the first
62
+ argument, and the secnod argument must be from the previous generation.
63
+
64
+ What it returns, is tokens that should be streamed to the client.
65
+
66
+ e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
67
+ -> 'ple'
68
+
69
+ """
70
+ suffix = find_common_suffix(curr, old)
71
+
72
+ old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
73
+ prefix = find_common_prefix(curr, old)
74
+ diff = curr
75
+ if len(suffix):
76
+ diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
77
+
78
+ if len(prefix):
79
+ # replace the prefix only once in case it's mirrored
80
+ diff = diff.replace(prefix, '', 1)
81
+
82
+ return diff
83
+
84
+
85
+ def find_all_indices(string: str, substring: str) -> List[int]:
86
+ """
87
+ Find all (starting) indices of a substring in a given string. Useful for
88
+ tool call extraction
89
+ """
90
+ indices = []
91
+ index = -1
92
+ while True:
93
+ index = string.find(substring, index + 1)
94
+ if index == -1:
95
+ break
96
+ indices.append(index)
97
+ return indices
98
+
99
+
100
+ # partial_json_parser doesn't support extra data and
101
+ # JSONDecorder.raw_decode doesn't support partial JSON
102
+ def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
103
+ try:
104
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
105
+ except JSONDecodeError as e:
106
+ if "Extra data" in e.msg:
107
+ dec = JSONDecoder()
108
+ return dec.raw_decode(input_str)
109
+ raise
110
+
111
+
112
+ def is_complete_json(input_str: str) -> bool:
113
+ try:
114
+ json.loads(input_str)
115
+ return True
116
+ except JSONDecodeError:
117
+ return False
118
+
119
+
120
+ def consume_space(i: int, s: str) -> int:
121
+ while i < len(s) and s[i].isspace():
122
+ i += 1
123
+ return i
.venv/lib/python3.11/site-packages/vllm/lora/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (182 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/fully_sharded_layers.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/layers.cpython-311.pyc ADDED
Binary file (59.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/lora.cpython-311.pyc ADDED
Binary file (9.57 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/models.cpython-311.pyc ADDED
Binary file (39.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/peft_helper.cpython-311.pyc ADDED
Binary file (6.91 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/request.cpython-311.pyc ADDED
Binary file (4.46 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.49 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/__pycache__/worker_manager.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/fully_sharded_layers.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # pylint: disable=unused-argument
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import PretrainedConfig
9
+
10
+ from vllm.config import LoRAConfig
11
+ from vllm.distributed.communication_op import (
12
+ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
13
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
14
+ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
15
+ MergedColumnParallelLinearWithLoRA,
16
+ MergedQKVParallelLinearWithLora,
17
+ QKVParallelLinearWithLora,
18
+ RowParallelLinearWithLoRA)
19
+
20
+ if TYPE_CHECKING:
21
+ pass
22
+
23
+
24
+ def _fully_sharded_can_replace(can_replace):
25
+ """
26
+ decorator which adds the condition of fully sharded loras
27
+ intended to wrap can_replace_layer()
28
+ """
29
+
30
+ def dec(*args, **kwargs):
31
+ return (can_replace(*args, **kwargs)
32
+ and kwargs["lora_config"].fully_sharded_loras)
33
+
34
+ return dec
35
+
36
+
37
+ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
38
+ """
39
+ For `ColumnParallelLinearWithLoRA` or classes that inherit from
40
+ `ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
41
+ """
42
+ assert (layer.n_slices == len(layer.lora_a_stacked) == len(
43
+ layer.lora_b_stacked) == len(layer.output_slices))
44
+ if layer.lora_bias_stacked is not None:
45
+ assert layer.n_slices == len(layer.lora_bias_stacked)
46
+
47
+ output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
48
+
49
+ x = x.view(-1, x.shape[-1])
50
+ output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
51
+
52
+ # Since communication is needed, the buffer is directly initialized as a
53
+ # tensor rather than a tuple of tensor.
54
+ buffers = torch.zeros(
55
+ (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
56
+ dtype=torch.float32,
57
+ device=x.device,
58
+ )
59
+
60
+ layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
61
+ buffers = tensor_model_parallel_all_gather(buffers)
62
+ layer.punica_wrapper.add_expand(output,
63
+ buffers,
64
+ layer.lora_b_stacked,
65
+ layer.lora_bias_stacked,
66
+ layer.output_slices,
67
+ offset_start=0,
68
+ add_input=True)
69
+
70
+ output = output.view(*out_orig_shape)
71
+ # now have column partitioned and packed output
72
+ return output
73
+
74
+
75
+ # these layers are based on the tensor parallelism strategy given in
76
+ # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
77
+ # https://arxiv.org/abs/2311.03285.
78
+
79
+
80
+ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
81
+ """
82
+ Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
83
+
84
+ Based on S-LoRA, slicing happens along the rank dim.
85
+ """
86
+
87
+ # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
88
+ # their `lora_a` and `lora_b` have different sharding patterns. After
89
+ # completing the `lora_a` GEMM , a gather operation is performed.
90
+ # Therefore, the sharding of `lora_a` only needs to correspond with the
91
+ # gather operation.
92
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
93
+ tp_rank = get_tensor_model_parallel_rank()
94
+ shard_size = self.lora_a_stacked[0].shape[2]
95
+ start_idx = tp_rank * shard_size
96
+ lora_a = lora_a[:, start_idx:start_idx + shard_size]
97
+ return lora_a
98
+
99
+ def apply(self,
100
+ x: torch.Tensor,
101
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
102
+ return _mcp_apply(x, bias, self)
103
+
104
+ @classmethod
105
+ @_fully_sharded_can_replace
106
+ def can_replace_layer(
107
+ cls,
108
+ source_layer: nn.Module,
109
+ lora_config: LoRAConfig,
110
+ packed_modules_list: List,
111
+ model_config: Optional[PretrainedConfig],
112
+ ) -> bool:
113
+ # specifying kwargs so they can be easily accessed in decorator
114
+ return super().can_replace_layer(
115
+ source_layer=source_layer,
116
+ lora_config=lora_config,
117
+ packed_modules_list=packed_modules_list,
118
+ model_config=model_config,
119
+ decorate=False,
120
+ )
121
+
122
+
123
+ class MergedColumnParallelLinearWithShardedLoRA(
124
+ MergedColumnParallelLinearWithLoRA):
125
+ """
126
+ Differs from MergedColumnParallelLinearWithLoRA by slicing the
127
+ LoRA A's also.
128
+
129
+ Based on S-LoRA, slicing happens along the rank dim.
130
+ """
131
+
132
+ def slice_lora_a(
133
+ self, lora_a: List[Union[torch.Tensor, None]]
134
+ ) -> List[Union[torch.Tensor, None]]:
135
+ #NOTE: lora_a contains 2 subloras, and each sublora could be None.
136
+ output_shard_size = self.lora_a_stacked[0].shape[2]
137
+ output_start_idx = self.tp_rank * output_shard_size
138
+ lora_a = [
139
+ lora_a[0][:, output_start_idx:output_start_idx +
140
+ output_shard_size] if lora_a[0] is not None else None,
141
+ lora_a[1][:, output_start_idx:output_start_idx +
142
+ output_shard_size] if lora_a[1] is not None else None,
143
+ ]
144
+ return lora_a
145
+
146
+ def apply(self,
147
+ x: torch.Tensor,
148
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
149
+ return _mcp_apply(x, bias, self)
150
+
151
+ @classmethod
152
+ @_fully_sharded_can_replace
153
+ def can_replace_layer(
154
+ cls,
155
+ source_layer: nn.Module,
156
+ lora_config: LoRAConfig,
157
+ packed_modules_list: List,
158
+ model_config: Optional[PretrainedConfig],
159
+ ) -> bool:
160
+ # specifying kwargs so they can be easily accessed in decorator
161
+ return super().can_replace_layer(
162
+ source_layer=source_layer,
163
+ lora_config=lora_config,
164
+ packed_modules_list=packed_modules_list,
165
+ model_config=model_config,
166
+ decorate=False,
167
+ )
168
+
169
+
170
+ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
171
+ """
172
+ Differs from QKVParallelLinearWithLora by slicing the
173
+ LoRA A's also.
174
+
175
+ Based on S-LoRA, slicing happens along the rank dim.
176
+ """
177
+
178
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
179
+ tp_rank = get_tensor_model_parallel_rank()
180
+ shard_size = self.lora_a_stacked[0].shape[2]
181
+ start_idx = tp_rank * shard_size
182
+ lora_a = lora_a[:, start_idx:start_idx + shard_size]
183
+ return lora_a
184
+
185
+ def apply(self,
186
+ x: torch.Tensor,
187
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
188
+ return _mcp_apply(x, bias, self)
189
+
190
+ @classmethod
191
+ @_fully_sharded_can_replace
192
+ def can_replace_layer(cls, source_layer: nn.Module,
193
+ lora_config: LoRAConfig, packed_modules_list: List,
194
+ model_config: Optional[PretrainedConfig]) -> bool:
195
+ # specifying kwargs so they can be easily accessed in decorator
196
+ return super().can_replace_layer(
197
+ source_layer=source_layer,
198
+ lora_config=lora_config,
199
+ packed_modules_list=packed_modules_list,
200
+ model_config=model_config,
201
+ decorate=False,
202
+ )
203
+
204
+
205
+ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
206
+ """
207
+ Differs from MergedQKVParallelLinearWithLora by slicing the
208
+ LoRA A's also.
209
+
210
+ Based on S-LoRA, slicing happens along the rank dim.
211
+ """
212
+
213
+ def slice_lora_a(
214
+ self, lora_a: List[Union[torch.Tensor, None]]
215
+ ) -> List[Union[torch.Tensor, None]]:
216
+ # NOTE: lora_a contains 3 subloras, and each sublora could be None.
217
+ shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
218
+ start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
219
+ lora_a = [
220
+ lora_a[0][:, start_idx[0]:start_idx[0] +
221
+ shard_size[0]] if lora_a[0] is not None else None,
222
+ lora_a[1][:, start_idx[1]:start_idx[1] +
223
+ shard_size[1]] if lora_a[1] is not None else None,
224
+ lora_a[2][:, start_idx[2]:start_idx[2] +
225
+ shard_size[2]] if lora_a[2] is not None else None,
226
+ ]
227
+ return lora_a
228
+
229
+ def apply(self,
230
+ x: torch.Tensor,
231
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
232
+ return _mcp_apply(x, bias, self)
233
+
234
+ @classmethod
235
+ @_fully_sharded_can_replace
236
+ def can_replace_layer(
237
+ cls,
238
+ source_layer: nn.Module,
239
+ lora_config: LoRAConfig,
240
+ packed_modules_list: List,
241
+ model_config: Optional[PretrainedConfig],
242
+ ) -> bool:
243
+ # specifying kwargs so they can be easily accessed in decorator
244
+ return super().can_replace_layer(
245
+ source_layer=source_layer,
246
+ lora_config=lora_config,
247
+ packed_modules_list=packed_modules_list,
248
+ model_config=model_config,
249
+ decorate=False,
250
+ )
251
+
252
+
253
+ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
254
+ """
255
+ Differs from RowParallelLinearWithLoRA by slicing the
256
+ LoRA B's also.
257
+
258
+ Based on S-LoRA, slicing happens along the output dim.
259
+ This yields a combined partial sum from the row parallel base
260
+ layer and column partitioned output from the LoRA.
261
+ """
262
+
263
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
264
+ shard_size = self.lora_b_stacked[0].shape[2]
265
+ start_idx = self.tp_rank * shard_size
266
+ end_idx = (self.tp_rank + 1) * shard_size
267
+ lora_b = lora_b[:, start_idx:end_idx]
268
+ return lora_b
269
+
270
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
271
+ if bias is None:
272
+ return bias
273
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
274
+ self.lora_bias_stacked)
275
+ shard_size = self.lora_bias_stacked[0].shape[2]
276
+ start_idx = self.tp_rank * shard_size
277
+ end_idx = (self.tp_rank + 1) * shard_size
278
+ bias = bias[start_idx:end_idx]
279
+ return bias
280
+
281
+ def apply(self,
282
+ x: torch.Tensor,
283
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
284
+ output = self.base_layer.quant_method.apply(self.base_layer, x)
285
+
286
+ x = x.view(-1, x.shape[-1])
287
+ output, out_orig_shape = output.view(-1,
288
+ output.shape[-1]), output.shape
289
+ buffer = torch.zeros(
290
+ (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
291
+ dtype=torch.float32,
292
+ device=x.device,
293
+ )
294
+
295
+ self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
296
+ buffer = tensor_model_parallel_all_reduce(buffer)
297
+
298
+ # following S-LoRA, allows the fusing of all_gather and all_reduce
299
+ # by adding the column partitioned lora output to a slice of output
300
+ # tensor, which is a partial sum due to row parallel. All that
301
+ # remains is a standard all_reduce. User should be aware though that
302
+ # the output is not the same as a normal row_parallel, it should be
303
+ # reduced before being used
304
+ # NOTE offset are based on the rank.
305
+ shard_size = self.lora_b_stacked[0].shape[2]
306
+ offset_start = self.tp_rank * shard_size
307
+ self.punica_wrapper.add_expand(
308
+ output,
309
+ buffer,
310
+ self.lora_b_stacked,
311
+ self.lora_bias_stacked,
312
+ self.output_slices,
313
+ offset_start=offset_start,
314
+ add_input=True,
315
+ )
316
+ output = output.view(*out_orig_shape)
317
+ return output
318
+
319
+ @classmethod
320
+ @_fully_sharded_can_replace
321
+ def can_replace_layer(
322
+ cls,
323
+ source_layer: nn.Module,
324
+ lora_config: LoRAConfig,
325
+ packed_modules_list: List,
326
+ model_config: Optional[PretrainedConfig],
327
+ ) -> bool:
328
+ # specifying kwargs so they can be easily accessed in decorator
329
+ return super().can_replace_layer(
330
+ source_layer=source_layer,
331
+ lora_config=lora_config,
332
+ packed_modules_list=packed_modules_list,
333
+ model_config=model_config,
334
+ decorate=False,
335
+ )
.venv/lib/python3.11/site-packages/vllm/lora/layers.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # pylint: disable=unused-argument
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import PretrainedConfig
12
+
13
+ from vllm.adapter_commons.layers import AdapterMapping
14
+ from vllm.config import LoRAConfig
15
+ from vllm.distributed import (get_tensor_model_parallel_rank,
16
+ get_tensor_model_parallel_world_size,
17
+ split_tensor_along_last_dim,
18
+ tensor_model_parallel_all_gather,
19
+ tensor_model_parallel_all_reduce,
20
+ tensor_model_parallel_gather)
21
+ from vllm.distributed.utils import divide
22
+ # yapf: disable
23
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
+ LinearBase,
25
+ MergedColumnParallelLinear,
26
+ QKVParallelLinear,
27
+ ReplicatedLinear,
28
+ RowParallelLinear)
29
+ # yapf: enable
30
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
+ from vllm.model_executor.layers.rotary_embedding import (
32
+ LinearScalingRotaryEmbedding, RotaryEmbedding)
33
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
34
+ VocabParallelEmbedding)
35
+ from vllm.platforms import current_platform
36
+
37
+ if TYPE_CHECKING:
38
+ from vllm.lora.punica_wrapper import PunicaWrapperBase
39
+
40
+
41
+ def _get_lora_device(base_layer: nn.Module) -> torch.device:
42
+ # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
43
+ """Returns the device for where to place the LoRA tensors."""
44
+ # unquantizedLinear
45
+ if hasattr(base_layer, "weight"):
46
+ return base_layer.weight.device
47
+ # Compressed Tensor
48
+ elif hasattr(base_layer, "weight_packed"):
49
+ return base_layer.weight_packed.device
50
+ # GPTQ/AWQ
51
+ elif hasattr(base_layer, "qweight"):
52
+ return base_layer.qweight.device
53
+ # marlin
54
+ elif hasattr(base_layer, "B"):
55
+ return base_layer.B.device
56
+ # HQQ marlin
57
+ elif hasattr(base_layer, "W_q"):
58
+ return base_layer.W_q.device
59
+ else:
60
+ raise ValueError(f"Unsupported base layer: {base_layer}")
61
+
62
+
63
+ def _not_fully_sharded_can_replace(can_replace):
64
+ """
65
+ decorator which adds the condition of not using fully sharded loras
66
+ intended to wrap can_replace_layer()
67
+ """
68
+
69
+ def dec(*args, **kwargs):
70
+ decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
71
+ condition = (not kwargs["lora_config"].fully_sharded_loras
72
+ if decorate else True)
73
+ return can_replace(*args, **kwargs) and condition
74
+
75
+ return dec
76
+
77
+
78
+ @dataclass
79
+ class LoRAMapping(AdapterMapping):
80
+ is_prefill: bool = False
81
+
82
+
83
+ class BaseLayerWithLoRA(nn.Module):
84
+
85
+ def slice_lora_a(
86
+ self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
87
+ ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
88
+ """Slice lora a if splitting for tensor parallelism."""
89
+ ...
90
+
91
+ def slice_lora_b(
92
+ self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
93
+ ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
94
+ """Slice lora b if splitting with tensor parallelism."""
95
+ ...
96
+
97
+ def create_lora_weights(
98
+ self,
99
+ max_loras: int,
100
+ lora_config: LoRAConfig,
101
+ model_config: Optional[PretrainedConfig] = None,
102
+ ) -> None:
103
+ """Initializes lora matrices."""
104
+ ...
105
+
106
+ def reset_lora(self, index: int):
107
+ """Resets the lora weights at index back to 0."""
108
+ ...
109
+
110
+ def set_lora(
111
+ self,
112
+ index: int,
113
+ lora_a: torch.Tensor,
114
+ lora_b: torch.Tensor,
115
+ embeddings_tensor: Optional[torch.Tensor],
116
+ bias: Optional[torch.Tensor] = None,
117
+ ):
118
+ """Overwrites lora tensors at index."""
119
+ ...
120
+
121
+ def set_mapping(
122
+ self,
123
+ punica_wrapper,
124
+ ):
125
+ self.punica_wrapper: PunicaWrapperBase = punica_wrapper
126
+
127
+ @classmethod
128
+ def can_replace_layer(
129
+ cls,
130
+ source_layer: nn.Module,
131
+ lora_config: LoRAConfig,
132
+ packed_modules_list: List,
133
+ model_config: Optional[PretrainedConfig],
134
+ ) -> bool:
135
+ """Returns True if the layer can be replaced by this LoRA layer."""
136
+ raise NotImplementedError
137
+
138
+
139
+ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
140
+
141
+ def __init__(self, base_layer: VocabParallelEmbedding) -> None:
142
+ super().__init__()
143
+ self.base_layer = base_layer
144
+ self.embeddings_slice: Optional[Tuple[int, int]]
145
+ self.embeddings_weights: Optional[torch.Tensor]
146
+
147
+ def create_lora_weights(
148
+ self,
149
+ max_loras: int,
150
+ lora_config: LoRAConfig,
151
+ model_config: Optional[PretrainedConfig] = None) -> None:
152
+
153
+ if self.base_layer.num_added_embeddings_per_partition > 0:
154
+ # We can start adding lora weights
155
+ self.embeddings_weights = self.base_layer.weight.data[
156
+ self.base_layer.num_org_embeddings_per_partition:self.
157
+ base_layer.num_org_embeddings_per_partition +
158
+ self.base_layer.num_added_embeddings_per_partition]
159
+ self.embeddings_slice = (
160
+ self.base_layer.shard_indices.added_vocab_start_index -
161
+ self.base_layer.org_vocab_size,
162
+ self.base_layer.shard_indices.added_vocab_end_index -
163
+ self.base_layer.org_vocab_size)
164
+ self.base_layer.weight.data[
165
+ self.base_layer.num_org_embeddings_per_partition:].fill_(0)
166
+ else:
167
+ self.embeddings_slice = None
168
+ self.embeddings_weights = None
169
+
170
+ self.embeddings_tensors = torch.zeros(
171
+ (
172
+ max_loras,
173
+ lora_config.lora_extra_vocab_size,
174
+ self.base_layer.embedding_dim,
175
+ ),
176
+ dtype=self.base_layer.weight.dtype,
177
+ device=self.base_layer.weight.device,
178
+ )
179
+ self.lora_a_stacked = torch.zeros(
180
+ (
181
+ max_loras,
182
+ self.base_layer.org_vocab_size +
183
+ lora_config.lora_extra_vocab_size,
184
+ lora_config.max_lora_rank,
185
+ ),
186
+ dtype=lora_config.lora_dtype,
187
+ device=self.base_layer.weight.device,
188
+ )
189
+ self.lora_b_stacked = torch.zeros(
190
+ (
191
+ max_loras,
192
+ 1,
193
+ self.base_layer.embedding_dim,
194
+ lora_config.max_lora_rank,
195
+ ),
196
+ dtype=lora_config.lora_dtype,
197
+ device=self.base_layer.weight.device,
198
+ )
199
+ self.lora_a_stacked_2d = self.lora_a_stacked.view(
200
+ self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
201
+ self.lora_a_stacked.shape[2],
202
+ )
203
+
204
+ def reset_lora(self, index: int):
205
+ self.lora_a_stacked[index] = 0
206
+ self.lora_b_stacked[index] = 0
207
+ self.embeddings_tensors[index] = 0
208
+
209
+ def set_lora(
210
+ self,
211
+ index: int,
212
+ lora_a: torch.Tensor,
213
+ lora_b: torch.Tensor,
214
+ embeddings_tensor: Optional[torch.Tensor],
215
+ bias: Optional[torch.Tensor] = None,
216
+ ):
217
+ self.reset_lora(index)
218
+ self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
219
+ lora_a, non_blocking=True)
220
+ self.lora_b_stacked[index,
221
+ 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
222
+ lora_b.T, non_blocking=True)
223
+ if embeddings_tensor is not None:
224
+ self.embeddings_tensors[
225
+ index,
226
+ :embeddings_tensor.shape[0],
227
+ :embeddings_tensor.shape[1],
228
+ ].copy_(embeddings_tensor, non_blocking=True)
229
+ if self.embeddings_slice is not None:
230
+ # TODO(yard1): Optimize this copy, we don't need to copy
231
+ # everything, just the modified part
232
+ embeddings = self.embeddings_tensors.view(
233
+ self.embeddings_tensors.shape[0] *
234
+ self.embeddings_tensors.shape[1],
235
+ self.embeddings_tensors.shape[2],
236
+ )[self.embeddings_slice[0]:self.embeddings_slice[1]]
237
+ assert self.embeddings_weights is not None
238
+ self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ added_tokens_mask = x > self.base_layer.org_vocab_size - 1
242
+ embeddings_indices = self.punica_wrapper.embeddings_indices
243
+ indices = embeddings_indices[1].view_as(x)
244
+ full_lora_a_embeddings = F.embedding(
245
+ x + indices,
246
+ self.lora_a_stacked_2d,
247
+ )
248
+ indices = embeddings_indices[0].view_as(x)
249
+ full_output = self.base_layer.forward(
250
+ x.add_(indices * added_tokens_mask))
251
+
252
+ full_output_org = full_output
253
+ if full_output.ndim == 3:
254
+ full_output = full_output.view(
255
+ full_output.shape[0] * full_output.shape[1], -1)
256
+ if full_lora_a_embeddings.ndim == 3:
257
+ full_lora_a_embeddings = full_lora_a_embeddings.view(
258
+ full_lora_a_embeddings.shape[0] *
259
+ full_lora_a_embeddings.shape[1],
260
+ -1,
261
+ )
262
+ self.punica_wrapper.add_lora_embedding(full_output,
263
+ full_lora_a_embeddings,
264
+ self.lora_b_stacked,
265
+ add_input=True)
266
+ return full_output.view_as(full_output_org)
267
+
268
+ @classmethod
269
+ def can_replace_layer(
270
+ cls,
271
+ source_layer: nn.Module,
272
+ lora_config: LoRAConfig,
273
+ packed_modules_list: List,
274
+ model_config: Optional[PretrainedConfig],
275
+ ) -> bool:
276
+ return type(source_layer) is VocabParallelEmbedding
277
+
278
+
279
+ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
280
+
281
+ def __init__(self, base_layer: LinearBase):
282
+ super().__init__()
283
+ self.base_layer = base_layer
284
+ self.input_size = self.base_layer.input_size
285
+ self.device = _get_lora_device(self.base_layer)
286
+ self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None
287
+
288
+ self.output_slices: Tuple[int, ...]
289
+ self.tp_size: int
290
+ self.output_size: int
291
+ self.n_slices: int
292
+
293
+ def create_lora_weights(
294
+ self,
295
+ max_loras: int,
296
+ lora_config: LoRAConfig,
297
+ model_config: Optional[PretrainedConfig] = None,
298
+ ) -> None:
299
+ self.lora_config = lora_config
300
+ #
301
+ if isinstance(self.base_layer, ReplicatedLinear):
302
+ lora_a_out_size = lora_config.max_lora_rank
303
+ lora_b_out_size = self.output_size
304
+
305
+ elif isinstance(self.base_layer, ColumnParallelLinear):
306
+ lora_a_out_size = (lora_config.max_lora_rank if
307
+ not lora_config.fully_sharded_loras else divide(
308
+ lora_config.max_lora_rank, self.tp_size))
309
+ lora_b_out_size = self.output_size
310
+
311
+ elif isinstance(self.base_layer, RowParallelLinear):
312
+ lora_a_out_size = lora_config.max_lora_rank
313
+ lora_b_out_size = (self.output_size if
314
+ not lora_config.fully_sharded_loras else divide(
315
+ self.output_size, self.tp_size))
316
+ else:
317
+ raise NotImplementedError
318
+
319
+ self.lora_a_stacked = tuple(
320
+ torch.zeros(
321
+ max_loras,
322
+ 1,
323
+ lora_a_out_size,
324
+ self.input_size,
325
+ dtype=lora_config.lora_dtype,
326
+ device=self.device,
327
+ ) for _ in range(self.n_slices))
328
+ self.lora_b_stacked = tuple(
329
+ torch.zeros(
330
+ max_loras,
331
+ 1,
332
+ lora_b_out_size,
333
+ lora_config.max_lora_rank,
334
+ dtype=lora_config.lora_dtype,
335
+ device=self.device,
336
+ ) for _ in range(self.n_slices))
337
+ if lora_config.bias_enabled:
338
+ lora_bias_out_size = lora_b_out_size
339
+ self.lora_bias_stacked = tuple(
340
+ torch.zeros(
341
+ max_loras,
342
+ 1,
343
+ lora_bias_out_size,
344
+ dtype=lora_config.lora_dtype,
345
+ device=self.device,
346
+ ) for _ in range(self.n_slices))
347
+ self.output_slices = (self.lora_b_stacked[0].shape[2], )
348
+
349
+ def reset_lora(self, index: int):
350
+ for s_index in range(self.n_slices):
351
+ self.lora_a_stacked[s_index][index] = 0
352
+ self.lora_b_stacked[s_index][index] = 0
353
+ if self.lora_config.bias_enabled:
354
+ # Make mypy happy
355
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
356
+ self.lora_bias_stacked)
357
+ self.lora_bias_stacked[s_index][index] = 0
358
+
359
+ def set_lora(
360
+ self,
361
+ index: int,
362
+ lora_a: torch.Tensor,
363
+ lora_b: torch.Tensor,
364
+ embeddings_tensor: Optional[torch.Tensor],
365
+ lora_bias: Optional[torch.Tensor] = None,
366
+ ):
367
+ # Except for QKVParallelLinearWithLora and
368
+ # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
369
+ # store weights in a tuple of size 1. These two layers will
370
+ # override this function.
371
+ assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
372
+ self.n_slices == 1)
373
+
374
+ self.reset_lora(index)
375
+ if self.tp_size > 1:
376
+ lora_a = self.slice_lora_a(lora_a)
377
+ lora_b = self.slice_lora_b(lora_b)
378
+ if lora_bias is not None:
379
+ lora_bias = self.slice_bias(lora_bias)
380
+
381
+ self.lora_a_stacked[0][index,
382
+ 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
383
+ lora_a.T, non_blocking=True)
384
+ self.lora_b_stacked[0][index,
385
+ 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
386
+ lora_b.T, non_blocking=True)
387
+ if lora_bias is not None:
388
+
389
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
390
+ self.lora_bias_stacked)
391
+ assert len(self.lora_bias_stacked)
392
+ self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
393
+ lora_bias.T, non_blocking=True)
394
+
395
+ def apply(self,
396
+ x: torch.Tensor,
397
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
398
+ output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
399
+ self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
400
+ self.lora_b_stacked,
401
+ self.lora_bias_stacked, 1.0,
402
+ self.output_slices)
403
+ return output
404
+
405
+
406
+ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
407
+
408
+ def __init__(self, base_layer: ReplicatedLinear) -> None:
409
+ super().__init__(base_layer, )
410
+ # To ensure interface compatibility, set to 1 always.
411
+ self.tp_size = 1
412
+ self.output_size = self.base_layer.output_size
413
+ self.n_slices = 1
414
+
415
+ def forward(
416
+ self, input_: torch.Tensor
417
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
418
+ """Forward of ReplicatedLinearWithLoRA
419
+
420
+ Args:
421
+ input_: Tensor whose last dimension is `input_size`.
422
+
423
+ Returns:
424
+ - output
425
+ - bias
426
+ """
427
+ bias = (self.base_layer.bias
428
+ if not self.base_layer.skip_bias_add else None)
429
+
430
+ # Matrix multiply.
431
+ output = self.apply(input_, bias)
432
+
433
+ output_bias = (self.base_layer.bias
434
+ if self.base_layer.skip_bias_add else None)
435
+ return output, output_bias
436
+
437
+ # ReplicatedLinear should always be replaced, regardless of the fully
438
+ # sharded LoRAs setting, because it is, by definition, copied per GPU.
439
+ @classmethod
440
+ def can_replace_layer(
441
+ cls,
442
+ source_layer: nn.Module,
443
+ lora_config: LoRAConfig,
444
+ packed_modules_list: List,
445
+ model_config: Optional[PretrainedConfig],
446
+ ) -> bool:
447
+ return type(source_layer) is ReplicatedLinear
448
+
449
+
450
+ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
451
+ """
452
+ LoRA on top of ColumnParallelLinear layer.
453
+ LoRA B is sliced for tensor parallelism.
454
+ There are two types for the `base_layer`:
455
+ 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
456
+ 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
457
+ """
458
+
459
+ def __init__(self, base_layer: ColumnParallelLinear) -> None:
460
+ super().__init__(base_layer)
461
+ # The base_layer type is ColumnParallelLinear or
462
+ # MergedColumnParallelLinear, their weight sharding logic is
463
+ # inconsistent when TP is greater than 1.
464
+ self.is_merged_col_linear = type(
465
+ base_layer) is MergedColumnParallelLinear
466
+ self.tp_size = get_tensor_model_parallel_world_size()
467
+ self.output_size = self.base_layer.output_size_per_partition
468
+ # There is only one LoRA layer
469
+ self.n_slices = 1
470
+
471
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
472
+ return lora_a
473
+
474
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
475
+ # Applicable to cases where the base_layer is
476
+ # MergedColumnParallelLinear.
477
+ if self.is_merged_col_linear:
478
+ tp_rank = get_tensor_model_parallel_rank()
479
+ shard_size = self.output_size // 2
480
+ offset = lora_b.shape[-1] // 2
481
+
482
+ left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
483
+ shard_size]
484
+ right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
485
+ (tp_rank + 1) * shard_size]
486
+ lora_b = torch.cat([left_weight, right_weight], dim=1)
487
+ # Applicable to cases where the base_layer is
488
+ # ColumnParallelLinear.
489
+ else:
490
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
491
+ shard_size = self.output_size
492
+ start_idx = tensor_model_parallel_rank * shard_size
493
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
494
+ lora_b = lora_b[:, start_idx:end_idx]
495
+ return lora_b
496
+
497
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
498
+ # TODO: Fix the slicing logic of bias.
499
+ if bias is None:
500
+ return bias
501
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
502
+ shard_size = self.output_size
503
+ start_idx = tensor_model_parallel_rank * shard_size
504
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
505
+ bias = bias[start_idx:end_idx]
506
+ return bias
507
+
508
+ def forward(
509
+ self, input_: torch.Tensor
510
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
511
+ """Forward of ColumnParallelLinear
512
+
513
+ Args:
514
+ input_: Tensor whose last dimension is `input_size`.
515
+
516
+ Returns:
517
+ - output
518
+ - bias
519
+ """
520
+ bias = (self.base_layer.bias
521
+ if not self.base_layer.skip_bias_add else None)
522
+
523
+ # Matrix multiply.
524
+ output_parallel = self.apply(input_, bias)
525
+ if self.base_layer.gather_output:
526
+ # All-gather across the partitions.
527
+ output = tensor_model_parallel_all_gather(output_parallel)
528
+ else:
529
+ output = output_parallel
530
+ output_bias = (self.base_layer.bias
531
+ if self.base_layer.skip_bias_add else None)
532
+ return output, output_bias
533
+
534
+ @classmethod
535
+ @_not_fully_sharded_can_replace
536
+ def can_replace_layer(
537
+ cls,
538
+ source_layer: nn.Module,
539
+ lora_config: LoRAConfig,
540
+ packed_modules_list: List,
541
+ model_config: Optional[PretrainedConfig],
542
+ ) -> bool:
543
+ return type(source_layer) is ColumnParallelLinear or (
544
+ type(source_layer) is MergedColumnParallelLinear
545
+ and len(packed_modules_list) == 1)
546
+
547
+
548
+ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
549
+ """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
550
+ packed together (eg. gate_proj + up_proj -> gate_up_proj).
551
+
552
+ This means we have 2 LoRAs, each applied to one half of the layer.
553
+
554
+ Both slices must have the same size.
555
+ """
556
+
557
+ def __init__(
558
+ self, base_layer: Union[MergedColumnParallelLinear,
559
+ QKVParallelLinear]) -> None:
560
+ super().__init__(base_layer)
561
+ # There are two LoRA layers
562
+ self.tp_size = get_tensor_model_parallel_world_size()
563
+ self.tp_rank = get_tensor_model_parallel_rank()
564
+ # the output_sizes in MergedColumnParallelLinear is not sharded by tp
565
+ # we need to divide it by the tp_size to get correct slices size
566
+ output_sizes = self.base_layer.output_sizes
567
+ self.output_slices = tuple(
568
+ divide(output_size, self.tp_size) for output_size in output_sizes)
569
+ self.n_slices = len(self.output_slices)
570
+ self.output_ids = (self.tp_rank, ) * self.n_slices
571
+
572
+ def create_lora_weights(
573
+ self,
574
+ max_loras: int,
575
+ lora_config: LoRAConfig,
576
+ model_config: Optional[PretrainedConfig] = None,
577
+ ) -> None:
578
+ """
579
+ The main reason for overriding this function is to enhance code
580
+ maintainability.
581
+ """
582
+ self.lora_config = lora_config
583
+
584
+ lora_a_output_size_per_partition = (
585
+ lora_config.max_lora_rank if not lora_config.fully_sharded_loras
586
+ else divide(lora_config.max_lora_rank, self.tp_size))
587
+
588
+ self.lora_a_stacked = tuple(
589
+ torch.zeros(
590
+ max_loras,
591
+ 1,
592
+ lora_a_output_size_per_partition,
593
+ self.input_size,
594
+ dtype=lora_config.lora_dtype,
595
+ device=self.device,
596
+ ) for _ in range(self.n_slices))
597
+ self.lora_b_stacked = tuple(
598
+ torch.zeros(
599
+ max_loras,
600
+ 1,
601
+ output_size,
602
+ lora_config.max_lora_rank,
603
+ dtype=lora_config.lora_dtype,
604
+ device=self.device,
605
+ ) for output_size in self.output_slices)
606
+ if lora_config.bias_enabled:
607
+ self.lora_bias_stacked = tuple(
608
+ torch.zeros(
609
+ max_loras,
610
+ 1,
611
+ output_size,
612
+ dtype=lora_config.lora_dtype,
613
+ device=self.device,
614
+ ) for output_size in self.output_slices)
615
+
616
+ def slice_lora_a(
617
+ self, lora_a: List[Union[torch.Tensor, None]]
618
+ ) -> List[Union[torch.Tensor, None]]:
619
+ return lora_a
620
+
621
+ def slice_lora_b(
622
+ self, lora_b: List[Union[torch.Tensor, None]]
623
+ ) -> List[Union[torch.Tensor, None]]:
624
+ for i, (shard_id, shard_size) in enumerate(
625
+ zip(self.output_ids, self.output_slices)):
626
+ if (lora_b_i := lora_b[i]) is not None:
627
+ lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
628
+ (shard_id + 1)]
629
+ return lora_b
630
+
631
+ def slice_bias(
632
+ self, bias: List[Union[torch.Tensor,
633
+ None]]) -> List[Union[torch.Tensor, None]]:
634
+ for i, (shard_id, shard_size) in enumerate(
635
+ zip(self.output_ids, self.output_slices)):
636
+ if (bias_i := bias[i]) is not None:
637
+ bias[i] = bias_i[shard_size * shard_id:shard_size *
638
+ (shard_id + 1)]
639
+ return bias
640
+
641
+ def set_lora(
642
+ self,
643
+ index: int,
644
+ lora_a: torch.Tensor,
645
+ lora_b: torch.Tensor,
646
+ embeddings_tensor: Optional[torch.Tensor],
647
+ lora_bias: Optional[torch.Tensor] = None,
648
+ ):
649
+ self.reset_lora(index)
650
+
651
+ if self.tp_size > 1:
652
+ lora_a = self.slice_lora_a(lora_a)
653
+ lora_b = self.slice_lora_b(lora_b)
654
+ if lora_bias is not None:
655
+ lora_bias = self.slice_bias(lora_bias)
656
+
657
+ for i in range(self.n_slices):
658
+ if (lora_a_i := lora_a[i]) is not None:
659
+ self.lora_a_stacked[i][
660
+ index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
661
+ lora_a_i.T, non_blocking=True)
662
+ if (lora_b_i := lora_b[i]) is not None:
663
+ self.lora_b_stacked[i][
664
+ index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
665
+ lora_b_i.T, non_blocking=True)
666
+
667
+ if lora_bias is not None:
668
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
669
+ self.lora_bias_stacked)
670
+ for i in range(self.n_slices):
671
+ if (lora_bias_i := lora_bias[i]) is not None:
672
+ self.lora_bias_stacked[i][index,
673
+ 0, :lora_bias_i.shape[0]].copy_(
674
+ lora_bias_i.T,
675
+ non_blocking=True)
676
+
677
+ @classmethod
678
+ @_not_fully_sharded_can_replace
679
+ def can_replace_layer(
680
+ cls,
681
+ source_layer: nn.Module,
682
+ lora_config: LoRAConfig,
683
+ packed_modules_list: List,
684
+ model_config: Optional[PretrainedConfig],
685
+ ) -> bool:
686
+ return (type(source_layer) is MergedColumnParallelLinear
687
+ and len(packed_modules_list) == 2)
688
+
689
+
690
+ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
691
+ """
692
+ ColumnParallelLinear layer that is specifically designed for
693
+ qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
694
+ only contains a single LoRA within their qkv_proj layer.
695
+
696
+ During inference with Tensor Parallel, the weights of lora_b
697
+ must be accurately partitioned according to the respective ranks.
698
+
699
+ Q slice may have different shape than K and V slices (which both have
700
+ the same shape).
701
+ """
702
+
703
+ def __init__(self, base_layer: QKVParallelLinear) -> None:
704
+ super().__init__(base_layer)
705
+ self.q_proj_total_size = (self.base_layer.total_num_heads *
706
+ self.base_layer.head_size)
707
+ self.q_proj_shard_size = (self.base_layer.num_heads *
708
+ self.base_layer.head_size)
709
+ self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
710
+ self.base_layer.head_size)
711
+ self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
712
+ self.base_layer.head_size)
713
+ # There is only one LoRA layer
714
+ self.n_slices = 1
715
+
716
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
717
+ tp_rank = get_tensor_model_parallel_rank()
718
+ self.q_shard_id = tp_rank
719
+ self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
720
+ lora_b_q = lora_b[:, self.q_proj_shard_size *
721
+ self.q_shard_id:self.q_proj_shard_size *
722
+ (self.q_shard_id + 1)]
723
+ k_offset = self.q_proj_total_size
724
+ lora_b_k = lora_b[:, k_offset +
725
+ self.kv_proj_shard_size * self.kv_shard_id:k_offset +
726
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
727
+ v_offset = k_offset + self.kv_proj_total_size
728
+ lora_b_v = lora_b[:, v_offset +
729
+ self.kv_proj_shard_size * self.kv_shard_id:v_offset +
730
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
731
+ lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
732
+ return lora_b
733
+
734
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
735
+ bias_q = bias[self.q_proj_shard_size *
736
+ self.q_shard_id:self.q_proj_shard_size *
737
+ (self.q_shard_id + 1)]
738
+ k_offset = self.q_proj_total_size
739
+ bias_k = bias[k_offset +
740
+ self.kv_proj_shard_size * self.kv_shard_id:k_offset +
741
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
742
+ v_offset = k_offset + self.kv_proj_total_size
743
+ bias_v = bias[v_offset +
744
+ self.kv_proj_shard_size * self.kv_shard_id:v_offset +
745
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
746
+ bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
747
+ return bias
748
+
749
+ @classmethod
750
+ @_not_fully_sharded_can_replace
751
+ def can_replace_layer(cls, source_layer: nn.Module,
752
+ lora_config: LoRAConfig, packed_modules_list: List,
753
+ model_config: Optional[PretrainedConfig]) -> bool:
754
+ return type(source_layer) is QKVParallelLinear and len(
755
+ packed_modules_list) == 1
756
+
757
+
758
+ class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
759
+ """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
760
+ packed together in qkv proj fashion
761
+ (q_proj + k_proj + v_proj -> qkv_proj).
762
+
763
+ This means we have 3 LoRAs, each applied to one slice of the layer.
764
+
765
+ Q slice may have different shape than K and V slices (which both have
766
+ the same shape).
767
+ """
768
+
769
+ def __init__(self, base_layer: QKVParallelLinear) -> None:
770
+ super().__init__(base_layer)
771
+ # There are three LoRA layer.
772
+ self.n_slices = len(self.base_layer.output_sizes)
773
+ self.tp_size = get_tensor_model_parallel_world_size()
774
+ self.tp_rank = get_tensor_model_parallel_rank()
775
+
776
+ self.q_proj_shard_size = (self.base_layer.num_heads *
777
+ self.base_layer.head_size)
778
+ self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
779
+ self.base_layer.head_size)
780
+ self.q_shard_id = self.tp_rank
781
+ self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
782
+
783
+ self.output_slices = (
784
+ self.q_proj_shard_size,
785
+ self.kv_proj_shard_size,
786
+ self.kv_proj_shard_size,
787
+ )
788
+ self.output_ids = (
789
+ self.q_shard_id,
790
+ self.kv_shard_id,
791
+ self.kv_shard_id,
792
+ )
793
+
794
+ def create_lora_weights(
795
+ self,
796
+ max_loras: int,
797
+ lora_config: LoRAConfig,
798
+ model_config: Optional[PretrainedConfig] = None,
799
+ ) -> None:
800
+ """
801
+ The main reason for overloading this function is to handle inconsistent
802
+ weight dimensions in qkv lora.
803
+ """
804
+ super().create_lora_weights(max_loras, lora_config, model_config)
805
+
806
+ @classmethod
807
+ @_not_fully_sharded_can_replace
808
+ def can_replace_layer(
809
+ cls,
810
+ source_layer: nn.Module,
811
+ lora_config: LoRAConfig,
812
+ packed_modules_list: List,
813
+ model_config: Optional[PretrainedConfig],
814
+ ) -> bool:
815
+ return (type(source_layer) is QKVParallelLinear
816
+ and len(packed_modules_list) == 3)
817
+
818
+
819
+ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
820
+
821
+ def __init__(self, base_layer: RowParallelLinear) -> None:
822
+ super().__init__(base_layer)
823
+
824
+ self.tp_size = get_tensor_model_parallel_world_size()
825
+ # reset input_size
826
+ self.input_size = self.base_layer.input_size_per_partition
827
+ self.output_size = self.base_layer.output_size
828
+
829
+ self.tp_rank = get_tensor_model_parallel_rank()
830
+ # There is only one LoRA layer.
831
+ self.n_slices = 1
832
+
833
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
834
+
835
+ shard_size = self.input_size
836
+ start_idx = self.tp_rank * shard_size
837
+ end_idx = (self.tp_rank + 1) * shard_size
838
+ lora_a = lora_a[start_idx:end_idx, :]
839
+ return lora_a
840
+
841
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
842
+ return lora_b
843
+
844
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
845
+ return bias
846
+
847
+ def forward(
848
+ self, input_: torch.Tensor
849
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
850
+ """Forward of RowParallelLinear
851
+
852
+ Args:
853
+ input_: tensor whose last dimension is `input_size`. If
854
+ `input_is_parallel` is set, then the last dimension
855
+ is `input_size // tp_size`.
856
+
857
+ Returns:
858
+ - output
859
+ - bias
860
+ """
861
+ # Set up backprop all-reduce.
862
+ if self.base_layer.input_is_parallel:
863
+ input_parallel = input_
864
+ else:
865
+ # TODO: simplify code below
866
+ splitted_input = split_tensor_along_last_dim(
867
+ input_, num_partitions=self.base_layer.tp_size)
868
+ input_parallel = splitted_input[self.tp_rank].contiguous()
869
+
870
+ # Matrix multiply.
871
+ output_parallel = self.apply(input_parallel)
872
+ if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
873
+ output_ = tensor_model_parallel_all_reduce(output_parallel)
874
+ else:
875
+ output_ = output_parallel
876
+
877
+ if not self.base_layer.skip_bias_add:
878
+ output = (output_ + self.base_layer.bias
879
+ if self.base_layer.bias is not None else output_)
880
+ output_bias = None
881
+ else:
882
+ output = output_
883
+ output_bias = self.base_layer.bias
884
+ return output, output_bias
885
+
886
+ @property
887
+ def weight(self):
888
+ return (self.base_layer.weight if hasattr(self.base_layer, "weight")
889
+ else self.base_layer.qweight)
890
+
891
+ @classmethod
892
+ @_not_fully_sharded_can_replace
893
+ def can_replace_layer(
894
+ cls,
895
+ source_layer: nn.Module,
896
+ lora_config: LoRAConfig,
897
+ packed_modules_list: List,
898
+ model_config: Optional[PretrainedConfig],
899
+ ) -> bool:
900
+ return type(source_layer) is RowParallelLinear
901
+
902
+
903
+ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
904
+ """
905
+ LoRA wrapper for LogitsProcessor, with extra logic to handle the
906
+ application of the LoRA adapter and added LoRA vocabulary.
907
+
908
+ Args:
909
+ base_layer: LogitsProcessor layer
910
+ hidden_size: hidden size of the model
911
+ dtype: data type of the model
912
+ device: device of the model
913
+ sharded_to_full_mapping: index mapping from sharded vocab to full vocab
914
+ received from base_layer.get_sharded_to_full_mapping(). If None,
915
+ no reindexing will be done.
916
+ """
917
+
918
+ def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
919
+ dtype: torch.dtype, device: torch.device,
920
+ sharded_to_full_mapping: Optional[List[int]]) -> None:
921
+ super().__init__()
922
+ self.base_layer = base_layer
923
+ self.hidden_size = hidden_size
924
+ self.dtype = dtype
925
+ self.device = device
926
+ self.tp_size = get_tensor_model_parallel_world_size()
927
+ self.tp_rank = get_tensor_model_parallel_rank()
928
+ self.sharded_to_full_mapping = sharded_to_full_mapping
929
+
930
+ @property
931
+ def logits_as_input(self):
932
+ return self.base_layer.logits_as_input
933
+
934
+ @property
935
+ def vocab_size(self):
936
+ return self.base_layer.vocab_size
937
+
938
+ @property
939
+ def scale(self):
940
+ return self.base_layer.scale
941
+
942
+ @property
943
+ def soft_cap(self):
944
+ return self.base_layer.soft_cap
945
+
946
+ @property
947
+ def use_all_gather(self):
948
+ return self.base_layer.use_all_gather
949
+
950
+ @property
951
+ def org_vocab_size(self):
952
+ return self.base_layer.org_vocab_size
953
+
954
+ @property
955
+ def include_gpu_probs_tensor(self):
956
+ return self.base_layer.include_gpu_probs_tensor
957
+
958
+ @property
959
+ def should_modify_greedy_probs_inplace(self):
960
+ return self.base_layer.should_modify_greedy_probs_inplace
961
+
962
+ def create_lora_weights(
963
+ self,
964
+ max_loras: int,
965
+ lora_config: LoRAConfig,
966
+ model_config: Optional[PretrainedConfig] = None,
967
+ ) -> None:
968
+ # TODO: Verify if this condition can be further relaxed
969
+ if 32000 < self.base_layer.vocab_size > 257024:
970
+ raise ValueError("When using LoRA, vocab size must be "
971
+ "32000 >= vocab_size <= 257024")
972
+ self.lora_a_stacked = torch.zeros(
973
+ (
974
+ max_loras,
975
+ 1,
976
+ lora_config.max_lora_rank,
977
+ self.hidden_size,
978
+ ),
979
+ dtype=lora_config.lora_dtype,
980
+ device=self.device,
981
+ )
982
+ self.lora_b_stacked = torch.zeros(
983
+ (
984
+ max_loras,
985
+ 1,
986
+ # Pad for kernel compatibility
987
+ math.ceil(self.base_layer.vocab_size /
988
+ lora_config.lora_vocab_padding_size) *
989
+ lora_config.lora_vocab_padding_size,
990
+ lora_config.max_lora_rank,
991
+ ),
992
+ dtype=lora_config.lora_dtype,
993
+ device=self.device,
994
+ )
995
+ self.embeddings_tensors = torch.full(
996
+ (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
997
+ fill_value=float("-inf"),
998
+ dtype=self.dtype,
999
+ device=self.device,
1000
+ )
1001
+ if self.sharded_to_full_mapping is not None:
1002
+ self.sharded_to_full_mapping_gpu = torch.tensor(
1003
+ self.sharded_to_full_mapping,
1004
+ device=self.device,
1005
+ dtype=torch.long)
1006
+ else:
1007
+ self.sharded_to_full_mapping_gpu = None
1008
+
1009
+ def reset_lora(self, index: int):
1010
+ self.lora_a_stacked[index] = 0
1011
+ self.lora_b_stacked[index] = 0
1012
+ self.embeddings_tensors[index] = float("-inf")
1013
+
1014
+ def set_lora(
1015
+ self,
1016
+ index: int,
1017
+ lora_a: torch.Tensor,
1018
+ lora_b: torch.Tensor,
1019
+ embeddings_tensor: Optional[torch.Tensor],
1020
+ bias: Optional[torch.Tensor] = None,
1021
+ ):
1022
+ self.reset_lora(index)
1023
+ self.lora_a_stacked[index,
1024
+ 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
1025
+ lora_a.T, non_blocking=True)
1026
+ self.lora_b_stacked[index,
1027
+ 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
1028
+ lora_b.T, non_blocking=True)
1029
+ if embeddings_tensor is not None:
1030
+ self.embeddings_tensors[
1031
+ index,
1032
+ :embeddings_tensor.shape[0],
1033
+ :embeddings_tensor.shape[1],
1034
+ ] = embeddings_tensor
1035
+
1036
+ def _get_logits(
1037
+ self,
1038
+ hidden_states: torch.Tensor,
1039
+ lm_head: VocabParallelEmbedding,
1040
+ embedding_bias: Optional[torch.Tensor] = None,
1041
+ ) -> Optional[torch.Tensor]:
1042
+ # Get the logits for the next tokens.
1043
+ logits = lm_head.linear_method.apply(lm_head, hidden_states)
1044
+ if embedding_bias is not None:
1045
+ logits += embedding_bias
1046
+ logits = tensor_model_parallel_gather(logits)
1047
+ if logits is None:
1048
+ return None
1049
+
1050
+ if self.sharded_to_full_mapping_gpu is not None:
1051
+ # Reindex full logits tensor to ensure 1:1 mapping between
1052
+ # index and token_id
1053
+ # Example for:
1054
+ # org_vocab_size = 4
1055
+ # added_vocab_size = 2
1056
+ # pad_to_size = 8
1057
+ # tp_size = 2
1058
+
1059
+ # indices: [0, 1, 2, 3, 4, 5, 6, 7]
1060
+ # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
1061
+
1062
+ # Therefore, the mapping is expected to be:
1063
+ # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
1064
+ # we get:
1065
+ # indices: [0, 1, 2, 3, 4, 5, 6, 7]
1066
+ # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
1067
+ logits = logits[:, self.sharded_to_full_mapping_gpu]
1068
+
1069
+ lora_logits = torch.empty(
1070
+ self.embeddings_tensors.shape[0] + 1,
1071
+ self.embeddings_tensors.shape[1],
1072
+ hidden_states.shape[0],
1073
+ dtype=self.embeddings_tensors.dtype,
1074
+ device=self.embeddings_tensors.device,
1075
+ )
1076
+ torch.matmul(self.embeddings_tensors,
1077
+ hidden_states.T,
1078
+ out=lora_logits[:-1])
1079
+ lora_logits[-1] = float("-inf")
1080
+ lora_logits = lora_logits.mT
1081
+ indices_padded = self.punica_wrapper.sampler_indices_padded
1082
+ lora_logits = (lora_logits.reshape(
1083
+ lora_logits.shape[0] * lora_logits.shape[1],
1084
+ lora_logits.shape[2],
1085
+ ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
1086
+ posinf=float("inf"),
1087
+ neginf=float("-inf")))
1088
+
1089
+ # HPU needs special handling to prune out dummy samples.
1090
+ if current_platform.is_hpu():
1091
+ lora_logits = lora_logits[:logits.shape[0], :]
1092
+
1093
+ logits[:,
1094
+ self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1095
+ lora_logits.shape[1]] = lora_logits
1096
+
1097
+ # LogitsProcessorWithLoRA always using bgmv
1098
+ self.punica_wrapper.add_lora_logits(logits, hidden_states,
1099
+ self.lora_a_stacked,
1100
+ self.lora_b_stacked, 1.0)
1101
+
1102
+ # Remove paddings in vocab (if any).
1103
+ logits = logits[:, :self.base_layer.vocab_size]
1104
+ return logits
1105
+
1106
+ def forward(self, *args, **kwargs):
1107
+ return type(self.base_layer).forward(self, *args, **kwargs)
1108
+
1109
+ @classmethod
1110
+ def can_replace_layer(
1111
+ cls,
1112
+ source_layer: nn.Module,
1113
+ lora_config: LoRAConfig,
1114
+ packed_modules_list: List,
1115
+ model_config: Optional[PretrainedConfig],
1116
+ ) -> bool:
1117
+ # Special handling for the LogitsProcessor.
1118
+ return False
1119
+
1120
+
1121
+ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
1122
+ """Implements RoPE-scaled embeddings with linear scaling for
1123
+ multiple LoRA adapters with a specialized kernel.
1124
+
1125
+ Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
1126
+ which can handle multi lora adapters in a specialied kernel.
1127
+ """
1128
+
1129
+ def __init__(self, base_layer: RotaryEmbedding) -> None:
1130
+ super().__init__()
1131
+ self.base_layer = base_layer
1132
+
1133
+ @property
1134
+ def scaling_factors(self):
1135
+ return self.base_layer.scaling_factors
1136
+
1137
+ @property
1138
+ def rotary_dim(self):
1139
+ return self.base_layer.rotary_dim
1140
+
1141
+ def create_lora_weights(
1142
+ self,
1143
+ max_loras: int,
1144
+ lora_config: LoRAConfig,
1145
+ model_config: Optional[PretrainedConfig] = None,
1146
+ ) -> None:
1147
+ scaling_factors = (list(lora_config.long_lora_scaling_factors)
1148
+ if lora_config.long_lora_scaling_factors else [])
1149
+ base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
1150
+ self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
1151
+ scaling_factors = sorted(
1152
+ list(set([base_scaling_factor] + scaling_factors)))
1153
+ self.base_layer = LinearScalingRotaryEmbedding(
1154
+ self.base_layer.head_size,
1155
+ self.base_layer.rotary_dim,
1156
+ self.base_layer.max_position_embeddings,
1157
+ self.base_layer.base,
1158
+ self.base_layer.is_neox_style,
1159
+ scaling_factors,
1160
+ self.base_layer.dtype,
1161
+ )
1162
+
1163
+ def reset_lora(self, index: int):
1164
+ ...
1165
+
1166
+ def set_lora(
1167
+ self,
1168
+ index: int,
1169
+ lora_a: torch.Tensor,
1170
+ lora_b: torch.Tensor,
1171
+ embeddings_tensor: Optional[torch.Tensor],
1172
+ bias: Optional[torch.Tensor] = None,
1173
+ ):
1174
+ ...
1175
+
1176
+ def forward(
1177
+ self,
1178
+ positions: torch.Tensor,
1179
+ query: torch.Tensor,
1180
+ key: torch.Tensor,
1181
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1182
+ return self.base_layer(
1183
+ positions,
1184
+ query,
1185
+ key,
1186
+ offsets=self.punica_wrapper.long_lora_indices,
1187
+ )
1188
+
1189
+ @property
1190
+ def scaling_factor_to_offset(self) -> Dict[float, int]:
1191
+ return self.base_layer.scaling_factor_to_offset
1192
+
1193
+ @classmethod
1194
+ def can_replace_layer(
1195
+ cls,
1196
+ source_layer: nn.Module,
1197
+ lora_config: LoRAConfig,
1198
+ packed_modules_list: List,
1199
+ model_config: Optional[PretrainedConfig],
1200
+ ) -> bool:
1201
+ """Returns True if the layer can be replaced by this LoRA layer."""
1202
+ return (type(source_layer) is LinearScalingRotaryEmbedding
1203
+ or type(source_layer) is RotaryEmbedding)
1204
+
1205
+ def extra_repr(self) -> str:
1206
+ return self.base_layer.extra_repr()
.venv/lib/python3.11/site-packages/vllm/lora/lora.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import List, Optional
4
+ from typing import Sequence as GenericSequence
5
+
6
+ import torch
7
+ import torch.types
8
+
9
+ from vllm.lora.peft_helper import PEFTHelper
10
+ from vllm.utils import is_pin_memory_available
11
+
12
+
13
+ class LoRALayerWeights:
14
+ """LoRA weights for a layer composed of two low rank matrixes."""
15
+
16
+ def __init__(
17
+ self,
18
+ module_name: str,
19
+ rank: int,
20
+ lora_alpha: int,
21
+ lora_a: torch.Tensor,
22
+ lora_b: torch.Tensor,
23
+ bias: Optional[torch.Tensor] = None,
24
+ embeddings_tensor: Optional[torch.Tensor] = None,
25
+ scaling: Optional[float] = None,
26
+ ) -> None:
27
+ self.module_name = module_name
28
+ self.rank = rank
29
+ self.lora_alpha = lora_alpha
30
+ self.lora_a = lora_a
31
+ self.lora_b = lora_b
32
+ self.bias = bias
33
+ self.embeddings_tensor = embeddings_tensor
34
+
35
+ if scaling is None:
36
+ self.scaling = self.lora_alpha / self.rank
37
+ else:
38
+ self.scaling = scaling
39
+
40
+ def optimize(self) -> "LoRALayerWeights":
41
+ """Optimize the LoRA by merging the scaling into lora_b."""
42
+ if self.scaling == 1:
43
+ return self
44
+ self.lora_b *= self.scaling
45
+ self.scaling = 1
46
+ return self
47
+
48
+ @property
49
+ def input_dim(self) -> int:
50
+ return self.lora_a.shape[0]
51
+
52
+ @property
53
+ def output_dim(self) -> int:
54
+ return self.lora_b.shape[1]
55
+
56
+ @property
57
+ def is_packed(self) -> bool:
58
+ return False
59
+
60
+ @property
61
+ def extra_vocab_size(self) -> int:
62
+ return self.embeddings_tensor.shape[
63
+ 0] if self.embeddings_tensor is not None else 0
64
+
65
+ @classmethod
66
+ def from_config(
67
+ cls,
68
+ module_name: str,
69
+ peft_helper: PEFTHelper,
70
+ embeddings_tensor: Optional[torch.Tensor] = None,
71
+ ) -> "LoRALayerWeights":
72
+ return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
73
+ None, None, embeddings_tensor,
74
+ peft_helper.vllm_lora_scaling_factor)
75
+
76
+ @classmethod
77
+ def create_dummy_lora_weights(
78
+ cls,
79
+ module_name: str,
80
+ input_dim: int,
81
+ output_dim: int,
82
+ rank: int,
83
+ dtype: torch.dtype,
84
+ device: torch.types.Device,
85
+ embeddings_tensor_dim: Optional[int] = None,
86
+ bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
87
+ pin_memory = str(device) == "cpu" and is_pin_memory_available()
88
+ lora_a = torch.zeros([input_dim, rank],
89
+ dtype=dtype,
90
+ device=device,
91
+ pin_memory=pin_memory)
92
+ lora_b = torch.zeros([rank, output_dim],
93
+ dtype=dtype,
94
+ device=device,
95
+ pin_memory=pin_memory)
96
+ if bias_enabled:
97
+ bias = torch.zeros([output_dim],
98
+ dtype=dtype,
99
+ device=device,
100
+ pin_memory=pin_memory)
101
+ else:
102
+ bias = None
103
+
104
+ embeddings_tensor = torch.rand(
105
+ 10,
106
+ embeddings_tensor_dim,
107
+ dtype=dtype,
108
+ device=device,
109
+ pin_memory=pin_memory) if embeddings_tensor_dim else None
110
+ return cls(
111
+ module_name,
112
+ rank=rank,
113
+ lora_alpha=1,
114
+ lora_a=lora_a,
115
+ lora_b=lora_b,
116
+ bias=bias,
117
+ embeddings_tensor=embeddings_tensor,
118
+ )
119
+
120
+
121
+ class PackedLoRALayerWeights(LoRALayerWeights):
122
+ """LoRA used for packed layers (eg. qkv_proj)."""
123
+
124
+ def __init__(
125
+ self,
126
+ module_name: str,
127
+ rank: int,
128
+ lora_alphas: List[Optional[int]],
129
+ lora_a: List[Optional[torch.Tensor]],
130
+ lora_b: List[Optional[torch.Tensor]],
131
+ bias: Optional[List[Optional[torch.Tensor]]] = None,
132
+ scaling: Optional[List[float]] = None,
133
+ ) -> None:
134
+ super().__init__(
135
+ module_name=module_name,
136
+ rank=rank,
137
+ lora_alpha=0,
138
+ lora_a=lora_a,
139
+ lora_b=lora_b,
140
+ bias=bias,
141
+ scaling=scaling, # type: ignore
142
+ embeddings_tensor=None,
143
+ )
144
+ self.lora_alphas = lora_alphas
145
+ if scaling is None:
146
+ self.scaling = [ # type: ignore
147
+ lora_alpha / self.rank # type: ignore # noqa
148
+ for lora_alpha in self.lora_alphas
149
+ ]
150
+
151
+ @classmethod
152
+ def pack(
153
+ cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
154
+ ) -> "PackedLoRALayerWeights":
155
+ """Pack a list of LoRAs into a single LoRA.
156
+
157
+ If LoRA is None, it signifies that the submodule does not have a LoRA.
158
+ """
159
+ first_lora = next(lora for lora in loras if lora is not None)
160
+ for lora in loras:
161
+ if lora is None:
162
+ continue
163
+ lora.optimize()
164
+ rank = first_lora.rank
165
+ module_name = first_lora.module_name
166
+ obj = cls(
167
+ module_name,
168
+ rank,
169
+ [lora.lora_alpha if lora is not None else None for lora in loras],
170
+ [lora.lora_a if lora is not None else None for lora in loras],
171
+ [lora.lora_b if lora is not None else None for lora in loras],
172
+ [lora.bias if lora is not None else None for lora in loras],
173
+ scaling=[
174
+ 1 if lora is not None else None # type: ignore
175
+ for lora in loras
176
+ ])
177
+ return obj
178
+
179
+ def optimize(self) -> "PackedLoRALayerWeights":
180
+ """Optimize the LoRA by merging the scaling into lora_b."""
181
+ for i in range(len(self.lora_b)):
182
+ if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
183
+ continue
184
+ self.lora_b[i] *= self.scaling[i] # type: ignore
185
+ self.scaling[i] = 1 # type: ignore
186
+ return self
187
+
188
+ @property
189
+ def input_dim(self) -> int:
190
+ raise NotImplementedError()
191
+
192
+ @property
193
+ def output_dim(self) -> int:
194
+ raise NotImplementedError()
195
+
196
+ @property
197
+ def is_packed(self) -> bool:
198
+ return True
.venv/lib/python3.11/site-packages/vllm/lora/models.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import copy
4
+ import math
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
9
+
10
+ import safetensors.torch
11
+ import torch
12
+ from torch import nn
13
+
14
+ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
15
+ AdapterModelManager)
16
+ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
17
+ get_adapter, list_adapters,
18
+ remove_adapter, set_adapter_mapping)
19
+ from vllm.config import LoRAConfig
20
+ from vllm.logger import init_logger
21
+ from vllm.lora.layers import (BaseLayerWithLoRA,
22
+ LinearScalingRotaryEmbeddingWithLora,
23
+ LoRAMapping)
24
+ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
25
+ from vllm.lora.peft_helper import PEFTHelper
26
+ from vllm.lora.punica_wrapper import get_punica_wrapper
27
+ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
28
+ is_regex_target_modules,
29
+ parse_fine_tuned_lora_name, replace_submodule)
30
+ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
31
+ from vllm.model_executor.models.module_mapping import MultiModelKeys
32
+ from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
33
+ from vllm.utils import is_pin_memory_available
34
+
35
+ logger = init_logger(__name__)
36
+
37
+ _GLOBAL_LORA_ID = 0
38
+
39
+
40
+ @dataclass
41
+ class LongContextLoRAContext:
42
+ """Context for lora adapters that support long context."""
43
+ # The scaling factors to support long context lora fine tuned models.
44
+ scaling_factors: List[float]
45
+ # dimension to apply rotary embedding.
46
+ rot_dim: int
47
+ # offsets to the sin_cos_cache for each lora_id loaded.
48
+ # This value is dynamically modified.
49
+ offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
50
+
51
+
52
+ def get_lora_id():
53
+ global _GLOBAL_LORA_ID
54
+ _GLOBAL_LORA_ID += 1
55
+ return _GLOBAL_LORA_ID
56
+
57
+
58
+ class LoRAModel(AdapterModel):
59
+ """A LoRA fine-tuned model."""
60
+
61
+ def __init__(
62
+ self,
63
+ lora_model_id: int,
64
+ rank: int,
65
+ loras: Dict[str, LoRALayerWeights],
66
+ scaling_factor: Optional[float] = None,
67
+ ) -> None:
68
+ """
69
+ Args:
70
+ lora_model_id: The integer id for the lora model.
71
+ rank: lora rank.
72
+ loras: module name -> weights for lora-replaced layers.
73
+ scaling_factor: Scaling factor to support long context lora model.
74
+ None if the lora is not tuned for long context support.
75
+ """
76
+ self.id = lora_model_id
77
+ # Scaling factor for long context lora model. None if it is not
78
+ # fine tuned for the long context.
79
+ self.scaling_factor = scaling_factor
80
+ assert (
81
+ lora_model_id
82
+ > 0), f"a valid lora id should be greater than 0, got {self.id}"
83
+ self.rank = rank
84
+ self.loras: Dict[str, LoRALayerWeights] = loras
85
+
86
+ def clone(self, lora_model_id: int) -> "LoRAModel":
87
+ """Return a copy of the object with different ids.
88
+
89
+ Will share the underlying tensors."""
90
+ return self.__class__(
91
+ lora_model_id,
92
+ rank=self.rank,
93
+ loras=self.loras.copy(),
94
+ )
95
+
96
+ @property
97
+ def extra_vocab_size(self) -> int:
98
+ return max(lora.extra_vocab_size
99
+ for lora in self.loras.values()) if self.loras else 0
100
+
101
+ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
102
+ """Get LoRA for a given module by name"""
103
+ return self.loras.get(module_name, None)
104
+
105
+ # (yard1): TODO see if we can derive target_embedding_padding automatically
106
+ @classmethod
107
+ def from_lora_tensors(
108
+ cls,
109
+ lora_model_id: int,
110
+ tensors: Dict[str, torch.Tensor],
111
+ peft_helper: PEFTHelper,
112
+ device: str = "cuda",
113
+ dtype: Optional[torch.dtype] = None,
114
+ embeddings: Optional[Dict[str, torch.Tensor]] = None,
115
+ target_embedding_padding: Optional[int] = None,
116
+ embedding_modules: Optional[Dict[str, str]] = None,
117
+ embedding_padding_modules: Optional[List[str]] = None,
118
+ weights_mapper: Optional[WeightsMapper] = None,
119
+ ) -> "LoRAModel":
120
+ """Create a LoRAModel from a dictionary of tensors."""
121
+ pin_memory = str(device) == "cpu" and is_pin_memory_available()
122
+ loras: Dict[str, LoRALayerWeights] = {}
123
+ for tensor_name, tensor in tensors.items():
124
+ module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
125
+ tensor_name, weights_mapper)
126
+ if module_name not in loras:
127
+ lora_embeddings_tensor = None
128
+ if embeddings:
129
+ assert embedding_modules is not None
130
+ embeddings_module = next(
131
+ (k for k in embedding_modules if k in module_name),
132
+ None)
133
+ if embeddings_module:
134
+ lora_embeddings_tensor = embeddings[
135
+ embedding_modules[embeddings_module]].to(
136
+ device=device, dtype=dtype)
137
+ if pin_memory:
138
+ lora_embeddings_tensor = (
139
+ lora_embeddings_tensor.pin_memory())
140
+ loras[module_name] = LoRALayerWeights.from_config(
141
+ module_name, peft_helper, lora_embeddings_tensor)
142
+
143
+ if is_bias:
144
+ loras[module_name].bias = tensor.to(device=device,
145
+ dtype=dtype).t()
146
+ bias = tensor.to(device=device, dtype=dtype).t()
147
+ if pin_memory:
148
+ bias = bias.pin_memory()
149
+ loras[module_name].bias = bias
150
+ elif is_lora_a:
151
+ loras[module_name].lora_a = tensor.to(device=device,
152
+ dtype=dtype).t()
153
+ if pin_memory:
154
+ loras[module_name].lora_a = loras[
155
+ module_name].lora_a.pin_memory()
156
+ else:
157
+ loras[module_name].lora_b = tensor.to(device=device,
158
+ dtype=dtype).t()
159
+ assert embedding_padding_modules is not None
160
+ if any(name in module_name
161
+ for name in embedding_padding_modules
162
+ ) and target_embedding_padding is not None:
163
+ lora_b = loras[module_name].lora_b
164
+ assert target_embedding_padding >= lora_b.shape[1]
165
+ addition = target_embedding_padding - lora_b.shape[1]
166
+ loras[module_name].lora_b = torch.nn.functional.pad(
167
+ lora_b, (0, addition))
168
+ if pin_memory:
169
+ loras[module_name].lora_b = loras[
170
+ module_name].lora_b.pin_memory()
171
+
172
+ for lora in loras.values():
173
+ lora.optimize()
174
+
175
+ return cls(lora_model_id,
176
+ peft_helper.r,
177
+ loras,
178
+ scaling_factor=peft_helper.vllm_long_context_scaling_factor)
179
+
180
+ @classmethod
181
+ def from_local_checkpoint(
182
+ cls,
183
+ lora_dir: str,
184
+ expected_lora_modules: List[str],
185
+ peft_helper: PEFTHelper,
186
+ *,
187
+ lora_model_id: Optional[int] = None,
188
+ device: str = "cuda",
189
+ dtype: Optional[torch.dtype] = None,
190
+ target_embedding_padding: Optional[int] = None,
191
+ embedding_modules: Optional[Dict[str, str]] = None,
192
+ embedding_padding_modules: Optional[List[str]] = None,
193
+ weights_mapper: Optional[WeightsMapper] = None,
194
+ ) -> "LoRAModel":
195
+ """Create a LoRAModel from a local checkpoint.
196
+
197
+ Args:
198
+ lora_dir: The local path that has lora data.
199
+ expected_lora_modules: Name of modules that are expected to be
200
+ replaced by lora.
201
+ peft_helper: Loaded lora configuration information.
202
+ lora_model_id: Lora model id. If not given, automatically set by
203
+ a global counter.
204
+ device: Device where the lora model is loaded.
205
+ dtype: dtype of the lora model weights.
206
+
207
+ Returns:
208
+ Loaded LoRA Model.
209
+ """
210
+ lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
211
+ lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
212
+ new_embeddings_tensor_path = os.path.join(
213
+ lora_dir, "new_embeddings.safetensors")
214
+ new_embeddings_bin_file_path = os.path.join(lora_dir,
215
+ "new_embeddings.bin")
216
+
217
+ unexpected_modules: List[Union[list[str], str]]
218
+ if os.path.isfile(lora_tensor_path):
219
+ tensors: Dict[str, torch.Tensor] = {}
220
+ # Find unexpected modules.
221
+ # Use safetensor key as a source of truth to find expected modules.
222
+ # in peft if you have target_modules A, B, C and C does not exist
223
+ # in the model it won’t error and model will be trained with A, B
224
+ # loraified. C won’t exist in the safetensor but it will exist in
225
+ # the target_modules of the adapter_config.json.
226
+ unexpected_modules = []
227
+ with safetensors.safe_open(lora_tensor_path,
228
+ framework="pt") as f: # type: ignore
229
+ for lora_module in f.keys(): # noqa
230
+ module_name, _, _ = parse_fine_tuned_lora_name(
231
+ lora_module, weights_mapper)
232
+ part_name = module_name.split(".")[-1]
233
+ if part_name not in expected_lora_modules:
234
+ unexpected_modules.append(module_name)
235
+ if unexpected_modules:
236
+ raise ValueError(
237
+ f"While loading {lora_dir}, expected"
238
+ f" target modules in {expected_lora_modules}"
239
+ f" but received {unexpected_modules}."
240
+ f" Please verify that the loaded LoRA module is correct"
241
+ )
242
+ # Load tensors if there are only expected modules.
243
+ for module in f.keys(): # noqa
244
+ tensors[module] = f.get_tensor(module)
245
+ elif os.path.isfile(lora_bin_file_path):
246
+ # When a bin file is provided, we rely on config to find unexpected
247
+ # modules.
248
+ unexpected_modules = []
249
+ target_modules = peft_helper.target_modules
250
+ if not isinstance(target_modules, list):
251
+ target_modules = [target_modules]
252
+ for module in target_modules:
253
+ # Compatible with more modules,
254
+ # such as:layers.11.self_attn.k_proj
255
+ part_name = module.split(".")[-1]
256
+ if part_name not in expected_lora_modules:
257
+ unexpected_modules.append(module)
258
+ # loaded lora's target modules must be a subset of
259
+ # expected_lora_modules. It is not reliable. See
260
+ # https://github.com/vllm-project/vllm/pull/5909. But there's no
261
+ # other better mechanism.
262
+ if unexpected_modules and not is_regex_target_modules(
263
+ peft_helper.target_modules, expected_lora_modules):
264
+ raise ValueError(
265
+ f"While loading {lora_dir}, expected"
266
+ f" target modules in {expected_lora_modules}"
267
+ f" but received {unexpected_modules}."
268
+ f" Please verify that the loaded LoRA module is correct")
269
+ tensors = torch.load(lora_bin_file_path, map_location=device)
270
+ else:
271
+ raise ValueError(f"{lora_dir} doesn't contain tensors")
272
+
273
+ embeddings = None
274
+ if os.path.isfile(new_embeddings_tensor_path):
275
+ embeddings = safetensors.torch.load_file(
276
+ new_embeddings_tensor_path)
277
+ elif os.path.isfile(new_embeddings_bin_file_path):
278
+ embeddings = torch.load(new_embeddings_bin_file_path,
279
+ map_location=device,
280
+ weights_only=True)
281
+
282
+ return cls.from_lora_tensors(
283
+ lora_model_id=get_lora_id()
284
+ if lora_model_id is None else lora_model_id,
285
+ tensors=tensors,
286
+ peft_helper=peft_helper,
287
+ device=device,
288
+ dtype=dtype,
289
+ embeddings=embeddings,
290
+ target_embedding_padding=target_embedding_padding,
291
+ embedding_modules=embedding_modules,
292
+ embedding_padding_modules=embedding_padding_modules,
293
+ weights_mapper=weights_mapper)
294
+
295
+
296
+ class LoRAModelManager(AdapterModelManager):
297
+ """A manager that manages multiple LoRA-fine-tuned models."""
298
+
299
+ def __init__(
300
+ self,
301
+ model: SupportsLoRA,
302
+ max_num_seqs: int,
303
+ max_num_batched_tokens: int,
304
+ vocab_size: int,
305
+ lora_config: LoRAConfig,
306
+ device: torch.device,
307
+ ):
308
+ """Create a LoRAModelManager and adapter for a given model.
309
+
310
+ Args:
311
+ model: the model to be adapted.
312
+ max_num_seqs: the maximum number of sequences model can run in a
313
+ single batch.
314
+ max_num_batched_tokens: the maximum number of tokens model can run
315
+ in a single batch.
316
+ vocab_size: the vocab size of the model.
317
+ lora_config: the LoRA configuration.
318
+ """
319
+ self.lora_config = lora_config
320
+ self.device = device
321
+ self.max_num_seqs = max_num_seqs
322
+ assert self.capacity >= self.lora_slots
323
+ self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
324
+ self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
325
+ self.vocab_size = vocab_size
326
+ self.long_lora_context: Optional[LongContextLoRAContext] = None
327
+ self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
328
+ max_batches=self.max_num_seqs,
329
+ device=self.device)
330
+ # Scaling factor -> offset to the sin_cos_cache to it.
331
+ # Used for long context lora.
332
+ self.scaling_factor_to_offset: Dict[float, int] = {}
333
+ super().__init__(model)
334
+ if hasattr(self.model, "supported_lora_modules"):
335
+ self.supported_lora_modules = copy.deepcopy(
336
+ self.model.supported_lora_modules)
337
+ if lora_config.long_lora_scaling_factors:
338
+ # We need to replace rotary emb layer to do batch computation
339
+ # for long lora.
340
+ self.supported_lora_modules.append("rotary_emb")
341
+ self.packed_modules_mapping = copy.deepcopy(
342
+ self.model.packed_modules_mapping)
343
+ # Used to indicate whether the model is a multimodal model
344
+ self.supports_mm: bool = (
345
+ supports_multimodal(self.model)
346
+ # In case the model only supports LoRA for
347
+ # text modules (e.g. ChatGLM)
348
+ and hasattr(self.model, "get_mm_mapping"))
349
+ self.packed_modules: Dict[str, List[str]] = {}
350
+ self.modules: Dict[str, BaseLayerWithLoRA] = {}
351
+ # Dict instead of a Set for compatibility with LRUCache.
352
+ self._last_mapping: Optional[LoRAMapping] = None
353
+ self._create_lora_modules()
354
+ self.model.lora_manager = self
355
+ self.adapter_type = 'LoRa'
356
+
357
+ @property
358
+ def capacity(self) -> int:
359
+ return self.lora_config.max_cpu_loras
360
+
361
+ @property
362
+ def lora_slots(self) -> int:
363
+ return self.lora_config.max_loras
364
+
365
+ @property
366
+ def adapter_slots(self) -> int:
367
+ return self.lora_slots
368
+
369
+ def activate_adapter(
370
+ self,
371
+ lora_id: int,
372
+ ) -> bool:
373
+ """Move LoRA into a GPU buffer to be used in the forward pass."""
374
+ if lora_id in self._active_adapters:
375
+ return False
376
+ first_free_slot = next(
377
+ ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
378
+ if lora_id is None), None)
379
+ if first_free_slot is None:
380
+ raise ValueError("No free lora slots")
381
+ index, _ = first_free_slot
382
+ self._active_adapters[lora_id] = None
383
+ lora_model = self._registered_adapters[lora_id]
384
+ logger.debug("Activating LoRA. int id: %d, slot index: %d",
385
+ lora_model.id, index)
386
+ self.lora_index_to_id[index] = lora_model.id
387
+ for module_name, module in self.modules.items():
388
+ module_lora = lora_model.get_lora(module_name)
389
+ if module_lora:
390
+ module_lora.optimize()
391
+ # Bias is not explicitly enabled with the flag enable_lora_bias.
392
+ bias = module_lora.bias
393
+ if ((torch.is_tensor(bias) or
394
+ (isinstance(bias, Sequence) and any(b is not None
395
+ for b in bias)))
396
+ and not self.lora_config.bias_enabled):
397
+ module_lora.bias = None
398
+ raise ValueError(
399
+ f"Adapter bias cannot be used for {module_name}"
400
+ " without --enable-lora-bias.")
401
+ module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
402
+ module_lora.embeddings_tensor,
403
+ module_lora.bias)
404
+ else:
405
+ module.reset_lora(index)
406
+ return True
407
+
408
+ def _deactivate_adapter(self, lora_id: int):
409
+ try:
410
+ index = self.lora_index_to_id.index(lora_id)
411
+ self.lora_index_to_id[index] = None
412
+ except ValueError:
413
+ pass
414
+
415
+ def _set_long_lora_context(self, lora: LoRAModel):
416
+ if self.long_lora_context is None:
417
+ return
418
+
419
+ if lora.scaling_factor is None:
420
+ return
421
+
422
+ if (lora.scaling_factor not in self.scaling_factor_to_offset):
423
+ raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
424
+ " has not been initialized.")
425
+
426
+ offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
427
+ if offsets:
428
+ self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
429
+
430
+ def _add_adapter(self, lora: LoRAModel):
431
+ self._create_merged_loras_inplace(lora)
432
+ self._registered_adapters[lora.id] = lora
433
+ self._set_long_lora_context(lora)
434
+
435
+ def pin_adapter(self, lora_id: int) -> bool:
436
+ """Pin a LoRAModel in the manager cache."""
437
+ raise NotImplementedError(
438
+ "Pinning is not supported in LoRAModelManager."
439
+ "Use LRUCacheLoRAModelManager for pinning") # type: ignore
440
+
441
+ def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
442
+ # update lora states
443
+ self.punica_wrapper.update_metadata(
444
+ mapping,
445
+ self.lora_index_to_id,
446
+ self.lora_slots + 1,
447
+ self.vocab_size,
448
+ self.lora_config.lora_extra_vocab_size,
449
+ self.long_lora_context,
450
+ )
451
+
452
+ def remove_all_adapters(self):
453
+ """Remove all LoRAModels from the manager."""
454
+ self._registered_adapters.clear()
455
+ self.lora_index_to_id = [None] * self.lora_slots
456
+ self._active_adapters.clear()
457
+
458
+ def _create_lora_modules(self):
459
+ for module_name, module in self.model.named_modules(
460
+ remove_duplicate=False):
461
+ if isinstance(module, PPMissingLayer):
462
+ continue
463
+ if not self._match_target_modules(module_name):
464
+ continue
465
+ # A temporary approach for multimodal models to support LoRA
466
+ # TODO: Remove this restriction
467
+ if self._filter_unsupported_mm_module(module_name):
468
+ logger.warning(
469
+ "Regarding multimodal models, vLLM currently only supports "
470
+ "adding LoRA to language model, %s will be ignored.",
471
+ module_name,
472
+ )
473
+ continue
474
+ parts = module_name.split(".")[-1]
475
+ packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
476
+ new_module = replace_submodule(
477
+ self.model, module_name,
478
+ from_layer(module, self.lora_slots, self.lora_config,
479
+ packed_moduled_lst, self.model.config))
480
+
481
+ # LinearScalingRotaryEmbeddingWithLora is used to handle
482
+ # long context lora. Register relevant metadata.
483
+ if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
484
+ self.long_lora_context = LongContextLoRAContext(
485
+ new_module.scaling_factors, new_module.rotary_dim)
486
+ self.scaling_factor_to_offset = \
487
+ new_module.scaling_factor_to_offset
488
+ # (yard1): TODO make this more robust
489
+ if "lm_head" in module_name:
490
+ logits_processor_module = self.model.get_submodule(
491
+ "logits_processor")
492
+ new_module = replace_submodule(
493
+ self.model, "logits_processor",
494
+ from_layer_logits_processor(logits_processor_module,
495
+ module, self.lora_slots,
496
+ self.lora_config,
497
+ self.model.config))
498
+
499
+ # In some models, especially multimodal ones, layers with the same
500
+ # name may have different types, such as nn.Linear and
501
+ # ReplicatedLinear. The nn.Linear layers cannot be replaced with
502
+ # LoRA layers, leading to assertion error. The following check
503
+ # aims to prevent this error
504
+ if self.supports_mm and not isinstance(new_module,
505
+ BaseLayerWithLoRA):
506
+ continue
507
+ self.register_module(module_name, new_module)
508
+ self._register_packed_modules(module_name)
509
+ # All lora layers share the same punica_wrapper based on reference.
510
+ new_module.set_mapping(self.punica_wrapper)
511
+
512
+ def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
513
+ assert isinstance(module, BaseLayerWithLoRA)
514
+ self.modules[module_name] = module
515
+
516
+ def create_dummy_lora(
517
+ self,
518
+ lora_id: int,
519
+ rank: int,
520
+ scaling_factor: Optional[float],
521
+ embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
522
+ """Create zero-initialized LoRAModel for warmup."""
523
+ model = LoRAModel(lora_id, rank, {}, scaling_factor)
524
+ for module_name, module in self.model.named_modules():
525
+ bias_enabled = self.lora_config.bias_enabled
526
+ if (not self._match_target_modules(module_name)
527
+ or not isinstance(module, BaseLayerWithLoRA)
528
+ or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
529
+ or self._filter_unsupported_mm_module(module_name)):
530
+ continue
531
+ parts = module_name.split(".")
532
+ if module_name not in self.packed_modules:
533
+ assert embedding_modules is not None
534
+ if parts[-1] in embedding_modules:
535
+ input_dim = (module.base_layer.org_vocab_size +
536
+ self.lora_config.lora_extra_vocab_size if
537
+ hasattr(module.base_layer, "org_vocab_size")
538
+ else module.base_layer.weight.shape[1])
539
+ output_dim = module.base_layer.embedding_dim if hasattr(
540
+ module.base_layer,
541
+ "embedding_dim") else module.base_layer.weight.shape[0]
542
+ embeddings_tensor_dim = (module.base_layer.embedding_dim if
543
+ hasattr(module.base_layer,
544
+ "embedding_dim") else
545
+ module.base_layer.weight.shape[1])
546
+ lora = LoRALayerWeights.create_dummy_lora_weights(
547
+ module_name,
548
+ input_dim,
549
+ output_dim,
550
+ rank,
551
+ module.lora_a_stacked[0].dtype,
552
+ "cpu",
553
+ embeddings_tensor_dim=embeddings_tensor_dim,
554
+ bias_enabled=bias_enabled)
555
+ else:
556
+ lora = LoRALayerWeights.create_dummy_lora_weights(
557
+ module_name,
558
+ module.lora_a_stacked[0].shape[-1],
559
+ module.lora_b_stacked[0].shape[-2],
560
+ rank,
561
+ module.lora_a_stacked[0].dtype,
562
+ "cpu",
563
+ bias_enabled=bias_enabled,
564
+ )
565
+ lora.optimize()
566
+ else:
567
+ parts = module_name.split(".")
568
+ replacements = self.packed_modules_mapping[parts[-1]]
569
+ subloras: List[Optional[LoRALayerWeights]] = []
570
+ for i, r in enumerate(replacements):
571
+ lora = LoRALayerWeights.create_dummy_lora_weights(
572
+ module_name + "." + r,
573
+ module.lora_a_stacked[i].shape[-1],
574
+ module.lora_b_stacked[i].shape[-2],
575
+ rank,
576
+ module.lora_a_stacked[i].dtype,
577
+ "cpu",
578
+ bias_enabled=bias_enabled,
579
+ )
580
+ lora.optimize()
581
+ subloras.append(lora)
582
+ lora = PackedLoRALayerWeights.pack(subloras)
583
+ model.loras[module_name] = lora
584
+ return model
585
+
586
+ def _match_target_modules(self, module_name: str):
587
+ return any(
588
+ re.match(
589
+ r".*\.{target_module}$".format(target_module=target_module),
590
+ module_name) or target_module == module_name
591
+ for target_module in self.supported_lora_modules)
592
+
593
+ def _filter_unsupported_mm_module(self, module_name: str) -> bool:
594
+ """
595
+ Regarding multimodal models, vLLM currently only supports adding LoRA to
596
+ language model. LoRA for other modules, such as the vision tower, will
597
+ be filtered out.
598
+ """
599
+ if self.supports_mm:
600
+ module_mapping: MultiModelKeys = self.model.get_mm_mapping()
601
+ prefix_lst = module_mapping.connector + module_mapping.tower_model
602
+ return any(
603
+ [module_name.startswith(prefix) for prefix in prefix_lst])
604
+ return False
605
+
606
+ def _register_packed_modules(self, module_full_name: str) -> None:
607
+ parts = module_full_name.split(".")
608
+ module_name = parts[-1]
609
+ replacements = self.packed_modules_mapping.get(module_name, [])
610
+ # When replacements is less than or equal to 1, it indicates that this
611
+ # module is not a packed module.
612
+ if len(replacements) <= 1:
613
+ return
614
+ prefix = ".".join(parts[:-1])
615
+ self.packed_modules[module_full_name] = [
616
+ prefix + "." + r if prefix else r for r in replacements
617
+ ]
618
+
619
+ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
620
+ for module_name, new_module_names in self.packed_modules.items():
621
+ replacement_loras: List[Optional[LoRALayerWeights]] = []
622
+ has_replacement = False
623
+ for r in new_module_names:
624
+ lora = lora_model.get_lora(r)
625
+ replacement_loras.append(lora)
626
+ if lora:
627
+ has_replacement = True
628
+ if not has_replacement:
629
+ continue
630
+ for i in range(len(replacement_loras)):
631
+ if replacement_loras[i]:
632
+ continue
633
+ replacement_loras[i] = None
634
+ lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
635
+ replacement_loras)
636
+
637
+ def deactivate_adapter(self, adapter_id: int) -> bool:
638
+ return deactivate_adapter(adapter_id, self._active_adapters,
639
+ self._deactivate_adapter)
640
+
641
+ def add_adapter(self, adapter: LoRAModel) -> bool:
642
+ logger.debug(
643
+ "Adding lora. Model id: %d, "
644
+ "int id: %d, "
645
+ "scaling factor: %s", adapter.id, adapter.id,
646
+ adapter.scaling_factor)
647
+ return add_adapter(adapter, self._registered_adapters, self.capacity,
648
+ self._add_adapter)
649
+
650
+ def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
651
+ self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
652
+ self._set_adapter_mapping)
653
+
654
+ def remove_adapter(self, adapter_id: int) -> bool:
655
+ return remove_adapter(adapter_id, self._registered_adapters,
656
+ self.deactivate_adapter)
657
+
658
+ def list_adapters(self) -> Dict[int, Any]:
659
+ return list_adapters(self._registered_adapters)
660
+
661
+ def get_adapter(self, adapter_id: int) -> Optional[Any]:
662
+ return get_adapter(adapter_id, self._registered_adapters)
663
+
664
+
665
+ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
666
+
667
+ def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
668
+ bool]):
669
+ super().__init__(capacity, deactivate_lora_fn)
670
+
671
+
672
+ class LRUCacheLoRAModelManager(LoRAModelManager):
673
+ """A model manager that manages multiple LoRAs with LRU cache."""
674
+
675
+ def __init__(self, model: nn.Module, max_num_seqs: int,
676
+ max_num_batched_tokens: int, vocab_size: int,
677
+ lora_config: LoRAConfig, device: torch.device):
678
+ super().__init__(model, max_num_seqs, max_num_batched_tokens,
679
+ vocab_size, lora_config, device)
680
+ self._registered_adapters: LoRALRUCache = LoRALRUCache(
681
+ self.capacity, self.deactivate_adapter)
682
+ self._active_adapters: LoRALRUCache = LoRALRUCache(
683
+ self.lora_slots, self._deactivate_adapter)
684
+
685
+ def list_adapters(self) -> Dict[int, LoRAModel]:
686
+ """List all registered LoRAModels."""
687
+ return dict(self._registered_adapters.cache)
688
+
689
+ def add_adapter(self, lora: LoRAModel) -> bool:
690
+ """Add a LoRAModel to the manager."""
691
+ logger.debug(
692
+ "Adding lora. Model id: %d, "
693
+ "int id: %d, "
694
+ "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
695
+ if lora.id not in self._registered_adapters:
696
+ self._add_adapter(lora)
697
+ was_added = True
698
+ else:
699
+ # We always touch to update the LRU cache order
700
+ self._registered_adapters.touch(lora.id)
701
+ was_added = False
702
+ return was_added
703
+
704
+ def activate_adapter(
705
+ self,
706
+ lora_id: int,
707
+ ) -> bool:
708
+ if lora_id not in self._active_adapters and len(
709
+ self._active_adapters) >= self.lora_slots:
710
+ self._active_adapters.remove_oldest()
711
+ result = super().activate_adapter(lora_id)
712
+ # We always touch to update the LRU cache order
713
+ self._active_adapters.touch(lora_id)
714
+ return result
715
+
716
+ def remove_oldest_adapter(self) -> bool:
717
+ if len(self._registered_adapters) > 0:
718
+ self._registered_adapters.remove_oldest()
719
+ return True
720
+ return False
721
+
722
+ def pin_adapter(self, lora_id: int) -> bool:
723
+ """Pin a LoRAModel in the manager cache."""
724
+ self._pin_lora_in_cpu_cache(lora_id)
725
+ self._pin_lora_in_gpu_cache(lora_id)
726
+ return True
727
+
728
+ def _pin_lora_in_cpu_cache(self, lora_id: int):
729
+ try:
730
+ self._registered_adapters.pin(lora_id)
731
+ except ValueError as err:
732
+ raise ValueError("Pinning failed. "
733
+ f"LoRA {lora_id} is not registered.") from err
734
+
735
+ def _pin_lora_in_gpu_cache(self, lora_id: int):
736
+ if lora_id not in self._active_adapters:
737
+ # move lora to gpu if not already active
738
+ self.activate_adapter(lora_id)
739
+
740
+ self._active_adapters.pin(lora_id)
741
+
742
+
743
+ def create_lora_manager(
744
+ model: nn.Module,
745
+ max_num_seqs: int,
746
+ max_num_batched_tokens: int,
747
+ vocab_size: int,
748
+ lora_config: LoRAConfig,
749
+ device: torch.device,
750
+ lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
751
+ **kwargs) -> LoRAModelManager:
752
+ """Create a LoRA adapter for a given model."""
753
+ if not hasattr(model, "supported_lora_modules"):
754
+ raise ValueError(f"Model {type(model)} is not supported for LoRA.")
755
+ lora_manager = lora_manager_cls(
756
+ model=model,
757
+ max_num_seqs=max_num_seqs,
758
+ max_num_batched_tokens=max_num_batched_tokens,
759
+ vocab_size=vocab_size,
760
+ lora_config=lora_config,
761
+ device=device,
762
+ **kwargs)
763
+ return lora_manager
.venv/lib/python3.11/site-packages/vllm/lora/ops/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/lora/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (186 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
4
+ from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
5
+ sgmv_expand, sgmv_expand_slice,
6
+ sgmv_shrink)
7
+
8
+ __all__ = [
9
+ "bgmv_expand",
10
+ "bgmv_expand_slice",
11
+ "bgmv_shrink",
12
+ "sgmv_expand",
13
+ "sgmv_expand_slice",
14
+ "sgmv_shrink",
15
+ ]
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (546 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/__pycache__/lora_ops.cpython-311.pyc ADDED
Binary file (5.25 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/ops/torch_ops/lora_ops.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import torch
4
+
5
+
6
+ def sgmv_expand(inputs: torch.Tensor,
7
+ lora_b_weights: torch.Tensor,
8
+ output_tensor: torch.Tensor,
9
+ b_seq_start_loc: torch.Tensor,
10
+ seq_len_tensor: torch.Tensor,
11
+ lora_indices_tensor: torch.Tensor,
12
+ batches: int,
13
+ max_seq_length: int,
14
+ token_nums: int,
15
+ add_inputs: bool = False):
16
+ exploded_indices = torch.repeat_interleave(lora_indices_tensor,
17
+ seq_len_tensor)
18
+
19
+ bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
20
+ add_inputs)
21
+
22
+
23
+ def bgmv_expand(inputs: torch.Tensor,
24
+ lora_b_weights: torch.Tensor,
25
+ output_tensor: torch.Tensor,
26
+ lora_indices_tensor: torch.Tensor,
27
+ add_inputs: bool = True):
28
+ selected_loras = lora_b_weights[lora_indices_tensor].to(
29
+ dtype=output_tensor.dtype)
30
+ if len(selected_loras.shape) == 4:
31
+ selected_loras = selected_loras.squeeze(dim=1)
32
+ inputs = inputs.to(dtype=output_tensor.dtype)
33
+ outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
34
+
35
+ limit = output_tensor.shape[0]
36
+ if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
37
+ limit = 1
38
+
39
+ if add_inputs:
40
+ output_tensor[:, :outputs.shape[1]] += outputs[:limit, :]
41
+ else:
42
+ output_tensor[:, :outputs.shape[1]] = outputs[:limit, :]
43
+
44
+
45
+ def sgmv_shrink(
46
+ inputs: torch.Tensor,
47
+ lora_a_weights: torch.Tensor,
48
+ output_tensor: torch.Tensor,
49
+ b_seq_start_loc: torch.Tensor,
50
+ seq_len_tensor: torch.Tensor,
51
+ lora_indices_tensor: torch.Tensor,
52
+ batches: int,
53
+ max_seq_length: int,
54
+ token_nums: int,
55
+ scaling: float,
56
+ ):
57
+ exploded_indices = torch.repeat_interleave(lora_indices_tensor,
58
+ seq_len_tensor)
59
+
60
+ bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
61
+ scaling)
62
+
63
+
64
+ def bgmv_shrink(inputs: torch.Tensor,
65
+ lora_b_weights: torch.Tensor,
66
+ output_tensor: torch.Tensor,
67
+ lora_indices_tensor: torch.Tensor,
68
+ scaling: float = 1.0):
69
+ selected_loras = lora_b_weights[lora_indices_tensor].to(
70
+ dtype=output_tensor.dtype)
71
+ if len(selected_loras.shape) == 4:
72
+ selected_loras = selected_loras.squeeze(dim=1)
73
+ inputs = inputs.to(dtype=output_tensor.dtype)
74
+ outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
75
+
76
+ output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
77
+
78
+
79
+ def sgmv_expand_slice(inputs: torch.Tensor,
80
+ lora_b_weights: torch.Tensor,
81
+ output_tensor: torch.Tensor,
82
+ b_seq_start_loc: torch.Tensor,
83
+ seq_len_tensor: torch.Tensor,
84
+ lora_indices_tensor: torch.Tensor,
85
+ batches: int,
86
+ max_seq_length: int,
87
+ token_nums: int,
88
+ slice_offset: int,
89
+ slice_size: int,
90
+ add_inputs: bool = False):
91
+ exploded_indices = torch.repeat_interleave(lora_indices_tensor,
92
+ seq_len_tensor)
93
+
94
+ bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
95
+ slice_offset, slice_size, add_inputs)
96
+
97
+
98
+ def bgmv_expand_slice(inputs: torch.Tensor,
99
+ lora_b_weights: torch.Tensor,
100
+ output_tensor: torch.Tensor,
101
+ lora_indices_tensor: torch.Tensor,
102
+ slice_offset: int,
103
+ slice_size: int,
104
+ add_inputs: bool = True):
105
+ selected_loras = lora_b_weights[lora_indices_tensor].to(
106
+ dtype=output_tensor.dtype)
107
+ inputs = inputs.to(dtype=output_tensor.dtype)
108
+ if len(selected_loras.shape) == 4:
109
+ selected_loras = selected_loras.squeeze(dim=1)
110
+ outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
111
+
112
+ if add_inputs:
113
+ output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
114
+ else:
115
+ output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
4
+ from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
5
+ from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
6
+ from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
7
+ from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
8
+
9
+ __all__ = [
10
+ "bgmv_expand",
11
+ "bgmv_expand_slice",
12
+ "bgmv_shrink",
13
+ "sgmv_expand",
14
+ "sgmv_shrink",
15
+ ]
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (708 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand.cpython-311.pyc ADDED
Binary file (7.21 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_expand_slice.cpython-311.pyc ADDED
Binary file (7.45 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/lora/ops/triton_ops/__pycache__/bgmv_shrink.cpython-311.pyc ADDED
Binary file (6.52 kB). View file