koichi12 commited on
Commit
e6a64f6
·
verified ·
1 Parent(s): c27e68a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/custom_op.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/parameter.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/pooling_metadata.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/sampling_metadata.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/utils.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__init__.py +141 -0
  8. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/utils.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/xgrammar_decoding.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/outlines_logits_processors.py +229 -0
  16. .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/utils.py +237 -0
  17. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__init__.py +20 -0
  18. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/__init__.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/loader.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/openvino.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/utils.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/loader.py +1441 -0
  25. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/neuron.py +212 -0
  26. .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/utils.py +162 -0
  27. .venv/lib/python3.11/site-packages/vllm/model_executor/models/arctic.py +582 -0
  28. .venv/lib/python3.11/site-packages/vllm/model_executor/models/bart.py +1000 -0
  29. .venv/lib/python3.11/site-packages/vllm/model_executor/models/bert.py +534 -0
  30. .venv/lib/python3.11/site-packages/vllm/model_executor/models/blip2.py +736 -0
  31. .venv/lib/python3.11/site-packages/vllm/model_executor/models/bloom.py +385 -0
  32. .venv/lib/python3.11/site-packages/vllm/model_executor/models/chameleon.py +1161 -0
  33. .venv/lib/python3.11/site-packages/vllm/model_executor/models/chatglm.py +801 -0
  34. .venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek.py +503 -0
  35. .venv/lib/python3.11/site-packages/vllm/model_executor/models/eagle.py +214 -0
  36. .venv/lib/python3.11/site-packages/vllm/model_executor/models/falcon.py +529 -0
  37. .venv/lib/python3.11/site-packages/vllm/model_executor/models/florence2.py +266 -0
  38. .venv/lib/python3.11/site-packages/vllm/model_executor/models/fuyu.py +399 -0
  39. .venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma.py +458 -0
  40. .venv/lib/python3.11/site-packages/vllm/model_executor/models/glm4_vision_encoder.py +312 -0
  41. .venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt2.py +339 -0
  42. .venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_bigcode.py +359 -0
  43. .venv/lib/python3.11/site-packages/vllm/model_executor/models/granitemoe.py +461 -0
  44. .venv/lib/python3.11/site-packages/vllm/model_executor/models/h2ovl.py +553 -0
  45. .venv/lib/python3.11/site-packages/vllm/model_executor/models/idefics3.py +713 -0
  46. .venv/lib/python3.11/site-packages/vllm/model_executor/models/internlm2.py +495 -0
  47. .venv/lib/python3.11/site-packages/vllm/model_executor/models/internvl.py +962 -0
  48. .venv/lib/python3.11/site-packages/vllm/model_executor/models/jamba.py +632 -0
  49. .venv/lib/python3.11/site-packages/vllm/model_executor/models/llama.py +601 -0
  50. .venv/lib/python3.11/site-packages/vllm/model_executor/models/llava.py +845 -0
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (633 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/custom_op.cpython-311.pyc ADDED
Binary file (6.88 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/parameter.cpython-311.pyc ADDED
Binary file (21.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/pooling_metadata.cpython-311.pyc ADDED
Binary file (3.29 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/sampling_metadata.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.32 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__init__.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from vllm.logger import init_logger
8
+ from vllm.model_executor.guided_decoding.utils import (
9
+ convert_lark_to_gbnf, grammar_is_likely_lark,
10
+ has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
11
+ from vllm.platforms import CpuArchEnum
12
+
13
+ if TYPE_CHECKING:
14
+ from transformers import PreTrainedTokenizer
15
+
16
+ from vllm.config import ModelConfig
17
+ from vllm.logits_process import LogitsProcessor
18
+ from vllm.sampling_params import GuidedDecodingParams
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ def maybe_backend_fallback(
24
+ guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
25
+ # lm-format-enforce doesn't support grammar, fallback to xgrammar
26
+ if guided_params.backend == "lm-format-enforcer":
27
+ if guided_params.grammar is not None:
28
+ logger.warning(
29
+ "lm-format-enforcer does not support grammar guided decoding. "
30
+ "Falling back to use xgrammar instead.")
31
+ guided_params.backend = "xgrammar"
32
+
33
+ # lm-format-enforcer doesn't support some JSON schema features
34
+ elif (guided_params.json is not None
35
+ and has_lmf_unsupported_json_features(guided_params.json)):
36
+ logger.warning(
37
+ "lm-format-enforcer does not support advanced JSON schema "
38
+ "features like patterns or numeric ranges. "
39
+ "Falling back to use outlines instead.")
40
+ guided_params.backend = "outlines"
41
+
42
+ if guided_params.backend == "xgrammar":
43
+ # xgrammar only has x86 wheels for linux, fallback to outlines
44
+ from vllm.platforms import current_platform
45
+ if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
46
+ logger.warning("xgrammar is only supported on x86 CPUs. "
47
+ "Falling back to use outlines instead.")
48
+ guided_params.backend = "outlines"
49
+
50
+ # xgrammar doesn't support regex or choice, fallback to outlines
51
+ if guided_params.regex is not None or guided_params.choice is not None:
52
+ logger.warning(
53
+ "xgrammar only supports json or grammar guided decoding. "
54
+ "Falling back to use outlines instead.")
55
+ guided_params.backend = "outlines"
56
+
57
+ # xgrammar doesn't support some JSON schema features
58
+ elif (guided_params.json is not None
59
+ and has_xgrammar_unsupported_json_features(guided_params.json)):
60
+ logger.warning(
61
+ "xgrammar does not support advanced JSON schema features like "
62
+ "patterns or numeric ranges. "
63
+ "Falling back to use outlines instead.")
64
+ guided_params.backend = "outlines"
65
+
66
+ # xgrammar only supports GBNF grammars, so we must convert Lark.
67
+ # We must check if the grammar is likely Lark and if that
68
+ # grammar is convertible to GBNF
69
+ elif (guided_params.grammar is not None
70
+ and grammar_is_likely_lark(guided_params.grammar)):
71
+ try:
72
+ convert_lark_to_gbnf(guided_params.grammar)
73
+ except Exception:
74
+ logger.warning(
75
+ "xgrammar does not support Lark grammars and the "
76
+ "grammar failed to convert to GBNF. "
77
+ "Falling back to use outlines instead.")
78
+ guided_params.backend = "outlines"
79
+
80
+ if (guided_params.backend == "outlines"
81
+ and guided_params.json_object is not None):
82
+ # outlines doesn't support json_object, fallback to xgrammar
83
+ logger.warning("outlines does not support json_object. "
84
+ "Falling back to use xgrammar instead.")
85
+ guided_params.backend = "xgrammar"
86
+
87
+ return guided_params
88
+
89
+
90
+ async def get_guided_decoding_logits_processor(
91
+ guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
92
+ model_config: ModelConfig) -> LogitsProcessor | None:
93
+ guided_params = maybe_backend_fallback(guided_params)
94
+ # CFG grammar not supported by LMFE, so we use outlines instead
95
+ if guided_params.backend == 'outlines':
96
+ # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
97
+ from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
98
+ get_outlines_guided_decoding_logits_processor)
99
+ return await get_outlines_guided_decoding_logits_processor(
100
+ guided_params, tokenizer)
101
+ if guided_params.backend == 'lm-format-enforcer':
102
+ from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
103
+ get_local_lm_format_enforcer_guided_decoding_logits_processor)
104
+ return get_local_lm_format_enforcer_guided_decoding_logits_processor(
105
+ guided_params, tokenizer)
106
+ if guided_params.backend == 'xgrammar':
107
+ from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
108
+ get_local_xgrammar_guided_decoding_logits_processor)
109
+ return get_local_xgrammar_guided_decoding_logits_processor(
110
+ guided_params, tokenizer, model_config)
111
+
112
+ raise ValueError(
113
+ f"Unknown guided decoding backend '{guided_params.backend}'. "
114
+ "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
115
+
116
+
117
+ def get_local_guided_decoding_logits_processor(
118
+ guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
119
+ model_config: ModelConfig) -> LogitsProcessor | None:
120
+ guided_params = maybe_backend_fallback(guided_params)
121
+ # CFG grammar not supported by LMFE, so we use outlines instead
122
+ if guided_params.backend == 'outlines':
123
+ # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
124
+ from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
125
+ get_local_outlines_guided_decoding_logits_processor)
126
+ return get_local_outlines_guided_decoding_logits_processor(
127
+ guided_params, tokenizer)
128
+ if guided_params.backend == 'lm-format-enforcer':
129
+ from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
130
+ get_local_lm_format_enforcer_guided_decoding_logits_processor)
131
+ return get_local_lm_format_enforcer_guided_decoding_logits_processor(
132
+ guided_params, tokenizer)
133
+ if guided_params.backend == 'xgrammar':
134
+ from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
135
+ get_local_xgrammar_guided_decoding_logits_processor)
136
+ return get_local_xgrammar_guided_decoding_logits_processor(
137
+ guided_params, tokenizer, model_config)
138
+
139
+ raise ValueError(
140
+ f"Unknown guided decoding backend '{guided_params.backend}'. "
141
+ "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (5.97 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-311.pyc ADDED
Binary file (2.58 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-311.pyc ADDED
Binary file (3.57 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-311.pyc ADDED
Binary file (5.59 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/utils.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/xgrammar_decoding.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/outlines_logits_processors.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Copyright 2024- the Outlines developers
4
+ # This file is adapted from
5
+ # https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ import copy
19
+ import json
20
+ from collections import defaultdict
21
+ from functools import lru_cache
22
+ from typing import Callable, DefaultDict, Dict, List, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ from outlines import grammars
27
+ from outlines.caching import cache
28
+ from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
29
+ RegexGuide, Write)
30
+ from outlines.fsm.parsing import PartialLark
31
+ from outlines_core.fsm.json_schema import build_regex_from_schema
32
+ from pydantic import BaseModel
33
+ from transformers import PreTrainedTokenizerBase
34
+
35
+
36
+ class BaseLogitsProcessor:
37
+
38
+ def __init__(self, guide: Guide):
39
+ self._guide: Guide = guide
40
+ # CFGState is used for the FSM state for CFGGuide
41
+ self._fsm_state: DefaultDict[int, Union[int,
42
+ CFGState]] = defaultdict(int)
43
+
44
+ def __call__(self, input_ids: List[int],
45
+ scores: torch.Tensor) -> torch.Tensor:
46
+ """Use the FSM to bias the logits before sampling the next token."""
47
+ seq_id = hash(tuple(input_ids))
48
+
49
+ if len(input_ids) > 0:
50
+ last_token = input_ids[-1]
51
+ last_seq_id = hash(tuple(input_ids[:-1]))
52
+ self._fsm_state[seq_id] = self._guide.get_next_state(
53
+ state=self._fsm_state[last_seq_id], token_id=last_token)
54
+ else:
55
+ # Note: this is a hack.
56
+ # Lark pickling does not work properly (silent failure),
57
+ # which breaks the RPC (which uses python pickleing).
58
+ # We need to find a better solution.
59
+ # On the first time this is called, we simply re-create
60
+ # the Lark object.
61
+ if isinstance(self._guide, CFGGuide):
62
+ self._guide.parser = PartialLark(
63
+ self._guide.cfg_string,
64
+ parser="lalr",
65
+ import_paths=[grammars.GRAMMAR_PATH],
66
+ )
67
+ self._fsm_state[seq_id] = CFGState(
68
+ parser_state=self._guide.parser.parse(""), prev_token=None)
69
+
70
+ instruction = self._guide.get_next_instruction(
71
+ state=self._fsm_state[seq_id])
72
+
73
+ if type(instruction) == Generate: # noqa: E721
74
+ allowed_tokens = instruction.tokens
75
+ elif type(instruction) == Write: # noqa: E721
76
+ # TODO: support fast forward tokens
77
+ allowed_tokens = [instruction.tokens[0]]
78
+ else:
79
+ raise TypeError(
80
+ f"Unsupported instruction type {type(instruction)}")
81
+
82
+ mask = torch.full((scores.shape[-1], ),
83
+ -torch.inf,
84
+ device=scores.device)
85
+ # The tokenizer may support more token ids than the model can generate,
86
+ # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
87
+ # but scores.shape == torch.Size([128256])
88
+ # Using NumPy is faster for filtering token ids
89
+ allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
90
+ allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
91
+ allowed_tokens = allowed_tokens.masked_select(
92
+ allowed_tokens < scores.shape[-1])
93
+ mask.index_fill_(0, allowed_tokens, 0)
94
+ scores.add_(mask)
95
+ return scores
96
+
97
+
98
+ class RegexLogitsProcessor(BaseLogitsProcessor):
99
+
100
+ @classmethod
101
+ @cache()
102
+ def _get_guide(cls, regex_string: str,
103
+ tokenizer: PreTrainedTokenizerBase) -> Guide:
104
+ tokenizer = _adapt_tokenizer(tokenizer)
105
+ return RegexGuide.from_regex(regex_string, tokenizer)
106
+
107
+ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
108
+ """Compile the FSM that drives the regex-structured generation.
109
+
110
+ Parameters
111
+ ----------
112
+ regex_string
113
+ A string that represents a regular expression
114
+ tokenizer
115
+ The model's tokenizer
116
+
117
+ """
118
+ super().__init__(
119
+ RegexLogitsProcessor._get_guide(regex_string, tokenizer))
120
+
121
+
122
+ class JSONLogitsProcessor(RegexLogitsProcessor):
123
+
124
+ def __init__(self, schema: Union[str, Dict, BaseModel],
125
+ tokenizer: PreTrainedTokenizerBase,
126
+ whitespace_pattern: Union[str, None]):
127
+ """Compile the FSM that drives the JSON-guided generation.
128
+
129
+ Parameters
130
+ ----------
131
+ schema
132
+ A JSON schema that encodes the structure we want the model to
133
+ generate
134
+ tokenizer
135
+ The model's tokenizer
136
+ whitespace_pattern
137
+ Pattern to use for JSON syntactic whitespace (doesn't impact
138
+ string literals)
139
+ Example: allow only a single space or newline with
140
+ `whitespace_pattern=r"[\n ]?"`
141
+ """
142
+ if isinstance(schema, type(BaseModel)):
143
+ schema_str = json.dumps(schema.model_json_schema())
144
+ elif isinstance(schema, Dict):
145
+ schema_str = json.dumps(schema)
146
+ elif isinstance(schema, str):
147
+ schema_str = schema
148
+ else:
149
+ raise ValueError(
150
+ f"Cannot parse schema {schema}. The schema must be either "
151
+ f"a Pydantic object, a dictionary or a string that contains "
152
+ f"the JSON Schema specification")
153
+ regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
154
+ super().__init__(regex_string, tokenizer)
155
+
156
+
157
+ class CFGLogitsProcessor(BaseLogitsProcessor):
158
+
159
+ @classmethod
160
+ @cache()
161
+ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
162
+ tokenizer = _adapt_tokenizer(tokenizer)
163
+ return CFGGuide(cfg, tokenizer)
164
+
165
+ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
166
+ """Compile the FSM that drives the context free grammar generation.
167
+
168
+ Parameters
169
+ ----------
170
+ cfg
171
+ A string that represents a context-free grammar
172
+ tokenizer
173
+ The model's tokenizer
174
+
175
+ """
176
+ super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
177
+ self._guide = self._guide.copy()
178
+
179
+
180
+ @lru_cache(maxsize=32)
181
+ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
182
+ """Adapt vLLM's tokenizer to use to compile the FSM.
183
+
184
+ The API of Outlines tokenizers is slightly different to that of
185
+ `transformers`. The decoder of outlines, returns a list whereas
186
+ the decode of vLLM returns an str. To sync the vLLM decoder with
187
+ outlines internal api, the decoder should be adapted. In addition
188
+ we need to handle the missing spaces to Llama's tokenizer to be
189
+ able to compile FSMs for this model.
190
+
191
+ """
192
+ if getattr(tokenizer, "_outlines_adapted", False):
193
+ return tokenizer
194
+
195
+ tokenizer = copy.deepcopy(tokenizer)
196
+
197
+ tokenizer.vocabulary = tokenizer.get_vocab()
198
+ tokenizer.special_tokens = set(tokenizer.all_special_tokens)
199
+
200
+ def convert_token_to_string(token: str) -> str:
201
+ from transformers.file_utils import SPIECE_UNDERLINE
202
+
203
+ string = tokenizer.convert_tokens_to_string([token])
204
+
205
+ # A hack to handle missing spaces to HF's Llama tokenizers
206
+ if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
207
+ or token == "<0x20>"):
208
+ return " " + string
209
+
210
+ return string
211
+
212
+ def change_decoder(
213
+ decoder: Callable[[List[int]],
214
+ str]) -> Callable[[List[int]], List[str]]:
215
+ """Sync vLLM's decoder with the outlines by returning list."""
216
+
217
+ def new_decoder(inp_tokens: List[int]) -> List[str]:
218
+ if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
219
+ and isinstance(inp_tokens[0], list)):
220
+ inp_tokens = inp_tokens[0]
221
+ return [decoder(inp_tokens)]
222
+
223
+ return new_decoder
224
+
225
+ tokenizer.convert_token_to_string = convert_token_to_string
226
+ tokenizer.decode = change_decoder(tokenizer.decode)
227
+ setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
228
+
229
+ return tokenizer
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import re
4
+
5
+
6
+ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
7
+ """Check if JSON schema contains features unsupported by xgrammar."""
8
+
9
+ def check_object(obj: dict) -> bool:
10
+ if not isinstance(obj, dict):
11
+ return False
12
+
13
+ # Check for pattern restrictions
14
+ if "pattern" in obj:
15
+ return True
16
+
17
+ # Check for numeric ranges
18
+ if obj.get("type") in ("integer", "number") and any(
19
+ key in obj for key in [
20
+ "minimum", "maximum", "exclusiveMinimum",
21
+ "exclusiveMaximum", "multipleOf"
22
+ ]):
23
+ return True
24
+
25
+ # Check for array unsupported keywords
26
+ if obj.get("type") == "array" and any(key in obj for key in [
27
+ "uniqueItems", "contains", "minContains", "maxContains",
28
+ "minItems", "maxItems"
29
+ ]):
30
+ return True
31
+
32
+ # Recursively check all nested objects and arrays
33
+ for value in obj.values():
34
+ if isinstance(value, dict):
35
+ if check_object(value):
36
+ return True
37
+ elif isinstance(value, list):
38
+ for item in value:
39
+ if isinstance(item, dict) and check_object(item):
40
+ return True
41
+
42
+ return False
43
+
44
+ return check_object(schema)
45
+
46
+
47
+ def has_lmf_unsupported_json_features(schema: dict) -> bool:
48
+ """
49
+ Check if JSON schema contains features unsupported
50
+ by lm_format_enforcer.
51
+
52
+ Known issues:
53
+ - Regex patterns:
54
+ "grade": {
55
+ "type": "string",
56
+ "pattern": "^[A-D]$" # Regex pattern
57
+ },
58
+ """
59
+
60
+ def check_object(obj: dict) -> bool:
61
+ if not isinstance(obj, dict):
62
+ return False
63
+
64
+ # Check for pattern restrictions
65
+ if "pattern" in obj:
66
+ return True
67
+
68
+ # Recursively check all nested objects and arrays
69
+ for value in obj.values():
70
+ if isinstance(value, dict):
71
+ if check_object(value):
72
+ return True
73
+ elif isinstance(value, list):
74
+ for item in value:
75
+ if isinstance(item, dict) and check_object(item):
76
+ return True
77
+
78
+ return False
79
+
80
+ return check_object(schema)
81
+
82
+
83
+ def grammar_is_likely_lark(grammar_str: str) -> bool:
84
+ """
85
+ Check if grammar appears to use Lark syntax.
86
+
87
+ Args:
88
+ grammar_str: Input grammar string
89
+
90
+ Returns:
91
+ bool: True if grammar appears to be in Lark format, False otherwise
92
+
93
+ Examples:
94
+ >>> grammar_is_likely_lark("rule: 'abc'")
95
+ True
96
+ >>> grammar_is_likely_lark("rule ::= 'abc'")
97
+ False
98
+ """
99
+ if not grammar_str or not isinstance(grammar_str, str):
100
+ return False
101
+
102
+ for line in grammar_str.split('\n'):
103
+ # Remove both comment styles
104
+ line = re.sub(r'(#|//).*$', '', line).strip()
105
+ if not line:
106
+ continue
107
+
108
+ # Look for GBNF rule definition
109
+ if '::=' in line:
110
+ return False
111
+
112
+ return True
113
+
114
+
115
+ def convert_lark_to_gbnf(grammar_str: str) -> str:
116
+ """
117
+ Convert a Lark grammar string to GBNF format.
118
+
119
+ GBNF reference:
120
+ https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
121
+ Lark grammar reference:
122
+ https://lark-parser.readthedocs.io/en/latest/grammar.html
123
+
124
+ Args:
125
+ grammar_str: Input grammar in Lark format
126
+
127
+ Returns:
128
+ str: Converted grammar in GBNF format
129
+
130
+ Examples:
131
+ >>> print(convert_lark_to_gbnf("rule: 'hello'"))
132
+ root ::= rule
133
+ rule ::= "hello"
134
+ """
135
+ if not isinstance(grammar_str, str):
136
+ raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
137
+ if not grammar_str.strip():
138
+ raise ValueError("Grammar string cannot be empty")
139
+
140
+ defined_rules = set()
141
+ referenced_rules = set()
142
+ output_lines = []
143
+
144
+ def clean_line(line: str) -> str:
145
+ """Remove comments and whitespace from line."""
146
+ return re.sub(r'(#|//).*$', '', line).strip()
147
+
148
+ def check_quotes(text: str, rule_name: str, line_num: int) -> None:
149
+ """Validate quote matching in text."""
150
+ if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
151
+ raise ValueError(
152
+ f"Mismatched quotes in {rule_name} on line {line_num}")
153
+
154
+ def extract_references(text: str) -> set:
155
+ """Extract rule references from text."""
156
+ # Remove quoted strings and special characters
157
+ text = re.sub(r'"[^"]*"', '', text)
158
+ text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
159
+ return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
160
+
161
+ # First pass: Find root rule and validate rule definitions
162
+ lines = [clean_line(line) for line in grammar_str.split('\n')]
163
+ first_rule = None
164
+
165
+ for line_num, line in enumerate(lines, 1):
166
+ if not line or line.startswith('|'):
167
+ continue
168
+
169
+ if ':' in line:
170
+ try:
171
+ name = line.split(':', 1)[0].strip().strip('?')
172
+ defined_rules.add(name)
173
+ if first_rule is None:
174
+ first_rule = name
175
+ if name == 'start':
176
+ first_rule = 'start'
177
+ except IndexError as e:
178
+ raise ValueError(f"Invalid rule format on line {line_num}. "
179
+ "Expected 'rule_name: definition'") from e
180
+
181
+ if not defined_rules:
182
+ raise ValueError("No valid rules found in grammar")
183
+
184
+ # Add root rule
185
+ output_lines.append(f"root ::= {first_rule}")
186
+
187
+ # Second pass: Process rule definitions and alternatives
188
+ current_rule = None
189
+ current_definition = []
190
+
191
+ for line_num, line in enumerate(lines, 1):
192
+ if not line:
193
+ continue
194
+
195
+ try:
196
+ if ':' in line and not line.startswith('|'):
197
+ # Save previous rule if exists
198
+ if current_rule:
199
+ output_lines.append(
200
+ f"{current_rule} ::= {' | '.join(current_definition)}")
201
+
202
+ # Process new rule
203
+ name, definition = line.split(':', 1)
204
+ current_rule = name.strip().strip('?')
205
+
206
+ check_quotes(definition, f"rule '{current_rule}'", line_num)
207
+ definition = re.sub(r"'([^']*)'", r'"\1"', definition)
208
+ referenced_rules.update(extract_references(definition))
209
+ current_definition = [definition.strip()]
210
+
211
+ elif line.startswith('|'):
212
+ if not current_rule:
213
+ raise ValueError(f"Alternative '|' on line {line_num} "
214
+ "without a preceding rule definition")
215
+
216
+ alt_def = line[1:].strip()
217
+ check_quotes(alt_def, f"alternative for rule '{current_rule}'",
218
+ line_num)
219
+ alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
220
+ referenced_rules.update(extract_references(alt_def))
221
+ current_definition.append(alt_def)
222
+
223
+ except ValueError as e:
224
+ raise ValueError(f"Error on line {line_num}: {str(e)}") from e
225
+
226
+ # Add final rule if exists
227
+ if current_rule:
228
+ output_lines.append(
229
+ f"{current_rule} ::= {' | '.join(current_definition)}")
230
+
231
+ # Validate all rules are defined
232
+ undefined_rules = referenced_rules - defined_rules - {'root'}
233
+ if undefined_rules:
234
+ raise ValueError("Referenced rules are not defined: "
235
+ f"{', '.join(sorted(undefined_rules))}")
236
+
237
+ return '\n'.join(output_lines)
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from torch import nn
4
+
5
+ from vllm.config import VllmConfig
6
+ from vllm.model_executor.model_loader.loader import (BaseModelLoader,
7
+ get_model_loader)
8
+ from vllm.model_executor.model_loader.utils import (
9
+ get_architecture_class_name, get_model_architecture)
10
+
11
+
12
+ def get_model(*, vllm_config: VllmConfig) -> nn.Module:
13
+ loader = get_model_loader(vllm_config.load_config)
14
+ return loader.load_model(vllm_config=vllm_config)
15
+
16
+
17
+ __all__ = [
18
+ "get_model", "get_model_loader", "BaseModelLoader",
19
+ "get_architecture_class_name", "get_model_architecture"
20
+ ]
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.04 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/loader.cpython-311.pyc ADDED
Binary file (73.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/openvino.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-311.pyc ADDED
Binary file (24.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/utils.cpython-311.pyc ADDED
Binary file (8.49 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-311.pyc ADDED
Binary file (35.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/loader.py ADDED
@@ -0,0 +1,1441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # ruff: noqa: SIM117
4
+ import collections
5
+ import copy
6
+ import dataclasses
7
+ import fnmatch
8
+ import glob
9
+ import inspect
10
+ import itertools
11
+ import math
12
+ import os
13
+ import warnings
14
+ from abc import ABC, abstractmethod
15
+ from contextlib import contextmanager
16
+ from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
17
+ Tuple, cast)
18
+
19
+ import gguf
20
+ import huggingface_hub
21
+ import numpy as np
22
+ import torch
23
+ from huggingface_hub import HfApi
24
+ from torch import nn
25
+ from transformers import AutoModelForCausalLM
26
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
27
+
28
+ from vllm.attention import Attention
29
+ from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
30
+ VllmConfig, set_current_vllm_config)
31
+ from vllm.distributed import (get_tensor_model_parallel_rank,
32
+ get_tensor_model_parallel_world_size)
33
+ from vllm.envs import VLLM_USE_MODELSCOPE
34
+ from vllm.logger import init_logger
35
+ from vllm.model_executor.layers.linear import (LinearBase,
36
+ MergedColumnParallelLinear,
37
+ QKVParallelLinear,
38
+ ReplicatedLinear,
39
+ RowParallelLinear)
40
+ from vllm.model_executor.layers.quantization.base_config import (
41
+ QuantizeMethodBase)
42
+ from vllm.model_executor.model_loader.tensorizer import (
43
+ TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
44
+ serialize_vllm_model, tensorizer_weights_iterator)
45
+ from vllm.model_executor.model_loader.utils import (ParamMapping,
46
+ configure_quant_config,
47
+ get_model_architecture,
48
+ set_default_torch_dtype)
49
+ from vllm.model_executor.model_loader.weight_utils import (
50
+ download_safetensors_index_file_from_hf, download_weights_from_hf,
51
+ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
52
+ get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
53
+ initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
54
+ runai_safetensors_weights_iterator, safetensors_weights_iterator)
55
+ from vllm.model_executor.utils import set_weight_attrs
56
+ from vllm.platforms import current_platform
57
+ from vllm.transformers_utils.s3_utils import glob as s3_glob
58
+ from vllm.transformers_utils.utils import is_s3
59
+ from vllm.utils import is_pin_memory_available
60
+
61
+
62
+ @contextmanager
63
+ def device_loading_context(module: torch.nn.Module,
64
+ target_device: torch.device):
65
+ if target_device.type == "cpu":
66
+ # If target is CPU, no need to move anything
67
+ yield module
68
+ return
69
+
70
+ original_device_states: Dict[str, torch.device] = {}
71
+
72
+ # Store original device states and move parameters to GPU if they're on CPU
73
+ for name, p in module.named_parameters():
74
+ if p.device.type == "cpu":
75
+ original_device_states[name] = p.device
76
+ p.data = p.data.to(target_device)
77
+ # Parameters already on target device are not touched
78
+
79
+ try:
80
+ yield module
81
+
82
+ finally:
83
+ # Restore parameters to their original devices, ignoring new parameters
84
+ pin_memory = is_pin_memory_available()
85
+ for name, p in module.named_parameters():
86
+ if name in original_device_states:
87
+ original_device: torch.device = original_device_states[name]
88
+ if original_device.type == "cpu":
89
+ # `torch.empty_like` does not support `pin_memory` argument
90
+ cpu_data = torch.empty_strided(
91
+ size=p.data.size(),
92
+ stride=p.data.stride(),
93
+ dtype=p.data.dtype,
94
+ layout=p.data.layout,
95
+ device="cpu",
96
+ pin_memory=pin_memory,
97
+ )
98
+ cpu_data.copy_(p.data)
99
+ p.data = cpu_data
100
+ else:
101
+ p.data = p.data.to(original_device)
102
+ # New parameters or parameters already on target device are untouched
103
+
104
+
105
+ logger = init_logger(__name__)
106
+
107
+
108
+ def _initialize_model(
109
+ vllm_config: VllmConfig,
110
+ *,
111
+ prefix: str = "",
112
+ ) -> nn.Module:
113
+ """Initialize a model with the given configurations."""
114
+ model_config = vllm_config.model_config
115
+ model_class, _ = get_model_architecture(model_config)
116
+
117
+ if vllm_config.quant_config is not None:
118
+ configure_quant_config(vllm_config.quant_config, model_class)
119
+
120
+ signatures = inspect.signature(model_class.__init__)
121
+ all_params = [param.name for param in signatures.parameters.values()]
122
+ if "vllm_config" in all_params and "prefix" in all_params:
123
+ # new-style model class
124
+ with set_current_vllm_config(vllm_config, check_compile=True):
125
+ return model_class(vllm_config=vllm_config, prefix=prefix)
126
+
127
+ msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
128
+ "input arguments. Possibly you have an old-style model class"
129
+ " registered from out of tree and it is used for new vLLM version. "
130
+ "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
131
+ "for the design and update the model class accordingly.")
132
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
133
+
134
+ logger.warning(
135
+ "Trying to guess the arguments for old-style model class %s",
136
+ model_class,
137
+ )
138
+ # try to be compatible with old-style model class
139
+ kwargs = {}
140
+ if "prefix" in all_params:
141
+ kwargs["prefix"] = prefix
142
+ if "config" in all_params:
143
+ kwargs["config"] = model_config.hf_config
144
+ if "cache_config" in all_params:
145
+ kwargs["cache_config"] = vllm_config.cache_config
146
+ if "quant_config" in all_params:
147
+ kwargs["quant_config"] = vllm_config.quant_config
148
+ if "lora_config" in all_params:
149
+ kwargs["lora_config"] = vllm_config.lora_config
150
+ if "scheduler_config" in all_params:
151
+ kwargs["scheduler_config"] = vllm_config.scheduler_config
152
+ with set_current_vllm_config(vllm_config, check_compile=True):
153
+ return model_class(**kwargs)
154
+
155
+
156
+ class BaseModelLoader(ABC):
157
+ """Base class for model loaders."""
158
+
159
+ def __init__(self, load_config: LoadConfig):
160
+ self.load_config = load_config
161
+
162
+ @abstractmethod
163
+ def download_model(self, model_config: ModelConfig) -> None:
164
+ """Download a model so that it can be immediately loaded."""
165
+ raise NotImplementedError
166
+
167
+ @abstractmethod
168
+ def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
169
+ """Load a model with the given configurations."""
170
+ raise NotImplementedError
171
+
172
+
173
+ class DefaultModelLoader(BaseModelLoader):
174
+ """Model loader that can load different file types from disk."""
175
+
176
+ @dataclasses.dataclass
177
+ class Source:
178
+ """A source for weights."""
179
+
180
+ model_or_path: str
181
+ """The model ID or path."""
182
+
183
+ revision: Optional[str]
184
+ """The optional model revision."""
185
+
186
+ prefix: str = ""
187
+ """A prefix to prepend to all weights."""
188
+
189
+ fall_back_to_pt: bool = True
190
+ """Whether .pt weights can be used."""
191
+
192
+ allow_patterns_overrides: Optional[list[str]] = None
193
+ """If defined, weights will load exclusively using these patterns."""
194
+
195
+ def __init__(self, load_config: LoadConfig):
196
+ super().__init__(load_config)
197
+ if load_config.model_loader_extra_config:
198
+ raise ValueError(f"Model loader extra config is not supported for "
199
+ f"load format {load_config.load_format}")
200
+
201
+ def _maybe_download_from_modelscope(
202
+ self, model: str, revision: Optional[str]) -> Optional[str]:
203
+ """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
204
+
205
+ Returns the path to the downloaded model, or None if the model is not
206
+ downloaded from ModelScope."""
207
+ if VLLM_USE_MODELSCOPE:
208
+ # download model from ModelScope hub,
209
+ # lazy import so that modelscope is not required for normal use.
210
+ # pylint: disable=C.
211
+ from modelscope.hub.snapshot_download import snapshot_download
212
+
213
+ if not os.path.exists(model):
214
+ model_path = snapshot_download(
215
+ model_id=model,
216
+ cache_dir=self.load_config.download_dir,
217
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
218
+ revision=revision,
219
+ ignore_file_pattern=self.load_config.ignore_patterns,
220
+ )
221
+ else:
222
+ model_path = model
223
+ return model_path
224
+ return None
225
+
226
+ def _prepare_weights(
227
+ self,
228
+ model_name_or_path: str,
229
+ revision: Optional[str],
230
+ fall_back_to_pt: bool,
231
+ allow_patterns_overrides: Optional[list[str]],
232
+ ) -> Tuple[str, List[str], bool]:
233
+ """Prepare weights for the model.
234
+
235
+ If the model is not local, it will be downloaded."""
236
+ model_name_or_path = (self._maybe_download_from_modelscope(
237
+ model_name_or_path, revision) or model_name_or_path)
238
+
239
+ is_local = os.path.isdir(model_name_or_path)
240
+ load_format = self.load_config.load_format
241
+ use_safetensors = False
242
+ index_file = SAFE_WEIGHTS_INDEX_NAME
243
+ # Some quantized models use .pt files for storing the weights.
244
+ if load_format == LoadFormat.AUTO:
245
+ allow_patterns = ["*.safetensors", "*.bin"]
246
+ elif load_format == LoadFormat.SAFETENSORS:
247
+ use_safetensors = True
248
+ allow_patterns = ["*.safetensors"]
249
+ elif load_format == LoadFormat.MISTRAL:
250
+ use_safetensors = True
251
+ allow_patterns = ["consolidated*.safetensors"]
252
+ index_file = "consolidated.safetensors.index.json"
253
+ elif load_format == LoadFormat.PT:
254
+ allow_patterns = ["*.pt"]
255
+ elif load_format == LoadFormat.NPCACHE:
256
+ allow_patterns = ["*.bin"]
257
+ else:
258
+ raise ValueError(f"Unknown load_format: {load_format}")
259
+
260
+ if fall_back_to_pt:
261
+ allow_patterns += ["*.pt"]
262
+
263
+ if allow_patterns_overrides is not None:
264
+ allow_patterns = allow_patterns_overrides
265
+
266
+ if not is_local:
267
+ hf_folder = download_weights_from_hf(
268
+ model_name_or_path,
269
+ self.load_config.download_dir,
270
+ allow_patterns,
271
+ revision,
272
+ ignore_patterns=self.load_config.ignore_patterns,
273
+ )
274
+ else:
275
+ hf_folder = model_name_or_path
276
+
277
+ hf_weights_files: List[str] = []
278
+ for pattern in allow_patterns:
279
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
280
+ if len(hf_weights_files) > 0:
281
+ if pattern == "*.safetensors":
282
+ use_safetensors = True
283
+ break
284
+
285
+ if use_safetensors:
286
+ # For models like Mistral-7B-Instruct-v0.3
287
+ # there are both sharded safetensors files and a consolidated
288
+ # safetensors file. Using both breaks.
289
+ # Here, we download the `model.safetensors.index.json` and filter
290
+ # any files not found in the index.
291
+ if not is_local:
292
+ download_safetensors_index_file_from_hf(
293
+ model_name_or_path,
294
+ index_file,
295
+ self.load_config.download_dir,
296
+ revision,
297
+ )
298
+ hf_weights_files = filter_duplicate_safetensors_files(
299
+ hf_weights_files, hf_folder, index_file)
300
+ else:
301
+ hf_weights_files = filter_files_not_needed_for_inference(
302
+ hf_weights_files)
303
+
304
+ if len(hf_weights_files) == 0:
305
+ raise RuntimeError(
306
+ f"Cannot find any model weights with `{model_name_or_path}`")
307
+
308
+ return hf_folder, hf_weights_files, use_safetensors
309
+
310
+ def _get_weights_iterator(
311
+ self, source: "Source"
312
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
313
+ """Get an iterator for the model weights based on the load format."""
314
+ hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
315
+ source.model_or_path, source.revision, source.fall_back_to_pt,
316
+ source.allow_patterns_overrides)
317
+ if self.load_config.load_format == LoadFormat.NPCACHE:
318
+ # Currently np_cache only support *.bin checkpoints
319
+ assert use_safetensors is False
320
+ weights_iterator = np_cache_weights_iterator(
321
+ source.model_or_path,
322
+ self.load_config.download_dir,
323
+ hf_folder,
324
+ hf_weights_files,
325
+ )
326
+ elif use_safetensors:
327
+ weights_iterator = safetensors_weights_iterator(hf_weights_files)
328
+ else:
329
+ weights_iterator = pt_weights_iterator(hf_weights_files)
330
+
331
+ if current_platform.is_tpu():
332
+ # In PyTorch XLA, we should call `xm.mark_step` frequently so that
333
+ # not too many ops are accumulated in the XLA program.
334
+ import torch_xla.core.xla_model as xm
335
+
336
+ def _xla_weights_iterator(iterator: Generator):
337
+ for weights in iterator:
338
+ yield weights
339
+ xm.mark_step()
340
+
341
+ weights_iterator = _xla_weights_iterator(weights_iterator)
342
+
343
+ # Apply the prefix.
344
+ return ((source.prefix + name, tensor)
345
+ for (name, tensor) in weights_iterator)
346
+
347
+ def _get_all_weights(
348
+ self,
349
+ model_config: ModelConfig,
350
+ model: nn.Module,
351
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
352
+ primary_weights = DefaultModelLoader.Source(
353
+ model_config.model,
354
+ model_config.revision,
355
+ prefix="",
356
+ fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
357
+ True),
358
+ allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
359
+ None),
360
+ )
361
+ yield from self._get_weights_iterator(primary_weights)
362
+
363
+ secondary_weights = cast(
364
+ Iterable[DefaultModelLoader.Source],
365
+ getattr(model, "secondary_weights", ()),
366
+ )
367
+ for source in secondary_weights:
368
+ yield from self._get_weights_iterator(source)
369
+
370
+ def download_model(self, model_config: ModelConfig) -> None:
371
+ self._prepare_weights(model_config.model,
372
+ model_config.revision,
373
+ fall_back_to_pt=True,
374
+ allow_patterns_overrides=None)
375
+
376
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
377
+ device_config = vllm_config.device_config
378
+ model_config = vllm_config.model_config
379
+
380
+ target_device = torch.device(device_config.device)
381
+ with set_default_torch_dtype(model_config.dtype):
382
+ with target_device:
383
+ model = _initialize_model(vllm_config=vllm_config)
384
+
385
+ weights_to_load = {name for name, _ in model.named_parameters()}
386
+ loaded_weights = model.load_weights(
387
+ self._get_all_weights(model_config, model))
388
+ # We only enable strict check for non-quantized models
389
+ # that have loaded weights tracking currently.
390
+ if model_config.quantization is None and loaded_weights is not None:
391
+ weights_not_loaded = weights_to_load - loaded_weights
392
+ if weights_not_loaded:
393
+ raise ValueError(
394
+ "Following weights were not initialized from "
395
+ f"checkpoint: {weights_not_loaded}")
396
+
397
+ for _, module in model.named_modules():
398
+ quant_method = getattr(module, "quant_method", None)
399
+ if isinstance(quant_method, QuantizeMethodBase):
400
+ # When quant methods need to process weights after loading
401
+ # (for repacking, quantizing, etc), they expect parameters
402
+ # to be on the global target device. This scope is for the
403
+ # case where cpu offloading is used, where we will move the
404
+ # parameters onto device for processing and back off after.
405
+ with device_loading_context(module, target_device):
406
+ quant_method.process_weights_after_loading(module)
407
+ if isinstance(module, Attention) and \
408
+ hasattr(module, "process_weights_after_loading"):
409
+ # When attention modules need to process weights after
410
+ # currently only used by MLA
411
+ # TODO(lucas): see if there is a way to unify the signatures
412
+ # of process_weights_after_loading
413
+ module.process_weights_after_loading(model_config.dtype)
414
+ return model.eval()
415
+
416
+
417
+ class DummyModelLoader(BaseModelLoader):
418
+ """Model loader that will set model weights to random values."""
419
+
420
+ def __init__(self, load_config: LoadConfig):
421
+ super().__init__(load_config)
422
+ if load_config.model_loader_extra_config:
423
+ raise ValueError(f"Model loader extra config is not supported for "
424
+ f"load format {load_config.load_format}")
425
+
426
+ def download_model(self, model_config: ModelConfig) -> None:
427
+ pass # Nothing to download
428
+
429
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
430
+ device_config = vllm_config.device_config
431
+ model_config = vllm_config.model_config
432
+ with set_default_torch_dtype(model_config.dtype):
433
+ with torch.device(device_config.device):
434
+ model = _initialize_model(vllm_config=vllm_config)
435
+ # NOTE(woosuk): For accurate performance evaluation, we assign
436
+ # random values to the weights.
437
+ initialize_dummy_weights(model)
438
+
439
+ for _, module in model.named_modules():
440
+ quant_method = getattr(module, "quant_method", None)
441
+ if quant_method is not None:
442
+ # When quant methods need to process weights after loading
443
+ # (for repacking, quantizing, etc), they expect parameters
444
+ # to be on the global target device. This scope is for the
445
+ # case where cpu offloading is used, where we will move the
446
+ # parameters onto device for processing and back off after.
447
+ with device_loading_context(
448
+ module, torch.device(device_config.device)):
449
+ quant_method.process_weights_after_loading(module)
450
+ if isinstance(module, Attention) and \
451
+ hasattr(module, "process_weights_after_loading"):
452
+ # When attention modules need to process weights after
453
+ # currently only used by MLA
454
+ module.process_weights_after_loading(model_config.dtype)
455
+ return model.eval()
456
+
457
+
458
+ class TensorizerLoader(BaseModelLoader):
459
+ """Model loader using CoreWeave's tensorizer library."""
460
+
461
+ def __init__(self, load_config: LoadConfig):
462
+ super().__init__(load_config)
463
+ if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
464
+ self.tensorizer_config = load_config.model_loader_extra_config
465
+ else:
466
+ self.tensorizer_config = TensorizerConfig(
467
+ **load_config.model_loader_extra_config)
468
+
469
+ def _verify_config(self, model_config: ModelConfig,
470
+ parallel_config: ParallelConfig):
471
+ self.tensorizer_config.verify_with_model_config(model_config)
472
+ self.tensorizer_config.verify_with_parallel_config(parallel_config)
473
+
474
+ def _get_weights_iterator(
475
+ self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
476
+ tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
477
+ return tensorizer_weights_iterator(tensorizer_args)
478
+
479
+ def _load_model_serialized_cpu(
480
+ self,
481
+ vllm_config: VllmConfig,
482
+ ) -> nn.Module:
483
+ """Load a serialized model with tensorizer to the CPU.
484
+
485
+ This is only necessary when the model isn't vLLM-tensorized (see
486
+ examples/other/tensorize_vllm_model.py) This should still
487
+ be faster than default HuggingFace loading, but will be slower than
488
+ loading a vLLM-tensorized model.
489
+ """
490
+ device_config = vllm_config.device_config
491
+ model_config = vllm_config.model_config
492
+ with set_default_torch_dtype(model_config.dtype):
493
+ with torch.device(device_config.device):
494
+ model = _initialize_model(vllm_config=vllm_config)
495
+
496
+ model.load_weights(self._get_weights_iterator())
497
+ return model.eval()
498
+
499
+ def _load_model_serialized(
500
+ self,
501
+ vllm_config: VllmConfig,
502
+ ) -> nn.Module:
503
+ """Load a serialized model with tensorizer.
504
+
505
+ Expects a vLLM-tensorized model. See the
506
+ examples/other/tensorize_vllm_model.py example script
507
+ for serializing vLLM models."""
508
+
509
+ device_config = vllm_config.device_config
510
+ model_config = vllm_config.model_config
511
+
512
+ with set_default_torch_dtype(model_config.dtype):
513
+ with torch.device(device_config.device):
514
+ model_class = get_model_architecture(model_config)[0]
515
+
516
+ tensorizer_config = copy.copy(self.tensorizer_config)
517
+ tensorizer_config.model_class = model_class
518
+ tensorizer_config.hf_config = model_config.hf_config
519
+ tensorizer_config.dtype = model_config.dtype
520
+
521
+ model = load_with_tensorizer(tensorizer_config,
522
+ vllm_config=vllm_config)
523
+ return model.eval()
524
+
525
+ def download_model(self, model_config: ModelConfig) -> None:
526
+ self.tensorizer_config.verify_with_model_config(model_config)
527
+
528
+ with self.tensorizer_config.open_stream():
529
+ pass
530
+
531
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
532
+ model_config = vllm_config.model_config
533
+ parallel_config = vllm_config.parallel_config
534
+ self._verify_config(model_config, parallel_config)
535
+
536
+ if parallel_config.tensor_parallel_size > 1:
537
+ from vllm.distributed import get_tensor_model_parallel_rank
538
+
539
+ self.tensorizer_config.tensorizer_uri = (
540
+ self.tensorizer_config.tensorizer_uri %
541
+ get_tensor_model_parallel_rank())
542
+
543
+ if is_vllm_tensorized(self.tensorizer_config):
544
+ return self._load_model_serialized(vllm_config=vllm_config)
545
+ return self._load_model_serialized_cpu(vllm_config=vllm_config)
546
+
547
+ @staticmethod
548
+ def save_model(
549
+ model: torch.nn.Module,
550
+ tensorizer_config: TensorizerConfig,
551
+ ) -> None:
552
+ serialize_vllm_model(
553
+ model=model,
554
+ tensorizer_config=tensorizer_config,
555
+ )
556
+
557
+
558
+ class ShardedStateLoader(BaseModelLoader):
559
+ """
560
+ Model loader that directly loads each worker's model state dict, which
561
+ enables a fast load path for large tensor-parallel models where each worker
562
+ only needs to read its own shard rather than the entire checkpoint. See
563
+ `examples/offline_inference/save_sharded_state.py` for creating a sharded
564
+ checkpoint.
565
+ """
566
+
567
+ DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
568
+
569
+ def __init__(self, load_config: LoadConfig):
570
+ super().__init__(load_config)
571
+ extra_config = ({} if load_config.model_loader_extra_config is None
572
+ else load_config.model_loader_extra_config.copy())
573
+ self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
574
+ if extra_config:
575
+ raise ValueError(f"Unexpected extra config keys for load format "
576
+ f"{load_config.load_format}: "
577
+ f"{load_config.model_loader_extra_config.keys()}")
578
+
579
+ @staticmethod
580
+ def _filter_subtensors(
581
+ tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
582
+ """
583
+ Filter out all tensors that share the same memory or a subset of the
584
+ memory of another tensor.
585
+ """
586
+ same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
587
+ collections.defaultdict(list))
588
+ for key, tensor in tensors.items():
589
+ if tensor.numel():
590
+ ptr = tensor.untyped_storage().data_ptr()
591
+ same_storage_groups[tensor.device, ptr].append((key, tensor))
592
+
593
+ def get_end_ptr(tensor: torch.Tensor) -> int:
594
+ return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
595
+
596
+ result: Dict[str, torch.Tensor] = {}
597
+ for group in same_storage_groups.values():
598
+ for k, t in group:
599
+ a, b = t.data_ptr(), get_end_ptr(t)
600
+ for k2, t2 in group:
601
+ if not t2.is_contiguous():
602
+ continue
603
+ a2, b2 = t2.data_ptr(), get_end_ptr(t2)
604
+ if a < a2 or b2 < b:
605
+ continue
606
+ if a2 < a or b < b2 or not t.is_contiguous():
607
+ break # t2 covers strictly more memory than t.
608
+ if k2 < k:
609
+ # Same tensors, keep the one with the smaller key.
610
+ break
611
+ else:
612
+ result[k] = t
613
+ return result
614
+
615
+ def _prepare_weights(self, model_name_or_path: str,
616
+ revision: Optional[str]):
617
+ if os.path.isdir(model_name_or_path):
618
+ return model_name_or_path
619
+ else:
620
+ allow_patterns = ["*.safetensors"]
621
+ return download_weights_from_hf(
622
+ model_name_or_path,
623
+ self.load_config.download_dir,
624
+ allow_patterns,
625
+ revision,
626
+ ignore_patterns=self.load_config.ignore_patterns,
627
+ )
628
+
629
+ def download_model(self, model_config: ModelConfig) -> None:
630
+ self._prepare_weights(model_config.model, model_config.revision)
631
+
632
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
633
+ device_config = vllm_config.device_config
634
+ model_config = vllm_config.model_config
635
+ from safetensors.torch import safe_open
636
+
637
+ from vllm.distributed import get_tensor_model_parallel_rank
638
+
639
+ local_model_path = self._prepare_weights(model_config.model,
640
+ model_config.revision)
641
+
642
+ with set_default_torch_dtype(model_config.dtype):
643
+ with torch.device(device_config.device):
644
+ model = _initialize_model(vllm_config=vllm_config)
645
+ for _, module in model.named_modules():
646
+ quant_method = getattr(module, "quant_method", None)
647
+ if quant_method is not None:
648
+ quant_method.process_weights_after_loading(module)
649
+ if isinstance(module, Attention) and \
650
+ hasattr(module, "process_weights_after_loading"):
651
+ # When attention modules need to process weights after
652
+ # currently only used by MLA
653
+ module.process_weights_after_loading(
654
+ model_config.dtype)
655
+ rank = get_tensor_model_parallel_rank()
656
+ pattern = os.path.join(
657
+ local_model_path,
658
+ self.pattern.format(rank=rank, part="*"),
659
+ )
660
+ filepaths = glob.glob(pattern)
661
+ if not filepaths:
662
+ # TODO: support un-sharded checkpoints too
663
+ raise ValueError(
664
+ f"Could not find checkpoint files '{pattern}', only "
665
+ f"pre-sharded checkpoints are currently supported!")
666
+ state_dict = self._filter_subtensors(model.state_dict())
667
+ for path in filepaths:
668
+ with safe_open(path, framework="pt") as f:
669
+ for key in f.keys(): # noqa: SIM118
670
+ tensor = f.get_tensor(key)
671
+ # If loading with LoRA enabled, additional padding may
672
+ # be added to certain parameters. We only load into a
673
+ # narrowed view of the parameter data.
674
+ param_data = state_dict[key].data
675
+ param_shape = state_dict[key].shape
676
+ for dim, size in enumerate(tensor.shape):
677
+ if size < param_shape[dim]:
678
+ param_data = param_data.narrow(dim, 0, size)
679
+ if tensor.shape != param_shape:
680
+ logger.warning(
681
+ "loading tensor of shape %s into "
682
+ "parameter '%s' of shape %s",
683
+ tensor.shape,
684
+ key,
685
+ param_shape,
686
+ )
687
+ param_data.copy_(tensor)
688
+ state_dict.pop(key)
689
+ if state_dict:
690
+ raise ValueError(
691
+ f"Missing keys {tuple(state_dict)} in loaded state!")
692
+ return model.eval()
693
+
694
+ @staticmethod
695
+ def save_model(
696
+ model: torch.nn.Module,
697
+ path: str,
698
+ pattern: Optional[str] = None,
699
+ max_size: Optional[int] = None,
700
+ ) -> None:
701
+ from safetensors.torch import save_file
702
+
703
+ from vllm.distributed import get_tensor_model_parallel_rank
704
+
705
+ if pattern is None:
706
+ pattern = ShardedStateLoader.DEFAULT_PATTERN
707
+ rank = get_tensor_model_parallel_rank()
708
+ part_idx = 0
709
+ total_size = 0
710
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
711
+ state_dict_part: Dict[str, torch.Tensor] = {}
712
+ for key, tensor in state_dict.items():
713
+ param_size = tensor.nelement() * tensor.element_size()
714
+ if max_size is not None and total_size + param_size > max_size:
715
+ filename = pattern.format(rank=rank, part=part_idx)
716
+ save_file(
717
+ state_dict_part,
718
+ os.path.join(path, filename),
719
+ )
720
+ part_idx += 1
721
+ total_size = 0
722
+ state_dict_part = {}
723
+ state_dict_part[key] = tensor
724
+ total_size += param_size
725
+ if len(state_dict_part) > 0:
726
+ filename = pattern.format(rank=rank, part=part_idx)
727
+ save_file(
728
+ state_dict_part,
729
+ os.path.join(path, filename),
730
+ )
731
+
732
+
733
+ class BitsAndBytesModelLoader(BaseModelLoader):
734
+ """Model loader to load model weights with BitAndBytes quantization."""
735
+
736
+ possible_config_file_names = ["adapter_config.json"]
737
+
738
+ def __init__(self, load_config: LoadConfig):
739
+ super().__init__(load_config)
740
+
741
+ # Save the module names without sharding.
742
+ self.unsharded_weights_modules: List[str] = []
743
+ # Save the module names that are sharded by column.
744
+ self.column_sharded_weights_modules: List[str] = []
745
+ # Store all module names (from transformers) that support
746
+ # BNB quantization.
747
+ self.target_modules: List[str] = []
748
+ # mapping weight names from transformers to vllm.
749
+ self.weight_mapper: Callable = lambda name: name
750
+
751
+ def _get_weight_files(
752
+ self,
753
+ model_name_or_path: str,
754
+ allowed_patterns: List[str],
755
+ revision: Optional[str] = None,
756
+ ) -> Tuple[List[str], str]:
757
+ """Retrieve weight files. Download the files if necessary.
758
+
759
+ Return the weight files and the file pattern."""
760
+ is_local = os.path.isdir(model_name_or_path)
761
+
762
+ if is_local:
763
+ for pattern in allowed_patterns:
764
+ weight_files = glob.glob(
765
+ os.path.join(model_name_or_path, pattern))
766
+ if weight_files:
767
+ return weight_files, pattern
768
+ else:
769
+ hf_api = HfApi()
770
+ repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
771
+ for pattern in allowed_patterns:
772
+ matching_files = fnmatch.filter(repo_files, pattern)
773
+ if matching_files:
774
+ hf_folder = download_weights_from_hf(
775
+ model_name_or_path,
776
+ self.load_config.download_dir,
777
+ [pattern],
778
+ revision,
779
+ ignore_patterns=self.load_config.ignore_patterns,
780
+ )
781
+ return glob.glob(os.path.join(hf_folder, pattern)), pattern
782
+
783
+ raise RuntimeError(
784
+ f"No model weights found in: `{model_name_or_path}`")
785
+
786
+ def _prepare_weights(self, model_name_or_path: str,
787
+ revision: Optional[str]) -> Tuple[List[str], bool]:
788
+ """Prepare weight files for the model."""
789
+
790
+ allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
791
+
792
+ hf_weights_files, matched_pattern = self._get_weight_files(
793
+ model_name_or_path, allowed_patterns, revision)
794
+
795
+ if matched_pattern != "*.safetensors":
796
+ hf_weights_files = filter_files_not_needed_for_inference(
797
+ hf_weights_files)
798
+
799
+ if len(hf_weights_files) == 0:
800
+ raise RuntimeError(
801
+ f"Cannot find any model weights with `{model_name_or_path}`")
802
+
803
+ return hf_weights_files, matched_pattern == "*.safetensors"
804
+
805
+ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
806
+ if use_safetensors:
807
+ iterator = safetensors_weights_iterator(hf_weights_files)
808
+ else:
809
+ iterator = pt_weights_iterator(hf_weights_files)
810
+ for org_name, param in iterator:
811
+ # mapping weight names from transformers to vllm while preserving
812
+ # original names.
813
+ mapped_name = self.weight_mapper(org_name)
814
+ yield org_name, mapped_name, param
815
+
816
+ def _get_quantized_weights_iterator(
817
+ self,
818
+ model_name_or_path: str,
819
+ revision: Optional[str],
820
+ pre_quant: bool,
821
+ load_8bit: bool,
822
+ ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
823
+ Any]]:
824
+ """Get an iterator to the model weights with bitsandbytes quantization,
825
+ as well as the quantization state dictionary."""
826
+
827
+ # only load the bitsandbytes module when needed
828
+ try:
829
+ import bitsandbytes
830
+
831
+ if bitsandbytes.__version__ < "0.45.0":
832
+ raise ImportError("bitsandbytes version is wrong. Please "
833
+ "install bitsandbytes>=0.45.0.")
834
+ except ImportError as err:
835
+ raise ImportError("Please install bitsandbytes>=0.45.0 via "
836
+ "`pip install bitsandbytes>=0.45.0` to use "
837
+ "bitsandbytes quantizer.") from err
838
+
839
+ hf_weights_files, use_safetensors = self._prepare_weights(
840
+ model_name_or_path, revision)
841
+
842
+ quant_state_dict: Dict[str, Any] = {}
843
+
844
+ if pre_quant:
845
+ if load_8bit:
846
+ return self._quantized_8bit_generator(
847
+ hf_weights_files, use_safetensors,
848
+ quant_state_dict), quant_state_dict
849
+ else:
850
+ return self._quantized_4bit_generator(
851
+ hf_weights_files, use_safetensors,
852
+ quant_state_dict), quant_state_dict
853
+
854
+ return self._unquantized_generator(hf_weights_files, use_safetensors,
855
+ quant_state_dict), quant_state_dict
856
+
857
+ def _is_8bit_weight_name(self, weight_name: str):
858
+ quantized_suffix = {".scb", ".weight_format"}
859
+ return any(weight_name.lower().endswith(suffix)
860
+ for suffix in quantized_suffix)
861
+
862
+ def _is_4bit_weight_name(self, weight_name: str):
863
+ quantized_suffix = {
864
+ "absmax",
865
+ "quant_map",
866
+ "nested_absmax",
867
+ "nested_quant_map",
868
+ "bitsandbytes",
869
+ }
870
+ suffix = weight_name.split(".")[-1]
871
+ return any(q_suffix in suffix for q_suffix in quantized_suffix)
872
+
873
+ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
874
+ quant_state_dict) -> Generator:
875
+ for (
876
+ org_weight_name,
877
+ mapped_weight_name,
878
+ weight_tensor,
879
+ ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
880
+ if not mapped_weight_name.lower().endswith(".scb"):
881
+ continue
882
+
883
+ weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
884
+ quant_state_dict[weight_key] = weight_tensor
885
+
886
+ for (
887
+ org_weight_name,
888
+ mapped_weight_name,
889
+ weight_tensor,
890
+ ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
891
+ if self._is_8bit_weight_name(mapped_weight_name):
892
+ continue
893
+
894
+ if mapped_weight_name in quant_state_dict:
895
+ set_weight_attrs(weight_tensor, {"load_in_8bit": True})
896
+ yield org_weight_name, weight_tensor
897
+ else:
898
+ yield org_weight_name, weight_tensor
899
+
900
+ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
901
+ quant_state_dict) -> Generator:
902
+ from bitsandbytes.functional import QuantState
903
+
904
+ # First iterate over all quant state weights
905
+ weight_iterator = self._hf_weight_iter(hf_weights_files,
906
+ use_safetensors)
907
+ temp_state_dict = {}
908
+ for (
909
+ org_weight_name,
910
+ mapped_weight_name,
911
+ weight_tensor,
912
+ ) in weight_iterator:
913
+ if not self._is_4bit_weight_name(mapped_weight_name):
914
+ continue
915
+ # bitsandbytes library requires
916
+ # weight.quant_state.bitsandbytes__* in CPU
917
+ if "quant_state.bitsandbytes" in mapped_weight_name:
918
+ temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
919
+ else:
920
+ temp_state_dict[mapped_weight_name] = weight_tensor
921
+
922
+ # Closure to parse quant_state for each prequant weight
923
+ def _parse_quant_state(param_name: str,
924
+ temp_state_dict: Dict) -> QuantState:
925
+ quant_state = {}
926
+ for k in temp_state_dict:
927
+ if param_name + "." in k:
928
+ quant_state[k] = temp_state_dict[k]
929
+
930
+ return QuantState.from_dict(quant_state, device="cuda")
931
+
932
+ # Second iterate over all prequant and normal weights
933
+ # pre quantized weights would have a quant_state
934
+ for (
935
+ org_weight_name,
936
+ mapped_weight_name,
937
+ weight_tensor,
938
+ ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
939
+ if self._is_4bit_weight_name(mapped_weight_name):
940
+ continue
941
+
942
+ if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
943
+ in temp_state_dict) or (
944
+ f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
945
+ in temp_state_dict):
946
+ quant_state = _parse_quant_state(mapped_weight_name,
947
+ temp_state_dict)
948
+ quant_state_dict[mapped_weight_name] = quant_state
949
+ yield org_weight_name, weight_tensor
950
+ else:
951
+ yield org_weight_name, weight_tensor
952
+
953
+ def _unquantized_generator(self, hf_weights_files, use_safetensors,
954
+ quant_state_dict) -> Generator:
955
+ from bitsandbytes.functional import quantize_4bit
956
+
957
+ tp_size = get_tensor_model_parallel_world_size()
958
+ tp_rank = get_tensor_model_parallel_rank()
959
+
960
+ for (
961
+ org_weight_name,
962
+ mapped_weight_name,
963
+ weight_tensor,
964
+ ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
965
+ if any(target_module in mapped_weight_name
966
+ for target_module in self.target_modules
967
+ ) and mapped_weight_name.endswith(".weight"):
968
+ # Without sharding
969
+ if any(
970
+ mapped_weight_name.startswith(module)
971
+ for module in self.unsharded_weights_modules):
972
+ weight_sub_tensor = weight_tensor
973
+ # Shard by column
974
+ elif any(
975
+ mapped_weight_name.startswith(module)
976
+ for module in self.column_sharded_weights_modules):
977
+ total_size = weight_tensor.size(-1)
978
+ start_index = total_size // tp_size * tp_rank
979
+ end_index = total_size // tp_size * (tp_rank + 1)
980
+ weight_sub_tensor = weight_tensor[...,
981
+ start_index:end_index]
982
+ # Weights have fused on disk. In this case, we assume that the
983
+ # weight and module use same name.
984
+ elif any(
985
+ mapped_weight_name.startswith(module)
986
+ for module in self.maybe_fused_weights_modules):
987
+ # special case for fused weights
988
+ # get the size of each shard weight tensor
989
+ total_shard_sizes = next(
990
+ (sizes for module, sizes in
991
+ self.maybe_fused_weights_modules.items()
992
+ if mapped_weight_name.startswith(module)))
993
+ total_size = weight_tensor.size(0)
994
+ assert total_size == sum(total_shard_sizes)
995
+ # get the start/end index of each shard weight tensor
996
+ total_start_index = list(
997
+ itertools.accumulate([0] + total_shard_sizes))[:-1]
998
+ shard_weights_index = [(
999
+ idx + size // tp_size * tp_rank,
1000
+ idx + size // tp_size * (tp_rank + 1),
1001
+ ) for idx, size in zip(total_start_index,
1002
+ total_shard_sizes)]
1003
+ # slice and reorder the weight tensor
1004
+ weight_tensor = [
1005
+ weight_tensor[start_index:end_index, ...]
1006
+ for start_index, end_index in shard_weights_index
1007
+ ]
1008
+ weight_sub_tensor = torch.cat(weight_tensor, dim=0)
1009
+ # Shard by row
1010
+ else:
1011
+ total_size = weight_tensor.size(0)
1012
+ start_index = total_size // tp_size * tp_rank
1013
+ end_index = total_size // tp_size * (tp_rank + 1)
1014
+ weight_sub_tensor = weight_tensor[start_index:end_index,
1015
+ ...]
1016
+
1017
+ # bitsandbytes requires data in GPU
1018
+ if weight_sub_tensor.is_cuda:
1019
+ loaded_weight = weight_sub_tensor
1020
+ else:
1021
+ loaded_weight = weight_sub_tensor.cuda()
1022
+
1023
+ # remove the following after the issue is fixed:
1024
+ # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
1025
+ if loaded_weight.is_contiguous() is False:
1026
+ loaded_weight = loaded_weight.contiguous()
1027
+
1028
+ with set_default_torch_dtype(torch.float32):
1029
+ processed_weight, quant_state = quantize_4bit(
1030
+ loaded_weight,
1031
+ compress_statistics=True,
1032
+ quant_type="nf4",
1033
+ )
1034
+
1035
+ quant_state_dict[mapped_weight_name] = quant_state
1036
+ else:
1037
+ processed_weight = weight_tensor
1038
+ yield org_weight_name, processed_weight
1039
+
1040
+ def _get_bnb_target_modules(self, model: nn.Module) -> None:
1041
+
1042
+ for name, module in model.named_modules():
1043
+ if isinstance(module, (LinearBase, )):
1044
+ if modules_info := self.modules_mapping.get_sub_modules(name):
1045
+ # Map vllm's names to transformers's names.
1046
+ rep_name, sub_modules = modules_info
1047
+ for sub_name in sub_modules:
1048
+ self.target_modules.append(
1049
+ name.replace(rep_name, sub_name))
1050
+ # Add original module name even if the module has stacked map,
1051
+ # in case model has a mixture of disk-merged and disk-splitted
1052
+ # weights with same last name.
1053
+ self.target_modules.append(name)
1054
+
1055
+ assert (self.target_modules
1056
+ ), "vllm currently does not support BNB quantization for"
1057
+ f" {type(model).__name__}"
1058
+
1059
+ def _load_weights(self, model_config: ModelConfig,
1060
+ model: nn.Module) -> None:
1061
+ if not hasattr(model, "load_weights"):
1062
+ raise AttributeError(
1063
+ "The required method 'load_weights' is not defined in class"
1064
+ f" {type(model).__name__}.")
1065
+
1066
+ if not hasattr(model, "packed_modules_mapping"):
1067
+ raise AttributeError(
1068
+ f"Model {type(model).__name__} does not support BitsAndBytes "
1069
+ "quantization yet. No 'packed_modules_mapping' found.")
1070
+
1071
+ self.modules_mapping = ParamMapping(
1072
+ copy.deepcopy(model.packed_modules_mapping))
1073
+
1074
+ # For some models like Molmo, we need to use hf_to_vllm_mapper
1075
+ # to ensure correct loading of weights.
1076
+ if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
1077
+ self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
1078
+
1079
+ # Modules whose weights might have fused on disk
1080
+ # we need their output_sizes to make shard in flight correctly with TP
1081
+ self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
1082
+ self._get_bnb_target_modules(model)
1083
+ for name, module in model.named_modules():
1084
+ # Some modules like `ReplicatedLinear` should not have their weights
1085
+ # sharded. The reason for implementing it this way is to avoid new
1086
+ # static variable in the model implementation.
1087
+ if isinstance(module, (ReplicatedLinear, )):
1088
+ self.unsharded_weights_modules.append(name)
1089
+ # `QKVParallelLinear` and `MergedColumnParallelLinear` might have
1090
+ # fused weights on disk. We need to use the output sizes of these
1091
+ # modules to shard the weights correctly.
1092
+ elif isinstance(module,
1093
+ (QKVParallelLinear, MergedColumnParallelLinear)):
1094
+ self.maybe_fused_weights_modules[name] = module.output_sizes
1095
+ # In TP, these weights are partitioned along the column
1096
+ # dimension (dim=-1)
1097
+ elif isinstance(module, (RowParallelLinear, )):
1098
+ self.column_sharded_weights_modules.append(name)
1099
+
1100
+ self.model_type = type(model).__name__
1101
+
1102
+ logger.info("Loading weights with BitsAndBytes quantization. "
1103
+ " May take a while ...")
1104
+
1105
+ quant_config = getattr(model_config.hf_config, "quantization_config",
1106
+ None)
1107
+
1108
+ pre_quant = False
1109
+ if quant_config is not None:
1110
+ quant_method = quant_config.get("quant_method")
1111
+ if quant_method == "bitsandbytes":
1112
+ pre_quant = True
1113
+ else:
1114
+ raise ValueError(
1115
+ f"BitsAndBytes loader does not support {quant_method} "
1116
+ "quantization")
1117
+
1118
+ # The quant_states in pre_quantized models cannot work with a split
1119
+ # weight tensor. So TP does not work with pre_quantized bnb models.
1120
+ if pre_quant and get_tensor_model_parallel_world_size() > 1:
1121
+ raise ValueError(
1122
+ "Prequant BitsAndBytes models with tensor parallelism is not "
1123
+ "supported. Please try with pipeline parallelism.")
1124
+
1125
+ load_8bit = False
1126
+ if pre_quant:
1127
+ load_8bit = quant_config.get("load_in_8bit", False)
1128
+
1129
+ qweight_iterator, quant_state_dict = (
1130
+ self._get_quantized_weights_iterator(model_config.model,
1131
+ model_config.revision,
1132
+ pre_quant, load_8bit))
1133
+
1134
+ weights_to_load = {name for name, _ in model.named_parameters()}
1135
+ loaded_weights = model.load_weights(qweight_iterator)
1136
+ # Some models may have weights loading tracker unimplemented.
1137
+ if loaded_weights is not None:
1138
+ weights_not_loaded = weights_to_load - loaded_weights
1139
+ if weights_not_loaded:
1140
+ raise ValueError("Following weights were not initialized from "
1141
+ f"checkpoint: {weights_not_loaded}")
1142
+
1143
+ torch.cuda.empty_cache()
1144
+
1145
+ param_dict = dict(model.named_parameters())
1146
+ stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1147
+ # TODO: Change this lazy import to normal import
1148
+ # after the checks are updated to run on a new version
1149
+ from vllm.model_executor.models.utils import is_pp_missing_parameter
1150
+
1151
+ for quant_param_name in quant_state_dict:
1152
+ if is_pp_missing_parameter(quant_param_name, model):
1153
+ continue
1154
+
1155
+ non_stacked_param_name = quant_param_name
1156
+
1157
+ shard_index = 0
1158
+ for shard_name, (
1159
+ weight_name,
1160
+ index,
1161
+ ) in self.modules_mapping.inverse_packed_mapping.items():
1162
+ # Some models, such as MiniCPM V2.5/2.6, contain both
1163
+ # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
1164
+ # from being incorrectly identified as being present in
1165
+ # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
1166
+ shard_pos = quant_param_name.find(shard_name)
1167
+ can_correct_rename = (shard_pos
1168
+ > 0) and (quant_param_name[shard_pos - 1]
1169
+ == ".")
1170
+ # If the quant_param_name is packed, it won't occur in the
1171
+ # param_dict before renaming.
1172
+ new_quant_param_name = quant_param_name.replace(
1173
+ shard_name, weight_name)
1174
+ need_rename = (quant_param_name not in param_dict) \
1175
+ and (new_quant_param_name in param_dict)
1176
+ if can_correct_rename and need_rename:
1177
+ shard_index = index
1178
+ quant_param_name = new_quant_param_name
1179
+ break
1180
+
1181
+ # Models like Clip/Siglip may skip some layers in initialization,
1182
+ # causing unused quant_param_name in state_dict.
1183
+ if quant_param_name not in param_dict:
1184
+ continue
1185
+
1186
+ if quant_param_name not in stacked_quant_state_dict:
1187
+ stacked_quant_state_dict[quant_param_name] = {}
1188
+
1189
+ stacked_quant_state_dict[quant_param_name][shard_index] = (
1190
+ quant_state_dict[non_stacked_param_name])
1191
+
1192
+ # save quant_states and offsets as the attributes of the parameters
1193
+ for param_name, param in param_dict.items():
1194
+ if param_name in stacked_quant_state_dict:
1195
+ quant_states = stacked_quant_state_dict[param_name]
1196
+ set_weight_attrs(param, {"bnb_quant_state": quant_states})
1197
+
1198
+ pack_ratio = getattr(param, "pack_factor", -1)
1199
+ if pack_ratio == -1:
1200
+ raise ValueError(
1201
+ f"pack_factor not set for parameter {param_name}.")
1202
+
1203
+ num_elements = [0] * len(quant_states)
1204
+ for seq, quant_state in quant_states.items():
1205
+ num_elements[seq] = (math.prod(quant_state.shape) //
1206
+ pack_ratio)
1207
+
1208
+ offsets = np.concatenate(([0], np.cumsum(num_elements)))
1209
+ set_weight_attrs(param, {"bnb_shard_offsets": offsets})
1210
+
1211
+ if load_8bit:
1212
+ set_weight_attrs(
1213
+ param, {"matmul_state": [None] * len(quant_states)})
1214
+
1215
+ def download_model(self, model_config: ModelConfig) -> None:
1216
+ self._prepare_weights(model_config.model, model_config.revision)
1217
+
1218
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
1219
+ device_config = vllm_config.device_config
1220
+ model_config = vllm_config.model_config
1221
+ with set_default_torch_dtype(model_config.dtype):
1222
+ with torch.device(device_config.device):
1223
+ model = _initialize_model(vllm_config=vllm_config)
1224
+
1225
+ self._load_weights(model_config, model)
1226
+
1227
+ return model.eval()
1228
+
1229
+
1230
+ class GGUFModelLoader(BaseModelLoader):
1231
+ """
1232
+ Model loader that can load GGUF files. This is useful for loading models
1233
+ that are quantized with GGUF and saved in the GGUF format. This loader
1234
+ supports loading both full models and sharded models.
1235
+ """
1236
+
1237
+ def __init__(self, load_config: LoadConfig):
1238
+ super().__init__(load_config)
1239
+ if load_config.model_loader_extra_config:
1240
+ raise ValueError(f"Model loader extra config is not supported for "
1241
+ f"load format {load_config.load_format}")
1242
+
1243
+ def _prepare_weights(self, model_name_or_path: str):
1244
+ if os.path.isfile(model_name_or_path):
1245
+ return model_name_or_path
1246
+ else:
1247
+ raise ValueError(f"{model_name_or_path} is not a file.")
1248
+
1249
+ def _get_gguf_weights_map(self, model_config: ModelConfig):
1250
+ """
1251
+ GGUF uses this naming convention for their tensors from HF checkpoint:
1252
+ `blk.N.BB.weight` and `blk.N.BB.bias`
1253
+ where N signifies the block number of a layer, and BB signifies the
1254
+ attention/mlp layer components.
1255
+ See "Standardized tensor names" in
1256
+ https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
1257
+ """
1258
+ config = model_config.hf_config
1259
+ model_type = config.model_type
1260
+ # hack: ggufs have a different name than transformers
1261
+ if model_type == "cohere":
1262
+ model_type = "command-r"
1263
+ arch = None
1264
+ for key, value in gguf.MODEL_ARCH_NAMES.items():
1265
+ if value == model_type:
1266
+ arch = key
1267
+ break
1268
+ if arch is None:
1269
+ raise RuntimeError(f"Unknown gguf model_type: {model_type}")
1270
+ num_layers = config.num_hidden_layers
1271
+ name_map = gguf.get_tensor_name_map(arch, num_layers)
1272
+ with torch.device("meta"):
1273
+ dummy_model = AutoModelForCausalLM.from_config(config)
1274
+ state_dict = dummy_model.state_dict()
1275
+
1276
+ gguf_to_hf_name_map = {}
1277
+ for hf_name in state_dict:
1278
+ name, suffix = hf_name.rsplit(".", 1)
1279
+ gguf_name = name_map.get_name(name)
1280
+ gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
1281
+ return gguf_to_hf_name_map
1282
+
1283
+ def _get_weights_iterator(
1284
+ self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
1285
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1286
+ return gguf_quant_weights_iterator(model_name_or_path,
1287
+ gguf_to_hf_name_map)
1288
+
1289
+ def download_model(self, model_config: ModelConfig) -> None:
1290
+ self._prepare_weights(model_config.model)
1291
+
1292
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
1293
+ device_config = vllm_config.device_config
1294
+ model_config = vllm_config.model_config
1295
+ local_model_path = self._prepare_weights(model_config.model)
1296
+ gguf_weights_map = self._get_gguf_weights_map(model_config)
1297
+ # we can only know if tie word embeddings after mapping weights
1298
+ if "lm_head.weight" in get_gguf_extra_tensor_names(
1299
+ local_model_path, gguf_weights_map):
1300
+ model_config.hf_config.update({"tie_word_embeddings": True})
1301
+
1302
+ with set_default_torch_dtype(model_config.dtype):
1303
+ with torch.device(device_config.device):
1304
+ model = _initialize_model(vllm_config=vllm_config)
1305
+ model.load_weights(
1306
+ self._get_weights_iterator(local_model_path, gguf_weights_map))
1307
+ return model
1308
+
1309
+
1310
+ class RunaiModelStreamerLoader(BaseModelLoader):
1311
+ """
1312
+ Model loader that can load safetensors
1313
+ files from local FS or S3 bucket.
1314
+ """
1315
+
1316
+ def __init__(self, load_config: LoadConfig):
1317
+ super().__init__(load_config)
1318
+ if load_config.model_loader_extra_config:
1319
+ extra_config = load_config.model_loader_extra_config
1320
+
1321
+ if ("concurrency" in extra_config
1322
+ and isinstance(extra_config.get("concurrency"), int)):
1323
+ os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
1324
+ extra_config.get("concurrency"))
1325
+
1326
+ if ("memory_limit" in extra_config
1327
+ and isinstance(extra_config.get("memory_limit"), int)):
1328
+ os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
1329
+ extra_config.get("memory_limit"))
1330
+
1331
+ runai_streamer_s3_endpoint = os.getenv(
1332
+ 'RUNAI_STREAMER_S3_ENDPOINT')
1333
+ aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
1334
+ if (runai_streamer_s3_endpoint is None
1335
+ and aws_endpoint_url is not None):
1336
+ os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
1337
+
1338
+ def _prepare_weights(self, model_name_or_path: str,
1339
+ revision: Optional[str]) -> List[str]:
1340
+ """Prepare weights for the model.
1341
+
1342
+ If the model is not local, it will be downloaded."""
1343
+ is_s3_path = is_s3(model_name_or_path)
1344
+ is_local = os.path.isdir(model_name_or_path)
1345
+ safetensors_pattern = "*.safetensors"
1346
+ index_file = SAFE_WEIGHTS_INDEX_NAME
1347
+
1348
+ hf_folder = (model_name_or_path if
1349
+ (is_local or is_s3_path) else download_weights_from_hf(
1350
+ model_name_or_path,
1351
+ self.load_config.download_dir,
1352
+ [safetensors_pattern],
1353
+ revision,
1354
+ ignore_patterns=self.load_config.ignore_patterns,
1355
+ ))
1356
+
1357
+ if is_s3_path:
1358
+ hf_weights_files = s3_glob(path=hf_folder,
1359
+ allow_pattern=[safetensors_pattern])
1360
+ else:
1361
+ hf_weights_files = glob.glob(
1362
+ os.path.join(hf_folder, safetensors_pattern))
1363
+
1364
+ if not is_local and not is_s3_path:
1365
+ download_safetensors_index_file_from_hf(
1366
+ model_name_or_path, index_file, self.load_config.download_dir,
1367
+ revision)
1368
+
1369
+ if not hf_weights_files:
1370
+ raise RuntimeError(
1371
+ f"Cannot find any safetensors model weights with "
1372
+ f"`{model_name_or_path}`")
1373
+
1374
+ return hf_weights_files
1375
+
1376
+ def _get_weights_iterator(
1377
+ self, model_or_path: str,
1378
+ revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
1379
+ """Get an iterator for the model weights based on the load format."""
1380
+ hf_weights_files = self._prepare_weights(model_or_path, revision)
1381
+ return runai_safetensors_weights_iterator(hf_weights_files)
1382
+
1383
+ def download_model(self, model_config: ModelConfig) -> None:
1384
+ """Download model if necessary"""
1385
+ self._prepare_weights(model_config.model, model_config.revision)
1386
+
1387
+ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
1388
+ """Perform streaming of the model to destination"""
1389
+ device_config = vllm_config.device_config
1390
+ model_config = vllm_config.model_config
1391
+
1392
+ target_device = torch.device(device_config.device)
1393
+ with set_default_torch_dtype(model_config.dtype):
1394
+ with target_device:
1395
+ model = _initialize_model(vllm_config=vllm_config)
1396
+
1397
+ model_weights = model_config.model
1398
+ if hasattr(model_config, "model_weights"):
1399
+ model_weights = model_config.model_weights
1400
+ model.load_weights(
1401
+ self._get_weights_iterator(model_weights,
1402
+ model_config.revision))
1403
+
1404
+ for _, module in model.named_modules():
1405
+ quant_method = getattr(module, "quant_method", None)
1406
+ if quant_method is not None:
1407
+ with device_loading_context(module, target_device):
1408
+ quant_method.process_weights_after_loading(module)
1409
+ if isinstance(module, Attention) and \
1410
+ hasattr(module, "process_weights_after_loading"):
1411
+ # When attention modules need to process weights after
1412
+ # currently only used by MLA
1413
+ module.process_weights_after_loading(model_config.dtype)
1414
+ return model.eval()
1415
+
1416
+
1417
+ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1418
+ """Get a model loader based on the load format."""
1419
+
1420
+ if isinstance(load_config.load_format, type):
1421
+ return load_config.load_format(load_config)
1422
+
1423
+ if load_config.load_format == LoadFormat.DUMMY:
1424
+ return DummyModelLoader(load_config)
1425
+
1426
+ if load_config.load_format == LoadFormat.TENSORIZER:
1427
+ return TensorizerLoader(load_config)
1428
+
1429
+ if load_config.load_format == LoadFormat.SHARDED_STATE:
1430
+ return ShardedStateLoader(load_config)
1431
+
1432
+ if load_config.load_format == LoadFormat.BITSANDBYTES:
1433
+ return BitsAndBytesModelLoader(load_config)
1434
+
1435
+ if load_config.load_format == LoadFormat.GGUF:
1436
+ return GGUFModelLoader(load_config)
1437
+
1438
+ if load_config.load_format == LoadFormat.RUNAI_STREAMER:
1439
+ return RunaiModelStreamerLoader(load_config)
1440
+
1441
+ return DefaultModelLoader(load_config)
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/neuron.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Utilities for selecting and loading neuron models."""
3
+ import copy
4
+ import importlib
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import PretrainedConfig
11
+
12
+ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
13
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
14
+ from vllm.model_executor.layers.quantization import get_quantization_config
15
+ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
16
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
17
+ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
18
+ SequenceOutput)
19
+
20
+ TORCH_DTYPE_TO_NEURON_AMP = {
21
+ "auto": "f32",
22
+ "half": "f16",
23
+ "float16": "f16",
24
+ "bfloat16": "bf16",
25
+ "float": "f32",
26
+ "float32": "f32",
27
+ torch.float16: "f16",
28
+ torch.bfloat16: "bf16",
29
+ torch.float32: "f32",
30
+ }
31
+
32
+ # Models supported by Neuron.
33
+ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
34
+ "LlamaForCausalLM": ("transformers_neuronx.llama.model",
35
+ "LlamaForSampling", "LlamaForCausalLM"),
36
+ "MistralForCausalLM": ("transformers_neuronx.mistral.model",
37
+ "MistralForSampling", "MistralForCausalLM")
38
+ }
39
+
40
+
41
+ class NeuronCausalLM(nn.Module):
42
+
43
+ def __init__(self,
44
+ config: PretrainedConfig,
45
+ on_device_sampling_disabled: bool = False) -> None:
46
+ super().__init__()
47
+ self.config = config
48
+ self.logits_processor = LogitsProcessor(config.vocab_size,
49
+ logits_as_input=True)
50
+
51
+ self.on_device_sampling_disabled = on_device_sampling_disabled
52
+ if self.on_device_sampling_disabled:
53
+ # Use default sampler
54
+ self.sampler = Sampler()
55
+
56
+ # Lazy initialized
57
+ self.model: nn.Module
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.Tensor,
62
+ positions: torch.Tensor,
63
+ input_block_ids: torch.Tensor,
64
+ ) -> torch.Tensor:
65
+ logits = self.model(input_ids,
66
+ cache_ids=positions,
67
+ start_ids=input_block_ids)
68
+ return logits
69
+
70
+ def compute_logits(self, hidden_states: torch.Tensor,
71
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
72
+ logits = self.logits_processor(None, hidden_states, sampling_metadata)
73
+ return logits
74
+
75
+ def sample(
76
+ self,
77
+ logits: torch.Tensor,
78
+ sampling_metadata: SamplingMetadata,
79
+ ) -> Optional[SamplerOutput]:
80
+
81
+ if self.on_device_sampling_disabled:
82
+ next_tokens = self.sampler(logits, sampling_metadata)
83
+ return next_tokens
84
+
85
+ # On-device sampling outputs the token ids directly.
86
+ sampled_token_ids = logits.flatten()
87
+ next_tokens = []
88
+ sample_idx = 0
89
+ for seq_group in sampling_metadata.seq_groups:
90
+ samples = []
91
+ for seq_id in seq_group.seq_ids:
92
+ token_id = sampled_token_ids[sample_idx].item()
93
+ samples.append(
94
+ SequenceOutput(parent_seq_id=seq_id,
95
+ output_token=token_id,
96
+ logprobs={token_id: Logprob(token_id)}))
97
+ sample_idx += 1
98
+ next_tokens.append(
99
+ CompletionSequenceGroupOutput(samples=samples,
100
+ prompt_logprobs=None))
101
+
102
+ return SamplerOutput(outputs=next_tokens)
103
+
104
+ def load_weights(self, model_name_or_path: str, **kwargs):
105
+ arch = _get_model_architecture(self.config)
106
+ neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
107
+ _NEURON_SUPPORTED_MODELS[arch])
108
+ neuronx_module = importlib.import_module(neuronx_module_path)
109
+ neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
110
+
111
+ self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
112
+ **kwargs)
113
+ self.model.to_neuron()
114
+
115
+
116
+ def _get_model_architecture(config: PretrainedConfig) -> str:
117
+ architectures = getattr(config, "architectures", [])
118
+ for arch in architectures:
119
+ if arch in _NEURON_SUPPORTED_MODELS:
120
+ return arch
121
+ raise ValueError(
122
+ f"Model architectures {architectures} are not supported on Neuron "
123
+ f"for now. Supported architectures: "
124
+ f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
125
+
126
+
127
+ def _get_buckets(env: str, default_value: List[int]) -> List[int]:
128
+ env_value = os.getenv(env)
129
+ if env_value is None:
130
+ return default_value
131
+ buckets_remove_empty = filter(
132
+ lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
133
+ buckets_int = map(int, buckets_remove_empty)
134
+ buckets_list = list(buckets_int)
135
+ return buckets_list
136
+
137
+
138
+ def _get_default_neuron_config(model_config: ModelConfig,
139
+ parallel_config: ParallelConfig,
140
+ scheduler_config: SchedulerConfig):
141
+ from transformers_neuronx.config import ContinuousBatchingConfig
142
+ from transformers_neuronx.constants import LAYOUT_BSH
143
+
144
+ continuous_batching_config = ContinuousBatchingConfig(
145
+ batch_size_for_shared_caches=scheduler_config.max_num_seqs)
146
+ quant_config = dict(
147
+ dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
148
+ quantize_method="vector_dynamic")
149
+ neuron_quantization_config_builder = lambda quant: get_quantization_config(
150
+ quant).from_config(quant_config).get_quant_method(None, "")
151
+ # TODO: Add Paged attention config to the default neuron arguments.
152
+ default_neuron_args = dict(
153
+ collectives_layout=LAYOUT_BSH,
154
+ attention_layout=LAYOUT_BSH,
155
+ fuse_qkv=True,
156
+ quant=neuron_quantization_config_builder(model_config.quantization)
157
+ if model_config.quantization else None,
158
+ continuous_batching=continuous_batching_config,
159
+ weight_tiling=bool(model_config.quantization),
160
+ on_device_generation=_get_neuron_on_device_generation_config(
161
+ model_config))
162
+ return default_neuron_args
163
+
164
+
165
+ def _get_neuron_on_device_generation_config(model_config: ModelConfig):
166
+ if not _is_neuron_on_device_sampling_disabled(model_config):
167
+ return copy.deepcopy(model_config.neuron_sampling_params)
168
+ return None
169
+
170
+
171
+ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
172
+ return not getattr(model_config, "neuron_sampling_params", None)
173
+
174
+
175
+ def _get_neuron_config_after_override(default_neuron_config,
176
+ overridden_neuron_config):
177
+ from transformers_neuronx.config import NeuronConfig
178
+ overridden_neuron_config = overridden_neuron_config or {}
179
+ default_neuron_config.update(overridden_neuron_config)
180
+ return NeuronConfig(**default_neuron_config)
181
+
182
+
183
+ def get_neuron_model(model_config: ModelConfig,
184
+ parallel_config: ParallelConfig,
185
+ scheduler_config: SchedulerConfig) -> nn.Module:
186
+
187
+ # Create a model instance.
188
+ model = NeuronCausalLM(
189
+ model_config.hf_config,
190
+ _is_neuron_on_device_sampling_disabled(model_config))
191
+
192
+ default_neuron_config_args = _get_default_neuron_config(
193
+ model_config, parallel_config, scheduler_config)
194
+
195
+ neuron_config = _get_neuron_config_after_override(
196
+ default_neuron_config_args, model_config.override_neuron_config)
197
+
198
+ context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
199
+ [scheduler_config.max_model_len])
200
+ n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
201
+ [scheduler_config.max_model_len])
202
+
203
+ # Load the weights from the cached or downloaded files.
204
+ model.load_weights(model_config.model,
205
+ tp_degree=parallel_config.tensor_parallel_size,
206
+ amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
207
+ neuron_config=neuron_config,
208
+ context_length_estimate=context_length_estimates,
209
+ n_positions=n_positions,
210
+ batch_size=scheduler_config.max_num_seqs)
211
+
212
+ return model.eval()
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Utilities for selecting and loading models."""
3
+ import contextlib
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Optional, Tuple, Type
6
+
7
+ import torch
8
+ import transformers
9
+ from torch import nn
10
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
11
+
12
+ from vllm.config import ModelConfig, ModelImpl
13
+ from vllm.logger import init_logger
14
+ from vllm.model_executor.layers.quantization.base_config import (
15
+ QuantizationConfig)
16
+ from vllm.model_executor.models import ModelRegistry
17
+ from vllm.model_executor.models.adapters import (as_classification_model,
18
+ as_embedding_model,
19
+ as_reward_model)
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ @contextlib.contextmanager
25
+ def set_default_torch_dtype(dtype: torch.dtype):
26
+ """Sets the default torch dtype to the given dtype."""
27
+ old_dtype = torch.get_default_dtype()
28
+ torch.set_default_dtype(dtype)
29
+ yield
30
+ torch.set_default_dtype(old_dtype)
31
+
32
+
33
+ def is_transformers_impl_compatible(
34
+ arch: str,
35
+ module: Optional[transformers.PreTrainedModel] = None) -> bool:
36
+ mod = module or getattr(transformers, arch, None)
37
+ if mod is None:
38
+ return False
39
+ if hasattr(mod, "supports_backend"):
40
+ return mod.is_backend_compatible()
41
+ else:
42
+ return mod._supports_flex_attn
43
+
44
+
45
+ def resolve_transformers_fallback(model_config: ModelConfig,
46
+ architectures: list[str]):
47
+ for i, arch in enumerate(architectures):
48
+ if arch == "TransformersModel":
49
+ continue
50
+ custom_module = None
51
+ auto_map = getattr(model_config.hf_config, "auto_map", None)
52
+ if auto_map is not None and "AutoModel" in auto_map:
53
+ custom_module = get_class_from_dynamic_module(
54
+ model_config.hf_config.auto_map["AutoModel"],
55
+ model_config.model)
56
+ # TODO(Isotr0py): Further clean up these raises.
57
+ # perhaps handled them in _ModelRegistry._raise_for_unsupported?
58
+ if model_config.model_impl == ModelImpl.TRANSFORMERS:
59
+ if not is_transformers_impl_compatible(arch, custom_module):
60
+ raise ValueError(
61
+ f"The Transformers implementation of {arch} is not "
62
+ "compatible with vLLM.")
63
+ architectures[i] = "TransformersModel"
64
+ if model_config.model_impl == ModelImpl.AUTO:
65
+ if not is_transformers_impl_compatible(arch, custom_module):
66
+ raise ValueError(
67
+ f"{arch} has no vLLM implementation and the Transformers "
68
+ "implementation is not compatible with vLLM.")
69
+ logger.warning(
70
+ "%s has no vLLM implementation, falling back to Transformers "
71
+ "implementation. Some features may not be supported and "
72
+ "performance may not be optimal.", arch)
73
+ architectures[i] = "TransformersModel"
74
+ return architectures
75
+
76
+
77
+ def get_model_architecture(
78
+ model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
79
+ architectures = getattr(model_config.hf_config, "architectures", [])
80
+
81
+ # Special handling for quantized Mixtral.
82
+ # FIXME(woosuk): This is a temporary hack.
83
+ mixtral_supported = [
84
+ "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
85
+ ]
86
+
87
+ if (model_config.quantization is not None
88
+ and model_config.quantization not in mixtral_supported
89
+ and "MixtralForCausalLM" in architectures):
90
+ architectures = ["QuantMixtralForCausalLM"]
91
+
92
+ vllm_supported_archs = ModelRegistry.get_supported_archs()
93
+ is_vllm_supported = any(arch in vllm_supported_archs
94
+ for arch in architectures)
95
+ if (not is_vllm_supported
96
+ or model_config.model_impl == ModelImpl.TRANSFORMERS):
97
+ architectures = resolve_transformers_fallback(model_config,
98
+ architectures)
99
+
100
+ model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
101
+ if model_config.task == "embed":
102
+ model_cls = as_embedding_model(model_cls)
103
+ elif model_config.task == "classify":
104
+ model_cls = as_classification_model(model_cls)
105
+ elif model_config.task == "reward":
106
+ model_cls = as_reward_model(model_cls)
107
+
108
+ return model_cls, arch
109
+
110
+
111
+ def get_architecture_class_name(model_config: ModelConfig) -> str:
112
+ return get_model_architecture(model_config)[1]
113
+
114
+
115
+ @dataclass
116
+ class ParamMapping:
117
+ """
118
+ A class to handle parameter mapping for model weight loading.
119
+ It creates a bidirectional mapping between packed parameters and their
120
+ constituent parts.
121
+ """
122
+ packed_mapping: Dict[str, List[str]]
123
+ inverse_packed_mapping: Dict[str, Tuple[str,
124
+ int]] = field(default_factory=dict)
125
+
126
+ def __post_init__(self):
127
+ for packed_name, sub_params in self.packed_mapping.items():
128
+ # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
129
+ if len(sub_params) == 1 and sub_params[0] == packed_name:
130
+ continue
131
+ for index, param_name in enumerate(sub_params):
132
+ self.inverse_packed_mapping[param_name] = (
133
+ packed_name,
134
+ index,
135
+ )
136
+
137
+ def get_sub_modules(self,
138
+ module_name: str) -> Optional[Tuple[str, List[str]]]:
139
+ for key, value in self.packed_mapping.items():
140
+ if module_name.endswith(key):
141
+ return key, value
142
+ return None
143
+
144
+
145
+ def configure_quant_config(quant_config: QuantizationConfig,
146
+ model_class: Type[nn.Module]):
147
+ """
148
+ Pass packed_modules_mapping by reference to quant_config so that
149
+ quant_config can properly match fused modules
150
+
151
+ Note that model attributes are passed by reference to quant_config,
152
+ enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
153
+ """
154
+ packed_mapping = getattr(model_class, "packed_modules_mapping", None)
155
+ if packed_mapping is not None:
156
+ # pass packed_modules_mapping by reference to quant_config
157
+ quant_config.packed_modules_mapping = packed_mapping
158
+ else:
159
+ logger.warning(
160
+ "The model class %s has not defined `packed_modules_mapping`, "
161
+ "this may lead to incorrect mapping of quantized or ignored "
162
+ "modules", model_class.__name__)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/arctic.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Inference-only Snowflake Arctic model."""
3
+ from typing import Iterable, List, Optional, Set, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from vllm.attention import Attention, AttentionMetadata
9
+ from vllm.compilation.decorators import support_torch_compile
10
+ from vllm.config import CacheConfig, VllmConfig
11
+ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ tensor_model_parallel_all_reduce)
14
+ from vllm.logger import init_logger
15
+ from vllm.model_executor.layers.activation import SiluAndMul
16
+ from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19
+ QKVParallelLinear,
20
+ ReplicatedLinear,
21
+ RowParallelLinear)
22
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
+ from vllm.model_executor.layers.quantization import QuantizationConfig
24
+ from vllm.model_executor.layers.quantization.deepspeedfp import (
25
+ DeepSpeedFPConfig, DeepSpeedFPParameter)
26
+ from vllm.model_executor.layers.rotary_embedding import get_rope
27
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
28
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead, VocabParallelEmbedding)
30
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
32
+ from vllm.model_executor.utils import set_weight_attrs
33
+ from vllm.sequence import IntermediateTensors
34
+ from vllm.transformers_utils.configs.arctic import ArcticConfig
35
+
36
+ from .interfaces import SupportsPP
37
+ from .utils import (extract_layer_index, is_pp_missing_parameter,
38
+ make_empty_intermediate_tensors_factory, make_layers,
39
+ maybe_prefix)
40
+
41
+ logger = init_logger(__name__)
42
+
43
+
44
+ class ArcticMLP(nn.Module):
45
+
46
+ def __init__(self,
47
+ config: ArcticConfig,
48
+ expert_id: int = -1,
49
+ is_residual_mlp: bool = False,
50
+ quant_config: Optional[QuantizationConfig] = None,
51
+ reduce_results: bool = True,
52
+ prefix: str = ""):
53
+ super().__init__()
54
+ self.hidden_size = config.hidden_size
55
+ self.expert_id = expert_id
56
+
57
+ self.ffn_dim = config.intermediate_size if not is_residual_mlp \
58
+ else self.hidden_size
59
+
60
+ self.w13 = MergedColumnParallelLinear(self.hidden_size,
61
+ [self.ffn_dim] * 2,
62
+ bias=False,
63
+ quant_config=quant_config)
64
+ self.w2 = RowParallelLinear(self.ffn_dim,
65
+ self.hidden_size,
66
+ bias=False,
67
+ reduce_results=reduce_results,
68
+ quant_config=quant_config)
69
+ if config.hidden_act != "silu":
70
+ raise ValueError(f"Unsupported activation: {config.hidden_act}. "
71
+ "Only silu is supported for now.")
72
+ self.act_fn = SiluAndMul()
73
+
74
+ def forward(self, hidden_states):
75
+ gate_up, _ = self.w13(hidden_states)
76
+ hidden_states = self.act_fn(gate_up)
77
+ hidden_states, _ = self.w2(hidden_states)
78
+ return hidden_states
79
+
80
+
81
+ class ArcticMoE(nn.Module):
82
+ """
83
+ Model-parallel implementation of Arctic MoE Layer.
84
+ """
85
+
86
+ def __init__(self,
87
+ config: ArcticConfig,
88
+ tp_size: Optional[int] = None,
89
+ params_dtype: Optional[torch.dtype] = None,
90
+ quant_config: Optional[QuantizationConfig] = None,
91
+ reduce_results: bool = True,
92
+ prefix: str = ""):
93
+ super().__init__()
94
+
95
+ layer_id = extract_layer_index(prefix)
96
+ self.tp_size = tp_size or get_tensor_model_parallel_world_size()
97
+ self.hidden_size = config.hidden_size
98
+ self.num_experts = config.num_local_experts
99
+ self.layer_id = layer_id
100
+ self.top_k = config.num_experts_per_tok
101
+ self.intermediate_size = config.intermediate_size // self.tp_size
102
+
103
+ self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
104
+ self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
105
+ self.reduce_results = reduce_results
106
+ # Some other parameters
107
+ if params_dtype is None:
108
+ params_dtype = torch.get_default_dtype()
109
+ self.params_dtype = params_dtype
110
+
111
+ if not self.is_moe_layer:
112
+ self.mlp = ArcticMLP(config,
113
+ quant_config=quant_config,
114
+ reduce_results=reduce_results,
115
+ prefix=f"{prefix}.mlp")
116
+ else:
117
+ self.gate = ReplicatedLinear(self.hidden_size,
118
+ self.num_experts,
119
+ bias=False,
120
+ params_dtype=self.params_dtype,
121
+ quant_config=quant_config,
122
+ prefix=f"{prefix}.gate")
123
+ if self.is_quant:
124
+ self.ws = DeepSpeedFPParameter(
125
+ torch.Size((self.num_experts, 2 * self.intermediate_size,
126
+ self.hidden_size)),
127
+ params_dtype=params_dtype,
128
+ quant_config=quant_config,
129
+ )
130
+ self.w2s = DeepSpeedFPParameter(
131
+ torch.Size((self.num_experts, self.hidden_size,
132
+ self.intermediate_size)),
133
+ params_dtype=params_dtype,
134
+ quant_config=quant_config,
135
+ )
136
+ else:
137
+ self.ws = nn.Parameter(
138
+ torch.empty(self.num_experts,
139
+ 2 * self.intermediate_size,
140
+ self.hidden_size,
141
+ device="cuda",
142
+ dtype=self.params_dtype))
143
+ self.w2s = nn.Parameter(
144
+ torch.empty(self.num_experts,
145
+ self.hidden_size,
146
+ self.intermediate_size,
147
+ device="cuda",
148
+ dtype=self.params_dtype))
149
+ set_weight_attrs(self.ws, {
150
+ "weight_loader": self.weight_loader,
151
+ })
152
+ set_weight_attrs(self.w2s, {
153
+ "weight_loader": self.weight_loader,
154
+ })
155
+
156
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
157
+ weight_name: str, expert_id: int):
158
+ tp_rank = get_tensor_model_parallel_rank()
159
+ param_data = param.ds_dequantize() if self.is_quant else param.data
160
+ shard_size = self.intermediate_size
161
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
162
+ if weight_name.endswith("w1.weight"):
163
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
164
+ if weight_name.endswith("w3.weight"):
165
+ param_data[expert_id,
166
+ shard_size:2 * shard_size, :] = loaded_weight[shard, :]
167
+ if weight_name.endswith("w2.weight"):
168
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
169
+ if self.is_quant:
170
+ param.ds_quantize_(param_data)
171
+
172
+ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
173
+ num_tokens, hidden_size = hidden_states.shape
174
+ hidden_states = hidden_states.view(-1, self.hidden_size)
175
+ # router_logits: (num_tokens, n_experts)
176
+ router_logits, _ = self.gate(hidden_states)
177
+ do_normalize = self.top_k > 1
178
+ topk_weights, topk_ids = fused_topk(hidden_states,
179
+ router_logits,
180
+ self.top_k,
181
+ renormalize=do_normalize)
182
+ # topk_ids: (num_tokens, k)
183
+ if self.is_quant:
184
+ if 2 * num_tokens <= self.num_experts:
185
+ # If much fewer tokens than experts, use selective dequantize.
186
+ ws_dequantized = self.ws.ds_selective_dequantize(
187
+ topk_ids.flatten())
188
+ w2s_dequantized = self.w2s.ds_selective_dequantize(
189
+ topk_ids.flatten())
190
+ # We gathered the experts to the tokens so update the mapping.
191
+ topk_ids = torch.arange(
192
+ 0,
193
+ topk_ids.numel(),
194
+ device=topk_ids.device,
195
+ ).reshape(topk_ids.shape)
196
+ else:
197
+ ws_dequantized = self.ws.ds_dequantize()
198
+ w2s_dequantized = self.w2s.ds_dequantize()
199
+
200
+ final_hidden_states = fused_experts(
201
+ hidden_states,
202
+ ws_dequantized if self.is_quant else self.ws,
203
+ w2s_dequantized if self.is_quant else self.w2s,
204
+ topk_weights,
205
+ topk_ids,
206
+ inplace=True)
207
+ if self.reduce_results and self.tp_size > 1:
208
+ final_hidden_states = tensor_model_parallel_all_reduce(
209
+ final_hidden_states)
210
+ return final_hidden_states.view(num_tokens, hidden_size)
211
+
212
+ def forward(self, hidden_states: torch.Tensor):
213
+ if self.is_moe_layer:
214
+ final_hidden_states = self.local_moe_fused(hidden_states)
215
+ else:
216
+ final_hidden_states = self.mlp(hidden_states)
217
+ return final_hidden_states
218
+
219
+
220
+ class ArcticAttention(nn.Module):
221
+
222
+ def __init__(
223
+ self,
224
+ config: ArcticConfig,
225
+ cache_config: Optional[CacheConfig] = None,
226
+ quant_config: Optional[QuantizationConfig] = None,
227
+ prefix: str = "",
228
+ ):
229
+ super().__init__()
230
+ self.config = config
231
+ self.hidden_size = config.hidden_size
232
+
233
+ tp_size = get_tensor_model_parallel_world_size()
234
+ self.total_num_heads = config.num_attention_heads
235
+ assert self.total_num_heads % tp_size == 0
236
+ self.num_heads = self.total_num_heads // tp_size
237
+ self.total_num_kv_heads = config.num_key_value_heads
238
+ if self.total_num_kv_heads >= tp_size:
239
+ assert self.total_num_kv_heads % tp_size == 0
240
+ else:
241
+ assert tp_size % self.total_num_kv_heads == 0
242
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
243
+ self.head_dim = self.hidden_size // self.total_num_heads
244
+ self.q_size = self.num_heads * self.head_dim
245
+ self.kv_size = self.num_kv_heads * self.head_dim
246
+
247
+ self.max_position_embeddings = config.max_position_embeddings
248
+ self.rope_theta = config.rope_theta
249
+ self.scaling = self.head_dim**-0.5
250
+
251
+ self.qkv_proj = QKVParallelLinear(self.hidden_size,
252
+ self.head_dim,
253
+ self.total_num_heads,
254
+ self.total_num_kv_heads,
255
+ bias=False,
256
+ quant_config=quant_config)
257
+ self.o_proj = RowParallelLinear(
258
+ self.total_num_heads * self.head_dim,
259
+ self.hidden_size,
260
+ bias=False,
261
+ reduce_results=True,
262
+ quant_config=quant_config,
263
+ )
264
+
265
+ self.rotary_emb = get_rope(
266
+ self.head_dim,
267
+ rotary_dim=self.head_dim,
268
+ max_position=self.max_position_embeddings,
269
+ base=int(self.rope_theta),
270
+ is_neox_style=True,
271
+ )
272
+
273
+ self.attn = Attention(self.num_heads,
274
+ self.head_dim,
275
+ self.scaling,
276
+ num_kv_heads=self.num_kv_heads,
277
+ cache_config=cache_config,
278
+ quant_config=quant_config,
279
+ prefix=f"{prefix}.attn")
280
+
281
+ def forward(
282
+ self,
283
+ positions: torch.Tensor,
284
+ hidden_states: torch.Tensor,
285
+ kv_cache: torch.Tensor,
286
+ attn_metadata: AttentionMetadata,
287
+ ) -> torch.Tensor:
288
+ qkv, _ = self.qkv_proj(hidden_states)
289
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
290
+ q, k = self.rotary_emb(positions, q, k)
291
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
292
+ output, _ = self.o_proj(attn_output)
293
+ return output
294
+
295
+
296
+ class ArcticDecoderLayer(nn.Module):
297
+
298
+ def __init__(
299
+ self,
300
+ config: ArcticConfig,
301
+ cache_config: Optional[CacheConfig] = None,
302
+ quant_config: Optional[QuantizationConfig] = None,
303
+ prefix: str = "",
304
+ ) -> None:
305
+ super().__init__()
306
+ self.hidden_size = config.hidden_size
307
+ layer_idx = extract_layer_index(prefix)
308
+ is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
309
+ self.use_residual = config.use_residual and is_moe_layer
310
+ self.self_attn = ArcticAttention(config,
311
+ cache_config,
312
+ quant_config=quant_config,
313
+ prefix=f"{prefix}.self_attn")
314
+ self.block_sparse_moe = ArcticMoE(
315
+ config,
316
+ quant_config=quant_config,
317
+ reduce_results=(not self.use_residual),
318
+ prefix=f"{prefix}.block_sparse_moe",
319
+ )
320
+
321
+ self.input_layernorm = RMSNorm(config.hidden_size,
322
+ eps=config.rms_norm_eps)
323
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
324
+ eps=config.rms_norm_eps)
325
+
326
+ if self.use_residual:
327
+ self.residual_layernorm = RMSNorm(config.hidden_size,
328
+ eps=config.rms_norm_eps)
329
+ self.residual_mlp = ArcticMLP(config,
330
+ is_residual_mlp=True,
331
+ reduce_results=False,
332
+ prefix=f"{prefix}.residual_mlp")
333
+
334
+ def forward(
335
+ self,
336
+ positions: torch.Tensor,
337
+ hidden_states: torch.Tensor,
338
+ kv_cache: torch.Tensor,
339
+ attn_metadata: AttentionMetadata,
340
+ ) -> torch.Tensor:
341
+ residual_input = hidden_states
342
+ hidden_states = self.input_layernorm(hidden_states)
343
+ hidden_states = self.self_attn(
344
+ positions=positions,
345
+ hidden_states=hidden_states,
346
+ kv_cache=kv_cache,
347
+ attn_metadata=attn_metadata,
348
+ )
349
+ hidden_states = residual_input + hidden_states
350
+
351
+ residual_attn = hidden_states
352
+ if self.use_residual:
353
+ hidden_states = self.residual_layernorm(hidden_states)
354
+ hidden_states = self.residual_mlp(hidden_states)
355
+ residual_mlp = hidden_states
356
+ hidden_states = self.post_attention_layernorm(residual_input)
357
+ hidden_states = self.block_sparse_moe(hidden_states)
358
+ hidden_states = residual_mlp + hidden_states
359
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
360
+ hidden_states = residual_attn + hidden_states
361
+ else:
362
+ hidden_states = self.post_attention_layernorm(hidden_states)
363
+ hidden_states = self.block_sparse_moe(hidden_states)
364
+ hidden_states = residual_attn + hidden_states
365
+ return hidden_states
366
+
367
+
368
+ @support_torch_compile
369
+ class ArcticModel(nn.Module):
370
+
371
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
372
+ super().__init__()
373
+
374
+ config = vllm_config.model_config.hf_config
375
+ cache_config = vllm_config.cache_config
376
+ quant_config = vllm_config.quant_config
377
+
378
+ self.padding_idx = config.pad_token_id
379
+ self.vocab_size = config.vocab_size
380
+ self.embed_tokens = VocabParallelEmbedding(
381
+ self.vocab_size,
382
+ config.hidden_size,
383
+ org_num_embeddings=self.vocab_size)
384
+ self.start_layer, self.end_layer, self.layers = make_layers(
385
+ config.num_hidden_layers,
386
+ lambda prefix: ArcticDecoderLayer(
387
+ config, cache_config, quant_config, prefix=prefix),
388
+ prefix=f"{prefix}.layers")
389
+ self._attn_implementation = config._attn_implementation
390
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
+ self.make_empty_intermediate_tensors = (
392
+ make_empty_intermediate_tensors_factory(["hidden_states"],
393
+ config.hidden_size))
394
+
395
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
396
+ return self.embed_tokens(input_ids)
397
+
398
+ def forward(
399
+ self,
400
+ input_ids: torch.Tensor,
401
+ positions: torch.Tensor,
402
+ kv_caches: List[torch.Tensor],
403
+ attn_metadata: AttentionMetadata,
404
+ intermediate_tensors: Optional[IntermediateTensors],
405
+ inputs_embeds: Optional[torch.Tensor] = None,
406
+ ) -> Union[torch.Tensor, IntermediateTensors]:
407
+ if get_pp_group().is_first_rank:
408
+ if inputs_embeds is not None:
409
+ hidden_states = inputs_embeds
410
+ else:
411
+ hidden_states = self.get_input_embeddings(input_ids)
412
+ else:
413
+ assert intermediate_tensors is not None
414
+ hidden_states = intermediate_tensors["hidden_states"]
415
+ for i in range(self.start_layer, self.end_layer):
416
+ layer = self.layers[i]
417
+ hidden_states = layer(positions, hidden_states,
418
+ kv_caches[i - self.start_layer],
419
+ attn_metadata)
420
+ if not get_pp_group().is_last_rank:
421
+ return IntermediateTensors({"hidden_states": hidden_states})
422
+ hidden_states = self.norm(hidden_states)
423
+ return hidden_states
424
+
425
+
426
+ class ArcticForCausalLM(nn.Module, SupportsPP):
427
+
428
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
429
+ super().__init__()
430
+ config = vllm_config.model_config.hf_config
431
+ quant_config = vllm_config.quant_config
432
+ self.config = config
433
+ self.model = ArcticModel(vllm_config=vllm_config,
434
+ prefix=maybe_prefix(prefix, "model"))
435
+ self.vocab_size = config.vocab_size
436
+ self.lm_head = ParallelLMHead(
437
+ self.vocab_size,
438
+ config.hidden_size,
439
+ quant_config=quant_config,
440
+ )
441
+ if self.config.tie_word_embeddings:
442
+ self.lm_head.weight = self.model.embed_tokens.weight
443
+ self.num_experts = config.num_local_experts
444
+ self.num_experts_per_tok = config.num_experts_per_tok
445
+ self.unpadded_vocab_size = config.vocab_size
446
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
447
+ config.vocab_size)
448
+ self.sampler = get_sampler()
449
+ self.make_empty_intermediate_tensors = (
450
+ self.model.make_empty_intermediate_tensors)
451
+
452
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
453
+ return self.model.get_input_embeddings(input_ids)
454
+
455
+ def forward(
456
+ self,
457
+ input_ids: torch.Tensor,
458
+ positions: torch.Tensor,
459
+ kv_caches: List[torch.Tensor],
460
+ attn_metadata: AttentionMetadata,
461
+ intermediate_tensors: Optional[IntermediateTensors] = None,
462
+ inputs_embeds: Optional[torch.Tensor] = None,
463
+ ) -> Union[torch.Tensor, IntermediateTensors]:
464
+ hidden_states = self.model(input_ids, positions, kv_caches,
465
+ attn_metadata, intermediate_tensors,
466
+ inputs_embeds)
467
+ return hidden_states
468
+
469
+ def compute_logits(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ sampling_metadata: SamplingMetadata,
473
+ ) -> Optional[torch.Tensor]:
474
+ logits = self.logits_processor(self.lm_head, hidden_states,
475
+ sampling_metadata)
476
+ return logits
477
+
478
+ def sample(
479
+ self,
480
+ logits: Optional[torch.Tensor],
481
+ sampling_metadata: SamplingMetadata,
482
+ ) -> Optional[SamplerOutput]:
483
+ next_tokens = self.sampler(logits, sampling_metadata)
484
+ return next_tokens
485
+
486
+ def load_weights(self, weights: Iterable[Tuple[str,
487
+ torch.Tensor]]) -> Set[str]:
488
+ stacked_params_mapping = [
489
+ # (param_name, shard_name, shard_id)
490
+ ("qkv_proj", "q_proj", "q"),
491
+ ("qkv_proj", "k_proj", "k"),
492
+ ("qkv_proj", "v_proj", "v"),
493
+ ]
494
+
495
+ mlp_params_mapping: List[Tuple[str, str, int]] = []
496
+ expert_params_mapping: List[Tuple[str, str, int]] = []
497
+ num_layers = self.config.num_hidden_layers
498
+
499
+ for layer in range(num_layers):
500
+ mlp_params_mapping.append(
501
+ (f"layers.{layer}.residual_mlp.w13.weight",
502
+ f"layers.{layer}.residual_mlp.w1.weight", 0))
503
+ mlp_params_mapping.append(
504
+ (f"layers.{layer}.residual_mlp.w13.weight",
505
+ f"layers.{layer}.residual_mlp.w3.weight", 1))
506
+ if layer % 2 == 0:
507
+ # MLP layers
508
+ mlp_params_mapping.append(
509
+ (f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
510
+ f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0))
511
+ mlp_params_mapping.append(
512
+ (f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
513
+ f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1))
514
+ else:
515
+ # MoE layers
516
+ for expert_id in range(self.config.num_local_experts):
517
+ expert_params_mapping.append(
518
+ ("ws", f"experts.{expert_id}.w1.weight", expert_id))
519
+ expert_params_mapping.append(
520
+ ("w2s", f"experts.{expert_id}.w2.weight", expert_id))
521
+ expert_params_mapping.append(
522
+ ("ws", f"experts.{expert_id}.w3.weight", expert_id))
523
+
524
+ params_dict = dict(self.named_parameters())
525
+ loaded_params: Set[str] = set()
526
+
527
+ logger.info(
528
+ "It will take ~10 minutes loading from the 16-bit weights. "
529
+ "Alternatively, use the prequantized 8-bit weights of arctic "
530
+ "and set load-format to `sharded_state` will accelerate loading.")
531
+ for name, loaded_weight in weights:
532
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
533
+ if weight_name not in name:
534
+ continue
535
+ name = name.replace(weight_name, param_name)
536
+ # Skip loading extra bias for GPTQ models.
537
+ if name.endswith(".bias") and name not in params_dict:
538
+ continue
539
+ if is_pp_missing_parameter(name, self):
540
+ continue
541
+ param = params_dict[name]
542
+ weight_loader = param.weight_loader
543
+ weight_loader(param, loaded_weight, shard_id)
544
+ break
545
+ else:
546
+ for param_name, weight_name, shard_id in mlp_params_mapping:
547
+ if weight_name not in name:
548
+ continue
549
+ name = name.replace(weight_name, param_name)
550
+ if is_pp_missing_parameter(name, self):
551
+ continue
552
+ param = params_dict[name]
553
+ weight_loader = param.weight_loader
554
+ weight_loader(param, loaded_weight, shard_id)
555
+ break
556
+ else:
557
+ for param_name, weight_name, shard_id \
558
+ in expert_params_mapping:
559
+ if weight_name not in name:
560
+ continue
561
+ name = name.replace(weight_name, param_name)
562
+ if is_pp_missing_parameter(name, self):
563
+ continue
564
+ param = params_dict[name]
565
+ weight_loader = param.weight_loader
566
+ weight_loader(param,
567
+ loaded_weight,
568
+ weight_name,
569
+ expert_id=shard_id)
570
+ break
571
+ else:
572
+ if name.endswith(".bias") and name not in params_dict:
573
+ continue
574
+ if is_pp_missing_parameter(name, self):
575
+ continue
576
+ param = params_dict[name]
577
+
578
+ weight_loader = getattr(param, "weight_loader",
579
+ default_weight_loader)
580
+ weight_loader(param, loaded_weight)
581
+ loaded_params.add(name)
582
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/bart.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Derived from BART implementation posted on HuggingFace; license below:
4
+ #
5
+ # coding=utf-8
6
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
7
+ # All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch BART model."""
21
+ import math
22
+ from typing import Iterable, List, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import BartConfig
27
+ from transformers.utils import logging
28
+
29
+ from vllm.attention import Attention, AttentionMetadata, AttentionType
30
+ from vllm.config import CacheConfig, LoRAConfig, VllmConfig
31
+ from vllm.distributed import get_tensor_model_parallel_world_size
32
+ from vllm.model_executor.layers.activation import get_act_fn
33
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
+ from vllm.model_executor.layers.quantization.base_config import (
38
+ QuantizationConfig)
39
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
40
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead, VocabParallelEmbedding)
42
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
44
+ from vllm.sequence import IntermediateTensors
45
+
46
+ from .utils import maybe_prefix
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def get_bsz_seq_len(input_ids):
52
+ shp = input_ids.shape
53
+ ndim = len(shp)
54
+ if ndim == 1:
55
+ return 1, input_ids.numel()
56
+ else:
57
+ return shp[:2]
58
+
59
+
60
+ class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
61
+ """
62
+ This module learns positional embeddings up to a fixed maximum size.
63
+ """
64
+
65
+ def __init__(self, num_embeddings: int, embedding_dim: int):
66
+ # Bart is set up so that if padding_idx is
67
+ # specified then offset the embedding ids by 2
68
+ # and adjust num_embeddings appropriately.
69
+ # Other models don't have this hack
70
+ self.offset = 2
71
+ super().__init__(num_embeddings + self.offset, embedding_dim)
72
+
73
+ def forward(
74
+ self,
75
+ positions: torch.Tensor,
76
+ ) -> torch.Tensor:
77
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
78
+ return super().forward(positions + self.offset)
79
+
80
+
81
+ class BartScaledWordEmbedding(VocabParallelEmbedding):
82
+ """
83
+ This module overrides VocabParallelEmbedding's
84
+ forward by multiplying with embeddings scale.
85
+ """
86
+
87
+ def __init__(self,
88
+ num_embeddings: int,
89
+ embedding_dim: int,
90
+ embed_scale: float = 1.0):
91
+ super().__init__(num_embeddings, embedding_dim)
92
+ self.embed_scale = embed_scale
93
+
94
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
95
+ return super().forward(input_ids) * self.embed_scale
96
+
97
+
98
+ class BartParallelLMHead(ParallelLMHead):
99
+ """
100
+ This module overrides ParallelLMHead's
101
+ forward by dividing by embeddings scale,
102
+ yielding effectively the inverse of
103
+ BartScaledWordEmbedding
104
+ """
105
+
106
+ def __init__(self,
107
+ num_embeddings: int,
108
+ embedding_dim: int,
109
+ embed_scale: float = 1.0):
110
+ super().__init__(num_embeddings, embedding_dim)
111
+ self.embed_scale = embed_scale
112
+
113
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
114
+ return super().forward(input_ids) / self.embed_scale
115
+
116
+
117
+ class BartEncoderAttention(nn.Module):
118
+
119
+ def __init__(
120
+ self,
121
+ embed_dim: int,
122
+ num_heads: int,
123
+ bias: bool = True,
124
+ config: Optional[BartConfig] = None,
125
+ cache_config: Optional[CacheConfig] = None,
126
+ quant_config: Optional[QuantizationConfig] = None,
127
+ prefix: str = "",
128
+ ):
129
+ super().__init__()
130
+ self.d_model = config.d_model
131
+ self.embed_dim = embed_dim
132
+ self.total_num_heads = num_heads
133
+ self.total_num_kv_heads = self.total_num_heads
134
+ self.head_dim = embed_dim // num_heads
135
+ self.config = config
136
+
137
+ if (self.head_dim * num_heads) != self.embed_dim:
138
+ raise ValueError(f"embed_dim must be divisible by num_heads "
139
+ f"(got `embed_dim`: {self.embed_dim}"
140
+ f" and `num_heads`: {num_heads}).")
141
+ self.scaling = self.head_dim**-0.5
142
+
143
+ self.qkv_proj = QKVParallelLinear(
144
+ self.d_model,
145
+ self.d_model // self.total_num_heads,
146
+ self.total_num_heads,
147
+ self.total_num_kv_heads,
148
+ bias=bias,
149
+ quant_config=quant_config,
150
+ )
151
+
152
+ self.out_proj = RowParallelLinear(
153
+ embed_dim,
154
+ embed_dim,
155
+ bias=bias,
156
+ quant_config=quant_config,
157
+ )
158
+
159
+ tp_world_size = get_tensor_model_parallel_world_size()
160
+ assert self.total_num_heads % tp_world_size == 0
161
+ self.num_heads = self.total_num_heads // tp_world_size
162
+
163
+ if self.total_num_kv_heads >= tp_world_size:
164
+ # Number of KV heads is greater than TP size, so we partition
165
+ # the KV heads across multiple tensor parallel GPUs.
166
+ assert self.total_num_kv_heads % tp_world_size == 0
167
+ else:
168
+ # Number of KV heads is less than TP size, so we replicate
169
+ # the KV heads across multiple tensor parallel GPUs.
170
+ assert tp_world_size % self.total_num_kv_heads == 0
171
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
172
+ self.q_size = self.num_heads * self.head_dim
173
+ self.kv_size = self.num_kv_heads * self.head_dim
174
+
175
+ self.attn = Attention(self.num_heads,
176
+ self.head_dim,
177
+ self.scaling,
178
+ num_kv_heads=self.num_kv_heads,
179
+ cache_config=cache_config,
180
+ quant_config=quant_config,
181
+ prefix=f"{prefix}.attn",
182
+ attn_type=AttentionType.ENCODER)
183
+
184
+ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
185
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
186
+ """Input shape: Batch x Time x Channel"""
187
+
188
+ qkv, _ = self.qkv_proj(hidden_states)
189
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
190
+
191
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
192
+
193
+ output, _ = self.out_proj(attn_output)
194
+ return output
195
+
196
+
197
+ class BartDecoderSelfAttention(nn.Module):
198
+
199
+ def __init__(
200
+ self,
201
+ embed_dim: int,
202
+ num_heads: int,
203
+ bias: bool = True,
204
+ config: Optional[BartConfig] = None,
205
+ cache_config: Optional[CacheConfig] = None,
206
+ quant_config: Optional[QuantizationConfig] = None,
207
+ prefix: str = "",
208
+ ):
209
+ super().__init__()
210
+ self.d_model = config.d_model
211
+ self.embed_dim = embed_dim
212
+ self.total_num_heads = num_heads
213
+ self.total_num_kv_heads = self.total_num_heads
214
+ self.head_dim = embed_dim // num_heads
215
+ self.config = config
216
+
217
+ if (self.head_dim * num_heads) != self.embed_dim:
218
+ raise ValueError(f"embed_dim must be divisible by num_heads "
219
+ f"(got `embed_dim`: {self.embed_dim}"
220
+ f" and `num_heads`: {num_heads}).")
221
+ self.scaling = self.head_dim**-0.5
222
+
223
+ self.qkv_proj = QKVParallelLinear(
224
+ self.d_model,
225
+ self.d_model // self.total_num_heads,
226
+ self.total_num_heads,
227
+ self.total_num_kv_heads,
228
+ bias=bias,
229
+ quant_config=quant_config,
230
+ )
231
+
232
+ self.out_proj = RowParallelLinear(
233
+ embed_dim,
234
+ embed_dim,
235
+ bias=bias,
236
+ quant_config=quant_config,
237
+ )
238
+
239
+ tp_world_size = get_tensor_model_parallel_world_size()
240
+ assert self.total_num_heads % tp_world_size == 0
241
+ self.num_heads = self.total_num_heads // tp_world_size
242
+
243
+ if self.total_num_kv_heads >= tp_world_size:
244
+ # Number of KV heads is greater than TP size, so we partition
245
+ # the KV heads across multiple tensor parallel GPUs.
246
+ assert self.total_num_kv_heads % tp_world_size == 0
247
+ else:
248
+ # Number of KV heads is less than TP size, so we replicate
249
+ # the KV heads across multiple tensor parallel GPUs.
250
+ assert tp_world_size % self.total_num_kv_heads == 0
251
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
252
+ self.q_size = self.num_heads * self.head_dim
253
+ self.kv_size = self.num_kv_heads * self.head_dim
254
+
255
+ self.attn = Attention(self.num_heads,
256
+ self.head_dim,
257
+ self.scaling,
258
+ num_kv_heads=self.num_kv_heads,
259
+ cache_config=cache_config,
260
+ quant_config=quant_config,
261
+ prefix=f"{prefix}.attn",
262
+ attn_type=AttentionType.DECODER)
263
+
264
+ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
265
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
266
+ """Input shape: Batch x Time x Channel"""
267
+
268
+ qkv, _ = self.qkv_proj(hidden_states)
269
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
270
+
271
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
272
+
273
+ output, _ = self.out_proj(attn_output)
274
+ return output
275
+
276
+
277
+ class BartCrossAttention(nn.Module):
278
+
279
+ def __init__(
280
+ self,
281
+ embed_dim: int,
282
+ num_heads: int,
283
+ bias: bool = True,
284
+ config: Optional[BartConfig] = None,
285
+ cache_config: Optional[CacheConfig] = None,
286
+ quant_config: Optional[QuantizationConfig] = None,
287
+ prefix: str = "",
288
+ ):
289
+ super().__init__()
290
+ self.d_model = config.d_model
291
+ self.embed_dim = embed_dim
292
+ self.total_num_heads = num_heads
293
+ self.total_num_kv_heads = self.total_num_heads
294
+ self.head_dim = embed_dim // num_heads
295
+ self.config = config
296
+
297
+ if (self.head_dim * num_heads) != self.embed_dim:
298
+ raise ValueError(f"embed_dim must be divisible by num_heads "
299
+ f"(got `embed_dim`: {self.embed_dim}"
300
+ f" and `num_heads`: {num_heads}).")
301
+ self.scaling = self.head_dim**-0.5
302
+
303
+ self.qkv_proj = QKVParallelLinear(
304
+ self.d_model,
305
+ self.d_model // self.total_num_heads,
306
+ self.total_num_heads,
307
+ self.total_num_kv_heads,
308
+ bias=bias,
309
+ quant_config=quant_config,
310
+ )
311
+
312
+ self.out_proj = RowParallelLinear(
313
+ embed_dim,
314
+ embed_dim,
315
+ bias=bias,
316
+ quant_config=quant_config,
317
+ )
318
+
319
+ tp_world_size = get_tensor_model_parallel_world_size()
320
+ assert self.total_num_heads % tp_world_size == 0
321
+ self.num_heads = self.total_num_heads // tp_world_size
322
+
323
+ if self.total_num_kv_heads >= tp_world_size:
324
+ # Number of KV heads is greater than TP size, so we partition
325
+ # the KV heads across multiple tensor parallel GPUs.
326
+ assert self.total_num_kv_heads % tp_world_size == 0
327
+ else:
328
+ # Number of KV heads is less than TP size, so we replicate
329
+ # the KV heads across multiple tensor parallel GPUs.
330
+ assert tp_world_size % self.total_num_kv_heads == 0
331
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
332
+ self.q_size = self.num_heads * self.head_dim
333
+ self.kv_size = self.num_kv_heads * self.head_dim
334
+
335
+ self.attn = Attention(self.num_heads,
336
+ self.head_dim,
337
+ self.scaling,
338
+ num_kv_heads=self.num_kv_heads,
339
+ cache_config=cache_config,
340
+ quant_config=quant_config,
341
+ prefix=f"{prefix}.attn",
342
+ attn_type=AttentionType.ENCODER_DECODER)
343
+
344
+ def forward(
345
+ self,
346
+ decoder_hidden_states: torch.Tensor,
347
+ kv_cache: torch.Tensor,
348
+ attn_metadata: AttentionMetadata,
349
+ encoder_hidden_states: Optional[torch.Tensor] = None,
350
+ ) -> torch.Tensor:
351
+ """Input shape: Batch x Time x Channel"""
352
+
353
+ # (afeldman-nm 2024/07/22) TODO:
354
+ # Need a more efficient solution for q/k/v
355
+ qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
356
+ q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
357
+ dim=-1)
358
+ if encoder_hidden_states is None:
359
+ k = None
360
+ v = None
361
+ else:
362
+ qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
363
+ _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
364
+ dim=-1)
365
+
366
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
367
+
368
+ output, _ = self.out_proj(attn_output)
369
+ return output
370
+
371
+
372
+ class BartEncoderLayer(nn.Module):
373
+
374
+ def __init__(
375
+ self,
376
+ config: BartConfig,
377
+ cache_config: Optional[CacheConfig] = None,
378
+ quant_config: Optional[QuantizationConfig] = None,
379
+ prefix: str = "",
380
+ ):
381
+ super().__init__()
382
+ self.embed_dim = config.d_model
383
+
384
+ self.self_attn = BartEncoderAttention(
385
+ embed_dim=self.embed_dim,
386
+ num_heads=config.encoder_attention_heads,
387
+ config=config,
388
+ cache_config=cache_config,
389
+ quant_config=quant_config,
390
+ prefix=f"{prefix}.self_attn",
391
+ )
392
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
393
+ self.activation_fn = get_act_fn(config.activation_function)
394
+
395
+ ffn_hidden_size = self.embed_dim
396
+ ffn_intermediate_size = config.encoder_ffn_dim
397
+ ffn_has_bias = True
398
+ self.fc1 = ColumnParallelLinear(
399
+ ffn_hidden_size,
400
+ ffn_intermediate_size,
401
+ bias=ffn_has_bias,
402
+ quant_config=quant_config,
403
+ )
404
+ self.act = get_act_fn("gelu")
405
+ self.fc2 = RowParallelLinear(
406
+ ffn_intermediate_size,
407
+ ffn_hidden_size,
408
+ bias=ffn_has_bias,
409
+ quant_config=quant_config,
410
+ )
411
+
412
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
413
+
414
+ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
415
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
416
+ r"""
417
+ Args:
418
+ hidden_states
419
+ torch.Tensor of *encoder* input embeddings.
420
+ kv_cache:
421
+ Layer-wise list of KV cache tensors
422
+ attn_metadata:
423
+ vLLM Attention metadata structure
424
+ Returns:
425
+ Encoder layer output torch.Tensor
426
+ """
427
+ residual = hidden_states
428
+ hidden_states = self.self_attn(hidden_states=hidden_states,
429
+ kv_cache=kv_cache,
430
+ attn_metadata=attn_metadata)
431
+
432
+ hidden_states = residual + hidden_states
433
+ hidden_states = self.self_attn_layer_norm(hidden_states)
434
+
435
+ residual = hidden_states
436
+ fc1_out, _ = self.fc1(hidden_states)
437
+ hidden_states = self.activation_fn(fc1_out)
438
+
439
+ hidden_states, _ = self.fc2(hidden_states)
440
+
441
+ hidden_states = residual + hidden_states
442
+ hidden_states = self.final_layer_norm(hidden_states)
443
+
444
+ if hidden_states.dtype == torch.float16 and (
445
+ torch.isinf(hidden_states).any()
446
+ or torch.isnan(hidden_states).any()):
447
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
448
+ hidden_states = torch.clamp(hidden_states,
449
+ min=-clamp_value,
450
+ max=clamp_value)
451
+
452
+ return hidden_states
453
+
454
+
455
+ class BartDecoderLayer(nn.Module):
456
+
457
+ def __init__(
458
+ self,
459
+ config: BartConfig,
460
+ cache_config: Optional[CacheConfig] = None,
461
+ quant_config: Optional[QuantizationConfig] = None,
462
+ prefix: str = "",
463
+ ):
464
+ super().__init__()
465
+ self.embed_dim = config.d_model
466
+
467
+ self.self_attn = BartDecoderSelfAttention(
468
+ embed_dim=self.embed_dim,
469
+ num_heads=config.decoder_attention_heads,
470
+ config=config,
471
+ cache_config=cache_config,
472
+ quant_config=quant_config,
473
+ prefix=f"{prefix}.self_attn",
474
+ )
475
+ self.activation_fn = get_act_fn(config.activation_function)
476
+
477
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
478
+ '''
479
+ afeldman-nm: personally I would call this "cross-attention",
480
+ however I left the name as "encoder_attn" to maintain consistency
481
+ with the name of the pretrained weights.
482
+ '''
483
+ self.encoder_attn = BartCrossAttention(
484
+ self.embed_dim,
485
+ config.decoder_attention_heads,
486
+ config=config,
487
+ prefix=f"{prefix}.encoder_attn",
488
+ )
489
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
490
+
491
+ ffn_hidden_size = self.embed_dim
492
+ ffn_intermediate_size = config.encoder_ffn_dim
493
+ ffn_has_bias = True
494
+ self.fc1 = ColumnParallelLinear(
495
+ ffn_hidden_size,
496
+ ffn_intermediate_size,
497
+ bias=ffn_has_bias,
498
+ quant_config=quant_config,
499
+ )
500
+ self.fc2 = RowParallelLinear(
501
+ ffn_intermediate_size,
502
+ ffn_hidden_size,
503
+ bias=ffn_has_bias,
504
+ quant_config=quant_config,
505
+ )
506
+
507
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
508
+
509
+ def forward(
510
+ self,
511
+ decoder_hidden_states: torch.Tensor,
512
+ kv_cache: torch.Tensor,
513
+ attn_metadata: AttentionMetadata,
514
+ encoder_hidden_states: Optional[torch.Tensor] = None,
515
+ ) -> torch.Tensor:
516
+ r"""
517
+ Args:
518
+ decoder_hidden_states
519
+ torch.Tensor of *decoder* input embeddings.
520
+ kv_cache:
521
+ KV cache tensor
522
+ attn_metadata:
523
+ vLLM Attention metadata structure
524
+ encoder_hidden_states
525
+ torch.Tensor of *encoder* input embeddings.
526
+ Returns:
527
+ Decoder layer output torch.Tensor
528
+ """
529
+ residual = decoder_hidden_states
530
+
531
+ # Self Attention
532
+ hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
533
+ kv_cache=kv_cache,
534
+ attn_metadata=attn_metadata)
535
+
536
+ hidden_states = residual + hidden_states
537
+ hidden_states = self.self_attn_layer_norm(hidden_states)
538
+
539
+ # Cross-Attention Block
540
+
541
+ residual = hidden_states
542
+
543
+ hidden_states = self.encoder_attn(
544
+ decoder_hidden_states=hidden_states,
545
+ kv_cache=kv_cache,
546
+ attn_metadata=attn_metadata,
547
+ encoder_hidden_states=encoder_hidden_states,
548
+ )
549
+
550
+ hidden_states = residual + hidden_states
551
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
552
+
553
+ # Fully Connected
554
+ residual = hidden_states
555
+ fc1_out, _ = self.fc1(hidden_states)
556
+ hidden_states = self.activation_fn(fc1_out)
557
+
558
+ hidden_states, _ = self.fc2(hidden_states)
559
+
560
+ hidden_states = residual + hidden_states
561
+ hidden_states = self.final_layer_norm(hidden_states)
562
+
563
+ return hidden_states
564
+
565
+
566
+ class BartEncoder(nn.Module):
567
+ """
568
+ Transformer encoder consisting of *config.encoder_layers*
569
+ self attention layers. Each layer is a [`BartEncoderLayer`].
570
+ Args:
571
+ config: BartConfig
572
+ embed_tokens (nn.Embedding): output embedding
573
+ """
574
+
575
+ def __init__(self,
576
+ config: BartConfig,
577
+ cache_config: Optional[CacheConfig] = None,
578
+ quant_config: Optional[QuantizationConfig] = None,
579
+ lora_config: Optional[LoRAConfig] = None,
580
+ embed_tokens: Optional[nn.Embedding] = None,
581
+ prefix: str = ""):
582
+ super().__init__()
583
+
584
+ self.cache_config = cache_config
585
+ self.quant_config = quant_config
586
+ self.lora_config = lora_config
587
+ embed_dim = config.d_model
588
+ self.max_source_positions = config.max_position_embeddings
589
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
590
+
591
+ self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
592
+ embed_dim,
593
+ embed_scale=embed_scale)
594
+
595
+ if embed_tokens is not None:
596
+ self.embed_tokens.weight = embed_tokens.weight
597
+
598
+ self.embed_positions = BartLearnedPositionalEmbedding(
599
+ config.max_position_embeddings,
600
+ embed_dim,
601
+ )
602
+ self.layers = nn.ModuleList([
603
+ BartEncoderLayer(config,
604
+ cache_config,
605
+ quant_config,
606
+ prefix=f"{prefix}.layers.{layer_idx}")
607
+ for layer_idx in range(config.encoder_layers)
608
+ ])
609
+
610
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
611
+
612
+ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
613
+ kv_caches: List[torch.Tensor],
614
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
615
+ r"""
616
+ Args:
617
+ input_ids
618
+ Indices of *encoder* input sequence tokens in the vocabulary.
619
+ Padding will be ignored by default should you
620
+ provide it.
621
+ positions
622
+ Positions of *encoder* input sequence tokens.
623
+ kv_caches:
624
+ Layer-wise list of KV cache tensors
625
+ attn_metadata:
626
+ vLLM Attention metadata structure
627
+ Returns:
628
+ Decoder output torch.Tensor
629
+ """
630
+ # retrieve input_ids and inputs_embeds
631
+ inputs_embeds = self.embed_tokens(input_ids)
632
+
633
+ embed_pos = self.embed_positions(positions)
634
+ embed_pos = embed_pos.to(inputs_embeds.device)
635
+
636
+ hidden_states = inputs_embeds + embed_pos
637
+ hidden_states = self.layernorm_embedding(hidden_states)
638
+
639
+ for idx, encoder_layer in enumerate(self.layers):
640
+ hidden_states = encoder_layer(
641
+ hidden_states=hidden_states,
642
+ kv_cache=kv_caches[idx],
643
+ attn_metadata=attn_metadata,
644
+ )
645
+
646
+ return hidden_states
647
+
648
+
649
+ class BartDecoder(nn.Module):
650
+ """
651
+ Transformer decoder consisting of *config.decoder_layers* layers.
652
+ Each layer is a [`BartDecoderLayer`]
653
+ Args:
654
+ config: BartConfig
655
+ embed_tokens (nn.Embedding): output embedding
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ config: BartConfig,
661
+ cache_config: Optional[CacheConfig] = None,
662
+ quant_config: Optional[QuantizationConfig] = None,
663
+ lora_config: Optional[LoRAConfig] = None,
664
+ embed_tokens: Optional[nn.Embedding] = None,
665
+ prefix: str = "",
666
+ ):
667
+ super().__init__()
668
+ self.cache_config = cache_config
669
+ self.quant_config = quant_config
670
+ self.lora_config = lora_config
671
+ self.max_target_positions = config.max_position_embeddings
672
+ embed_scale = math.sqrt(
673
+ config.d_model) if config.scale_embedding else 1.0
674
+
675
+ self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
676
+ config.d_model,
677
+ embed_scale=embed_scale)
678
+
679
+ if embed_tokens is not None:
680
+ self.embed_tokens.weight = embed_tokens.weight
681
+
682
+ self.embed_positions = BartLearnedPositionalEmbedding(
683
+ config.max_position_embeddings,
684
+ config.d_model,
685
+ )
686
+
687
+ self.layers = nn.ModuleList(
688
+ [BartDecoderLayer(config,cache_config,quant_config,
689
+ prefix=f"{prefix}.layers.{layer_idx}") \
690
+ for layer_idx in range(config.decoder_layers)])
691
+
692
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
693
+
694
+ def forward(self, decoder_input_ids: torch.Tensor,
695
+ decoder_positions: torch.Tensor,
696
+ encoder_hidden_states: Optional[torch.Tensor],
697
+ kv_caches: List[torch.Tensor],
698
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
699
+ r"""
700
+ Args:
701
+ decoder_input_ids
702
+ Indices of *decoder* input sequence tokens in the vocabulary.
703
+ Padding will be ignored by default should you
704
+ provide it.
705
+ decoder_positions
706
+ Positions of *decoder* input sequence tokens.
707
+ encoder_hidden_states:
708
+ Tensor of encoder output embeddings
709
+ kv_caches:
710
+ Layer-wise list of KV cache tensors
711
+ attn_metadata:
712
+ vLLM Attention metadata structure
713
+ Returns:
714
+ Decoder output torch.Tensor
715
+ """
716
+
717
+ inputs_embeds = self.embed_tokens(decoder_input_ids)
718
+
719
+ # embed positions
720
+ embed_pos = self.embed_positions(decoder_positions)
721
+ embed_pos = embed_pos.to(inputs_embeds.device)
722
+
723
+ hidden_states = inputs_embeds + embed_pos
724
+ hidden_states = self.layernorm_embedding(hidden_states)
725
+
726
+ # decoder layers
727
+
728
+ for idx, decoder_layer in enumerate(self.layers):
729
+ hidden_states = decoder_layer(
730
+ decoder_hidden_states=hidden_states,
731
+ kv_cache=kv_caches[idx],
732
+ attn_metadata=attn_metadata,
733
+ encoder_hidden_states=encoder_hidden_states,
734
+ )
735
+
736
+ return hidden_states
737
+
738
+
739
+ class BartModel(nn.Module):
740
+ _tied_weights_keys = [
741
+ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
742
+ ]
743
+
744
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
745
+ super().__init__()
746
+
747
+ config = vllm_config.model_config.hf_config
748
+ cache_config = vllm_config.cache_config
749
+ quant_config = vllm_config.quant_config
750
+ lora_config = vllm_config.lora_config
751
+
752
+ self.config = config
753
+
754
+ self.padding_idx = config.pad_token_id
755
+ lora_vocab = (lora_config.lora_extra_vocab_size *
756
+ (lora_config.max_loras or 1)) if lora_config else 0
757
+ self.vocab_size = config.vocab_size + lora_vocab
758
+ self.org_vocab_size = config.vocab_size
759
+
760
+ self.encoder = BartEncoder(config,
761
+ cache_config,
762
+ quant_config=quant_config,
763
+ prefix=f"{prefix}.encoder")
764
+ self.decoder = BartDecoder(config,
765
+ cache_config,
766
+ quant_config=quant_config,
767
+ prefix=f"{prefix}.decoder")
768
+
769
+ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
770
+ encoder_input_ids: torch.Tensor,
771
+ encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
772
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
773
+ r"""
774
+ Args:
775
+ input_ids
776
+ Indices of *decoder* input sequence tokens in the vocabulary.
777
+ Padding will be ignored by default should you
778
+ provide it.
779
+ positions
780
+ Positions of *decoder* input sequence tokens.
781
+ encoder_input_ids
782
+ Indices of *encoder* input sequence tokens in the vocabulary.
783
+ encoder_positions:
784
+ Positions of *encoder* input sequence tokens.
785
+ kv_caches:
786
+ Layer-wise list of KV cache tensors
787
+ attn_metadata:
788
+ vLLM Attention metadata structure
789
+ Returns:
790
+ Model output torch.Tensor
791
+ """
792
+
793
+ encoder_hidden_states = None
794
+
795
+ if encoder_input_ids.numel() > 0:
796
+ # Run encoder attention if a non-zero number of encoder tokens
797
+ # are provided as input
798
+ encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
799
+ positions=encoder_positions,
800
+ kv_caches=kv_caches,
801
+ attn_metadata=attn_metadata)
802
+
803
+ # decoder outputs consists of
804
+ # (dec_features, past_key_value, dec_hidden, dec_attn)
805
+ decoder_outputs = self.decoder(
806
+ decoder_input_ids=input_ids,
807
+ decoder_positions=positions,
808
+ encoder_hidden_states=encoder_hidden_states,
809
+ kv_caches=kv_caches,
810
+ attn_metadata=attn_metadata)
811
+
812
+ return decoder_outputs
813
+
814
+
815
+ class BartForConditionalGeneration(nn.Module):
816
+ base_model_prefix = "model"
817
+
818
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
819
+
820
+ super().__init__()
821
+ config = vllm_config.model_config.hf_config
822
+ lora_config = vllm_config.lora_config
823
+ # currently all existing BART models have `tie_word_embeddings` enabled
824
+ assert config.tie_word_embeddings
825
+ self.config = config
826
+ self.model = BartModel(vllm_config=vllm_config,
827
+ prefix=maybe_prefix(prefix, "model"))
828
+
829
+ self.unpadded_vocab_size = config.vocab_size
830
+ if lora_config:
831
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
832
+
833
+ embed_scale = math.sqrt(
834
+ config.d_model) if config.scale_embedding else 1.0
835
+
836
+ self.lm_head = BartParallelLMHead(config.vocab_size,
837
+ config.d_model,
838
+ embed_scale=embed_scale)
839
+
840
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
841
+ config.vocab_size)
842
+ self.sampler = get_sampler()
843
+
844
+ def forward(
845
+ self,
846
+ input_ids: torch.Tensor,
847
+ positions: torch.Tensor,
848
+ kv_caches: List[torch.Tensor],
849
+ attn_metadata: AttentionMetadata,
850
+ intermediate_tensors: Optional[IntermediateTensors] = None,
851
+ *,
852
+ encoder_input_ids: torch.Tensor,
853
+ encoder_positions: torch.Tensor,
854
+ **kwargs,
855
+ ) -> torch.Tensor:
856
+ r"""
857
+ Args:
858
+ input_ids
859
+ torch.Tensor of *decoder* input token ids.
860
+ positions
861
+ torch.Tensor of *decoder* position indices.
862
+ encoder_input_ids
863
+ torch.Tensor of *encoder* input token ids.
864
+ encoder_positions
865
+ torch.Tensor of *encoder* position indices
866
+ kv_caches:
867
+ Layer-wise list of KV cache tensors
868
+ attn_metadata:
869
+ vLLM Attention metadata structure
870
+ Returns:
871
+ Output torch.Tensor
872
+ """
873
+ return self.model(input_ids, positions, encoder_input_ids,
874
+ encoder_positions, kv_caches, attn_metadata)
875
+
876
+ def compute_logits(
877
+ self,
878
+ hidden_states: torch.Tensor,
879
+ sampling_metadata: SamplingMetadata,
880
+ ) -> Optional[torch.Tensor]:
881
+ logits = self.logits_processor(self.lm_head, hidden_states,
882
+ sampling_metadata)
883
+ return logits
884
+
885
+ def sample(
886
+ self,
887
+ logits: Optional[torch.Tensor],
888
+ sampling_metadata: SamplingMetadata,
889
+ ) -> Optional[SamplerOutput]:
890
+ next_tokens = self.sampler(logits, sampling_metadata)
891
+ return next_tokens
892
+
893
+ stacked_params_mapping = {
894
+ "q_proj": {
895
+ "param_name": "qkv_proj",
896
+ "shard_id": "q",
897
+ },
898
+ "k_proj": {
899
+ "param_name": "qkv_proj",
900
+ "shard_id": "k",
901
+ },
902
+ "v_proj": {
903
+ "param_name": "qkv_proj",
904
+ "shard_id": "v",
905
+ },
906
+ }
907
+
908
+ params_mapping = {
909
+ "beta": "bias",
910
+ "gamma": "weight",
911
+ "LayerNorm": "layernorm",
912
+ }
913
+
914
+ def _rename_key(self, key: str):
915
+ prefix = f"{self.base_model_prefix}."
916
+ key = key[len(prefix):] if key.startswith(prefix) else key
917
+
918
+ for src, dst in self.params_mapping.items():
919
+ key = key.replace(src, dst)
920
+
921
+ return key
922
+
923
+ def _rename_stacked_param(
924
+ self,
925
+ name: str,
926
+ ) -> Tuple[str, Optional[str]]:
927
+ for key, mapping in self.stacked_params_mapping.items():
928
+ if key in name:
929
+ name = name.replace(key, mapping["param_name"])
930
+ return name, mapping["shard_id"]
931
+ return name, None
932
+
933
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
934
+
935
+ model_params_dict = dict(self.model.named_parameters())
936
+ top_params_dict = dict(self.named_parameters())
937
+
938
+ weights_tuple_list = list(weights)
939
+
940
+ shared_embedding_weight = None
941
+ shared_embedding_shard_id = None
942
+
943
+ for name, loaded_weight in weights_tuple_list:
944
+
945
+ name = self._rename_key(name)
946
+ name, shard_id = self._rename_stacked_param(name)
947
+
948
+ if ('shared.weight' in name
949
+ or 'encoder.embed_tokens.weight' in name
950
+ or 'decoder.embed_tokens.weight' in name
951
+ or 'lm_head.weight' in name):
952
+ assert shared_embedding_weight is None, (
953
+ "Conflicting embedding weights.")
954
+ shared_embedding_weight = loaded_weight
955
+ shared_embedding_shard_id = shard_id
956
+ else:
957
+ # Skip the specific downstream task weight.
958
+ if name.startswith('cls.'):
959
+ continue
960
+ # use Pooler instead.
961
+ if name.startswith('pooler.'):
962
+ continue
963
+ # Skip loading extra bias for GPTQ models.
964
+ if name.endswith(".bias") and name not in model_params_dict:
965
+ continue
966
+
967
+ param = model_params_dict[name]
968
+ weight_loader = getattr(param, "weight_loader",
969
+ default_weight_loader)
970
+ if shard_id:
971
+ weight_loader(param, loaded_weight, shard_id)
972
+ else:
973
+ weight_loader(param, loaded_weight)
974
+
975
+ # Assign shared weight values
976
+ encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
977
+ encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
978
+ default_weight_loader)
979
+
980
+ decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
981
+ decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
982
+ default_weight_loader)
983
+
984
+ lm_head_in_param = top_params_dict['lm_head.weight']
985
+ lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
986
+ default_weight_loader)
987
+
988
+ assert shared_embedding_weight is not None
989
+
990
+ if shared_embedding_shard_id:
991
+ encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
992
+ shared_embedding_shard_id)
993
+ decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
994
+ shared_embedding_shard_id)
995
+ lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
996
+ shared_embedding_shard_id)
997
+ else:
998
+ encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
999
+ decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
1000
+ lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/bert.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Iterable, List, Optional, Set, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import BertConfig
8
+
9
+ from vllm.attention import Attention, AttentionMetadata, AttentionType
10
+ from vllm.compilation.decorators import support_torch_compile
11
+ from vllm.config import CacheConfig, PoolerConfig, VllmConfig
12
+ from vllm.distributed import get_tensor_model_parallel_world_size
13
+ from vllm.model_executor.layers.activation import get_act_fn
14
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
15
+ QKVParallelLinear,
16
+ RowParallelLinear)
17
+ from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
18
+ PoolingType)
19
+ from vllm.model_executor.layers.quantization import QuantizationConfig
20
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
21
+ VocabParallelEmbedding)
22
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
+ from vllm.model_executor.pooling_metadata import PoolingMetadata
24
+ from vllm.sequence import IntermediateTensors, PoolerOutput
25
+ from vllm.transformers_utils.config import (
26
+ get_cross_encoder_activation_function)
27
+
28
+ from .interfaces import SupportsCrossEncoding
29
+ from .utils import WeightsMapper, maybe_prefix
30
+
31
+
32
+ class BertEmbedding(nn.Module):
33
+
34
+ def __init__(self, config: BertConfig):
35
+
36
+ super().__init__()
37
+ self.size = config.hidden_size
38
+ self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
39
+ config.hidden_size)
40
+ self.position_embeddings = VocabParallelEmbedding(
41
+ config.max_position_embeddings, config.hidden_size)
42
+ self.token_type_embeddings = VocabParallelEmbedding(
43
+ config.type_vocab_size, config.hidden_size)
44
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
45
+ eps=config.layer_norm_eps)
46
+ self.position_ids = nn.Parameter(
47
+ torch.empty((1, config.max_position_embeddings)), )
48
+
49
+ self.position_embedding_type = config.position_embedding_type
50
+ if self.position_embedding_type != "absolute":
51
+ raise ValueError("Only 'absolute' position_embedding_type" +
52
+ " is supported")
53
+
54
+ def forward(
55
+ self,
56
+ input_ids: torch.Tensor,
57
+ seq_lens: torch.Tensor,
58
+ position_ids: torch.Tensor,
59
+ token_type_ids: Optional[torch.Tensor] = None,
60
+ ) -> torch.Tensor:
61
+ input_shape = input_ids.size()
62
+
63
+ # Input embeddings.
64
+ inputs_embeds = self.word_embeddings(input_ids)
65
+
66
+ # Position embeddings.
67
+ position_embeddings = self.position_embeddings(position_ids)
68
+
69
+ if token_type_ids is None:
70
+ token_type_ids = torch.zeros(input_shape,
71
+ dtype=torch.long,
72
+ device=inputs_embeds.device)
73
+
74
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
75
+
76
+ embeddings = inputs_embeds + token_type_embeddings + position_embeddings
77
+ embeddings = self.LayerNorm(embeddings)
78
+ return embeddings
79
+
80
+
81
+ class BertPooler(nn.Module):
82
+
83
+ def __init__(self, config: BertConfig):
84
+ super().__init__()
85
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
86
+ self.activation = nn.Tanh()
87
+
88
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
89
+ # We "pool" the model by simply taking the hidden state corresponding
90
+ # to the first token.
91
+ first_token_tensor = hidden_states[0, :]
92
+ pooled_output = self.dense(first_token_tensor)
93
+ pooled_output = self.activation(pooled_output)
94
+ return pooled_output
95
+
96
+
97
+ @support_torch_compile
98
+ class BertEncoder(nn.Module):
99
+
100
+ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
101
+ super().__init__()
102
+ config = vllm_config.model_config.hf_config
103
+ cache_config = vllm_config.cache_config
104
+ quant_config = vllm_config.quant_config
105
+ self.layer = nn.ModuleList([
106
+ BertLayer(config=config,
107
+ cache_config=cache_config,
108
+ quant_config=quant_config,
109
+ prefix=f"{prefix}.layer.{layer_idx}")
110
+ for layer_idx in range(config.num_hidden_layers)
111
+ ])
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ kv_caches: List[torch.Tensor],
117
+ attn_metadata: AttentionMetadata,
118
+ ) -> torch.Tensor:
119
+ for i in range(len(self.layer)):
120
+ layer = self.layer[i]
121
+ hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
122
+ return hidden_states
123
+
124
+
125
+ class BertLayer(nn.Module):
126
+
127
+ def __init__(self,
128
+ config: BertConfig,
129
+ cache_config: Optional[CacheConfig] = None,
130
+ quant_config: Optional[QuantizationConfig] = None,
131
+ prefix: str = ""):
132
+ super().__init__()
133
+
134
+ self.attention = BertAttention(
135
+ hidden_size=config.hidden_size,
136
+ num_attention_heads=config.num_attention_heads,
137
+ layer_norm_eps=config.layer_norm_eps,
138
+ cache_config=cache_config,
139
+ quant_config=quant_config,
140
+ prefix=f"{prefix}.attention")
141
+
142
+ self.intermediate = BertIntermediate(
143
+ hidden_size=config.hidden_size,
144
+ intermediate_size=config.intermediate_size,
145
+ hidden_act=config.hidden_act,
146
+ quant_config=quant_config,
147
+ prefix=f"{prefix}.intermediate")
148
+
149
+ self.output = BertOutput(hidden_size=config.hidden_size,
150
+ intermediate_size=config.intermediate_size,
151
+ layer_norm_eps=config.layer_norm_eps,
152
+ quant_config=quant_config,
153
+ prefix=f"{prefix}.output")
154
+
155
+ def forward(
156
+ self,
157
+ hidden_states: torch.Tensor,
158
+ kv_cache: Optional[torch.Tensor],
159
+ attn_metadata: AttentionMetadata,
160
+ ):
161
+ attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
162
+ intermediate_output = self.intermediate(attn_output)
163
+ output = self.output(intermediate_output, attn_output)
164
+ return output
165
+
166
+
167
+ class BertAttention(nn.Module):
168
+
169
+ def __init__(
170
+ self,
171
+ hidden_size: int,
172
+ num_attention_heads: int,
173
+ layer_norm_eps: float,
174
+ cache_config: Optional[CacheConfig] = None,
175
+ quant_config: Optional[QuantizationConfig] = None,
176
+ prefix: str = "",
177
+ ):
178
+ super().__init__()
179
+
180
+ self.self = BertSelfAttention(hidden_size=hidden_size,
181
+ num_attention_heads=num_attention_heads,
182
+ cache_config=cache_config,
183
+ quant_config=quant_config,
184
+ prefix=f"{prefix}.output")
185
+
186
+ self.output = BertSelfOutput(hidden_size=hidden_size,
187
+ layer_norm_eps=layer_norm_eps,
188
+ quant_config=quant_config,
189
+ prefix=f"{prefix}.output")
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ kv_cache: torch.Tensor,
195
+ attn_metadata: AttentionMetadata,
196
+ ) -> torch.Tensor:
197
+ self_output = self.self(hidden_states, kv_cache, attn_metadata)
198
+ return self.output(self_output, hidden_states)
199
+
200
+
201
+ class BertSelfAttention(nn.Module):
202
+
203
+ def __init__(
204
+ self,
205
+ hidden_size: int,
206
+ num_attention_heads: int,
207
+ cache_config: Optional[CacheConfig] = None,
208
+ quant_config: Optional[QuantizationConfig] = None,
209
+ prefix: str = "",
210
+ ):
211
+ super().__init__()
212
+ self.hidden_size = hidden_size
213
+ tp_size = get_tensor_model_parallel_world_size()
214
+
215
+ self.total_num_heads = num_attention_heads
216
+ assert self.total_num_heads % tp_size == 0
217
+
218
+ self.num_heads = self.total_num_heads // tp_size
219
+ self.total_num_kv_heads = self.total_num_heads
220
+ self.head_dim = self.hidden_size // self.total_num_heads
221
+ assert self.head_dim * self.total_num_heads == self.hidden_size
222
+
223
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
224
+
225
+ self.q_size = self.num_heads * self.head_dim
226
+ self.kv_size = self.num_kv_heads * self.head_dim
227
+ self.scaling = self.head_dim**-0.5
228
+ self.qkv_proj = QKVParallelLinear(
229
+ hidden_size=self.hidden_size,
230
+ head_size=self.head_dim,
231
+ total_num_heads=self.total_num_heads,
232
+ total_num_kv_heads=self.total_num_kv_heads,
233
+ bias=True,
234
+ quant_config=quant_config,
235
+ prefix=f"{prefix}.qkv_proj")
236
+
237
+ self.attn = Attention(num_heads=self.num_heads,
238
+ head_size=self.head_dim,
239
+ scale=self.scaling,
240
+ num_kv_heads=self.num_kv_heads,
241
+ cache_config=cache_config,
242
+ quant_config=quant_config,
243
+ prefix=f"{prefix}.attn",
244
+ attn_type=AttentionType.ENCODER_ONLY)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ kv_cache: torch.Tensor,
250
+ attn_metadata: AttentionMetadata,
251
+ ) -> torch.Tensor:
252
+ qkv, _ = self.qkv_proj(hidden_states)
253
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
254
+ output = self.attn(q, k, v, kv_cache, attn_metadata)
255
+ return output
256
+
257
+
258
+ class BertSelfOutput(nn.Module):
259
+
260
+ def __init__(self,
261
+ hidden_size: int,
262
+ layer_norm_eps: float,
263
+ quant_config: Optional[QuantizationConfig] = None,
264
+ prefix: str = ""):
265
+ super().__init__()
266
+ self.dense = RowParallelLinear(input_size=hidden_size,
267
+ output_size=hidden_size,
268
+ bias=True,
269
+ quant_config=quant_config,
270
+ prefix=f"{prefix}.dense")
271
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
272
+
273
+ def forward(self, hidden_states: torch.Tensor,
274
+ input_tensor: torch.Tensor) -> torch.Tensor:
275
+ hidden_states, _ = self.dense(hidden_states)
276
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
277
+ return hidden_states
278
+
279
+
280
+ class BertIntermediate(nn.Module):
281
+
282
+ def __init__(self,
283
+ hidden_size: int,
284
+ intermediate_size: int,
285
+ hidden_act: str,
286
+ quant_config: Optional[QuantizationConfig] = None,
287
+ prefix: str = ""):
288
+ super().__init__()
289
+ self.dense = ColumnParallelLinear(input_size=hidden_size,
290
+ output_size=intermediate_size,
291
+ bias=True,
292
+ quant_config=quant_config,
293
+ prefix=f"{prefix}.dense")
294
+ self.intermediate_act_fn = get_act_fn(hidden_act)
295
+
296
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
297
+ hidden_states, _ = self.dense(hidden_states)
298
+ hidden_states = self.intermediate_act_fn(hidden_states)
299
+ return hidden_states
300
+
301
+
302
+ class BertOutput(nn.Module):
303
+
304
+ def __init__(self,
305
+ hidden_size: int,
306
+ intermediate_size: int,
307
+ layer_norm_eps: float,
308
+ quant_config: Optional[QuantizationConfig] = None,
309
+ prefix: str = ""):
310
+ super().__init__()
311
+
312
+ self.dense = RowParallelLinear(input_size=intermediate_size,
313
+ output_size=hidden_size,
314
+ bias=True,
315
+ quant_config=quant_config,
316
+ prefix=f"{prefix}.dense")
317
+
318
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
319
+
320
+ def forward(self, hidden_states: torch.Tensor,
321
+ input_tensor: torch.Tensor) -> torch.Tensor:
322
+ hidden_states, _ = self.dense(hidden_states)
323
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
324
+ return hidden_states
325
+
326
+
327
+ class BertModel(nn.Module):
328
+
329
+ def __init__(self,
330
+ *,
331
+ vllm_config: VllmConfig,
332
+ prefix: str = "",
333
+ embedding_class: type = BertEmbedding,
334
+ add_pooling_layer: bool = False):
335
+ super().__init__()
336
+ config = vllm_config.model_config.hf_config
337
+ self.embeddings = embedding_class(config)
338
+ self.encoder = BertEncoder(vllm_config=vllm_config,
339
+ prefix=f"{prefix}.encoder")
340
+ self.pooler = BertPooler(config) if add_pooling_layer else None
341
+
342
+ def forward(
343
+ self,
344
+ input_ids: torch.Tensor,
345
+ position_ids: torch.Tensor,
346
+ kv_caches: List[torch.Tensor],
347
+ attn_metadata: AttentionMetadata,
348
+ intermediate_tensors: Optional[IntermediateTensors] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ token_type_ids: Optional[torch.Tensor] = None,
351
+ ) -> torch.Tensor:
352
+ if inputs_embeds is not None:
353
+ hidden_states = inputs_embeds
354
+ else:
355
+ assert hasattr(attn_metadata, "seq_lens_tensor")
356
+ hidden_states = self.embeddings(
357
+ input_ids=input_ids,
358
+ seq_lens=attn_metadata.seq_lens_tensor,
359
+ position_ids=position_ids,
360
+ token_type_ids=token_type_ids)
361
+ return self.encoder(hidden_states, kv_caches, attn_metadata)
362
+
363
+ def load_weights(self, weights: Iterable[Tuple[str,
364
+ torch.Tensor]]) -> Set[str]:
365
+ stacked_params_mapping = [
366
+ # (param_name, shard_name, shard_id)
367
+ ("qkv_proj", "query", "q"),
368
+ ("qkv_proj", "key", "k"),
369
+ ("qkv_proj", "value", "v"),
370
+ ]
371
+
372
+ params_dict = dict(self.named_parameters())
373
+ loaded_params: Set[str] = set()
374
+ for name, loaded_weight in weights:
375
+ if self.pooler is None and "pooler" in name:
376
+ continue
377
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
378
+ if weight_name not in name:
379
+ continue
380
+ name = name.replace(weight_name, param_name)
381
+ # Skip loading extra bias for GPTQ models.
382
+ if name.endswith(".bias") and name not in params_dict:
383
+ continue
384
+ param = params_dict[name]
385
+ weight_loader = param.weight_loader
386
+ weight_loader(param, loaded_weight, shard_id)
387
+ break
388
+ else:
389
+ # Skip loading extra bias for GPTQ models.
390
+ if name.endswith(".bias") and name not in params_dict:
391
+ continue
392
+ param = params_dict[name]
393
+ weight_loader = getattr(param, "weight_loader",
394
+ default_weight_loader)
395
+ weight_loader(param, loaded_weight)
396
+ loaded_params.add(name)
397
+ return loaded_params
398
+
399
+
400
+ class BertEmbeddingModel(nn.Module):
401
+ """A model that uses Bert to provide embedding functionalities.
402
+
403
+ This class encapsulates the BertModel and provides an interface for
404
+ embedding operations and customized pooling functions.
405
+
406
+ Attributes:
407
+ model: An instance of BertModel used for forward operations.
408
+ _pooler: An instance of Pooler used for pooling operations.
409
+ """
410
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
411
+
412
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413
+ super().__init__()
414
+ pooler_config = vllm_config.model_config.pooler_config
415
+ self.model = self._build_model(vllm_config=vllm_config,
416
+ prefix=maybe_prefix(prefix, "model"))
417
+ self._pooler = self._build_pooler(pooler_config)
418
+
419
+ def forward(
420
+ self,
421
+ input_ids: Optional[torch.Tensor],
422
+ positions: torch.Tensor,
423
+ kv_caches: List[torch.Tensor],
424
+ attn_metadata: AttentionMetadata,
425
+ intermediate_tensors: Optional[IntermediateTensors] = None,
426
+ inputs_embeds: Optional[torch.Tensor] = None,
427
+ ) -> torch.Tensor:
428
+ return self.model(input_ids=input_ids,
429
+ position_ids=positions,
430
+ kv_caches=kv_caches,
431
+ inputs_embeds=inputs_embeds,
432
+ intermediate_tensors=intermediate_tensors,
433
+ attn_metadata=attn_metadata)
434
+
435
+ def pooler(
436
+ self,
437
+ hidden_states: torch.Tensor,
438
+ pooling_metadata: PoolingMetadata,
439
+ ) -> Optional[PoolerOutput]:
440
+ return self._pooler(hidden_states, pooling_metadata)
441
+
442
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
443
+ weights = self.hf_to_vllm_mapper.apply(weights)
444
+ weights = ((name, data) for name, data in weights
445
+ if not name.startswith("lm_head."))
446
+ self.model.load_weights(weights)
447
+
448
+ def _build_model(self,
449
+ vllm_config: VllmConfig,
450
+ prefix: str = "") -> BertModel:
451
+ return BertModel(vllm_config=vllm_config,
452
+ prefix=prefix,
453
+ embedding_class=BertEmbedding)
454
+
455
+ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
456
+ return Pooler.from_config_with_defaults(pooler_config,
457
+ pooling_type=PoolingType.CLS,
458
+ normalize=True,
459
+ softmax=False)
460
+
461
+
462
+ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
463
+ """A model that uses Bert to provide embedding functionalities.
464
+
465
+ This class encapsulates the BertModel and provides an interface for
466
+ embedding operations and customized pooling functions.
467
+
468
+ Attributes:
469
+ model: An instance of BertModel used for forward operations.
470
+ _pooler: An instance of Pooler used for pooling operations.
471
+ """
472
+
473
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
474
+ super().__init__()
475
+ config = vllm_config.model_config.hf_config
476
+
477
+ self.default_activation_function = \
478
+ get_cross_encoder_activation_function(config)
479
+
480
+ self.num_labels = config.num_labels
481
+ self.bert = BertModel(vllm_config=vllm_config,
482
+ prefix=maybe_prefix(prefix, "bert"),
483
+ embedding_class=BertEmbedding,
484
+ add_pooling_layer=True)
485
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
486
+ self._pooler = CrossEncodingPooler(config, self.classifier,
487
+ self.bert.pooler)
488
+
489
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
490
+
491
+ self_weights = []
492
+
493
+ def weight_filter():
494
+ for name, weight in weights:
495
+ if name.startswith("bert."):
496
+ yield (name[len("bert."):], weight)
497
+ else:
498
+ self_weights.append((name, weight))
499
+
500
+ self.bert.load_weights(weight_filter())
501
+
502
+ params_dict = dict(self.named_parameters())
503
+
504
+ for name, loaded_weight in self_weights:
505
+ if name.startswith("classifier"):
506
+ param = params_dict[name]
507
+ weight_loader = getattr(param, "weight_loader",
508
+ default_weight_loader)
509
+ weight_loader(param, loaded_weight)
510
+
511
+ def pooler(
512
+ self,
513
+ hidden_states: torch.Tensor,
514
+ pooling_metadata: PoolingMetadata,
515
+ ) -> Optional[PoolerOutput]:
516
+ return self._pooler(hidden_states, pooling_metadata)
517
+
518
+ def forward(
519
+ self,
520
+ input_ids: Optional[torch.Tensor],
521
+ positions: torch.Tensor,
522
+ kv_caches: List[torch.Tensor],
523
+ attn_metadata: AttentionMetadata,
524
+ intermediate_tensors: Optional[IntermediateTensors] = None,
525
+ inputs_embeds: Optional[torch.Tensor] = None,
526
+ token_type_ids: Optional[torch.Tensor] = None,
527
+ ) -> torch.Tensor:
528
+ return self.bert(input_ids=input_ids,
529
+ position_ids=positions,
530
+ kv_caches=kv_caches,
531
+ inputs_embeds=inputs_embeds,
532
+ intermediate_tensors=intermediate_tensors,
533
+ attn_metadata=attn_metadata,
534
+ token_type_ids=token_type_ids)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/blip2.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import cached_property
4
+ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
5
+ TypedDict, Union)
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
10
+ apply_chunking_to_forward)
11
+
12
+ from vllm.attention import AttentionMetadata
13
+ from vllm.config import CacheConfig, VllmConfig
14
+ from vllm.model_executor.layers.activation import get_act_fn
15
+ from vllm.model_executor.layers.quantization import QuantizationConfig
16
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
17
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
18
+ from vllm.multimodal import MULTIMODAL_REGISTRY
19
+ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
20
+ NestedTensors)
21
+ from vllm.multimodal.parse import MultiModalDataItems
22
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
23
+ BaseProcessingInfo, PromptReplacement,
24
+ PromptReplacementDetails)
25
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
26
+ from vllm.sequence import IntermediateTensors
27
+
28
+ from .blip import BlipVisionModel
29
+ from .interfaces import SupportsMultiModal, SupportsPP
30
+ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
31
+ maybe_prefix, merge_multimodal_embeddings)
32
+
33
+ # We use this internally as placeholders since there is no image token
34
+ # defined on the HuggingFace repo
35
+ _IMAGE_TOKEN_ID = 50265
36
+
37
+
38
+ class Blip2ImagePixelInputs(TypedDict):
39
+ type: Literal["pixel_values"]
40
+ data: torch.Tensor
41
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
42
+
43
+
44
+ class Blip2ImageEmbeddingInputs(TypedDict):
45
+ type: Literal["image_embeds"]
46
+ data: torch.Tensor
47
+ """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
48
+
49
+ `hidden_size` must match the hidden size of language model backbone.
50
+ """
51
+
52
+
53
+ Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
54
+
55
+
56
+ class Blip2QFormerMultiHeadAttention(nn.Module):
57
+
58
+ def __init__(
59
+ self,
60
+ config: Blip2QFormerConfig,
61
+ *,
62
+ quant_config: Optional[QuantizationConfig],
63
+ cache_config: Optional[CacheConfig],
64
+ is_cross_attention: bool = False,
65
+ ) -> None:
66
+ super().__init__()
67
+
68
+ self.config = config
69
+
70
+ if config.hidden_size % config.num_attention_heads != 0:
71
+ raise ValueError(
72
+ f"The hidden size ({config.hidden_size}) is not a multiple of "
73
+ f"the number of attention heads ({config.num_attention_heads})"
74
+ )
75
+
76
+ self.num_attention_heads = config.num_attention_heads
77
+ self.attention_head_size = (config.hidden_size //
78
+ config.num_attention_heads)
79
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
80
+ self.scaling = self.attention_head_size**-0.5
81
+
82
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
83
+ if is_cross_attention:
84
+ kv_hidden_size = config.encoder_hidden_size
85
+ else:
86
+ kv_hidden_size = config.hidden_size
87
+ self.key = nn.Linear(kv_hidden_size, self.all_head_size)
88
+ self.value = nn.Linear(kv_hidden_size, self.all_head_size)
89
+
90
+ self.position_embedding_type = getattr(config,
91
+ "position_embedding_type",
92
+ "absolute")
93
+ if self.position_embedding_type != "absolute":
94
+ raise NotImplementedError("Unsupported position_embedding_type: "
95
+ f"{self.position_embedding_type}")
96
+
97
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
98
+
99
+ def transpose_for_scores(self, x):
100
+ x = x.view(*x.size()[:-1], self.num_attention_heads,
101
+ self.attention_head_size)
102
+ return x.permute(0, 2, 1, 3)
103
+
104
+ def forward(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
108
+ ):
109
+ is_cross_attention = encoder_hidden_states is not None
110
+
111
+ if is_cross_attention:
112
+ key_layer = self.transpose_for_scores(
113
+ self.key(encoder_hidden_states))
114
+ value_layer = self.transpose_for_scores(
115
+ self.value(encoder_hidden_states))
116
+ else:
117
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
118
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
119
+
120
+ mixed_query_layer = self.query(hidden_states)
121
+
122
+ query_layer = self.transpose_for_scores(mixed_query_layer)
123
+
124
+ attention_scores = torch.matmul(query_layer,
125
+ key_layer.transpose(-1, -2))
126
+ attention_probs = torch.softmax(attention_scores * self.scaling,
127
+ dim=-1)
128
+
129
+ # This is actually dropping out entire tokens to attend to, which might
130
+ # seem a bit unusual, but is taken from the original Transformer paper.
131
+ attention_probs_dropped = self.dropout(attention_probs)
132
+
133
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
134
+
135
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
136
+ context_layer = context_layer.view(*context_layer.size()[:-2],
137
+ self.all_head_size)
138
+
139
+ return context_layer
140
+
141
+
142
+ class Blip2QFormerSelfOutput(nn.Module):
143
+
144
+ def __init__(self, config: Blip2QFormerConfig) -> None:
145
+ super().__init__()
146
+
147
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
148
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
149
+ eps=config.layer_norm_eps)
150
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
151
+
152
+ def forward(
153
+ self,
154
+ hidden_states: torch.Tensor,
155
+ input_tensor: torch.Tensor,
156
+ ) -> torch.Tensor:
157
+ hidden_states = self.dense(hidden_states)
158
+ hidden_states = self.dropout(hidden_states)
159
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
160
+ return hidden_states
161
+
162
+
163
+ class Blip2QFormerAttention(nn.Module):
164
+
165
+ def __init__(
166
+ self,
167
+ config: Blip2QFormerConfig,
168
+ *,
169
+ quant_config: Optional[QuantizationConfig],
170
+ cache_config: Optional[CacheConfig],
171
+ is_cross_attention: bool = False,
172
+ ) -> None:
173
+ super().__init__()
174
+
175
+ self.attention = Blip2QFormerMultiHeadAttention(
176
+ config,
177
+ quant_config=quant_config,
178
+ cache_config=cache_config,
179
+ is_cross_attention=is_cross_attention,
180
+ )
181
+
182
+ self.output = Blip2QFormerSelfOutput(config)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
188
+ ) -> Tuple[torch.Tensor]:
189
+ self_output = self.attention(
190
+ hidden_states,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ )
193
+ attention_output = self.output(self_output, hidden_states)
194
+
195
+ return attention_output
196
+
197
+
198
+ class Blip2QFormerIntermediate(nn.Module):
199
+
200
+ def __init__(self, config: Blip2QFormerConfig) -> None:
201
+ super().__init__()
202
+
203
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
204
+ self.intermediate_act_fn = get_act_fn(config.hidden_act)
205
+
206
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
207
+ hidden_states = self.dense(hidden_states)
208
+ hidden_states = self.intermediate_act_fn(hidden_states)
209
+ return hidden_states
210
+
211
+
212
+ class Blip2QFormerOutput(nn.Module):
213
+
214
+ def __init__(self, config: Blip2QFormerConfig) -> None:
215
+ super().__init__()
216
+
217
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
218
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
219
+ eps=config.layer_norm_eps)
220
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states: torch.Tensor,
225
+ input_tensor: torch.Tensor,
226
+ ) -> torch.Tensor:
227
+ hidden_states = self.dense(hidden_states)
228
+ hidden_states = self.dropout(hidden_states)
229
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
230
+ return hidden_states
231
+
232
+
233
+ class Blip2QFormerLayer(nn.Module):
234
+
235
+ def __init__(
236
+ self,
237
+ config: Blip2QFormerConfig,
238
+ *,
239
+ quant_config: Optional[QuantizationConfig],
240
+ cache_config: Optional[CacheConfig],
241
+ layer_idx: int,
242
+ ) -> None:
243
+ super().__init__()
244
+
245
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
246
+ self.seq_len_dim = 1
247
+ self.attention = Blip2QFormerAttention(config,
248
+ quant_config=quant_config,
249
+ cache_config=cache_config)
250
+
251
+ self.layer_idx = layer_idx
252
+
253
+ if layer_idx % config.cross_attention_frequency == 0:
254
+ self.crossattention = Blip2QFormerAttention(
255
+ config,
256
+ quant_config=quant_config,
257
+ cache_config=cache_config,
258
+ is_cross_attention=True)
259
+ self.has_cross_attention = True
260
+ else:
261
+ self.has_cross_attention = False
262
+
263
+ self.intermediate_query = Blip2QFormerIntermediate(config)
264
+ self.output_query = Blip2QFormerOutput(config)
265
+
266
+ def forward(
267
+ self,
268
+ hidden_states: torch.FloatTensor,
269
+ encoder_hidden_states: torch.FloatTensor,
270
+ query_length: int,
271
+ ):
272
+ attention_output = self.attention(hidden_states)
273
+
274
+ if query_length > 0:
275
+ query_attention_output = attention_output[:, :query_length, :]
276
+
277
+ if self.has_cross_attention:
278
+ query_attention_output = self.crossattention(
279
+ query_attention_output,
280
+ encoder_hidden_states=encoder_hidden_states,
281
+ )
282
+
283
+ layer_output = apply_chunking_to_forward(
284
+ self.feed_forward_chunk_query,
285
+ self.chunk_size_feed_forward,
286
+ self.seq_len_dim,
287
+ query_attention_output,
288
+ )
289
+
290
+ if attention_output.shape[1] > query_length:
291
+ layer_output_text = apply_chunking_to_forward(
292
+ self.feed_forward_chunk,
293
+ self.chunk_size_feed_forward,
294
+ self.seq_len_dim,
295
+ attention_output[:, query_length:, :],
296
+ )
297
+ layer_output = torch.cat([layer_output, layer_output_text],
298
+ dim=1)
299
+ else:
300
+ layer_output = apply_chunking_to_forward(
301
+ self.feed_forward_chunk,
302
+ self.chunk_size_feed_forward,
303
+ self.seq_len_dim,
304
+ attention_output,
305
+ )
306
+
307
+ return layer_output
308
+
309
+ def feed_forward_chunk(self,
310
+ attention_output: torch.Tensor) -> torch.Tensor:
311
+ intermediate_output = self.intermediate(attention_output)
312
+ layer_output = self.output(intermediate_output, attention_output)
313
+ return layer_output
314
+
315
+ def feed_forward_chunk_query(
316
+ self, attention_output: torch.Tensor) -> torch.Tensor:
317
+ intermediate_output = self.intermediate_query(attention_output)
318
+ layer_output = self.output_query(intermediate_output, attention_output)
319
+ return layer_output
320
+
321
+
322
+ class Blip2QFormerEncoder(nn.Module):
323
+
324
+ def __init__(
325
+ self,
326
+ config: Blip2QFormerConfig,
327
+ *,
328
+ quant_config: Optional[QuantizationConfig],
329
+ cache_config: Optional[CacheConfig],
330
+ ) -> None:
331
+ super().__init__()
332
+
333
+ self.config = config
334
+
335
+ self.layer = nn.ModuleList([
336
+ Blip2QFormerLayer(config,
337
+ quant_config=quant_config,
338
+ cache_config=cache_config,
339
+ layer_idx=layer_idx)
340
+ for layer_idx in range(config.num_hidden_layers)
341
+ ])
342
+
343
+ def forward(
344
+ self,
345
+ hidden_states: torch.FloatTensor,
346
+ encoder_hidden_states: torch.FloatTensor,
347
+ query_length: int,
348
+ ) -> torch.Tensor:
349
+ for i in range(self.config.num_hidden_layers):
350
+ layer_module = self.layer[i]
351
+
352
+ hidden_states = layer_module(
353
+ hidden_states,
354
+ encoder_hidden_states=encoder_hidden_states,
355
+ query_length=query_length,
356
+ )
357
+
358
+ return hidden_states
359
+
360
+
361
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
362
+ class Blip2QFormerModel(nn.Module):
363
+
364
+ def __init__(
365
+ self,
366
+ config: Blip2QFormerConfig,
367
+ *,
368
+ quant_config: Optional[QuantizationConfig],
369
+ cache_config: Optional[CacheConfig],
370
+ ) -> None:
371
+ super().__init__()
372
+
373
+ self.config = config
374
+
375
+ self.layernorm = nn.LayerNorm(config.hidden_size,
376
+ eps=config.layer_norm_eps)
377
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
378
+
379
+ self.encoder = Blip2QFormerEncoder(config,
380
+ quant_config=quant_config,
381
+ cache_config=cache_config)
382
+
383
+ def forward(
384
+ self,
385
+ query_embeds: torch.FloatTensor,
386
+ encoder_hidden_states: torch.FloatTensor,
387
+ ) -> torch.Tensor:
388
+ query_length = query_embeds.shape[1]
389
+
390
+ embedding_output = self.layernorm(query_embeds)
391
+ embedding_output = self.dropout(embedding_output)
392
+
393
+ sequence_output = self.encoder(
394
+ embedding_output,
395
+ encoder_hidden_states=encoder_hidden_states,
396
+ query_length=query_length,
397
+ )
398
+
399
+ return sequence_output
400
+
401
+
402
+ class Blip2ProcessingInfo(BaseProcessingInfo):
403
+
404
+ def get_hf_config(self):
405
+ return self.ctx.get_hf_config(Blip2Config)
406
+
407
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
408
+ return {"image": 1}
409
+
410
+ def get_mm_max_tokens_per_item(
411
+ self,
412
+ seq_len: int,
413
+ mm_counts: Mapping[str, int],
414
+ ) -> Mapping[str, int]:
415
+ return {"image": self.get_num_image_tokens()}
416
+
417
+ def get_num_image_tokens(self) -> int:
418
+ hf_config = self.get_hf_config()
419
+ return hf_config.num_query_tokens
420
+
421
+
422
+ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
423
+
424
+ def get_dummy_processor_inputs(
425
+ self,
426
+ seq_len: int,
427
+ mm_counts: Mapping[str, int],
428
+ ) -> ProcessorInputs:
429
+ hf_config = self.info.get_hf_config()
430
+ vision_config = hf_config.vision_config
431
+
432
+ max_image_size = vision_config.image_size
433
+ num_images = mm_counts.get("image", 0)
434
+
435
+ mm_data = {
436
+ "image":
437
+ self._get_dummy_images(width=max_image_size,
438
+ height=max_image_size,
439
+ num_images=num_images)
440
+ }
441
+
442
+ return ProcessorInputs(
443
+ prompt_text="",
444
+ mm_data=mm_data,
445
+ )
446
+
447
+
448
+ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
449
+
450
+ def _call_hf_processor(
451
+ self,
452
+ prompt: str,
453
+ mm_data: Mapping[str, object],
454
+ mm_kwargs: Mapping[str, object],
455
+ ) -> BatchFeature:
456
+ if not mm_data:
457
+ # HF processor always adds placeholders even when there's no image
458
+ tokenizer = self.info.get_tokenizer()
459
+ prompt_ids = tokenizer.encode(prompt)
460
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
461
+
462
+ return super()._call_hf_processor(
463
+ prompt=prompt,
464
+ mm_data=mm_data,
465
+ mm_kwargs=mm_kwargs,
466
+ )
467
+
468
+ def _get_mm_fields_config(
469
+ self,
470
+ hf_inputs: BatchFeature,
471
+ hf_processor_mm_kwargs: Mapping[str, object],
472
+ ) -> Mapping[str, MultiModalFieldConfig]:
473
+ return dict(
474
+ pixel_values=MultiModalFieldConfig.batched("image"),
475
+ image_embeds=MultiModalFieldConfig.batched("image"),
476
+ )
477
+
478
+ def _get_prompt_replacements(
479
+ self,
480
+ mm_items: MultiModalDataItems,
481
+ hf_processor_mm_kwargs: Mapping[str, object],
482
+ out_mm_kwargs: MultiModalKwargs,
483
+ ) -> list[PromptReplacement]:
484
+ tokenizer = self.info.get_tokenizer()
485
+ vocab = tokenizer.get_vocab()
486
+
487
+ bos_token_id = tokenizer.bos_token_id
488
+ assert isinstance(bos_token_id, int)
489
+
490
+ image_token_id = vocab["<image>"]
491
+ num_image_tokens = self.info.get_num_image_tokens()
492
+ image_tokens = [image_token_id] * num_image_tokens
493
+
494
+ return [
495
+ PromptReplacement(
496
+ modality="image",
497
+ target=[bos_token_id],
498
+ replacement=PromptReplacementDetails(
499
+ full=image_tokens + [bos_token_id],
500
+ features=image_tokens,
501
+ ),
502
+ )
503
+ ]
504
+
505
+
506
+ @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
507
+ info=Blip2ProcessingInfo,
508
+ dummy_inputs=Blip2DummyInputsBuilder)
509
+ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
510
+
511
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
512
+
513
+ super().__init__()
514
+ config = vllm_config.model_config.hf_config
515
+ cache_config = vllm_config.cache_config
516
+ quant_config = vllm_config.quant_config
517
+ multimodal_config = vllm_config.model_config.multimodal_config
518
+ self.config = config
519
+ self.multimodal_config = multimodal_config
520
+
521
+ # TODO: Optionally initializes this for supporting embeddings.
522
+ self.vision_model = BlipVisionModel(config.vision_config, quant_config)
523
+
524
+ self.query_tokens = nn.Parameter(
525
+ torch.zeros(1, config.num_query_tokens,
526
+ config.qformer_config.hidden_size))
527
+
528
+ self.qformer = Blip2QFormerModel(config.qformer_config,
529
+ cache_config=cache_config,
530
+ quant_config=quant_config)
531
+
532
+ self.language_projection = nn.Linear(
533
+ config.qformer_config.hidden_size,
534
+ config.text_config.hidden_size,
535
+ bias=True,
536
+ )
537
+
538
+ self.language_model = init_vllm_registered_model(
539
+ vllm_config=vllm_config,
540
+ hf_config=config.text_config,
541
+ prefix=maybe_prefix(prefix, "language_model"),
542
+ )
543
+
544
+ self.make_empty_intermediate_tensors = (
545
+ self.language_model.make_empty_intermediate_tensors)
546
+
547
+ @cached_property
548
+ def sampler(self):
549
+ if hasattr(self.language_model, "sampler"):
550
+ return self.language_model.sampler
551
+
552
+ return get_sampler()
553
+
554
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
555
+ h = w = self.config.vision_config.image_size
556
+ expected_dims = (3, h, w)
557
+ actual_dims = tuple(data.shape[1:])
558
+
559
+ if actual_dims != expected_dims:
560
+ expected_expr = ("batch_size", *map(str, expected_dims))
561
+ raise ValueError(
562
+ f"The expected shape of pixel values is {expected_expr}. "
563
+ f"You supplied {tuple(data.shape)}.")
564
+
565
+ return data
566
+
567
+ def _parse_and_validate_image_input(
568
+ self, **kwargs: object) -> Optional[Blip2ImageInputs]:
569
+ pixel_values = kwargs.pop("pixel_values", None)
570
+ image_embeds = kwargs.pop("image_embeds", None)
571
+
572
+ if pixel_values is None and image_embeds is None:
573
+ return None
574
+
575
+ if pixel_values is not None:
576
+ if not isinstance(pixel_values, torch.Tensor):
577
+ raise ValueError("Incorrect type of pixel values. "
578
+ f"Got type: {type(pixel_values)}")
579
+
580
+ # Remove the N dimension until multiple images are supported.
581
+ pixel_values = pixel_values.squeeze(1)
582
+
583
+ return Blip2ImagePixelInputs(
584
+ type="pixel_values",
585
+ data=self._validate_pixel_values(pixel_values),
586
+ )
587
+
588
+ if image_embeds is not None:
589
+ if not isinstance(image_embeds, torch.Tensor):
590
+ raise ValueError("Incorrect type of image embeddings. "
591
+ f"Got type: {type(image_embeds)}")
592
+
593
+ # Remove the N dimension until multiple images are supported.
594
+ image_embeds = image_embeds.squeeze(1)
595
+
596
+ return Blip2ImageEmbeddingInputs(
597
+ type="image_embeds",
598
+ data=image_embeds,
599
+ )
600
+
601
+ raise AssertionError("This line should be unreachable.")
602
+
603
+ def _image_pixels_to_features(self, vision_model: BlipVisionModel,
604
+ pixel_values: torch.Tensor) -> torch.Tensor:
605
+
606
+ # NOTE: we skip the step to select the vision feature layer since
607
+ # this is already done inside the vision tower
608
+ image_features = vision_model(pixel_values)
609
+
610
+ return image_features
611
+
612
+ def _process_image_pixels(self,
613
+ inputs: Blip2ImagePixelInputs) -> torch.Tensor:
614
+ assert self.vision_model is not None
615
+
616
+ pixel_values = inputs["data"]
617
+
618
+ return self._image_pixels_to_features(self.vision_model, pixel_values)
619
+
620
+ def _process_image_input(self,
621
+ image_input: Blip2ImageInputs) -> torch.Tensor:
622
+
623
+ if image_input["type"] == "image_embeds":
624
+ return image_input["data"]
625
+
626
+ assert self.vision_model is not None
627
+ image_features = self._process_image_pixels(image_input)
628
+
629
+ query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
630
+ -1)
631
+ query_output = self.qformer(
632
+ query_embeds=query_tokens,
633
+ encoder_hidden_states=image_features,
634
+ )
635
+
636
+ return self.language_projection(query_output)
637
+
638
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
639
+ image_input = self._parse_and_validate_image_input(**kwargs)
640
+ if image_input is None:
641
+ return None
642
+ vision_embeddings = self._process_image_input(image_input)
643
+ return vision_embeddings
644
+
645
+ def get_input_embeddings(
646
+ self,
647
+ input_ids: torch.Tensor,
648
+ multimodal_embeddings: Optional[NestedTensors] = None,
649
+ ) -> torch.Tensor:
650
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
651
+ if multimodal_embeddings is not None:
652
+ inputs_embeds = merge_multimodal_embeddings(
653
+ input_ids, inputs_embeds, multimodal_embeddings,
654
+ _IMAGE_TOKEN_ID)
655
+ return inputs_embeds
656
+
657
+ def forward(
658
+ self,
659
+ input_ids: torch.Tensor,
660
+ positions: torch.Tensor,
661
+ kv_caches: List[torch.Tensor],
662
+ attn_metadata: AttentionMetadata,
663
+ intermediate_tensors: Optional[IntermediateTensors] = None,
664
+ inputs_embeds: Optional[torch.Tensor] = None,
665
+ **kwargs: object,
666
+ ) -> Union[SamplerOutput, IntermediateTensors]:
667
+ """Run forward pass for BLIP-2.
668
+
669
+ One key thing to understand is the `input_ids` already accounts for the
670
+ positions of the to-be-inserted image embeddings.
671
+
672
+ Concretely, consider a text prompt:
673
+ `"Question: What's the content of the image? Answer:"`.
674
+
675
+ Tokenizer outputs:
676
+ `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
677
+
678
+ To reserve space in KV cache, we have to insert placeholder tokens
679
+ before they are inputted to the model, so the input processor prepends
680
+ dummy tokens (denoted as `50265`), resulting in:
681
+ `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
682
+
683
+ We insert 32 tokens since it corresponds to the number of query
684
+ embeddings outputted by the Q-Former and inputted to the language model.
685
+
686
+ This way, the `positions` and `attn_metadata` are consistent
687
+ with the `input_ids`.
688
+
689
+ Args:
690
+ input_ids: Flattened (concatenated) input_ids corresponding to a
691
+ batch.
692
+ pixel_values: The pixels in each input image.
693
+
694
+ See also:
695
+ :class:`Blip2ImageInputs`
696
+ """
697
+
698
+ if intermediate_tensors is not None:
699
+ inputs_embeds = None
700
+
701
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
702
+ # condition is for v0 compatibility.
703
+ elif inputs_embeds is None:
704
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
705
+ inputs_embeds = self.get_input_embeddings(input_ids,
706
+ vision_embeddings)
707
+ input_ids = None
708
+
709
+ hidden_states = self.language_model.model(input_ids,
710
+ positions,
711
+ kv_caches,
712
+ attn_metadata,
713
+ intermediate_tensors,
714
+ inputs_embeds=inputs_embeds)
715
+
716
+ return hidden_states
717
+
718
+ def compute_logits(
719
+ self,
720
+ hidden_states: torch.Tensor,
721
+ sampling_metadata: SamplingMetadata,
722
+ ) -> Optional[torch.Tensor]:
723
+ return self.language_model.compute_logits(hidden_states,
724
+ sampling_metadata)
725
+
726
+ def sample(
727
+ self,
728
+ logits: torch.Tensor,
729
+ sampling_metadata: SamplingMetadata,
730
+ ) -> Optional[SamplerOutput]:
731
+ return self.language_model.sample(logits, sampling_metadata)
732
+
733
+ def load_weights(self, weights: Iterable[Tuple[str,
734
+ torch.Tensor]]) -> Set[str]:
735
+ loader = AutoWeightsLoader(self)
736
+ return loader.load_weights(weights)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/bloom.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """Inference-only BLOOM model compatible with HuggingFace weights."""
20
+ import math
21
+ from typing import Iterable, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from transformers import BloomConfig
26
+
27
+ from vllm.attention import Attention, AttentionMetadata
28
+ from vllm.compilation.decorators import support_torch_compile
29
+ from vllm.config import CacheConfig, VllmConfig
30
+ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
31
+ get_tensor_model_parallel_world_size)
32
+ from vllm.model_executor.layers.activation import get_act_fn
33
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
+ from vllm.model_executor.layers.quantization import QuantizationConfig
38
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead, VocabParallelEmbedding)
41
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
43
+ from vllm.sequence import IntermediateTensors
44
+
45
+ from .interfaces import SupportsPP
46
+ from .utils import (is_pp_missing_parameter,
47
+ make_empty_intermediate_tensors_factory, make_layers,
48
+ maybe_prefix)
49
+
50
+
51
+ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
52
+ closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
53
+ base = torch.tensor(
54
+ 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
55
+ dtype=torch.float32,
56
+ )
57
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
58
+ slopes = torch.pow(base, powers)
59
+
60
+ if closest_power_of_2 != total_num_heads:
61
+ extra_base = torch.tensor(
62
+ 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
63
+ dtype=torch.float32,
64
+ )
65
+ num_remaining_heads = min(closest_power_of_2,
66
+ total_num_heads - closest_power_of_2)
67
+ extra_powers = torch.arange(start=1,
68
+ end=1 + 2 * num_remaining_heads,
69
+ step=2,
70
+ dtype=torch.int32)
71
+ slopes = torch.cat(
72
+ [slopes, torch.pow(extra_base, extra_powers)], dim=0)
73
+ return slopes
74
+
75
+
76
+ class BloomAttention(nn.Module):
77
+
78
+ def __init__(
79
+ self,
80
+ config: BloomConfig,
81
+ cache_config: Optional[CacheConfig] = None,
82
+ quant_config: Optional[QuantizationConfig] = None,
83
+ prefix: str = "",
84
+ ):
85
+ super().__init__()
86
+ self.hidden_size = config.hidden_size
87
+ self.total_num_heads = config.n_head
88
+ self.head_dim = self.hidden_size // self.total_num_heads
89
+ assert self.head_dim * self.total_num_heads == self.hidden_size
90
+
91
+ tp_world_size = get_tensor_model_parallel_world_size()
92
+ assert self.total_num_heads % tp_world_size == 0
93
+ self.num_heads = self.total_num_heads // tp_world_size
94
+
95
+ self.query_key_value = QKVParallelLinear(
96
+ self.hidden_size,
97
+ self.head_dim,
98
+ self.total_num_heads,
99
+ bias=True,
100
+ quant_config=quant_config,
101
+ )
102
+ self.dense = RowParallelLinear(
103
+ self.hidden_size,
104
+ self.hidden_size,
105
+ bias=True,
106
+ quant_config=quant_config,
107
+ )
108
+
109
+ # Create the alibi slopes and slice them.
110
+ tp_rank = get_tensor_model_parallel_rank()
111
+ head_start = tp_rank * self.num_heads
112
+ head_end = (tp_rank + 1) * self.num_heads
113
+ alibi_slopes = _get_alibi_slopes(self.total_num_heads)
114
+ alibi_slopes = alibi_slopes[head_start:head_end].tolist()
115
+
116
+ scaling = self.head_dim**-0.5
117
+ self.attn = Attention(self.num_heads,
118
+ self.head_dim,
119
+ scaling,
120
+ alibi_slopes=alibi_slopes,
121
+ cache_config=cache_config,
122
+ quant_config=quant_config,
123
+ prefix=f"{prefix}.attn")
124
+
125
+ def forward(
126
+ self,
127
+ position_ids: torch.Tensor,
128
+ hidden_states: torch.Tensor,
129
+ kv_cache: torch.Tensor,
130
+ attn_metadata: AttentionMetadata,
131
+ ) -> torch.Tensor:
132
+ del position_ids # Unused.
133
+ qkv, _ = self.query_key_value(hidden_states)
134
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
135
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
136
+ output, _ = self.dense(attn_output)
137
+ return output
138
+
139
+
140
+ class BloomMLP(nn.Module):
141
+
142
+ def __init__(
143
+ self,
144
+ config: BloomConfig,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ ):
147
+ super().__init__()
148
+ hidden_size = config.hidden_size
149
+ self.dense_h_to_4h = ColumnParallelLinear(
150
+ hidden_size,
151
+ 4 * hidden_size,
152
+ quant_config=quant_config,
153
+ )
154
+ self.gelu_impl = get_act_fn("gelu")
155
+ self.dense_4h_to_h = RowParallelLinear(
156
+ 4 * hidden_size,
157
+ hidden_size,
158
+ quant_config=quant_config,
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ x, _ = self.dense_h_to_4h(x)
163
+ x = self.gelu_impl(x)
164
+ x, _ = self.dense_4h_to_h(x)
165
+ return x
166
+
167
+
168
+ class BloomBlock(nn.Module):
169
+
170
+ def __init__(
171
+ self,
172
+ config: BloomConfig,
173
+ cache_config: Optional[CacheConfig] = None,
174
+ quant_config: Optional[QuantizationConfig] = None,
175
+ prefix: str = "",
176
+ ):
177
+ super().__init__()
178
+ hidden_size = config.hidden_size
179
+
180
+ self.input_layernorm = nn.LayerNorm(hidden_size,
181
+ eps=config.layer_norm_epsilon)
182
+ self.self_attention = BloomAttention(config,
183
+ cache_config,
184
+ quant_config,
185
+ prefix=f"{prefix}.self_attention")
186
+ self.post_attention_layernorm = nn.LayerNorm(
187
+ hidden_size, eps=config.layer_norm_epsilon)
188
+ self.mlp = BloomMLP(config, quant_config)
189
+ self.apply_residual_connection_post_layernorm = (
190
+ config.apply_residual_connection_post_layernorm)
191
+
192
+ def forward(
193
+ self,
194
+ position_ids: torch.Tensor,
195
+ hidden_states: torch.Tensor,
196
+ kv_cache: torch.Tensor,
197
+ attn_metadata: AttentionMetadata,
198
+ ) -> torch.Tensor:
199
+ # Layer norm at the beginning of the transformer layer.
200
+ layernorm_output = self.input_layernorm(hidden_states)
201
+
202
+ # Layer norm post the self attention.
203
+ if self.apply_residual_connection_post_layernorm:
204
+ residual = layernorm_output
205
+ else:
206
+ residual = hidden_states
207
+
208
+ # Self attention.
209
+ attention_output = self.self_attention(
210
+ position_ids=position_ids,
211
+ hidden_states=layernorm_output,
212
+ kv_cache=kv_cache,
213
+ attn_metadata=attn_metadata,
214
+ )
215
+ attention_output = attention_output + residual
216
+ layernorm_output = self.post_attention_layernorm(attention_output)
217
+
218
+ # Get residual
219
+ if self.apply_residual_connection_post_layernorm:
220
+ residual = layernorm_output
221
+ else:
222
+ residual = attention_output
223
+
224
+ # MLP.
225
+ output = self.mlp(layernorm_output) + residual
226
+ return output
227
+
228
+
229
+ @support_torch_compile
230
+ class BloomModel(nn.Module):
231
+
232
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
233
+ super().__init__()
234
+
235
+ config = vllm_config.model_config.hf_config
236
+ cache_config = vllm_config.cache_config
237
+ quant_config = vllm_config.quant_config
238
+
239
+ self.embed_dim = config.hidden_size
240
+
241
+ # Embedding + LN Embedding
242
+ self.word_embeddings = VocabParallelEmbedding(
243
+ config.vocab_size,
244
+ self.embed_dim,
245
+ )
246
+ self.word_embeddings_layernorm = nn.LayerNorm(
247
+ self.embed_dim, eps=config.layer_norm_epsilon)
248
+
249
+ # Transformer blocks
250
+ self.start_layer, self.end_layer, self.h = make_layers(
251
+ config.num_hidden_layers,
252
+ lambda prefix: BloomBlock(
253
+ config, cache_config, quant_config, prefix=prefix),
254
+ prefix=f"{prefix}.h")
255
+
256
+ # Final Layer Norm
257
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
258
+ self.make_empty_intermediate_tensors = (
259
+ make_empty_intermediate_tensors_factory(["hidden_states"],
260
+ config.hidden_size))
261
+
262
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
263
+ return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
264
+
265
+ def forward(
266
+ self,
267
+ input_ids: torch.Tensor,
268
+ position_ids: torch.Tensor,
269
+ kv_caches: List[torch.Tensor],
270
+ attn_metadata: AttentionMetadata,
271
+ intermediate_tensors: Optional[IntermediateTensors],
272
+ inputs_embeds: Optional[torch.Tensor] = None,
273
+ ) -> Union[torch.Tensor, IntermediateTensors]:
274
+ if get_pp_group().is_first_rank:
275
+ if inputs_embeds is not None:
276
+ hidden_states = inputs_embeds
277
+ else:
278
+ hidden_states = self.get_input_embeddings(input_ids)
279
+ else:
280
+ assert intermediate_tensors is not None
281
+ hidden_states = intermediate_tensors["hidden_states"]
282
+ for i in range(self.start_layer, self.end_layer):
283
+ layer = self.h[i]
284
+ hidden_states = layer(
285
+ position_ids,
286
+ hidden_states,
287
+ kv_caches[i - self.start_layer],
288
+ attn_metadata,
289
+ )
290
+ if not get_pp_group().is_last_rank:
291
+ return IntermediateTensors({"hidden_states": hidden_states})
292
+ hidden_states = self.ln_f(hidden_states)
293
+ return hidden_states
294
+
295
+
296
+ class BloomForCausalLM(nn.Module, SupportsPP):
297
+
298
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
299
+ super().__init__()
300
+ config = vllm_config.model_config.hf_config
301
+ quant_config = vllm_config.quant_config
302
+ self.config = config
303
+ self.quant_config = quant_config
304
+ self.transformer = BloomModel(vllm_config=vllm_config,
305
+ prefix=maybe_prefix(
306
+ prefix, "transformer"))
307
+ if self.config.tie_word_embeddings:
308
+ self.lm_head = self.transformer.word_embeddings
309
+ else:
310
+ self.lm_head = ParallelLMHead(self.config.vocab_size,
311
+ self.config.hidden_size)
312
+
313
+ self.logits_processor = LogitsProcessor(config.vocab_size)
314
+ self.sampler = get_sampler()
315
+ self.make_empty_intermediate_tensors = (
316
+ self.transformer.make_empty_intermediate_tensors)
317
+
318
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
319
+ return self.transformer.get_input_embeddings(input_ids)
320
+
321
+ def forward(
322
+ self,
323
+ input_ids: torch.Tensor,
324
+ positions: torch.Tensor,
325
+ kv_caches: List[torch.Tensor],
326
+ attn_metadata: AttentionMetadata,
327
+ intermediate_tensors: Optional[IntermediateTensors] = None,
328
+ inputs_embeds: Optional[torch.Tensor] = None,
329
+ ) -> Union[torch.Tensor, IntermediateTensors]:
330
+ hidden_states = self.transformer(input_ids, positions, kv_caches,
331
+ attn_metadata, intermediate_tensors,
332
+ inputs_embeds)
333
+ return hidden_states
334
+
335
+ def compute_logits(
336
+ self,
337
+ hidden_states: torch.Tensor,
338
+ sampling_metadata: SamplingMetadata,
339
+ ) -> Optional[torch.Tensor]:
340
+ logits = self.logits_processor(self.lm_head, hidden_states,
341
+ sampling_metadata)
342
+ return logits
343
+
344
+ def sample(
345
+ self,
346
+ logits: torch.Tensor,
347
+ sampling_metadata: SamplingMetadata,
348
+ ) -> Optional[SamplerOutput]:
349
+ next_tokens = self.sampler(logits, sampling_metadata)
350
+ return next_tokens
351
+
352
+ def load_weights(self, weights: Iterable[Tuple[str,
353
+ torch.Tensor]]) -> Set[str]:
354
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
355
+ loaded_params: Set[str] = set()
356
+ for name, loaded_weight in weights:
357
+ if name == "lm_head.weight":
358
+ continue
359
+ if not name.startswith("transformer."):
360
+ name = "transformer." + name
361
+ if is_pp_missing_parameter(name, self):
362
+ continue
363
+ param = params_dict[name]
364
+
365
+ if "query_key_value" in name:
366
+ # NOTE: BLOOM's fused QKV's output_dim has the shape of
367
+ # (num_heads * 3 * head_size), while the
368
+ # required shape is (3 * num_heads * head_size).
369
+ # Thus, we need weight conversion.
370
+ output_dim = getattr(param, "output_dim", None)
371
+ num_heads = self.config.num_attention_heads
372
+ if output_dim is not None:
373
+ loaded_weight_shape = loaded_weight.shape
374
+ loaded_weight = loaded_weight.view(
375
+ loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
376
+ loaded_weight_shape[output_dim + 1:])
377
+ loaded_weight = loaded_weight.transpose(
378
+ output_dim, output_dim + 1)
379
+ loaded_weight = loaded_weight.reshape(loaded_weight_shape)
380
+
381
+ weight_loader = getattr(param, "weight_loader",
382
+ default_weight_loader)
383
+ weight_loader(param, loaded_weight)
384
+ loaded_params.add(name)
385
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/chameleon.py ADDED
@@ -0,0 +1,1161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import cached_property
4
+ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
5
+ Tuple, TypedDict, Union)
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
11
+ ChameleonVQVAEConfig)
12
+
13
+ from vllm.attention import Attention, AttentionMetadata
14
+ from vllm.config import CacheConfig, VllmConfig
15
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
16
+ from vllm.logger import init_logger
17
+ from vllm.model_executor.layers.activation import SiluAndMul
18
+ from vllm.model_executor.layers.layernorm import RMSNorm
19
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ RowParallelLinear)
22
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
+ from vllm.model_executor.layers.quantization import QuantizationConfig
24
+ from vllm.model_executor.layers.rotary_embedding import get_rope
25
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
26
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
27
+ ParallelLMHead, VocabParallelEmbedding)
28
+ from vllm.model_executor.model_loader.weight_utils import (
29
+ default_weight_loader, row_parallel_weight_loader)
30
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
31
+ from vllm.model_executor.utils import set_weight_attrs
32
+ from vllm.multimodal import MULTIMODAL_REGISTRY
33
+ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
34
+ NestedTensors)
35
+ from vllm.multimodal.parse import MultiModalDataItems
36
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
37
+ BaseProcessingInfo, PromptReplacement,
38
+ PromptReplacementDetails)
39
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
40
+ from vllm.sequence import IntermediateTensors
41
+
42
+ from .interfaces import SupportsMultiModal, SupportsPP
43
+ from .utils import (is_pp_missing_parameter,
44
+ make_empty_intermediate_tensors_factory, make_layers,
45
+ maybe_prefix, merge_multimodal_embeddings)
46
+
47
+ logger = init_logger(__name__)
48
+
49
+
50
+ class ChameleonImagePixelInputs(TypedDict):
51
+ type: Literal["pixel_values"]
52
+ data: torch.Tensor
53
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
54
+
55
+
56
+ class ChameleonProcessingInfo(BaseProcessingInfo):
57
+
58
+ def get_hf_config(self):
59
+ return self.ctx.get_hf_config(ChameleonConfig)
60
+
61
+ def get_hf_processor(self):
62
+ return self.ctx.get_hf_processor(ChameleonProcessor)
63
+
64
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
65
+ return {"image": 1}
66
+
67
+ def get_mm_max_tokens_per_item(
68
+ self,
69
+ seq_len: int,
70
+ mm_counts: Mapping[str, int],
71
+ ) -> Mapping[str, int]:
72
+ return {"image": self.get_num_image_tokens()}
73
+
74
+ def get_num_image_tokens(self) -> int:
75
+ processor = self.get_hf_processor()
76
+ return processor.image_seq_length
77
+
78
+
79
+ class ChameleonDummyInputsBuilder(
80
+ BaseDummyInputsBuilder[ChameleonProcessingInfo]):
81
+
82
+ def get_dummy_processor_inputs(
83
+ self,
84
+ seq_len: int,
85
+ mm_counts: Mapping[str, int],
86
+ ) -> ProcessorInputs:
87
+ config = self.info.get_hf_config()
88
+
89
+ width = height = config.vq_config.resolution
90
+ num_images = mm_counts.get("image", 0)
91
+
92
+ mm_data = {
93
+ "image":
94
+ self._get_dummy_images(width=width,
95
+ height=height,
96
+ num_images=num_images)
97
+ }
98
+
99
+ return ProcessorInputs(
100
+ prompt_text="<image>" * num_images,
101
+ mm_data=mm_data,
102
+ )
103
+
104
+
105
+ class ChameleonMultiModalProcessor(
106
+ BaseMultiModalProcessor[ChameleonProcessingInfo]):
107
+
108
+ def _call_hf_processor(
109
+ self,
110
+ prompt: str,
111
+ mm_data: Mapping[str, object],
112
+ mm_kwargs: Mapping[str, object],
113
+ ) -> BatchFeature:
114
+ if not mm_data:
115
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
116
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
117
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
118
+
119
+ return super()._call_hf_processor(
120
+ prompt=prompt,
121
+ mm_data=mm_data,
122
+ mm_kwargs=mm_kwargs,
123
+ )
124
+
125
+ def _apply_hf_processor_tokens_only(
126
+ self,
127
+ prompt_tokens: list[int],
128
+ ) -> list[int]:
129
+ # HF processor adds sep token for chat mode
130
+ tokenizer = self.info.get_tokenizer()
131
+ vocab = tokenizer.get_vocab()
132
+
133
+ sep_token_id = vocab[tokenizer.sep_token] # type: ignore
134
+
135
+ return prompt_tokens + [sep_token_id]
136
+
137
+ def _get_mm_fields_config(
138
+ self,
139
+ hf_inputs: BatchFeature,
140
+ hf_processor_mm_kwargs: Mapping[str, object],
141
+ ) -> Mapping[str, MultiModalFieldConfig]:
142
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"))
143
+
144
+ def _get_prompt_replacements(
145
+ self,
146
+ mm_items: MultiModalDataItems,
147
+ hf_processor_mm_kwargs: Mapping[str, object],
148
+ out_mm_kwargs: MultiModalKwargs,
149
+ ) -> list[PromptReplacement]:
150
+ processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
151
+ tokenizer = self.info.get_tokenizer()
152
+ vocab = tokenizer.get_vocab()
153
+
154
+ image_start_id = vocab[processor.image_start_token]
155
+ image_token_id = vocab[processor.image_token]
156
+ image_end_id = vocab[processor.image_end_token]
157
+
158
+ num_image_tokens = self.info.get_num_image_tokens()
159
+ image_tokens = [image_token_id] * num_image_tokens
160
+
161
+ return [
162
+ PromptReplacement(
163
+ modality="image",
164
+ target=[image_token_id],
165
+ replacement=PromptReplacementDetails(
166
+ full=([image_start_id] + image_tokens + [image_end_id]),
167
+ features=image_tokens,
168
+ ),
169
+ )
170
+ ]
171
+
172
+
173
+ class ChameleonLayerNorm(nn.LayerNorm):
174
+
175
+ def __init__(self, hidden_size, *args, **kwargs):
176
+ super().__init__(hidden_size, *args, **kwargs)
177
+ self.normalized_shape = (hidden_size[-1], )
178
+
179
+ set_weight_attrs(self.weight,
180
+ {"weight_loader": row_parallel_weight_loader})
181
+ set_weight_attrs(self.bias,
182
+ {"weight_loader": row_parallel_weight_loader})
183
+
184
+ def forward(self, hidden_states):
185
+ hidden_states = F.layer_norm(hidden_states,
186
+ self.normalized_shape,
187
+ None,
188
+ None,
189
+ eps=1e-5)
190
+ hidden_states = hidden_states * self.weight + self.bias
191
+ return hidden_states
192
+
193
+
194
+ # Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
195
+ class ChameleonMLP(nn.Module):
196
+
197
+ def __init__(
198
+ self,
199
+ hidden_size: int,
200
+ intermediate_size: int,
201
+ hidden_act: str,
202
+ quant_config: Optional[QuantizationConfig] = None,
203
+ bias: bool = False,
204
+ ) -> None:
205
+ super().__init__()
206
+ self.gate_up_proj = MergedColumnParallelLinear(
207
+ input_size=hidden_size,
208
+ output_sizes=[intermediate_size] * 2,
209
+ bias=bias,
210
+ quant_config=quant_config)
211
+ self.down_proj = RowParallelLinear(input_size=intermediate_size,
212
+ output_size=hidden_size,
213
+ bias=bias,
214
+ quant_config=quant_config)
215
+ if hidden_act != "silu":
216
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
217
+ "Only silu is supported for now.")
218
+ self.act_fn = SiluAndMul()
219
+
220
+ def forward(self, x):
221
+ gate_up, _ = self.gate_up_proj(x)
222
+ x = self.act_fn(gate_up)
223
+ x, _ = self.down_proj(x)
224
+ return x
225
+
226
+
227
+ # Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
228
+ class ChameleonAttention(nn.Module):
229
+
230
+ def __init__(
231
+ self,
232
+ hidden_size: int,
233
+ num_heads: int,
234
+ num_kv_heads: int,
235
+ rope_theta: float = 10000,
236
+ rope_scaling: Optional[Dict[str, Any]] = None,
237
+ max_position_embeddings: int = 4096,
238
+ quant_config: Optional[QuantizationConfig] = None,
239
+ bias: bool = False,
240
+ cache_config: Optional[CacheConfig] = None,
241
+ prefix: str = "",
242
+ ) -> None:
243
+ super().__init__()
244
+ self.hidden_size = hidden_size
245
+ tp_size = get_tensor_model_parallel_world_size()
246
+ self.total_num_heads = num_heads
247
+ assert self.total_num_heads % tp_size == 0
248
+ self.num_heads = self.total_num_heads // tp_size
249
+ self.total_num_kv_heads = num_kv_heads
250
+ if self.total_num_kv_heads >= tp_size:
251
+ # Number of KV heads is greater than TP size, so we partition
252
+ # the KV heads across multiple tensor parallel GPUs.
253
+ assert self.total_num_kv_heads % tp_size == 0
254
+ else:
255
+ # Number of KV heads is less than TP size, so we replicate
256
+ # the KV heads across multiple tensor parallel GPUs.
257
+ assert tp_size % self.total_num_kv_heads == 0
258
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
259
+ self.head_dim = hidden_size // self.total_num_heads
260
+ self.q_size = self.num_heads * self.head_dim
261
+ self.kv_size = self.num_kv_heads * self.head_dim
262
+ self.scaling = self.head_dim**-0.5
263
+ self.rope_theta = rope_theta
264
+ self.max_position_embeddings = max_position_embeddings
265
+
266
+ self.qkv_proj = QKVParallelLinear(
267
+ hidden_size=hidden_size,
268
+ head_size=self.head_dim,
269
+ total_num_heads=self.total_num_heads,
270
+ total_num_kv_heads=self.total_num_kv_heads,
271
+ bias=bias,
272
+ quant_config=quant_config,
273
+ )
274
+ self.o_proj = RowParallelLinear(
275
+ input_size=self.total_num_heads * self.head_dim,
276
+ output_size=hidden_size,
277
+ bias=bias,
278
+ quant_config=quant_config,
279
+ )
280
+ self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
281
+ self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
282
+ self.rotary_emb = get_rope(
283
+ self.head_dim,
284
+ rotary_dim=self.head_dim,
285
+ max_position=max_position_embeddings,
286
+ base=rope_theta,
287
+ rope_scaling=rope_scaling,
288
+ )
289
+
290
+ self.attn = Attention(self.num_heads,
291
+ self.head_dim,
292
+ self.scaling,
293
+ num_kv_heads=self.num_kv_heads,
294
+ cache_config=cache_config,
295
+ quant_config=quant_config,
296
+ prefix=f"{prefix}.attn")
297
+
298
+ def _apply_qk_norm(self, q: torch.Tensor,
299
+ k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ # reshape for layernorm
301
+ q = q.reshape(-1, self.num_heads, self.head_dim)
302
+ k = k.reshape(-1, self.num_kv_heads, self.head_dim)
303
+ q = self.q_norm(q)
304
+ k = self.k_norm(k)
305
+ q = q.view(*q.shape[:-2], -1)
306
+ k = k.view(*k.shape[:-2], -1)
307
+ return q, k
308
+
309
+ def forward(
310
+ self,
311
+ positions: torch.Tensor,
312
+ hidden_states: torch.Tensor,
313
+ kv_cache: torch.Tensor,
314
+ attn_metadata: AttentionMetadata,
315
+ ) -> torch.Tensor:
316
+ qkv, _ = self.qkv_proj(hidden_states)
317
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
318
+ q, k = self._apply_qk_norm(q, k)
319
+
320
+ q, k = self.rotary_emb(positions, q, k)
321
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
322
+ output, _ = self.o_proj(attn_output)
323
+ return output
324
+
325
+
326
+ class ChameleonDecoderLayer(nn.Module):
327
+
328
+ def __init__(
329
+ self,
330
+ config: ChameleonConfig,
331
+ cache_config: Optional[CacheConfig] = None,
332
+ quant_config: Optional[QuantizationConfig] = None,
333
+ prefix: str = "",
334
+ ) -> None:
335
+ super().__init__()
336
+ self.hidden_size = config.hidden_size
337
+ rope_theta = getattr(config, "rope_theta", 10000)
338
+ rope_scaling = getattr(config, "rope_scaling", None)
339
+ if rope_scaling is not None and getattr(
340
+ config, "original_max_position_embeddings", None):
341
+ rope_scaling["original_max_position_embeddings"] = (
342
+ config.original_max_position_embeddings)
343
+ max_position_embeddings = getattr(config, "max_position_embeddings",
344
+ 4096)
345
+
346
+ self.self_attn = ChameleonAttention(
347
+ hidden_size=self.hidden_size,
348
+ num_heads=config.num_attention_heads,
349
+ num_kv_heads=getattr(config, "num_key_value_heads",
350
+ config.num_attention_heads),
351
+ rope_theta=rope_theta,
352
+ rope_scaling=rope_scaling,
353
+ max_position_embeddings=max_position_embeddings,
354
+ quant_config=quant_config,
355
+ bias=False,
356
+ cache_config=cache_config,
357
+ prefix=f"{prefix}.self_attn",
358
+ )
359
+ self.mlp = ChameleonMLP(
360
+ hidden_size=self.hidden_size,
361
+ intermediate_size=config.intermediate_size,
362
+ hidden_act=config.hidden_act,
363
+ quant_config=quant_config,
364
+ bias=getattr(config, "mlp_bias", False),
365
+ )
366
+ self.input_layernorm = RMSNorm(config.hidden_size,
367
+ eps=config.rms_norm_eps)
368
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
369
+ eps=config.rms_norm_eps)
370
+
371
+ def forward(
372
+ self,
373
+ positions: torch.Tensor,
374
+ hidden_states: torch.Tensor,
375
+ kv_cache: torch.Tensor,
376
+ attn_metadata: AttentionMetadata,
377
+ residual: Optional[torch.Tensor],
378
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
379
+
380
+ if residual is None:
381
+ residual = hidden_states
382
+ hidden_states = self.input_layernorm(hidden_states)
383
+ else:
384
+ hidden_states, residual = self.input_layernorm(
385
+ hidden_states, residual)
386
+ hidden_states = self.self_attn(
387
+ positions=positions,
388
+ hidden_states=hidden_states,
389
+ kv_cache=kv_cache,
390
+ attn_metadata=attn_metadata,
391
+ )
392
+
393
+ # Fully Connected
394
+ hidden_states, residual = self.post_attention_layernorm(
395
+ hidden_states, residual)
396
+ hidden_states = self.mlp(hidden_states)
397
+
398
+ return hidden_states, residual
399
+
400
+
401
+ class ChameleonSwinDecoderLayer(nn.Module):
402
+
403
+ def __init__(
404
+ self,
405
+ config: ChameleonConfig,
406
+ cache_config: Optional[CacheConfig] = None,
407
+ quant_config: Optional[QuantizationConfig] = None,
408
+ prefix: str = "",
409
+ ) -> None:
410
+ super().__init__()
411
+ self.hidden_size = config.hidden_size
412
+ rope_theta = getattr(config, "rope_theta", 10000)
413
+ rope_scaling = getattr(config, "rope_scaling", None)
414
+ if rope_scaling is not None and getattr(
415
+ config, "original_max_position_embeddings", None):
416
+ rope_scaling["original_max_position_embeddings"] = (
417
+ config.original_max_position_embeddings)
418
+ max_position_embeddings = getattr(config, "max_position_embeddings",
419
+ 4096)
420
+
421
+ self.self_attn = ChameleonAttention(
422
+ hidden_size=self.hidden_size,
423
+ num_heads=config.num_attention_heads,
424
+ num_kv_heads=getattr(config, "num_key_value_heads",
425
+ config.num_attention_heads),
426
+ rope_theta=rope_theta,
427
+ rope_scaling=rope_scaling,
428
+ max_position_embeddings=max_position_embeddings,
429
+ quant_config=quant_config,
430
+ bias=False,
431
+ cache_config=cache_config,
432
+ prefix=f"{prefix}.self_attn",
433
+ )
434
+ self.mlp = ChameleonMLP(
435
+ hidden_size=self.hidden_size,
436
+ intermediate_size=config.intermediate_size,
437
+ hidden_act=config.hidden_act,
438
+ quant_config=quant_config,
439
+ bias=getattr(config, "mlp_bias", False),
440
+ )
441
+ self.input_layernorm = RMSNorm(config.hidden_size,
442
+ eps=config.rms_norm_eps)
443
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
444
+ eps=config.rms_norm_eps)
445
+
446
+ def forward(
447
+ self,
448
+ positions: torch.Tensor,
449
+ hidden_states: torch.Tensor,
450
+ kv_cache: torch.Tensor,
451
+ attn_metadata: AttentionMetadata,
452
+ residual: Optional[torch.Tensor],
453
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
454
+
455
+ residual = hidden_states
456
+ hidden_states = self.self_attn(
457
+ positions=positions,
458
+ hidden_states=hidden_states,
459
+ kv_cache=kv_cache,
460
+ attn_metadata=attn_metadata,
461
+ )
462
+
463
+ hidden_states = self.input_layernorm(hidden_states)
464
+ hidden_states = hidden_states + residual
465
+
466
+ # Fully Connected
467
+ residual = hidden_states
468
+ hidden_states = self.mlp(hidden_states)
469
+ hidden_states = self.post_attention_layernorm(hidden_states)
470
+ hidden_states = residual + hidden_states
471
+
472
+ return hidden_states, residual
473
+
474
+
475
+ # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
476
+ class ChameleonVQVAEVectorQuantizer(nn.Module):
477
+
478
+ def __init__(self, config: ChameleonVQVAEConfig):
479
+ super().__init__()
480
+ self.num_embeddings = config.num_embeddings
481
+ self.embedding_dim = config.embed_dim
482
+ self.beta = getattr(config, "beta", 0.25)
483
+
484
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
485
+ self.re_embed = self.num_embeddings
486
+
487
+ def forward(self, hidden_state: torch.Tensor):
488
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
489
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
490
+
491
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
492
+ distances = (
493
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
494
+ torch.sum(self.embedding.weight**2, dim=1) -
495
+ 2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
496
+ self.embedding.weight.transpose(0, 1)))
497
+
498
+ min_encoding_indices = torch.argmin(distances, dim=1)
499
+ hidden_state_quant = self.embedding(min_encoding_indices).view(
500
+ hidden_state.shape)
501
+
502
+ # compute loss for embedding
503
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
504
+ 2) + self.beta * torch.mean(
505
+ (hidden_state_quant - hidden_state.detach())**2)
506
+
507
+ # preserve gradients
508
+ hidden_state_quant = hidden_state + (hidden_state_quant -
509
+ hidden_state).detach()
510
+
511
+ # reshape back to match original input shape
512
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
513
+ 2).contiguous()
514
+
515
+ return hidden_state_quant, loss, min_encoding_indices
516
+
517
+
518
+ # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
519
+ class ChameleonVQVAEEncoderConvDownsample(nn.Module):
520
+
521
+ def __init__(self, in_channels: int):
522
+ super().__init__()
523
+ self.conv = nn.Conv2d(in_channels,
524
+ in_channels,
525
+ kernel_size=3,
526
+ stride=2,
527
+ padding=0)
528
+
529
+ def forward(self, hidden_states: torch.Tensor):
530
+ # no asymmetric padding in torch conv, must do it ourselves
531
+ hidden_states = F.pad(hidden_states,
532
+ pad=(0, 1, 0, 1),
533
+ mode="constant",
534
+ value=0)
535
+ hidden_states = self.conv(hidden_states)
536
+ return hidden_states
537
+
538
+
539
+ # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
540
+ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
541
+
542
+ def __init__(
543
+ self,
544
+ config: ChameleonVQVAEConfig,
545
+ in_channels: int,
546
+ out_channels=None,
547
+ conv_shortcut=False,
548
+ ):
549
+ super().__init__()
550
+ self.in_channels = in_channels
551
+ self.out_channels = in_channels if out_channels is None \
552
+ else out_channels
553
+ self.use_conv_shortcut = conv_shortcut
554
+
555
+ self.norm1 = torch.nn.GroupNorm(num_groups=32,
556
+ num_channels=in_channels,
557
+ eps=1e-6,
558
+ affine=True)
559
+ self.conv1 = torch.nn.Conv2d(in_channels,
560
+ out_channels,
561
+ kernel_size=3,
562
+ stride=1,
563
+ padding=1)
564
+ self.norm2 = torch.nn.GroupNorm(num_groups=32,
565
+ num_channels=out_channels,
566
+ eps=1e-6,
567
+ affine=True)
568
+ self.dropout = torch.nn.Dropout(config.dropout)
569
+ self.conv2 = torch.nn.Conv2d(out_channels,
570
+ out_channels,
571
+ kernel_size=3,
572
+ stride=1,
573
+ padding=1)
574
+ if self.in_channels != self.out_channels:
575
+ if self.use_conv_shortcut:
576
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
577
+ out_channels,
578
+ kernel_size=3,
579
+ stride=1,
580
+ padding=1)
581
+ else:
582
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
583
+ out_channels,
584
+ kernel_size=1,
585
+ stride=1,
586
+ padding=0)
587
+
588
+ def forward(self, hidden_states: torch.Tensor):
589
+ residual = hidden_states
590
+ hidden_states = self.norm1(hidden_states)
591
+ hidden_states *= torch.sigmoid(hidden_states)
592
+ hidden_states = self.conv1(hidden_states)
593
+
594
+ hidden_states = self.norm2(hidden_states)
595
+ hidden_states *= torch.sigmoid(hidden_states)
596
+ hidden_states = self.dropout(hidden_states)
597
+ hidden_states = self.conv2(hidden_states)
598
+
599
+ if self.in_channels != self.out_channels:
600
+ if self.use_conv_shortcut:
601
+ residual = self.conv_shortcut(residual)
602
+ else:
603
+ residual = self.nin_shortcut(residual)
604
+
605
+ return residual + hidden_states
606
+
607
+
608
+ # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
609
+ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
610
+
611
+ def __init__(self, in_channels: int):
612
+ super().__init__()
613
+ self.in_channels = in_channels
614
+
615
+ self.norm = torch.nn.GroupNorm(num_groups=32,
616
+ num_channels=in_channels,
617
+ eps=1e-6,
618
+ affine=True)
619
+ self.q = torch.nn.Conv2d(in_channels,
620
+ in_channels,
621
+ kernel_size=1,
622
+ stride=1,
623
+ padding=0)
624
+ self.k = torch.nn.Conv2d(in_channels,
625
+ in_channels,
626
+ kernel_size=1,
627
+ stride=1,
628
+ padding=0)
629
+ self.v = torch.nn.Conv2d(in_channels,
630
+ in_channels,
631
+ kernel_size=1,
632
+ stride=1,
633
+ padding=0)
634
+ self.proj_out = torch.nn.Conv2d(in_channels,
635
+ in_channels,
636
+ kernel_size=1,
637
+ stride=1,
638
+ padding=0)
639
+
640
+ def forward(self, hidden_states: torch.Tensor):
641
+ residual = hidden_states
642
+ hidden_states = self.norm(hidden_states)
643
+ query_states = self.q(hidden_states)
644
+ key_states = self.k(hidden_states)
645
+ value_states = self.v(hidden_states)
646
+
647
+ # compute attention
648
+ batch_size, channels, height, width = query_states.shape
649
+ query_states = query_states.reshape(batch_size, channels,
650
+ height * width).permute(0, 2, 1)
651
+ key_states = key_states.reshape(batch_size, channels, height * width)
652
+ attn_weights = torch.bmm(query_states, key_states)
653
+ attn_weights = attn_weights * (int(channels)**(-0.5))
654
+ attn_weights = F.softmax(attn_weights, dim=2)
655
+
656
+ # attend to values
657
+ value_states = value_states.reshape(batch_size, channels,
658
+ height * width)
659
+ attn_weights = attn_weights.permute(0, 2, 1)
660
+ attn_output = torch.bmm(value_states,
661
+ attn_weights).reshape(batch_size, channels,
662
+ height, width)
663
+
664
+ attn_output = self.proj_out(attn_output)
665
+ return residual + attn_output
666
+
667
+
668
+ # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
669
+ class ChameleonVQVAEEncoder(nn.Module):
670
+
671
+ def __init__(self, config: ChameleonVQVAEConfig):
672
+ super().__init__()
673
+
674
+ self.num_resolutions = len(config.channel_multiplier)
675
+ self.num_res_blocks = config.num_res_blocks
676
+ base_channels = config.base_channels
677
+ resolution = config.resolution
678
+ in_channels = config.in_channels
679
+ double_latent = config.double_latent
680
+ latent_channels = config.latent_channels
681
+ channel_multiplier = config.channel_multiplier
682
+
683
+ self.conv_in = torch.nn.Conv2d(in_channels,
684
+ base_channels,
685
+ kernel_size=3,
686
+ stride=1,
687
+ padding=1)
688
+
689
+ curr_res = resolution
690
+ in_channel_multiplier = (1, ) + tuple(channel_multiplier)
691
+ self.in_channel_multiplier = in_channel_multiplier
692
+ self.down = nn.ModuleList()
693
+ for i_level in range(self.num_resolutions):
694
+ block = nn.ModuleList()
695
+ attn = nn.ModuleList()
696
+ block_in = base_channels * in_channel_multiplier[i_level]
697
+ block_out = base_channels * channel_multiplier[i_level]
698
+ for i_block in range(self.num_res_blocks):
699
+ block.append(
700
+ ChameleonVQVAEEncoderResnetBlock(
701
+ config=config,
702
+ in_channels=block_in,
703
+ out_channels=block_out,
704
+ ))
705
+ block_in = block_out
706
+ if (config.attn_resolutions is not None
707
+ and curr_res in config.attn_resolutions
708
+ and config.attn_type == "vanilla"):
709
+ attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
710
+
711
+ down = nn.Module()
712
+ down.block = block
713
+ down.attn = attn
714
+ if i_level != self.num_resolutions - 1:
715
+ down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
716
+ curr_res = curr_res // 2
717
+ self.down.append(down)
718
+
719
+ self.mid = nn.Module()
720
+ self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
721
+ config=config,
722
+ in_channels=block_in,
723
+ out_channels=block_in,
724
+ )
725
+ self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
726
+ block_in) if config.attn_type == "vanilla" else nn.Identity()
727
+ self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
728
+ config=config,
729
+ in_channels=block_in,
730
+ out_channels=block_in,
731
+ )
732
+
733
+ self.norm_out = torch.nn.GroupNorm(num_groups=32,
734
+ num_channels=block_in,
735
+ eps=1e-6,
736
+ affine=True)
737
+ self.conv_out = torch.nn.Conv2d(
738
+ block_in,
739
+ 2 * latent_channels if double_latent else latent_channels,
740
+ kernel_size=3,
741
+ stride=1,
742
+ padding=1,
743
+ )
744
+
745
+ def forward(self, pixel_values: torch.Tensor):
746
+ pixel_values = pixel_values.to(self.conv_in.weight.dtype)
747
+
748
+ # downsampling
749
+ hidden_states = [self.conv_in(pixel_values)]
750
+ for i_level in range(self.num_resolutions):
751
+ for i_block in range(self.num_res_blocks):
752
+ hidden_state = self.down[i_level].block[i_block](
753
+ hidden_states[-1])
754
+ if len(self.down[i_level].attn) > 0:
755
+ hidden_state = self.down[i_level].attn[i_block](
756
+ hidden_state)
757
+ hidden_states.append(hidden_state)
758
+ if i_level != self.num_resolutions - 1:
759
+ hidden_states.append(self.down[i_level].downsample(
760
+ hidden_states[-1]))
761
+
762
+ # middle
763
+ last_hidden_state = hidden_states[-1]
764
+ last_hidden_state = self.mid.block_1(last_hidden_state)
765
+ last_hidden_state = self.mid.attn_1(last_hidden_state)
766
+ last_hidden_state = self.mid.block_2(last_hidden_state)
767
+
768
+ # end
769
+ last_hidden_state = self.norm_out(last_hidden_state)
770
+ last_hidden_state *= torch.sigmoid(last_hidden_state)
771
+ last_hidden_state = self.conv_out(last_hidden_state)
772
+ return last_hidden_state
773
+
774
+
775
+ # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
776
+ class ChameleonVQVAE(nn.Module):
777
+
778
+ def __init__(self, config: ChameleonVQVAEConfig):
779
+ super().__init__()
780
+ self.encoder = ChameleonVQVAEEncoder(config)
781
+ self.quantize = ChameleonVQVAEVectorQuantizer(config)
782
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels,
783
+ config.embed_dim, 1)
784
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
785
+ config.latent_channels, 1)
786
+ self.eval() # Chameleon's VQ model is frozen
787
+
788
+ def encode(
789
+ self, pixel_values: torch.Tensor
790
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
791
+ hidden_states = self.encoder(pixel_values)
792
+ hidden_states = self.quant_conv(hidden_states)
793
+ quant, emb_loss, indices = self.quantize(hidden_states)
794
+ return quant, emb_loss, indices
795
+
796
+
797
+ # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
798
+ class ChameleonImageVocabularyMapping:
799
+ """
800
+ A class for mapping discrete image tokens from VQGAN to BPE tokens.
801
+ """
802
+
803
+ def __init__(self, vocab_map: Dict[str, int]):
804
+ self.vocab_map = vocab_map
805
+ self.image_token_id = vocab_map.get("<image>")
806
+
807
+ @cached_property
808
+ def val2name(self):
809
+ return {v: k for k, v in self.vocab_map.items()}
810
+
811
+ @cached_property
812
+ def image_tokens(self):
813
+ return sorted([
814
+ val for name, val in self.vocab_map.items()
815
+ if name.startswith("IMGIMG")
816
+ ])
817
+
818
+ @cached_property
819
+ def bpe2img(self):
820
+ img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
821
+
822
+ def remap(old_name: str) -> str:
823
+ return "".join(
824
+ img_tkn_chr_mapping.get(c, c)
825
+ for c in old_name[len("IMGIMG"):-1])
826
+
827
+ return {
828
+ tok: int(remap(self.val2name[tok]))
829
+ for tok in self.image_tokens
830
+ }
831
+
832
+ @cached_property
833
+ def img2bpe(self):
834
+ return {v: k for k, v in self.bpe2img.items()}
835
+
836
+ @cached_property
837
+ def bpe2img_search_tensors(self):
838
+ return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
839
+ sorted(self.bpe2img.values()))
840
+
841
+ @cached_property
842
+ def img2bpe_mapping_tensor(self):
843
+ mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
844
+ for k, v in self.img2bpe.items():
845
+ mapping[k] = v
846
+ return mapping
847
+
848
+ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
849
+ device = img_batch.device
850
+ img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
851
+ return img_tokens.to(device)
852
+
853
+
854
+ class ChameleonModel(nn.Module):
855
+
856
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
857
+ super().__init__()
858
+
859
+ config = vllm_config.model_config.hf_config
860
+ cache_config = vllm_config.cache_config
861
+ quant_config = vllm_config.quant_config
862
+
863
+ self.config = config
864
+ self.padding_idx = config.pad_token_id
865
+ self.vocab_size = config.vocab_size
866
+ self.embed_tokens = VocabParallelEmbedding(
867
+ self.vocab_size,
868
+ config.hidden_size,
869
+ )
870
+ self.vocabulary_mapping = ChameleonImageVocabularyMapping(
871
+ config.vocabulary_map)
872
+ decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
873
+ else ChameleonSwinDecoderLayer
874
+
875
+ self.start_layer, self.end_layer, self.layers = make_layers(
876
+ config.num_hidden_layers,
877
+ lambda prefix: decoder_layer(config=config,
878
+ cache_config=cache_config,
879
+ quant_config=quant_config,
880
+ prefix=prefix),
881
+ prefix=f"{prefix}.layers",
882
+ )
883
+
884
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
885
+ self.vqmodel = ChameleonVQVAE(config.vq_config)
886
+ self.make_empty_intermediate_tensors = (
887
+ make_empty_intermediate_tensors_factory(
888
+ ["hidden_states", "residual"], config.hidden_size))
889
+
890
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
891
+ return self.embed_tokens(input_ids)
892
+
893
+ def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
894
+ """
895
+ Tokenizes images into discrete tokens with VQGAN module. Converts
896
+ obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
897
+ special tokens.
898
+ """
899
+ batch_size = pixel_values.shape[0]
900
+ _, _, image_toks = self.vqmodel.encode(pixel_values)
901
+ bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
902
+ bpe_toks = bpe_toks.view(batch_size, -1)
903
+ return bpe_toks
904
+
905
+ def forward(
906
+ self,
907
+ input_ids: Optional[torch.Tensor],
908
+ positions: torch.Tensor,
909
+ kv_caches: List[torch.Tensor],
910
+ attn_metadata: AttentionMetadata,
911
+ intermediate_tensors: Optional[IntermediateTensors],
912
+ inputs_embeds: Optional[torch.Tensor] = None,
913
+ ) -> Union[torch.Tensor, IntermediateTensors]:
914
+ if get_pp_group().is_first_rank:
915
+ if inputs_embeds is not None:
916
+ hidden_states = inputs_embeds
917
+ else:
918
+ hidden_states = self.get_input_embeddings(input_ids)
919
+ residual = None
920
+ else:
921
+ assert intermediate_tensors is not None
922
+ hidden_states = intermediate_tensors["hidden_states"]
923
+ residual = intermediate_tensors["residual"]
924
+ for i in range(self.start_layer, self.end_layer):
925
+ layer = self.layers[i]
926
+ hidden_states, residual = layer(
927
+ positions,
928
+ hidden_states,
929
+ kv_caches[i - self.start_layer],
930
+ attn_metadata,
931
+ residual,
932
+ )
933
+ if not get_pp_group().is_last_rank:
934
+ return IntermediateTensors({
935
+ "hidden_states": hidden_states,
936
+ "residual": residual
937
+ })
938
+ hidden_states, _ = self.norm(hidden_states, residual)
939
+ return hidden_states
940
+
941
+
942
+ @MULTIMODAL_REGISTRY.register_processor(
943
+ ChameleonMultiModalProcessor,
944
+ info=ChameleonProcessingInfo,
945
+ dummy_inputs=ChameleonDummyInputsBuilder)
946
+ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
947
+ SupportsPP):
948
+
949
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
950
+ super().__init__()
951
+ config = vllm_config.model_config.hf_config
952
+ multimodal_config = vllm_config.model_config.multimodal_config
953
+ self.config = config
954
+ self.multimodal_config = multimodal_config
955
+ self.model = ChameleonModel(vllm_config=vllm_config,
956
+ prefix=maybe_prefix(prefix, "model"))
957
+ self.unpadded_vocab_size = config.vocab_size
958
+ self.lm_head = ParallelLMHead(
959
+ self.unpadded_vocab_size,
960
+ config.hidden_size,
961
+ )
962
+ if config.tie_word_embeddings:
963
+ self.lm_head.weight = self.model.embed_tokens.weight
964
+
965
+ logit_scale = getattr(config, "logit_scale", 1.0)
966
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
967
+ config.vocab_size, logit_scale)
968
+ self.sampler = get_sampler()
969
+ self.make_empty_intermediate_tensors = (
970
+ self.model.make_empty_intermediate_tensors)
971
+
972
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
973
+ vq_config: ChameleonVQVAEConfig = self.config.vq_config
974
+ expected_dims = (3, vq_config.resolution, vq_config.resolution)
975
+ actual_dims = tuple(data.shape[1:])
976
+
977
+ if actual_dims != expected_dims:
978
+ expected_expr = ("batch_size", *map(str, expected_dims))
979
+ raise ValueError(
980
+ f"The expected shape of pixel values is {expected_expr}. "
981
+ f"You supplied {tuple(data.shape)}.")
982
+
983
+ return data
984
+
985
+ def _parse_and_validate_image_input(
986
+ self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
987
+ pixel_values = kwargs.pop("pixel_values", None)
988
+
989
+ if pixel_values is None:
990
+ return None
991
+
992
+ if not isinstance(pixel_values, torch.Tensor):
993
+ raise ValueError("Incorrect type of pixel values. "
994
+ f"Got type: {type(pixel_values)}")
995
+
996
+ # Remove the N dimension until multiple images are supported.
997
+ pixel_values = pixel_values.squeeze(1)
998
+
999
+ return ChameleonImagePixelInputs(
1000
+ type="pixel_values",
1001
+ data=self._validate_pixel_values(pixel_values),
1002
+ )
1003
+
1004
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
1005
+ image_input = self._parse_and_validate_image_input(**kwargs)
1006
+ if image_input is None:
1007
+ return None
1008
+ assert self.model.vqmodel is not None
1009
+ image_tokens = self.model.get_image_tokens(image_input["data"].to(
1010
+ self.config.torch_dtype))
1011
+ vision_embeddings = self.model.get_input_embeddings(image_tokens)
1012
+ return vision_embeddings
1013
+
1014
+ def get_input_embeddings(
1015
+ self,
1016
+ input_ids: torch.Tensor,
1017
+ multimodal_embeddings: Optional[NestedTensors] = None,
1018
+ ) -> torch.Tensor:
1019
+
1020
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
1021
+ if multimodal_embeddings is not None:
1022
+ inputs_embeds = merge_multimodal_embeddings(
1023
+ input_ids, inputs_embeds, multimodal_embeddings,
1024
+ self.model.vocabulary_mapping.image_token_id)
1025
+ return inputs_embeds
1026
+
1027
+ def forward(
1028
+ self,
1029
+ input_ids: torch.Tensor,
1030
+ positions: torch.Tensor,
1031
+ kv_caches: List[torch.Tensor],
1032
+ attn_metadata: AttentionMetadata,
1033
+ intermediate_tensors: Optional[IntermediateTensors] = None,
1034
+ inputs_embeds: Optional[torch.Tensor] = None,
1035
+ **kwargs,
1036
+ ) -> Union[torch.Tensor, IntermediateTensors]:
1037
+
1038
+ if intermediate_tensors is not None:
1039
+ inputs_embeds = None
1040
+
1041
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
1042
+ # condition is for v0 compatibility.
1043
+ elif inputs_embeds is None:
1044
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1045
+ inputs_embeds = self.get_input_embeddings(input_ids,
1046
+ vision_embeddings)
1047
+ input_ids = None
1048
+
1049
+ hidden_states = self.model(input_ids,
1050
+ positions,
1051
+ kv_caches,
1052
+ attn_metadata,
1053
+ intermediate_tensors,
1054
+ inputs_embeds=inputs_embeds)
1055
+ return hidden_states
1056
+
1057
+ def compute_logits(
1058
+ self,
1059
+ hidden_states: torch.Tensor,
1060
+ sampling_metadata: SamplingMetadata,
1061
+ ) -> Optional[torch.Tensor]:
1062
+ logits = self.logits_processor(self.lm_head, hidden_states,
1063
+ sampling_metadata)
1064
+
1065
+ # Disallow image tokens which does not include special
1066
+ # begin-image and end-image tokens
1067
+ if logits is not None:
1068
+ image_tokens = self.model.vocabulary_mapping.image_tokens
1069
+ logits[:, image_tokens] = torch.finfo(logits.dtype).min
1070
+
1071
+ return logits
1072
+
1073
+ def sample(
1074
+ self,
1075
+ logits: torch.Tensor,
1076
+ sampling_metadata: SamplingMetadata,
1077
+ ) -> Optional[SamplerOutput]:
1078
+ next_tokens = self.sampler(logits, sampling_metadata)
1079
+ return next_tokens
1080
+
1081
+ def load_weights(self, weights: Iterable[Tuple[str,
1082
+ torch.Tensor]]) -> Set[str]:
1083
+ stacked_params_mapping = [
1084
+ # (param_name, shard_name, shard_id)
1085
+ (".qkv_proj", ".q_proj", "q"),
1086
+ (".qkv_proj", ".k_proj", "k"),
1087
+ (".qkv_proj", ".v_proj", "v"),
1088
+ (".gate_up_proj", ".gate_proj", 0),
1089
+ (".gate_up_proj", ".up_proj", 1),
1090
+ ]
1091
+ params_dict = dict(self.named_parameters())
1092
+ loaded_params: Set[str] = set()
1093
+ for name, loaded_weight in weights:
1094
+ if "rotary_emb.inv_freq" in name:
1095
+ continue
1096
+
1097
+ if ("rotary_emb.cos_cached" in name
1098
+ or "rotary_emb.sin_cached" in name):
1099
+ # Models trained using ColossalAI may include these tensors in
1100
+ # the checkpoint. Skip them.
1101
+ continue
1102
+
1103
+ # With tie_word_embeddings, we can skip lm_head.weight
1104
+ # The weight might appear unnecessarily in the files if the model is
1105
+ # processed with quantization, LoRA, fine-tuning, etc.
1106
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
1107
+ continue
1108
+
1109
+ use_default_weight_loading = False
1110
+ if "vqmodel" in name:
1111
+ if self.model.vqmodel is not None:
1112
+ # We only do sharding for language model and
1113
+ # not vqvae for now.
1114
+ use_default_weight_loading = True
1115
+ else:
1116
+ for (param_name, weight_name,
1117
+ shard_id) in stacked_params_mapping:
1118
+ if weight_name not in name:
1119
+ continue
1120
+ name = name.replace(weight_name, param_name)
1121
+ # Skip loading extra bias for GPTQ models.
1122
+ if name.endswith(".bias") and name not in params_dict:
1123
+ continue
1124
+ if is_pp_missing_parameter(name, self):
1125
+ continue
1126
+ param = params_dict[name]
1127
+ weight_loader = param.weight_loader
1128
+ weight_loader(param, loaded_weight, shard_id)
1129
+ break
1130
+ else:
1131
+ # Skip loading extra bias for GPTQ models.
1132
+ if name.endswith(".bias") and name not in params_dict:
1133
+ continue
1134
+ # Remapping the name of FP8 kv-scale.
1135
+ if name.endswith("kv_scale"):
1136
+ remapped_kv_scale_name = name.replace(
1137
+ ".kv_scale", ".attn.kv_scale")
1138
+ if remapped_kv_scale_name not in params_dict:
1139
+ logger.warning_once(
1140
+ "Found kv scale in the checkpoint (e.g. "
1141
+ f"{name}), but not found the expected name in "
1142
+ f"the model (e.g. {remapped_kv_scale_name}). "
1143
+ "kv-scale is not loaded.")
1144
+ continue
1145
+ else:
1146
+ name = remapped_kv_scale_name
1147
+ if is_pp_missing_parameter(name, self):
1148
+ continue
1149
+ param = params_dict[name]
1150
+ weight_loader = getattr(param, "weight_loader",
1151
+ default_weight_loader)
1152
+ weight_loader(param, loaded_weight)
1153
+ if use_default_weight_loading and name in params_dict:
1154
+ if is_pp_missing_parameter(name, self):
1155
+ continue
1156
+ param = params_dict[name]
1157
+ weight_loader = getattr(param, "weight_loader",
1158
+ default_weight_loader)
1159
+ weight_loader(param, loaded_weight)
1160
+ loaded_params.add(name)
1161
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/chatglm.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/THUDM/CogAgent
5
+ """Inference-only CogAgent model compatible with THUDM weights."""
6
+ from argparse import Namespace
7
+ from array import array
8
+ from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
9
+ TypedDict)
10
+
11
+ import torch
12
+ from PIL import Image
13
+ from torch import nn
14
+ from torch.nn import LayerNorm
15
+
16
+ from vllm.attention import Attention, AttentionMetadata
17
+ from vllm.config import CacheConfig, VllmConfig
18
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19
+ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
20
+ InputContext, token_inputs)
21
+ from vllm.logger import init_logger
22
+ from vllm.model_executor.layers.activation import SiluAndMul
23
+ from vllm.model_executor.layers.layernorm import RMSNorm
24
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
25
+ QKVParallelLinear,
26
+ RowParallelLinear)
27
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
28
+ from vllm.model_executor.layers.quantization import QuantizationConfig
29
+ from vllm.model_executor.layers.rotary_embedding import get_rope
30
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
31
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
32
+ ParallelLMHead, VocabParallelEmbedding)
33
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
+ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
35
+ from vllm.model_executor.models.module_mapping import MultiModelKeys
36
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
37
+ from vllm.multimodal import MULTIMODAL_REGISTRY
38
+ from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
39
+ NestedTensors)
40
+ from vllm.multimodal.utils import cached_get_tokenizer
41
+ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
42
+ SequenceData)
43
+ from vllm.transformers_utils.configs import ChatGLMConfig
44
+
45
+ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
46
+ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
47
+ make_empty_intermediate_tensors_factory, make_layers,
48
+ maybe_prefix)
49
+
50
+ logger = init_logger(__name__)
51
+
52
+
53
+ def calculate_image_placeholder(vision_config):
54
+ return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
55
+
56
+
57
+ def mm_input_mapper_for_glmv(
58
+ ctx: InputContext,
59
+ data: ModalityData[object],
60
+ ) -> Dict:
61
+ model_config = ctx.model_config
62
+ tokenizer = cached_get_tokenizer(
63
+ model_config.tokenizer,
64
+ trust_remote_code=model_config.trust_remote_code)
65
+ if tokenizer is None:
66
+ raise RuntimeError("No HuggingFace processor is available "
67
+ "to process the image object")
68
+ try:
69
+ raw_batch_data = tokenizer.apply_chat_template(
70
+ conversation=[{
71
+ "role": "user",
72
+ "image": data
73
+ }],
74
+ add_generation_prompt=True,
75
+ tokenize=True,
76
+ return_tensors="pt",
77
+ return_dict=True).data
78
+ except Exception:
79
+ logger.error("Failed to process image (%s)", data)
80
+ raise
81
+ pixel_values = raw_batch_data['images']
82
+
83
+ return MultiModalKwargs({'pixel_values': pixel_values})
84
+
85
+
86
+ def merge_glm_vision_embeddings(
87
+ input_ids: torch.Tensor,
88
+ inputs_embeds: torch.Tensor,
89
+ vision_embeddings: torch.Tensor,
90
+ boi_token_id: int,
91
+ eoi_token_id: int,
92
+ ) -> torch.Tensor:
93
+
94
+ boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
95
+ eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
96
+
97
+ mask = torch.zeros_like(input_ids, dtype=torch.bool)
98
+
99
+ for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
100
+ assert boi_pos < eoi_pos
101
+ mask[boi_pos:eoi_pos + 1] = True
102
+ inputs_embeds[mask] = vision_embeddings.view(-1,
103
+ vision_embeddings.shape[-1])
104
+ return inputs_embeds
105
+
106
+
107
+ class GLMImagePixelInputs(TypedDict):
108
+ pixel_values: torch.Tensor
109
+ """Shape: `(batch_size, num_channels, height, width)`"""
110
+
111
+
112
+ def get_max_glmv_image_tokens(ctx: InputContext):
113
+ hf_config = ctx.get_hf_config(ChatGLMConfig)
114
+
115
+ vision_config = getattr(hf_config, 'vision_config', None)
116
+ if vision_config is None:
117
+ return 1
118
+ elif isinstance(vision_config, dict):
119
+ return calculate_image_placeholder(vision_config)
120
+
121
+ msg = f"Unsupported vision config: {type(vision_config)}"
122
+ raise NotImplementedError(msg)
123
+
124
+
125
+ def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
126
+ mm_counts: Mapping[str, int]) -> DummyData:
127
+ hf_config = ctx.get_hf_config(ChatGLMConfig)
128
+ vision_config = getattr(hf_config, 'vision_config', None)
129
+
130
+ if vision_config is None:
131
+ token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
132
+ seq_data = SequenceData(token_ids)
133
+ return DummyData(seq_data, None)
134
+ elif isinstance(vision_config, dict):
135
+ image_size = vision_config["image_size"]
136
+ image_placeholder_length = calculate_image_placeholder(vision_config)
137
+ token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
138
+ [0] * image_placeholder_length +
139
+ [hf_config.eoi_token_id])
140
+ token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
141
+ [0] * (seq_len - image_placeholder_length - 2))
142
+ seq_data = SequenceData(token_ids)
143
+
144
+ mm_data = {
145
+ "image": Image.new("RGB", (image_size, image_size), color=0)
146
+ }
147
+
148
+ return DummyData(seq_data, mm_data)
149
+
150
+ msg = f"Unsupported vision config: {type(vision_config)}"
151
+ raise NotImplementedError(msg)
152
+
153
+
154
+ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
155
+ return [index for index, value in enumerate(input_ids) if value == target]
156
+
157
+
158
+ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
159
+ multi_modal_data = inputs.get("multi_modal_data")
160
+ if multi_modal_data is None or "image" not in multi_modal_data:
161
+ return inputs
162
+
163
+ hf_config = ctx.get_hf_config(ChatGLMConfig)
164
+ vision_config = getattr(hf_config, 'vision_config', None)
165
+
166
+ if vision_config is None:
167
+ return inputs
168
+ elif isinstance(vision_config, dict):
169
+ image_placeholder_length = calculate_image_placeholder(vision_config)
170
+ else:
171
+ msg = f"Unsupported vision config: {type(vision_config)}"
172
+ raise NotImplementedError(msg)
173
+
174
+ input_ids = inputs["prompt_token_ids"]
175
+
176
+ tokenizer = cached_get_tokenizer(
177
+ ctx.model_config.model,
178
+ trust_remote_code=ctx.model_config.trust_remote_code)
179
+
180
+ try:
181
+ raw_batch_data = tokenizer.apply_chat_template(
182
+ conversation=[{
183
+ "role": "user",
184
+ "image": multi_modal_data["image"],
185
+ "content": inputs['prompt'],
186
+ }],
187
+ add_generation_prompt=True,
188
+ tokenize=True,
189
+ return_tensors="pt",
190
+ return_dict=True,
191
+ ).data
192
+ except Exception:
193
+ logger.error("Failed to process content (%s)", inputs['prompt'])
194
+ raise
195
+ input_ids = raw_batch_data['input_ids'][0].tolist()
196
+
197
+ boi_token_id = hf_config.boi_token_id
198
+ eoi_token_id = hf_config.eoi_token_id
199
+ boi_positions = find_all_positions(input_ids, boi_token_id)
200
+ eoi_positions = find_all_positions(input_ids, eoi_token_id)
201
+
202
+ assert len(boi_positions) == len(eoi_positions)
203
+
204
+ new_input_ids = []
205
+ final_processed_position = 0
206
+
207
+ for boi_position, eoi_position in zip(boi_positions, eoi_positions):
208
+ assert boi_position < eoi_position
209
+ new_input_ids.extend(input_ids[final_processed_position:boi_position +
210
+ 1])
211
+ new_input_ids.extend([input_ids[boi_position + 1]] *
212
+ image_placeholder_length)
213
+ final_processed_position = eoi_position
214
+
215
+ new_input_ids.extend(input_ids[final_processed_position:])
216
+
217
+ prompt = inputs.get("prompt")
218
+ if prompt is None:
219
+ prompt = tokenizer.decode(new_input_ids)
220
+
221
+ return token_inputs(
222
+ prompt_token_ids=new_input_ids,
223
+ prompt=prompt,
224
+ multi_modal_data=multi_modal_data,
225
+ )
226
+
227
+
228
+ class GLMAttention(nn.Module):
229
+
230
+ def __init__(
231
+ self,
232
+ config: ChatGLMConfig,
233
+ cache_config: Optional[CacheConfig] = None,
234
+ quant_config: Optional[QuantizationConfig] = None,
235
+ prefix: str = "",
236
+ ):
237
+ super().__init__()
238
+ self.hidden_size = config.hidden_size
239
+ tp_size = get_tensor_model_parallel_world_size()
240
+ self.total_num_heads = config.num_attention_heads
241
+ assert self.total_num_heads % tp_size == 0
242
+ self.num_heads = self.total_num_heads // tp_size
243
+ self.multi_query_attention = config.multi_query_attention
244
+ self.total_num_kv_heads = (config.multi_query_group_num
245
+ if config.multi_query_attention else
246
+ config.num_attention_heads)
247
+ if self.total_num_kv_heads >= tp_size:
248
+ # Number of KV heads is greater than TP size, so we partition
249
+ # the KV heads across multiple tensor parallel GPUs.
250
+ assert self.total_num_kv_heads % tp_size == 0
251
+ else:
252
+ # Number of KV heads is less than TP size, so we replicate
253
+ # the KV heads across multiple tensor parallel GPUs.
254
+ assert tp_size % self.total_num_kv_heads == 0
255
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
256
+ self.head_dim = config.hidden_size // self.total_num_heads
257
+ self.q_size = self.num_heads * self.head_dim
258
+ self.kv_size = self.num_kv_heads * self.head_dim
259
+ self.scaling = self.head_dim**-0.5
260
+
261
+ self.query_key_value = QKVParallelLinear(
262
+ self.hidden_size,
263
+ self.head_dim,
264
+ self.total_num_heads,
265
+ self.total_num_kv_heads,
266
+ bias=config.add_bias_linear or config.add_qkv_bias,
267
+ quant_config=quant_config,
268
+ prefix=f"{prefix}.query_key_value",
269
+ )
270
+ self.dense = RowParallelLinear(
271
+ self.total_num_heads * self.head_dim,
272
+ config.hidden_size,
273
+ bias=config.add_bias_linear,
274
+ quant_config=quant_config,
275
+ prefix=f"{prefix}.dense",
276
+ )
277
+
278
+ # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
279
+ rope_ratio = getattr(config, "rope_ratio", 1.0)
280
+ max_positions = getattr(config, "seq_length", 8192)
281
+ # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
282
+ # which is equivalent to is_neox_style=True
283
+ is_neox_style = not config.original_rope
284
+ self.rotary_emb = get_rope(
285
+ self.head_dim,
286
+ rotary_dim=self.head_dim // 2,
287
+ max_position=max_positions,
288
+ base=10000 * rope_ratio,
289
+ is_neox_style=is_neox_style,
290
+ )
291
+ self.attn = Attention(self.num_heads,
292
+ self.head_dim,
293
+ self.scaling,
294
+ num_kv_heads=self.num_kv_heads,
295
+ cache_config=cache_config,
296
+ quant_config=quant_config,
297
+ prefix=f"{prefix}.attn")
298
+
299
+ def forward(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ position_ids: torch.Tensor,
303
+ kv_cache: torch.Tensor,
304
+ attn_metadata: AttentionMetadata,
305
+ ) -> torch.Tensor:
306
+ qkv, _ = self.query_key_value(hidden_states)
307
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
308
+ q, k = self.rotary_emb(position_ids, q, k)
309
+ context_layer = self.attn(
310
+ q,
311
+ k,
312
+ v,
313
+ kv_cache,
314
+ attn_metadata,
315
+ )
316
+ attn_output, _ = self.dense(context_layer)
317
+ return attn_output
318
+
319
+
320
+ class GLMMLP(nn.Module):
321
+ """MLP.
322
+
323
+ MLP will take the input with h hidden state, project it to 4*h
324
+ hidden dimension, perform nonlinear transformation, and project the
325
+ state back into h hidden dimension.
326
+ """
327
+
328
+ def __init__(
329
+ self,
330
+ config: ChatGLMConfig,
331
+ quant_config: Optional[QuantizationConfig] = None,
332
+ prefix: str = "",
333
+ ):
334
+ super().__init__()
335
+
336
+ self.add_bias = config.add_bias_linear
337
+
338
+ # Project to 4h.
339
+ self.dense_h_to_4h = MergedColumnParallelLinear(
340
+ config.hidden_size,
341
+ [config.ffn_hidden_size] * 2,
342
+ bias=config.add_bias_linear,
343
+ quant_config=quant_config,
344
+ prefix=f"{prefix}.dense_h_to_4h",
345
+ )
346
+
347
+ self.activation_func = SiluAndMul()
348
+
349
+ # Project back to h.
350
+ self.dense_4h_to_h = RowParallelLinear(
351
+ config.ffn_hidden_size,
352
+ config.hidden_size,
353
+ bias=config.add_bias_linear,
354
+ quant_config=quant_config,
355
+ prefix=f"{prefix}.dense_4h_to_h",
356
+ )
357
+
358
+ def forward(self, hidden_states):
359
+ # [s, b, 4hp]
360
+ intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
361
+ intermediate_parallel = self.activation_func(intermediate_parallel)
362
+ # [s, b, h]
363
+ output, _ = self.dense_4h_to_h(intermediate_parallel)
364
+ return output
365
+
366
+
367
+ class GLMBlock(nn.Module):
368
+ """A single transformer layer.
369
+
370
+ Transformer layer takes input with size [s, b, h] and returns an
371
+ output of the same size.
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ config: ChatGLMConfig,
377
+ cache_config: Optional[CacheConfig] = None,
378
+ quant_config: Optional[QuantizationConfig] = None,
379
+ prefix: str = "",
380
+ ):
381
+ super().__init__()
382
+ self.apply_residual_connection_post_layernorm = (
383
+ config.apply_residual_connection_post_layernorm)
384
+
385
+ self.fp32_residual_connection = config.fp32_residual_connection
386
+
387
+ layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
388
+ # Layernorm on the input data.
389
+ self.input_layernorm = layer_norm_func(config.hidden_size,
390
+ eps=config.layernorm_epsilon)
391
+
392
+ # Self attention.
393
+ self.self_attention = GLMAttention(config,
394
+ cache_config,
395
+ quant_config,
396
+ prefix=f"{prefix}.self_attention")
397
+ self.hidden_dropout = config.hidden_dropout
398
+
399
+ # Layernorm on the attention output
400
+ self.post_attention_layernorm = layer_norm_func(
401
+ config.hidden_size, eps=config.layernorm_epsilon)
402
+
403
+ # MLP
404
+ self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ position_ids: torch.Tensor,
410
+ kv_cache: torch.Tensor,
411
+ attn_metadata: AttentionMetadata,
412
+ ) -> torch.Tensor:
413
+ # hidden_states: [num_tokens, h]
414
+ # Layer norm at the beginning of the transformer layer.
415
+ layernorm_output = self.input_layernorm(hidden_states)
416
+ # Self attention.
417
+ attention_output = self.self_attention(
418
+ hidden_states=layernorm_output,
419
+ position_ids=position_ids,
420
+ kv_cache=kv_cache,
421
+ attn_metadata=attn_metadata,
422
+ )
423
+
424
+ # Residual connection.
425
+ if self.apply_residual_connection_post_layernorm:
426
+ residual = layernorm_output
427
+ else:
428
+ residual = hidden_states
429
+
430
+ layernorm_input = residual + attention_output
431
+
432
+ # Layer norm post the self attention.
433
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
434
+
435
+ # Second residual connection.
436
+ if self.apply_residual_connection_post_layernorm:
437
+ residual = layernorm_output
438
+ else:
439
+ residual = layernorm_input
440
+
441
+ output = self.mlp(layernorm_output) + residual
442
+
443
+ return output
444
+
445
+
446
+ class GLMTransformer(nn.Module):
447
+ """Transformer class."""
448
+
449
+ def __init__(
450
+ self,
451
+ config: ChatGLMConfig,
452
+ cache_config: Optional[CacheConfig] = None,
453
+ quant_config: Optional[QuantizationConfig] = None,
454
+ prefix: str = "",
455
+ ):
456
+ super().__init__()
457
+ self.post_layer_norm = config.post_layer_norm
458
+
459
+ # Number of layers.
460
+ self.num_layers = config.num_layers
461
+
462
+ # Transformer layers.
463
+ self.start_layer, self.end_layer, self.layers = make_layers(
464
+ self.num_layers,
465
+ lambda prefix: GLMBlock(
466
+ config, cache_config, quant_config, prefix=prefix),
467
+ prefix=f"{prefix}.layers",
468
+ )
469
+
470
+ if self.post_layer_norm:
471
+ layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
472
+ # Final layer norm before output.
473
+ self.final_layernorm = layer_norm_func(
474
+ config.hidden_size, eps=config.layernorm_epsilon)
475
+
476
+ self.make_empty_intermediate_tensors = (
477
+ make_empty_intermediate_tensors_factory(["hidden_states"],
478
+ config.hidden_size))
479
+
480
+ def forward(
481
+ self,
482
+ hidden_states: torch.Tensor,
483
+ position_ids: torch.Tensor,
484
+ kv_caches: List[torch.Tensor],
485
+ attn_metadata: AttentionMetadata,
486
+ ) -> torch.Tensor:
487
+ for i in range(self.start_layer, self.end_layer):
488
+ layer = self.layers[i]
489
+ hidden_states = layer(
490
+ hidden_states=hidden_states,
491
+ position_ids=position_ids,
492
+ kv_cache=kv_caches[i - self.start_layer],
493
+ attn_metadata=attn_metadata,
494
+ )
495
+ # Final layer norm.
496
+ if get_pp_group().is_last_rank and self.post_layer_norm:
497
+ hidden_states = self.final_layernorm(hidden_states)
498
+
499
+ return hidden_states
500
+
501
+
502
+ class ChatGLMModel(nn.Module):
503
+
504
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
505
+ super().__init__()
506
+
507
+ config = vllm_config.model_config.hf_config
508
+ cache_config = vllm_config.cache_config
509
+ quant_config = vllm_config.quant_config
510
+
511
+ self.config = config
512
+
513
+ self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
514
+ config.hidden_size,
515
+ quant_config=quant_config,
516
+ prefix=f"{prefix}.embedding")
517
+
518
+ self.num_layers = config.num_layers
519
+ self.multi_query_group_num = config.multi_query_group_num
520
+ self.kv_channels = config.kv_channels
521
+ self.encoder = GLMTransformer(config,
522
+ cache_config,
523
+ quant_config,
524
+ prefix=f"{prefix}.encoder")
525
+
526
+ self.output_layer = ParallelLMHead(config.padded_vocab_size,
527
+ config.hidden_size,
528
+ quant_config=quant_config,
529
+ prefix=f"{prefix}.output_layer")
530
+
531
+ vision_config_flag = getattr(config, 'vision_config', None)
532
+ if vision_config_flag is not None:
533
+ self.vision_config = Namespace(**config.vision_config)
534
+ self.vision = EVA2CLIPModel(self.config,
535
+ quant_config,
536
+ prefix=f"{prefix}.vision")
537
+ else:
538
+ self.vision = None
539
+
540
+ self.make_empty_intermediate_tensors = (
541
+ self.encoder.make_empty_intermediate_tensors)
542
+
543
+ def _parse_and_validate_image_input(
544
+ self, **kwargs: object) -> GLMImagePixelInputs:
545
+
546
+ pixel_values = kwargs.pop("pixel_values", None)
547
+ if pixel_values is not None and self.vision is not None:
548
+ if isinstance(pixel_values, torch.Tensor):
549
+ if pixel_values.ndim > 2:
550
+ pixel_values = torch.concat(list(pixel_values))
551
+ elif isinstance(pixel_values, list):
552
+ return torch.concat(pixel_values)
553
+ else:
554
+ raise TypeError("""pixel_values must be a torch.Tensor
555
+ or a list of torch.Tensor
556
+ """)
557
+ return GLMImagePixelInputs(pixel_values=pixel_values)
558
+
559
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
560
+ image_input = self._parse_and_validate_image_input(**kwargs)
561
+ if image_input["pixel_values"] is None:
562
+ return None
563
+ pixel_values = image_input["pixel_values"].to(
564
+ dtype=self.config.torch_dtype)
565
+ vision_embeddings = self.vision(pixel_values)
566
+ return vision_embeddings
567
+
568
+ def get_input_embeddings(
569
+ self,
570
+ input_ids: torch.Tensor,
571
+ multimodal_embeddings: Optional[NestedTensors] = None,
572
+ ) -> torch.Tensor:
573
+ inputs_embeds = self.embedding(input_ids)
574
+ if multimodal_embeddings is not None:
575
+ inputs_embeds = merge_glm_vision_embeddings(
576
+ input_ids=input_ids,
577
+ inputs_embeds=inputs_embeds,
578
+ vision_embeddings=multimodal_embeddings,
579
+ boi_token_id=self.config.boi_token_id,
580
+ eoi_token_id=self.config.eoi_token_id)
581
+ return inputs_embeds
582
+
583
+ def forward(
584
+ self,
585
+ input_ids: torch.Tensor,
586
+ positions: torch.Tensor,
587
+ kv_caches: List[torch.Tensor],
588
+ attn_metadata: AttentionMetadata,
589
+ intermediate_tensors: Optional[IntermediateTensors] = None,
590
+ inputs_embeds: Optional[torch.Tensor] = None,
591
+ **kwargs: object,
592
+ ) -> torch.Tensor:
593
+
594
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
595
+ # condition is for v0 compatibility.
596
+ if intermediate_tensors is None and inputs_embeds is None:
597
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
598
+ inputs_embeds = self.get_input_embeddings(input_ids,
599
+ vision_embeddings)
600
+ input_ids = None
601
+ else:
602
+ inputs_embeds = intermediate_tensors["hidden_states"]
603
+
604
+ # Run encoder.
605
+ hidden_states = self.encoder(
606
+ hidden_states=inputs_embeds,
607
+ position_ids=positions,
608
+ kv_caches=kv_caches,
609
+ attn_metadata=attn_metadata,
610
+ )
611
+
612
+ if not get_pp_group().is_last_rank:
613
+ return IntermediateTensors({"hidden_states": hidden_states})
614
+ return hidden_states
615
+
616
+ def load_weights(self, weights: Iterable[Tuple[str,
617
+ torch.Tensor]]) -> Set[str]:
618
+ stacked_params_mapping = [
619
+ # (param_name, shard_name, shard_id)
620
+ ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
621
+ ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
622
+ ]
623
+ params_dict = dict(self.named_parameters())
624
+ loaded_params: Set[str] = set()
625
+
626
+ for name, loaded_weight in weights:
627
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
628
+ if weight_name not in name:
629
+ continue
630
+ name = name.replace(weight_name, param_name)
631
+ # Skip loading extra bias for GPTQ models.
632
+ if name.endswith(".bias") and name not in params_dict:
633
+ continue
634
+ if is_pp_missing_parameter(name, self):
635
+ continue
636
+ param = params_dict[name]
637
+ weight_loader = param.weight_loader
638
+ weight_loader(param, loaded_weight, shard_id)
639
+ break
640
+ else:
641
+ if "rotary_pos_emb.inv_freq" in name:
642
+ continue
643
+ if name.endswith(".bias") and name not in params_dict:
644
+ continue
645
+ if is_pp_missing_parameter(name, self):
646
+ continue
647
+ param = params_dict[name]
648
+ weight_loader = getattr(param, "weight_loader",
649
+ default_weight_loader)
650
+ weight_loader(param, loaded_weight)
651
+ loaded_params.add(name)
652
+ return loaded_params
653
+
654
+
655
+ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
656
+
657
+ hf_to_vllm_mapper = WeightsMapper(
658
+ orig_to_new_substr={".word_embeddings": ""}, )
659
+
660
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
661
+ super().__init__()
662
+ config = vllm_config.model_config.hf_config
663
+ quant_config = vllm_config.quant_config
664
+ lora_config = vllm_config.lora_config
665
+ multimodal_config = vllm_config.model_config.multimodal_config
666
+ self.config = config
667
+ self.lora_config = lora_config
668
+ self.multimodal_config = multimodal_config
669
+
670
+ self.quant_config = quant_config
671
+ self.max_position_embeddings = getattr(config, "max_sequence_length",
672
+ 8192)
673
+ self.transformer = ChatGLMModel(vllm_config=vllm_config,
674
+ prefix=maybe_prefix(
675
+ prefix, "transformer"))
676
+ if self.config.tie_word_embeddings:
677
+ self.transformer.output_layer.weight = (
678
+ self.transformer.embedding.weight)
679
+ self.lm_head = self.transformer.output_layer
680
+ self.logits_processor = LogitsProcessor(config.padded_vocab_size)
681
+ self.sampler = get_sampler()
682
+
683
+ def forward(self,
684
+ input_ids: torch.Tensor,
685
+ positions: torch.Tensor,
686
+ kv_caches: List[torch.Tensor],
687
+ attn_metadata: AttentionMetadata,
688
+ intermediate_tensors: Optional[IntermediateTensors] = None,
689
+ **kwargs) -> torch.Tensor:
690
+ hidden_states = self.transformer(input_ids, positions, kv_caches,
691
+ attn_metadata, intermediate_tensors,
692
+ **kwargs)
693
+ return hidden_states
694
+
695
+ def compute_logits(
696
+ self,
697
+ hidden_states: torch.Tensor,
698
+ sampling_metadata: SamplingMetadata,
699
+ ) -> Optional[torch.Tensor]:
700
+ logits = self.logits_processor(self.lm_head, hidden_states,
701
+ sampling_metadata)
702
+ return logits
703
+
704
+ def sample(
705
+ self,
706
+ logits: torch.Tensor,
707
+ sampling_metadata: SamplingMetadata,
708
+ ) -> Optional[SamplerOutput]:
709
+ next_tokens = self.sampler(logits, sampling_metadata)
710
+ return next_tokens
711
+
712
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
713
+ loader = AutoWeightsLoader(self)
714
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
715
+
716
+
717
+ class ChatGLM(ChatGLMBaseModel):
718
+ packed_modules_mapping = {
719
+ "query_key_value": ["query_key_value"],
720
+ "dense_h_to_4h": ["dense_h_to_4h"]
721
+ }
722
+ # LoRA specific attributes
723
+ supported_lora_modules = [
724
+ "query_key_value",
725
+ "dense",
726
+ "dense_h_to_4h",
727
+ "dense_4h_to_h",
728
+ ]
729
+
730
+ embedding_modules = {}
731
+ embedding_padding_modules = []
732
+
733
+
734
+ class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
735
+
736
+ packed_modules_mapping = {
737
+ "query_key_value": ["query_key_value"],
738
+ "dense_h_to_4h": ["dense_h_to_4h"],
739
+ "merged_proj": ["gate_proj", "dense_h_to_4h"]
740
+ }
741
+ # LoRA specific attributes
742
+ supported_lora_modules = [
743
+ "query_key_value",
744
+ "dense",
745
+ "dense_h_to_4h",
746
+ "dense_4h_to_h",
747
+ # vision
748
+ "fc1",
749
+ "fc2",
750
+ "merged_proj",
751
+ "linear_proj"
752
+ ]
753
+
754
+ embedding_modules = {}
755
+ embedding_padding_modules = []
756
+
757
+ def get_mm_mapping(self) -> MultiModelKeys:
758
+ """
759
+ Get the module prefix in multimodal models
760
+ """
761
+ return MultiModelKeys.from_string_field(
762
+ language_model="transformer.encoder",
763
+ connector="transformer.vision.linear_proj",
764
+ tower_model="transformer.vision.transformer")
765
+
766
+
767
+ @MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
768
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
769
+ @INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
770
+ @INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
771
+ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
772
+ SupportsMultiModal):
773
+ # Ensure that the LoRA support check passes when the class is not
774
+ # initialized, but set all these attributes to empty.
775
+ # These will be updated when an instance class is selected
776
+ packed_modules_mapping = {}
777
+ supported_lora_modules = []
778
+ embedding_modules = {}
779
+ embedding_padding_modules = []
780
+
781
+ def __new__(
782
+ cls,
783
+ vllm_config: VllmConfig,
784
+ prefix: str = "",
785
+ ) -> None:
786
+ config = vllm_config.model_config.hf_config
787
+
788
+ # Initialize VL
789
+ if hasattr(config, "vision_config"): # noqa: SIM108
790
+ instance_cls = ChatGLMV
791
+ # Initialize LLM
792
+ else:
793
+ instance_cls = ChatGLM
794
+
795
+ # quant_config references base class members,
796
+ # so update values before init is called
797
+ cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
798
+ cls.supported_lora_modules += instance_cls.supported_lora_modules
799
+ cls.embedding_modules.update(instance_cls.embedding_modules)
800
+ cls.embedding_padding_modules += instance_cls.embedding_padding_modules
801
+ return instance_cls(vllm_config=vllm_config, prefix=prefix)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only Deepseek model."""
25
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
26
+
27
+ import torch
28
+ from torch import nn
29
+ from transformers import PretrainedConfig
30
+
31
+ from vllm.attention import Attention, AttentionMetadata
32
+ from vllm.config import CacheConfig, VllmConfig
33
+ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
34
+ get_tensor_model_parallel_world_size,
35
+ tensor_model_parallel_all_reduce)
36
+ from vllm.model_executor.layers.activation import SiluAndMul
37
+ from vllm.model_executor.layers.fused_moe import fused_moe
38
+ from vllm.model_executor.layers.layernorm import RMSNorm
39
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
+ QKVParallelLinear,
41
+ ReplicatedLinear,
42
+ RowParallelLinear)
43
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
+ from vllm.model_executor.layers.quantization import QuantizationConfig
45
+ from vllm.model_executor.layers.rotary_embedding import get_rope
46
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
47
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
48
+ ParallelLMHead, VocabParallelEmbedding)
49
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
50
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
51
+ from vllm.sequence import IntermediateTensors
52
+
53
+ from .interfaces import SupportsPP
54
+ from .utils import (extract_layer_index, is_pp_missing_parameter,
55
+ make_empty_intermediate_tensors_factory, make_layers,
56
+ maybe_prefix)
57
+
58
+
59
+ class DeepseekMLP(nn.Module):
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int,
64
+ intermediate_size: int,
65
+ hidden_act: str,
66
+ quant_config: Optional[QuantizationConfig] = None,
67
+ reduce_results: bool = True,
68
+ prefix: str = "",
69
+ ) -> None:
70
+ super().__init__()
71
+ self.gate_up_proj = MergedColumnParallelLinear(
72
+ hidden_size, [intermediate_size] * 2,
73
+ bias=False,
74
+ quant_config=quant_config)
75
+ self.down_proj = RowParallelLinear(intermediate_size,
76
+ hidden_size,
77
+ bias=False,
78
+ quant_config=quant_config,
79
+ reduce_results=reduce_results)
80
+ if hidden_act != "silu":
81
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
82
+ "Only silu is supported for now.")
83
+ self.act_fn = SiluAndMul()
84
+
85
+ def forward(self, x):
86
+ gate_up, _ = self.gate_up_proj(x)
87
+ x = self.act_fn(gate_up)
88
+ x, _ = self.down_proj(x)
89
+ return x
90
+
91
+
92
+ class DeepseekMoE(nn.Module):
93
+
94
+ def __init__(
95
+ self,
96
+ config: PretrainedConfig,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ):
100
+ super().__init__()
101
+ self.config = config
102
+ self.rank = get_tensor_model_parallel_rank()
103
+ self.tp_size = get_tensor_model_parallel_world_size()
104
+ self.n_routed_experts = config.n_routed_experts
105
+ self.top_k = config.num_experts_per_tok
106
+ if self.tp_size > self.n_routed_experts:
107
+ raise ValueError(
108
+ f"Tensor parallel size {self.tp_size} is greater than "
109
+ f"the number of experts {self.n_routed_experts}.")
110
+
111
+ self.experts = nn.ModuleList([
112
+ DeepseekMLP(hidden_size=config.hidden_size,
113
+ intermediate_size=config.moe_intermediate_size,
114
+ hidden_act=config.hidden_act,
115
+ quant_config=quant_config,
116
+ reduce_results=False)
117
+ for idx in range(self.n_routed_experts)
118
+ ])
119
+ self.pack_params()
120
+
121
+ self.gate = ReplicatedLinear(config.hidden_size,
122
+ self.n_routed_experts,
123
+ bias=False,
124
+ quant_config=None)
125
+
126
+ if config.n_shared_experts is not None:
127
+ intermediate_size = (config.moe_intermediate_size *
128
+ config.n_shared_experts)
129
+ self.shared_experts = DeepseekMLP(
130
+ hidden_size=config.hidden_size,
131
+ intermediate_size=intermediate_size,
132
+ hidden_act=config.hidden_act,
133
+ quant_config=quant_config,
134
+ reduce_results=False,
135
+ )
136
+
137
+ def pack_params(self):
138
+ w1 = []
139
+ w2 = []
140
+ for expert in self.experts:
141
+ w1.append(expert.gate_up_proj.weight)
142
+ w2.append(expert.down_proj.weight)
143
+ self.w1 = torch._utils._flatten_dense_tensors(w1)
144
+ w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
145
+ for data, param in zip(w1s, w1):
146
+ param.data = data
147
+ self.w1 = self.w1.view(len(w1), *w1s[0].shape)
148
+
149
+ self.w2 = torch._utils._flatten_dense_tensors(w2)
150
+ w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
151
+ for data, param in zip(w2s, w2):
152
+ param.data = data
153
+
154
+ self.w2 = self.w2.view(len(w2), *w2s[0].shape)
155
+
156
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
157
+ num_tokens, hidden_dim = hidden_states.shape
158
+ hidden_states = hidden_states.view(-1, hidden_dim)
159
+ if self.config.n_shared_experts is not None:
160
+ shared_output = self.shared_experts(hidden_states)
161
+ # router_logits: (num_tokens, n_experts)
162
+ router_logits, _ = self.gate(hidden_states)
163
+ final_hidden_states = fused_moe(hidden_states,
164
+ self.w1,
165
+ self.w2,
166
+ router_logits,
167
+ self.top_k,
168
+ renormalize=self.config.norm_topk_prob,
169
+ inplace=True)
170
+
171
+ if self.config.n_shared_experts is not None:
172
+ final_hidden_states = final_hidden_states + shared_output
173
+ final_hidden_states = tensor_model_parallel_all_reduce(
174
+ final_hidden_states)
175
+
176
+ return final_hidden_states.view(num_tokens, hidden_dim)
177
+
178
+
179
+ class DeepseekAttention(nn.Module):
180
+
181
+ def __init__(
182
+ self,
183
+ hidden_size: int,
184
+ num_heads: int,
185
+ num_kv_heads: int,
186
+ rope_theta: float = 10000,
187
+ rope_scaling: Optional[Dict[str, Any]] = None,
188
+ max_position_embeddings: int = 8192,
189
+ cache_config: Optional[CacheConfig] = None,
190
+ quant_config: Optional[QuantizationConfig] = None,
191
+ prefix: str = "",
192
+ ) -> None:
193
+ super().__init__()
194
+ self.hidden_size = hidden_size
195
+ tp_size = get_tensor_model_parallel_world_size()
196
+ self.total_num_heads = num_heads
197
+ assert self.total_num_heads % tp_size == 0
198
+ self.num_heads = self.total_num_heads // tp_size
199
+ self.total_num_kv_heads = num_kv_heads
200
+ if self.total_num_kv_heads >= tp_size:
201
+ # Number of KV heads is greater than TP size, so we partition
202
+ # the KV heads across multiple tensor parallel GPUs.
203
+ assert self.total_num_kv_heads % tp_size == 0
204
+ else:
205
+ # Number of KV heads is less than TP size, so we replicate
206
+ # the KV heads across multiple tensor parallel GPUs.
207
+ assert tp_size % self.total_num_kv_heads == 0
208
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
209
+ self.head_dim = hidden_size // self.total_num_heads
210
+ self.q_size = self.num_heads * self.head_dim
211
+ self.kv_size = self.num_kv_heads * self.head_dim
212
+ self.scaling = self.head_dim**-0.5
213
+ self.rope_theta = rope_theta
214
+ self.max_position_embeddings = max_position_embeddings
215
+
216
+ self.qkv_proj = QKVParallelLinear(
217
+ hidden_size,
218
+ self.head_dim,
219
+ self.total_num_heads,
220
+ self.total_num_kv_heads,
221
+ bias=False,
222
+ quant_config=quant_config,
223
+ )
224
+
225
+ self.o_proj = RowParallelLinear(
226
+ self.total_num_heads * self.head_dim,
227
+ hidden_size,
228
+ bias=False,
229
+ quant_config=quant_config,
230
+ )
231
+
232
+ self.rotary_emb = get_rope(
233
+ self.head_dim,
234
+ rotary_dim=self.head_dim,
235
+ max_position=max_position_embeddings,
236
+ base=rope_theta,
237
+ rope_scaling=rope_scaling,
238
+ )
239
+ self.attn = Attention(self.num_heads,
240
+ self.head_dim,
241
+ self.scaling,
242
+ num_kv_heads=self.num_kv_heads,
243
+ cache_config=cache_config,
244
+ quant_config=quant_config,
245
+ prefix=f"{prefix}.attn")
246
+
247
+ def forward(
248
+ self,
249
+ positions: torch.Tensor,
250
+ hidden_states: torch.Tensor,
251
+ kv_cache: torch.Tensor,
252
+ attn_metadata: AttentionMetadata,
253
+ ) -> torch.Tensor:
254
+ qkv, _ = self.qkv_proj(hidden_states)
255
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
256
+ q, k = self.rotary_emb(positions, q, k)
257
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
258
+ output, _ = self.o_proj(attn_output)
259
+ return output
260
+
261
+
262
+ class DeepseekDecoderLayer(nn.Module):
263
+
264
+ def __init__(
265
+ self,
266
+ config: PretrainedConfig,
267
+ cache_config: Optional[CacheConfig] = None,
268
+ quant_config: Optional[QuantizationConfig] = None,
269
+ prefix: str = "",
270
+ ) -> None:
271
+ super().__init__()
272
+ layer_idx = extract_layer_index(prefix)
273
+ self.hidden_size = config.hidden_size
274
+ rope_theta = getattr(config, "rope_theta", 10000)
275
+ rope_scaling = getattr(config, "rope_scaling", None)
276
+ max_position_embeddings = getattr(config, "max_position_embeddings",
277
+ 8192)
278
+ self.self_attn = DeepseekAttention(
279
+ hidden_size=self.hidden_size,
280
+ num_heads=config.num_attention_heads,
281
+ num_kv_heads=config.num_key_value_heads,
282
+ rope_theta=rope_theta,
283
+ rope_scaling=rope_scaling,
284
+ max_position_embeddings=max_position_embeddings,
285
+ cache_config=cache_config,
286
+ quant_config=quant_config,
287
+ prefix=f"{prefix}.self_attn",
288
+ )
289
+ if (config.n_routed_experts is not None
290
+ and layer_idx >= config.first_k_dense_replace
291
+ and layer_idx % config.moe_layer_freq == 0):
292
+ self.mlp = DeepseekMoE(config=config,
293
+ quant_config=quant_config,
294
+ prefix=f"{prefix}.mlp")
295
+ else:
296
+ self.mlp = DeepseekMLP(
297
+ hidden_size=config.hidden_size,
298
+ intermediate_size=config.intermediate_size,
299
+ hidden_act=config.hidden_act,
300
+ quant_config=quant_config,
301
+ prefix=f"{prefix}.mlp",
302
+ )
303
+ self.input_layernorm = RMSNorm(config.hidden_size,
304
+ eps=config.rms_norm_eps)
305
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
306
+ eps=config.rms_norm_eps)
307
+
308
+ def forward(
309
+ self,
310
+ positions: torch.Tensor,
311
+ hidden_states: torch.Tensor,
312
+ kv_cache: torch.Tensor,
313
+ attn_metadata: AttentionMetadata,
314
+ residual: Optional[torch.Tensor],
315
+ ) -> torch.Tensor:
316
+ # Self Attention
317
+ if residual is None:
318
+ residual = hidden_states
319
+ hidden_states = self.input_layernorm(hidden_states)
320
+ else:
321
+ hidden_states, residual = self.input_layernorm(
322
+ hidden_states, residual)
323
+ hidden_states = self.self_attn(
324
+ positions=positions,
325
+ hidden_states=hidden_states,
326
+ kv_cache=kv_cache,
327
+ attn_metadata=attn_metadata,
328
+ )
329
+
330
+ # Fully Connected
331
+ hidden_states, residual = self.post_attention_layernorm(
332
+ hidden_states, residual)
333
+ hidden_states = self.mlp(hidden_states)
334
+ return hidden_states, residual
335
+
336
+
337
+ class DeepseekModel(nn.Module):
338
+
339
+ fall_back_to_pt_during_load = False
340
+
341
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
342
+ super().__init__()
343
+
344
+ config = vllm_config.model_config.hf_config
345
+ cache_config = vllm_config.cache_config
346
+ quant_config = vllm_config.quant_config
347
+
348
+ self.padding_idx = config.pad_token_id
349
+ self.vocab_size = config.vocab_size
350
+
351
+ self.embed_tokens = VocabParallelEmbedding(
352
+ config.vocab_size,
353
+ config.hidden_size,
354
+ )
355
+ self.start_layer, self.end_layer, self.layers = make_layers(
356
+ config.num_hidden_layers,
357
+ lambda prefix: DeepseekDecoderLayer(
358
+ config, cache_config, quant_config=quant_config, prefix=prefix
359
+ ),
360
+ prefix=f"{prefix}.layers")
361
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
362
+ self.make_empty_intermediate_tensors = (
363
+ make_empty_intermediate_tensors_factory(
364
+ ["hidden_states", "residual"], config.hidden_size))
365
+
366
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
367
+ return self.embed_tokens(input_ids)
368
+
369
+ def forward(
370
+ self,
371
+ input_ids: torch.Tensor,
372
+ positions: torch.Tensor,
373
+ kv_caches: List[torch.Tensor],
374
+ attn_metadata: AttentionMetadata,
375
+ intermediate_tensors: Optional[IntermediateTensors],
376
+ inputs_embeds: Optional[torch.Tensor] = None,
377
+ ) -> Union[torch.Tensor, IntermediateTensors]:
378
+ if get_pp_group().is_first_rank:
379
+ if inputs_embeds is not None:
380
+ hidden_states = inputs_embeds
381
+ else:
382
+ hidden_states = self.get_input_embeddings(input_ids)
383
+ residual = None
384
+ else:
385
+ hidden_states = intermediate_tensors["hidden_states"]
386
+ residual = intermediate_tensors["residual"]
387
+ for i in range(self.start_layer, self.end_layer):
388
+ layer = self.layers[i]
389
+ hidden_states, residual = layer(positions, hidden_states,
390
+ kv_caches[i - self.start_layer],
391
+ attn_metadata, residual)
392
+ if not get_pp_group().is_last_rank:
393
+ return IntermediateTensors({
394
+ "hidden_states": hidden_states,
395
+ "residual": residual
396
+ })
397
+ hidden_states, _ = self.norm(hidden_states, residual)
398
+ return hidden_states
399
+
400
+
401
+ class DeepseekForCausalLM(nn.Module, SupportsPP):
402
+
403
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
404
+ super().__init__()
405
+ config = vllm_config.model_config.hf_config
406
+ quant_config = vllm_config.quant_config
407
+ self.config = config
408
+ self.quant_config = quant_config
409
+ self.model = DeepseekModel(vllm_config=vllm_config,
410
+ prefix=maybe_prefix(prefix, "model"))
411
+ self.lm_head = ParallelLMHead(config.vocab_size,
412
+ config.hidden_size,
413
+ quant_config=quant_config)
414
+ if self.config.tie_word_embeddings:
415
+ self.lm_head.weight = self.model.embed_tokens.weight
416
+ self.logits_processor = LogitsProcessor(config.vocab_size)
417
+ self.sampler = get_sampler()
418
+ self.make_empty_intermediate_tensors = (
419
+ self.model.make_empty_intermediate_tensors)
420
+
421
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
422
+ return self.model.get_input_embeddings(input_ids)
423
+
424
+ def forward(
425
+ self,
426
+ input_ids: torch.Tensor,
427
+ positions: torch.Tensor,
428
+ kv_caches: List[torch.Tensor],
429
+ attn_metadata: AttentionMetadata,
430
+ intermediate_tensors: Optional[IntermediateTensors] = None,
431
+ inputs_embeds: Optional[torch.Tensor] = None,
432
+ ) -> Union[torch.Tensor, IntermediateTensors]:
433
+ hidden_states = self.model(input_ids, positions, kv_caches,
434
+ attn_metadata, intermediate_tensors,
435
+ inputs_embeds)
436
+ return hidden_states
437
+
438
+ def compute_logits(
439
+ self,
440
+ hidden_states: torch.Tensor,
441
+ sampling_metadata: SamplingMetadata,
442
+ ) -> Optional[torch.Tensor]:
443
+ logits = self.logits_processor(self.lm_head, hidden_states,
444
+ sampling_metadata)
445
+ return logits
446
+
447
+ def sample(
448
+ self,
449
+ logits: Optional[torch.Tensor],
450
+ sampling_metadata: SamplingMetadata,
451
+ ) -> Optional[SamplerOutput]:
452
+ next_tokens = self.sampler(logits, sampling_metadata)
453
+ return next_tokens
454
+
455
+ def load_weights(self, weights: Iterable[Tuple[str,
456
+ torch.Tensor]]) -> Set[str]:
457
+ stacked_params_mapping = [
458
+ # (param_name, shard_name, shard_id)
459
+ ("qkv_proj", "q_proj", "q"),
460
+ ("qkv_proj", "k_proj", "k"),
461
+ ("qkv_proj", "v_proj", "v"),
462
+ ("gate_up_proj", "gate_proj", 0),
463
+ ("gate_up_proj", "up_proj", 1),
464
+ ]
465
+
466
+ params_dict = dict(self.named_parameters())
467
+ loaded_params: Set[str] = set()
468
+ for name, loaded_weight in weights:
469
+ if "rotary_emb.inv_freq" in name:
470
+ continue
471
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
472
+ if weight_name not in name:
473
+ continue
474
+ name = name.replace(weight_name, param_name)
475
+ # Skip loading extra bias for GPTQ models.
476
+ if name.endswith(".bias") and name not in params_dict:
477
+ continue
478
+ # Skip experts that are not assigned to this worker.
479
+ if (("mlp.experts." in name or "mlp.shared_experts." in name)
480
+ and name not in params_dict):
481
+ continue
482
+ if is_pp_missing_parameter(name, self):
483
+ continue
484
+ param = params_dict[name]
485
+ weight_loader = param.weight_loader
486
+ weight_loader(param, loaded_weight, shard_id)
487
+ break
488
+ else:
489
+ # Skip loading extra bias for GPTQ models.
490
+ if name.endswith(".bias") and name not in params_dict:
491
+ continue
492
+ # Skip experts that are not assigned to this worker.
493
+ if (("mlp.experts." in name or "mlp.shared_experts." in name)
494
+ and name not in params_dict):
495
+ continue
496
+ if is_pp_missing_parameter(name, self):
497
+ continue
498
+ param = params_dict[name]
499
+ weight_loader = getattr(param, "weight_loader",
500
+ default_weight_loader)
501
+ weight_loader(param, loaded_weight)
502
+ loaded_params.add(name)
503
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/eagle.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Iterable, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from vllm.attention.backends.abstract import AttentionMetadata
9
+ from vllm.config import VllmConfig
10
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
11
+ from vllm.model_executor.layers.sampler import SamplerOutput
12
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
13
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
14
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
15
+ from vllm.model_executor.models import ModelRegistry
16
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
17
+ from vllm.sequence import IntermediateTensors
18
+
19
+ from .utils import maybe_prefix
20
+
21
+
22
+ class DummyInputLayerNorm(nn.Module):
23
+
24
+ def __init__(self, weight=None, bias=None):
25
+ super().__init__()
26
+ self.weight = nn.Parameter(weight) if weight is not None else None
27
+ self.bias = nn.Parameter(bias) if bias is not None else None
28
+
29
+ def forward(self, x):
30
+ return x
31
+
32
+
33
+ class DummyOutputNorm(nn.Module):
34
+
35
+ def forward(self, x, residual):
36
+ if residual is None:
37
+ return x
38
+ else:
39
+ return x, residual
40
+
41
+
42
+ class EAGLE(nn.Module):
43
+ """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
44
+ Reference implementation: https://github.com/SafeAILab/EAGLE
45
+
46
+ Differences from reference implementation:
47
+ 1. In reference, LlamaDecoderLayer implementation doesn't have
48
+ input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
49
+ Following this approach, our implementation also disables
50
+ the input_layernorm for the first decoder layer.
51
+ 2. We allow any decoder layer to be used in EAGLE whereas in reference
52
+ decoder layer is fixed to be LlamaDecoderLayer.
53
+ 3. We have an optional token_map which reduces draft vocab to most
54
+ frequently used tokens to give some additional speed-up by reducing
55
+ sampling overhead. This is disabled unless the checkpoint file has
56
+ explicit token_map tensor and config has an optional attribute
57
+ truncated_vocab_size < vocab_size. To use this technique, one has to find
58
+ the top-k most frequent tokens in target dataset and add that as a tensor
59
+ in the draft checkpoint (using key token_map). Also, the draft config
60
+ needs to have truncated_vocab_size (=k) as an attribute."""
61
+
62
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
63
+ super().__init__()
64
+ config = vllm_config.model_config.hf_config
65
+ self.config = config
66
+
67
+ architectures = getattr(self.config.model, "architectures", [])
68
+ model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
69
+
70
+ self.model = model_cls(vllm_config=vllm_config,
71
+ prefix=maybe_prefix(prefix, "model"))
72
+
73
+ self.fc = nn.Linear(config.model.hidden_size * 2,
74
+ config.model.hidden_size,
75
+ bias=getattr(self.config, "eagle_fc_bias", False))
76
+
77
+ # Modify layer normalization and residual connections as suggested
78
+ # in the EAGLE framework: https://github.com/SafeAILab/EAGLE
79
+ # While weights and biases are generally not needed,
80
+ # they are retained here to support certain unit tests
81
+ # (e.g., spec_decode/e2e/test_eagle_correctness.py).
82
+ self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
83
+ weight=self.model.model.layers[0].input_layernorm.weight)
84
+ self.model.model.norm = DummyOutputNorm()
85
+
86
+ self.orig_vocab_size = config.vocab_size
87
+ self.truncated_vocab_size = config.truncated_vocab_size
88
+ self.unpadded_vocab_size = self.truncated_vocab_size
89
+
90
+ self.lm_head = ParallelLMHead(
91
+ self.unpadded_vocab_size,
92
+ config.hidden_size,
93
+ org_num_embeddings=self.truncated_vocab_size,
94
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
95
+ )
96
+
97
+ logit_scale = getattr(config, "logit_scale", 1.0)
98
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
99
+ self.truncated_vocab_size,
100
+ logit_scale)
101
+
102
+ # Token map is a idx to token mapping to reduce the vocab size for
103
+ # the draft model. Using smaller vocab size for draft, containing
104
+ # only most frequent tokens reduces the speculation overhead. This
105
+ # doesn't affect the acceptance rate much and thus gives more speed
106
+ # -up. By default, this is disabled and is only used if the EAGLE
107
+ # checkpoint file has token_map tensor.
108
+ self.token_map = None
109
+
110
+ @property
111
+ def sampler(self):
112
+ return self.model.sampler
113
+
114
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
115
+ return self.model.model.get_input_embeddings(input_ids)
116
+
117
+ def forward(
118
+ self,
119
+ input_ids: torch.Tensor,
120
+ positions: torch.Tensor,
121
+ kv_caches: List[torch.Tensor],
122
+ attn_metadata: AttentionMetadata,
123
+ previous_hidden_states: torch.Tensor,
124
+ intermediate_tensors: Optional[IntermediateTensors] = None,
125
+ inputs_embeds: Optional[torch.Tensor] = None,
126
+ ) -> torch.Tensor:
127
+
128
+ if inputs_embeds is None:
129
+ inputs_embeds = self.get_input_embeddings(input_ids)
130
+
131
+ inputs_embeds = self.fc(
132
+ torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
133
+
134
+ inputs_embeds[positions == 0] = 0 # masking inputs at position=0
135
+
136
+ hidden_states = self.model.model(
137
+ input_ids=None,
138
+ inputs_embeds=inputs_embeds,
139
+ positions=positions,
140
+ kv_caches=kv_caches,
141
+ attn_metadata=attn_metadata,
142
+ intermediate_tensors=intermediate_tensors,
143
+ )
144
+ return hidden_states
145
+
146
+ def compute_logits(self, hidden_states: torch.Tensor,
147
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
148
+ logits = self.logits_processor(self.lm_head, hidden_states,
149
+ sampling_metadata)
150
+
151
+ if self.token_map is not None:
152
+ _logits = logits
153
+ logits = -torch.inf * torch.ones(
154
+ size=(*_logits.shape[:-1], self.orig_vocab_size),
155
+ device=_logits.device,
156
+ dtype=_logits.dtype)
157
+
158
+ logits[..., self.token_map] = _logits
159
+
160
+ return logits
161
+
162
+ def sample(
163
+ self,
164
+ logits: torch.Tensor,
165
+ sampling_metadata: SamplingMetadata,
166
+ ) -> Optional[SamplerOutput]:
167
+ next_tokens = self.sampler(logits, sampling_metadata)
168
+ return next_tokens
169
+
170
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
171
+ # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
172
+ # due to missing lm_head weights and its config being that of a
173
+ # Llama model. Here's a compatible version with the same weights:
174
+ # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
175
+ # Also, here's an example script for converting trained EAGLE
176
+ # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
177
+ model_weights = {}
178
+ for name, loaded_weight in weights:
179
+ if name == "token_map":
180
+ if self.config.truncated_vocab_size < self.config.vocab_size:
181
+ self.token_map = nn.Parameter(loaded_weight,
182
+ requires_grad=False)
183
+ elif name.startswith("fc.weight"):
184
+ weight_loader = getattr(self.fc.weight, "weight_loader",
185
+ default_weight_loader)
186
+ weight_loader(self.fc.weight, loaded_weight)
187
+ elif name.startswith("fc.bias"):
188
+ if self.fc.bias is not None:
189
+ weight_loader = getattr(self.fc.bias, "weight_loader",
190
+ default_weight_loader)
191
+ weight_loader(self.fc.bias, loaded_weight)
192
+ else:
193
+ raise ValueError("Found bias in the loaded weights "
194
+ "but the model config doesn't have bias")
195
+ elif name.startswith("model.lm_head.") or name.startswith(
196
+ "model.model."):
197
+ model_weights[name.split("model.", 1)[-1]] = loaded_weight
198
+ elif name.startswith("lm_head.") or name.startswith("model."):
199
+ model_weights[name] = loaded_weight
200
+ else:
201
+ model_weights[f"model.{name}"] = loaded_weight
202
+
203
+ lm_head_weight = model_weights.pop("lm_head.weight")
204
+
205
+ if self.token_map is not None and\
206
+ lm_head_weight.shape[0] > self.token_map.shape[0]:
207
+
208
+ lm_head_weight = lm_head_weight[self.token_map]
209
+
210
+ weight_loader = getattr(self.lm_head.weight, "weight_loader",
211
+ default_weight_loader)
212
+ weight_loader(self.lm_head.weight, lm_head_weight)
213
+
214
+ self.model.load_weights(model_weights.items())
.venv/lib/python3.11/site-packages/vllm/model_executor/models/falcon.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
7
+ # reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Falcon model."""
21
+
22
+ import math
23
+ from typing import Iterable, List, Optional, Set, Tuple, Union
24
+
25
+ import torch
26
+ from torch import nn
27
+ from torch.nn import LayerNorm
28
+ from transformers import FalconConfig as HF_FalconConfig
29
+
30
+ from vllm.attention import Attention, AttentionMetadata
31
+ from vllm.compilation.decorators import support_torch_compile
32
+ from vllm.config import CacheConfig, VllmConfig
33
+ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
34
+ get_tensor_model_parallel_world_size,
35
+ tensor_model_parallel_all_reduce)
36
+ from vllm.model_executor.layers.activation import get_act_fn
37
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
38
+ QKVParallelLinear,
39
+ RowParallelLinear)
40
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
+ from vllm.model_executor.layers.quantization import QuantizationConfig
42
+ from vllm.model_executor.layers.rotary_embedding import get_rope
43
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead, VocabParallelEmbedding)
46
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
48
+ from vllm.sequence import IntermediateTensors
49
+ from vllm.transformers_utils.configs import RWConfig
50
+
51
+ from .interfaces import SupportsPP
52
+ from .utils import (is_pp_missing_parameter,
53
+ make_empty_intermediate_tensors_factory, make_layers,
54
+ maybe_prefix)
55
+
56
+ FalconConfig = Union[HF_FalconConfig, RWConfig]
57
+
58
+
59
+ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
60
+ closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
61
+ base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
62
+ dtype=torch.float32)
63
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
64
+ slopes = torch.pow(base, powers)
65
+
66
+ if closest_power_of_2 != total_num_heads:
67
+ extra_base = torch.tensor(
68
+ 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
69
+ dtype=torch.float32)
70
+ num_remaining_heads = min(closest_power_of_2,
71
+ total_num_heads - closest_power_of_2)
72
+ extra_powers = torch.arange(1,
73
+ 1 + 2 * num_remaining_heads,
74
+ 2,
75
+ dtype=torch.int32)
76
+ slopes = torch.cat(
77
+ [slopes, torch.pow(extra_base, extra_powers)], dim=0)
78
+
79
+ return slopes
80
+
81
+
82
+ class FalconAttention(nn.Module):
83
+
84
+ def __init__(
85
+ self,
86
+ config: FalconConfig,
87
+ cache_config: Optional[CacheConfig] = None,
88
+ quant_config: Optional[QuantizationConfig] = None,
89
+ prefix: str = "",
90
+ ):
91
+ super().__init__()
92
+
93
+ self.hidden_size = config.hidden_size
94
+ tp_size = get_tensor_model_parallel_world_size()
95
+
96
+ self.total_num_heads = config.num_attention_heads
97
+ assert self.total_num_heads % tp_size == 0
98
+ self.num_heads = self.total_num_heads // tp_size
99
+ self.head_dim = self.hidden_size // self.total_num_heads
100
+ assert self.head_dim * self.total_num_heads == self.hidden_size
101
+
102
+ self.new_decoder_architecture = config.new_decoder_architecture
103
+ self.multi_query = config.multi_query
104
+
105
+ if self.new_decoder_architecture:
106
+ self.total_num_kv_heads = config.num_kv_heads
107
+ elif self.multi_query:
108
+ self.total_num_kv_heads = 1
109
+ else:
110
+ self.total_num_kv_heads = self.total_num_heads
111
+ if self.total_num_kv_heads >= tp_size:
112
+ # Number of KV heads is greater than TP size, so we partition
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert self.total_num_kv_heads % tp_size == 0
115
+ else:
116
+ # Number of KV heads is less than TP size, so we replicate
117
+ # the KV heads across multiple tensor parallel GPUs.
118
+ assert tp_size % self.total_num_kv_heads == 0
119
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
120
+
121
+ self.query_key_value = QKVParallelLinear(
122
+ self.hidden_size,
123
+ self.head_dim,
124
+ self.total_num_heads,
125
+ self.total_num_kv_heads,
126
+ bias=config.bias,
127
+ skip_bias_add=True,
128
+ quant_config=quant_config,
129
+ )
130
+ self.q_size = self.num_heads * self.head_dim
131
+ self.kv_size = self.num_kv_heads * self.head_dim
132
+
133
+ # Layer-wise attention scaling
134
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
135
+ self.reduce_row_parallel_results = not (config.new_decoder_architecture
136
+ or config.parallel_attn)
137
+ self.dense = RowParallelLinear(
138
+ self.hidden_size,
139
+ self.hidden_size,
140
+ bias=config.bias,
141
+ skip_bias_add=True,
142
+ quant_config=quant_config,
143
+ reduce_results=self.reduce_row_parallel_results)
144
+
145
+ self.use_rotary = config.rotary
146
+ self.use_alibi = config.alibi
147
+ assert not (self.use_rotary and self.use_alibi), (
148
+ "Rotary and alibi are mutually exclusive.")
149
+
150
+ if self.use_rotary:
151
+ rope_theta = getattr(config, "rope_theta", 10000)
152
+ max_position_embeddings = getattr(config,
153
+ "max_position_embeddings", 8192)
154
+ self.rotary_emb = get_rope(
155
+ self.head_dim,
156
+ rotary_dim=self.head_dim,
157
+ max_position=max_position_embeddings,
158
+ base=rope_theta,
159
+ )
160
+ self.attn = Attention(self.num_heads,
161
+ self.head_dim,
162
+ self.inv_norm_factor,
163
+ num_kv_heads=self.num_kv_heads,
164
+ quant_config=quant_config,
165
+ prefix=f"{prefix}.attn")
166
+ elif self.use_alibi:
167
+ tp_rank = get_tensor_model_parallel_rank()
168
+ head_start = tp_rank * self.num_heads
169
+ head_end = (tp_rank + 1) * self.num_heads
170
+ alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
171
+ self.inv_norm_factor)
172
+ alibi_slopes = alibi_slopes[head_start:head_end].tolist()
173
+ self.attn = Attention(self.num_heads,
174
+ self.head_dim,
175
+ self.inv_norm_factor,
176
+ num_kv_heads=self.num_kv_heads,
177
+ alibi_slopes=alibi_slopes,
178
+ quant_config=quant_config,
179
+ prefix=f"{prefix}.attn")
180
+ else:
181
+ self.attn = Attention(self.num_heads,
182
+ self.head_dim,
183
+ scale=self.inv_norm_factor,
184
+ num_kv_heads=self.num_kv_heads,
185
+ cache_config=cache_config,
186
+ quant_config=quant_config,
187
+ prefix=f"{prefix}.attn")
188
+
189
+ def forward(
190
+ self,
191
+ positions: torch.Tensor,
192
+ hidden_states: torch.Tensor,
193
+ kv_cache: torch.Tensor,
194
+ attn_metadata: AttentionMetadata,
195
+ ) -> torch.Tensor:
196
+ qkv, bias = self.query_key_value(hidden_states)
197
+ if bias is not None:
198
+ qkv += bias
199
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
200
+ if self.use_rotary:
201
+ q, k = self.rotary_emb(positions, q, k)
202
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
203
+ attn_output, bias = self.dense(attn_output)
204
+ return attn_output, bias
205
+
206
+
207
+ class FalconMLP(nn.Module):
208
+
209
+ def __init__(
210
+ self,
211
+ config: FalconConfig,
212
+ quant_config: Optional[QuantizationConfig] = None,
213
+ ):
214
+ super().__init__()
215
+ hidden_size = config.hidden_size
216
+
217
+ self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
218
+ 4 * hidden_size,
219
+ bias=config.bias,
220
+ skip_bias_add=True,
221
+ quant_config=quant_config)
222
+ self.act = get_act_fn("gelu")
223
+ self.reduce_row_parallel_results = not (config.new_decoder_architecture
224
+ or config.parallel_attn)
225
+ self.dense_4h_to_h = RowParallelLinear(
226
+ 4 * hidden_size,
227
+ hidden_size,
228
+ bias=config.bias,
229
+ skip_bias_add=True,
230
+ reduce_results=self.reduce_row_parallel_results,
231
+ quant_config=quant_config)
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
235
+ x, bias = self.dense_h_to_4h(x)
236
+ if bias is not None:
237
+ x += bias
238
+ x = self.act(x)
239
+ x, bias = self.dense_4h_to_h(x)
240
+ return x, bias
241
+
242
+
243
+ class FalconDecoderLayer(nn.Module):
244
+
245
+ def __init__(
246
+ self,
247
+ config: FalconConfig,
248
+ cache_config: Optional[CacheConfig] = None,
249
+ quant_config: Optional[QuantizationConfig] = None,
250
+ prefix: str = "",
251
+ ):
252
+ super().__init__()
253
+ hidden_size = config.hidden_size
254
+ self.num_heads = config.num_attention_heads
255
+ self.self_attention = FalconAttention(
256
+ config,
257
+ cache_config,
258
+ quant_config,
259
+ prefix=f"{prefix}.self_attention")
260
+ self.mlp = FalconMLP(config, quant_config)
261
+ self.config = config
262
+
263
+ if (not hasattr(config, "num_ln_in_parallel_attn")):
264
+ config.num_ln_in_parallel_attn = None
265
+
266
+ if (config.num_ln_in_parallel_attn is None
267
+ and config.new_decoder_architecture):
268
+ config.num_ln_in_parallel_attn = 2
269
+
270
+ if not config.parallel_attn:
271
+ self.post_attention_layernorm = LayerNorm(
272
+ hidden_size, eps=config.layer_norm_epsilon)
273
+ self.input_layernorm = LayerNorm(hidden_size,
274
+ eps=config.layer_norm_epsilon)
275
+ else:
276
+ if config.num_ln_in_parallel_attn == 2:
277
+ # The layer norm before self-attention
278
+ self.ln_attn = LayerNorm(hidden_size,
279
+ eps=config.layer_norm_epsilon)
280
+ # The layer norm before the MLP
281
+ self.ln_mlp = LayerNorm(hidden_size,
282
+ eps=config.layer_norm_epsilon)
283
+ else:
284
+ self.input_layernorm = LayerNorm(hidden_size,
285
+ eps=config.layer_norm_epsilon)
286
+
287
+ self.reduce_row_parallel_results = not (config.new_decoder_architecture
288
+ or config.parallel_attn)
289
+
290
+ def forward(
291
+ self,
292
+ positions: torch.Tensor,
293
+ hidden_states: torch.Tensor,
294
+ kv_cache: torch.Tensor,
295
+ attn_metadata: AttentionMetadata,
296
+ ) -> torch.Tensor:
297
+ residual = hidden_states
298
+
299
+ if self.config.num_ln_in_parallel_attn == 2:
300
+ attention_layernorm_out = self.ln_attn(hidden_states)
301
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
302
+ else:
303
+ attention_layernorm_out = self.input_layernorm(hidden_states)
304
+
305
+ # Self attention.
306
+ attention_output, attention_bias = self.self_attention(
307
+ positions=positions,
308
+ hidden_states=attention_layernorm_out,
309
+ kv_cache=kv_cache,
310
+ attn_metadata=attn_metadata,
311
+ )
312
+ if self.reduce_row_parallel_results and attention_bias is not None:
313
+ attention_output += attention_bias
314
+
315
+ if not self.config.new_decoder_architecture:
316
+ if self.config.parallel_attn:
317
+ mlp_layernorm_out = attention_layernorm_out
318
+ else:
319
+ residual += attention_output
320
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
321
+
322
+ if (self.config.new_decoder_architecture and self.config.parallel_attn
323
+ and self.config.num_ln_in_parallel_attn == 1):
324
+ mlp_layernorm_out = attention_layernorm_out
325
+
326
+ # MLP.
327
+ mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
328
+ if self.reduce_row_parallel_results and mlp_bias is not None:
329
+ mlp_output += mlp_bias
330
+
331
+ if not self.reduce_row_parallel_results:
332
+ # When MLP and Attention layers are parallel, we can use
333
+ # only one all-reduce operator to reduce the results from
334
+ # both MLP and Attention layers.
335
+ mlp_output += attention_output
336
+ mlp_output = tensor_model_parallel_all_reduce(mlp_output)
337
+ if attention_bias is not None:
338
+ mlp_output += attention_bias
339
+ if mlp_bias is not None:
340
+ mlp_output += mlp_bias
341
+
342
+ output = mlp_output + residual
343
+ return output
344
+
345
+
346
+ @support_torch_compile
347
+ class FalconModel(nn.Module):
348
+
349
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
350
+ super().__init__()
351
+
352
+ config = vllm_config.model_config.hf_config
353
+ cache_config = vllm_config.cache_config
354
+ quant_config = vllm_config.quant_config
355
+
356
+ self.config = config
357
+ self.embed_dim = config.hidden_size
358
+ self.num_heads = config.num_attention_heads
359
+ self.use_alibi = config.alibi
360
+
361
+ # Embedding + LN Embedding
362
+ self.word_embeddings = VocabParallelEmbedding(
363
+ config.vocab_size,
364
+ self.embed_dim,
365
+ )
366
+
367
+ # Transformer blocks
368
+ self.start_layer, self.end_layer, self.h = make_layers(
369
+ config.num_hidden_layers,
370
+ lambda prefix: FalconDecoderLayer(
371
+ config, cache_config, quant_config, prefix=prefix),
372
+ prefix=f"{prefix}.h")
373
+
374
+ # Final Layer Norm
375
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
376
+ self.make_empty_intermediate_tensors = (
377
+ make_empty_intermediate_tensors_factory(["hidden_states"],
378
+ config.hidden_size))
379
+
380
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
381
+ return self.word_embeddings(input_ids)
382
+
383
+ def forward(
384
+ self,
385
+ input_ids: torch.Tensor,
386
+ positions: torch.Tensor,
387
+ kv_caches: List[torch.Tensor],
388
+ attn_metadata: AttentionMetadata,
389
+ intermediate_tensors: Optional[IntermediateTensors],
390
+ inputs_embeds: Optional[torch.Tensor] = None,
391
+ ) -> Union[torch.Tensor, IntermediateTensors]:
392
+ if get_pp_group().is_first_rank:
393
+ if inputs_embeds is not None:
394
+ hidden_states = inputs_embeds
395
+ else:
396
+ hidden_states = self.get_input_embeddings(input_ids)
397
+ else:
398
+ hidden_states = intermediate_tensors["hidden_states"]
399
+ for i in range(self.start_layer, self.end_layer):
400
+ layer = self.h[i]
401
+ hidden_states = layer(
402
+ positions,
403
+ hidden_states,
404
+ kv_caches[i - self.start_layer],
405
+ attn_metadata,
406
+ )
407
+ if not get_pp_group().is_last_rank:
408
+ return IntermediateTensors({"hidden_states": hidden_states})
409
+ hidden_states = self.ln_f(hidden_states)
410
+ return hidden_states
411
+
412
+
413
+ class FalconForCausalLM(nn.Module, SupportsPP):
414
+ packed_modules_mapping = {
415
+ "query_key_value": ["query_key_value"],
416
+ }
417
+
418
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
419
+ super().__init__()
420
+ config = vllm_config.model_config.hf_config
421
+ quant_config = vllm_config.quant_config
422
+ self.config = config
423
+ self.quant_config = quant_config
424
+ self.transformer = FalconModel(vllm_config=vllm_config,
425
+ prefix=maybe_prefix(
426
+ prefix, "transformer"))
427
+ # only Falcon-11B doesn't share lm_head weight with word embeddings
428
+ # and previous Falcon model doesn't have tie_word_embeddings config
429
+ # so we set tie_word_embeddings to True by default
430
+ self.tie_word_embeddings = (config.tie_word_embeddings
431
+ if config.tie_word_embeddings is not None
432
+ else True)
433
+ if self.tie_word_embeddings:
434
+ self.lm_head = self.transformer.word_embeddings
435
+ else:
436
+ self.lm_head = ParallelLMHead(
437
+ config.vocab_size,
438
+ config.hidden_size,
439
+ quant_config=quant_config,
440
+ )
441
+ self.logits_processor = LogitsProcessor(config.vocab_size)
442
+ self.sampler = get_sampler()
443
+ self.make_empty_intermediate_tensors = (
444
+ self.transformer.make_empty_intermediate_tensors)
445
+
446
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
447
+ return self.transformer.get_input_embeddings(input_ids)
448
+
449
+ def forward(
450
+ self,
451
+ input_ids: torch.LongTensor,
452
+ positions: torch.Tensor,
453
+ kv_caches: List[torch.Tensor],
454
+ attn_metadata: AttentionMetadata,
455
+ intermediate_tensors: Optional[IntermediateTensors] = None,
456
+ inputs_embeds: Optional[torch.Tensor] = None,
457
+ ) -> torch.Tensor:
458
+ hidden_states = self.transformer(input_ids, positions, kv_caches,
459
+ attn_metadata, intermediate_tensors,
460
+ inputs_embeds)
461
+ return hidden_states
462
+
463
+ def compute_logits(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ sampling_metadata: SamplingMetadata,
467
+ ) -> Optional[torch.Tensor]:
468
+ logits = self.logits_processor(self.lm_head, hidden_states,
469
+ sampling_metadata)
470
+ return logits
471
+
472
+ def sample(
473
+ self,
474
+ logits: torch.Tensor,
475
+ sampling_metadata: SamplingMetadata,
476
+ ) -> Optional[SamplerOutput]:
477
+ next_tokens = self.sampler(logits, sampling_metadata)
478
+ return next_tokens
479
+
480
+ def load_weights(self, weights: Iterable[Tuple[str,
481
+ torch.Tensor]]) -> Set[str]:
482
+ total_num_heads = self.config.num_attention_heads
483
+ if self.config.new_decoder_architecture:
484
+ total_num_kv_heads = self.config.num_kv_heads
485
+ elif self.config.multi_query:
486
+ total_num_kv_heads = 1
487
+ else:
488
+ total_num_kv_heads = total_num_heads
489
+ num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
490
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
491
+ loaded_params: Set[str] = set()
492
+ for name, loaded_weight in weights:
493
+ if name == "lm_head.weight" and self.tie_word_embeddings:
494
+ # Falcon uses tied embeddings except Falcon-11b.
495
+ continue
496
+ # Skip loading extra bias for GPTQ models.
497
+ if name.endswith(".bias") and name not in params_dict:
498
+ continue
499
+ if is_pp_missing_parameter(name, self):
500
+ continue
501
+ param = params_dict[name]
502
+ if "query_key_value" in name:
503
+ output_dim = getattr(param, "output_dim", None)
504
+ loaded_weight_shape = loaded_weight.shape
505
+ if output_dim is not None:
506
+ loaded_weight = loaded_weight.view(
507
+ loaded_weight_shape[:output_dim] +
508
+ (total_num_kv_heads, num_query_heads_per_kv_head + 2,
509
+ -1) + loaded_weight_shape[output_dim + 1:])
510
+ wq = loaded_weight.narrow(
511
+ output_dim + 1, 0,
512
+ num_query_heads_per_kv_head).reshape(
513
+ *loaded_weight_shape[:output_dim], -1,
514
+ *loaded_weight_shape[output_dim + 1:])
515
+ wk = loaded_weight.narrow(
516
+ output_dim + 1, num_query_heads_per_kv_head,
517
+ 1).reshape(*loaded_weight_shape[:output_dim], -1,
518
+ *loaded_weight_shape[output_dim + 1:])
519
+ wv = loaded_weight.narrow(
520
+ output_dim + 1, num_query_heads_per_kv_head + 1,
521
+ 1).reshape(*loaded_weight_shape[:output_dim], -1,
522
+ *loaded_weight_shape[output_dim + 1:])
523
+ loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
524
+
525
+ weight_loader = getattr(param, "weight_loader",
526
+ default_weight_loader)
527
+ weight_loader(param, loaded_weight)
528
+ loaded_params.add(name)
529
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/florence2.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import math
4
+ from typing import Iterable, List, Optional, Set, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from vllm.attention import AttentionMetadata
10
+ from vllm.config import VllmConfig
11
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
12
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
13
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
14
+ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
15
+ BartParallelLMHead,
16
+ BartScaledWordEmbedding)
17
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
18
+ from vllm.sequence import IntermediateTensors
19
+
20
+ from .utils import AutoWeightsLoader
21
+
22
+
23
+ class Florence2LanguageModel(nn.Module):
24
+
25
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
26
+ super().__init__()
27
+
28
+ config = vllm_config.model_config.hf_config
29
+ cache_config = vllm_config.cache_config
30
+ quant_config = vllm_config.quant_config
31
+
32
+ self.config = config
33
+
34
+ self.padding_idx = config.pad_token_id
35
+ self.vocab_size = config.vocab_size
36
+
37
+ self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
38
+ self.encoder = BartEncoder(config,
39
+ cache_config=cache_config,
40
+ quant_config=quant_config,
41
+ prefix=f"{prefix}.encoder")
42
+ self.decoder = BartDecoder(config,
43
+ cache_config=cache_config,
44
+ quant_config=quant_config,
45
+ prefix=f"{prefix}.decoder")
46
+
47
+ if self.config.tie_word_embeddings:
48
+ self.encoder.embed_tokens.weight = self.shared.weight
49
+ self.decoder.embed_tokens.weight = self.shared.weight
50
+
51
+ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
52
+ encoder_input_ids: torch.Tensor,
53
+ encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
54
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
55
+ r"""
56
+ Args:
57
+ input_ids
58
+ Indices of *decoder* input sequence tokens in the vocabulary.
59
+ Padding will be ignored by default should you
60
+ provide it.
61
+ positions
62
+ Positions of *decoder* input sequence tokens.
63
+ encoder_input_ids
64
+ Indices of *encoder* input sequence tokens in the vocabulary.
65
+ encoder_positions:
66
+ Positions of *encoder* input sequence tokens.
67
+ kv_caches:
68
+ Layer-wise list of KV cache tensors
69
+ attn_metadata:
70
+ vLLM Attention metadata structure
71
+ Returns:
72
+ Model output torch.Tensor
73
+ """
74
+
75
+ encoder_hidden_states = None
76
+
77
+ if encoder_input_ids.numel() > 0:
78
+ # Run encoder attention if a non-zero number of encoder tokens
79
+ # are provided as input
80
+ encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
81
+ positions=encoder_positions,
82
+ kv_caches=kv_caches,
83
+ attn_metadata=attn_metadata)
84
+
85
+ # decoder outputs consists of
86
+ # (dec_features, past_key_value, dec_hidden, dec_attn)
87
+ decoder_outputs = self.decoder(
88
+ decoder_input_ids=input_ids,
89
+ decoder_positions=positions,
90
+ encoder_hidden_states=encoder_hidden_states,
91
+ kv_caches=kv_caches,
92
+ attn_metadata=attn_metadata)
93
+
94
+ return decoder_outputs
95
+
96
+
97
+ class Florence2LanguageForConditionalGeneration(nn.Module):
98
+
99
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
100
+ super().__init__()
101
+
102
+ config = vllm_config.model_config.hf_config
103
+
104
+ self.config = config
105
+ self.model = Florence2LanguageModel(vllm_config=vllm_config,
106
+ prefix=f"{prefix}.model")
107
+ embed_scale = math.sqrt(
108
+ config.d_model) if config.scale_embedding else 1.0
109
+
110
+ self.vocab_size = config.vocab_size
111
+ self.lm_head = BartParallelLMHead(self.vocab_size,
112
+ config.d_model,
113
+ embed_scale=embed_scale)
114
+
115
+ self.logits_processor = LogitsProcessor(self.vocab_size,
116
+ config.vocab_size)
117
+ self.sampler = get_sampler()
118
+
119
+ def forward(
120
+ self,
121
+ input_ids: torch.Tensor,
122
+ positions: torch.Tensor,
123
+ encoder_input_ids: torch.Tensor,
124
+ encoder_positions: torch.Tensor,
125
+ kv_caches: List[torch.Tensor],
126
+ attn_metadata: AttentionMetadata,
127
+ **kwargs,
128
+ ) -> torch.Tensor:
129
+ r"""
130
+ Args:
131
+ input_ids
132
+ torch.Tensor of *decoder* input token ids.
133
+ positions
134
+ torch.Tensor of *decoder* position indices.
135
+ encoder_input_ids
136
+ torch.Tensor of *encoder* input token ids.
137
+ encoder_positions
138
+ torch.Tensor of *encoder* position indices
139
+ kv_caches:
140
+ Layer-wise list of KV cache tensors
141
+ attn_metadata:
142
+ vLLM Attention metadata structure
143
+ Returns:
144
+ Output torch.Tensor
145
+ """
146
+ return self.model(input_ids, positions, encoder_input_ids,
147
+ encoder_positions, kv_caches, attn_metadata)
148
+
149
+ def compute_logits(
150
+ self,
151
+ hidden_states: torch.Tensor,
152
+ sampling_metadata: SamplingMetadata,
153
+ ) -> Optional[torch.Tensor]:
154
+ logits = self.logits_processor(self.lm_head, hidden_states,
155
+ sampling_metadata)
156
+ return logits
157
+
158
+ def sample(self, logits: torch.Tensor,
159
+ sampling_metadata: SamplingMetadata) -> SamplerOutput:
160
+ next_tokens = self.sampler(logits, sampling_metadata)
161
+ return next_tokens
162
+
163
+ def load_weights(self, weights: Iterable[Tuple[str,
164
+ torch.Tensor]]) -> Set[str]:
165
+ stacked_params_mapping = [
166
+ # (param_name, shard_name, shard_id)
167
+ ("qkv_proj", "q_proj", "q"),
168
+ ("qkv_proj", "k_proj", "k"),
169
+ ("qkv_proj", "v_proj", "v"),
170
+ ]
171
+
172
+ params_dict = dict(self.named_parameters())
173
+ loaded_params: Set[str] = set()
174
+ for name, loaded_weight in weights:
175
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
176
+ if weight_name not in name:
177
+ continue
178
+ name = name.replace(weight_name, param_name)
179
+ param = params_dict[name]
180
+ weight_loader = param.weight_loader
181
+ weight_loader(param, loaded_weight, shard_id)
182
+ break
183
+ else:
184
+ if "final_logits_bias" in name:
185
+ continue
186
+ if self.config.tie_word_embeddings and "embed_tokens" in name:
187
+ continue
188
+ param = params_dict[name]
189
+ weight_loader = getattr(param, "weight_loader",
190
+ default_weight_loader)
191
+ weight_loader(param, loaded_weight)
192
+ loaded_params.add(name)
193
+ return loaded_params
194
+
195
+
196
+ class Florence2ForConditionalGeneration(nn.Module):
197
+
198
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199
+ super().__init__()
200
+ config = vllm_config.model_config.hf_config
201
+
202
+ # TODO(Isotr0py): Add vision backbone
203
+ self.language_model = Florence2LanguageForConditionalGeneration(
204
+ vllm_config=vllm_config.with_hf_config(config.text_config),
205
+ prefix=f"{prefix}.language_model",
206
+ )
207
+
208
+ @property
209
+ def sampler(self):
210
+ return self.language_model.sampler
211
+
212
+ def forward(
213
+ self,
214
+ input_ids: torch.Tensor,
215
+ positions: torch.Tensor,
216
+ kv_caches: List[torch.Tensor],
217
+ attn_metadata: AttentionMetadata,
218
+ intermediate_tensors: Optional[IntermediateTensors] = None,
219
+ *,
220
+ encoder_input_ids: torch.Tensor,
221
+ encoder_positions: torch.Tensor,
222
+ **kwargs,
223
+ ) -> torch.Tensor:
224
+ r"""
225
+ Args:
226
+ input_ids
227
+ torch.Tensor of *decoder* input token ids.
228
+ positions
229
+ torch.Tensor of *decoder* position indices.
230
+ encoder_input_ids
231
+ torch.Tensor of *encoder* input token ids.
232
+ encoder_positions
233
+ torch.Tensor of *encoder* position indices
234
+ kv_caches:
235
+ Layer-wise list of KV cache tensors
236
+ attn_metadata:
237
+ vLLM Attention metadata structure
238
+ Returns:
239
+ Output torch.Tensor
240
+ """
241
+ return self.language_model(input_ids, positions, encoder_input_ids,
242
+ encoder_positions, kv_caches, attn_metadata)
243
+
244
+ def compute_logits(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ sampling_metadata: SamplingMetadata,
248
+ ) -> Optional[torch.Tensor]:
249
+ return self.language_model.compute_logits(hidden_states,
250
+ sampling_metadata)
251
+
252
+ def sample(
253
+ self,
254
+ logits: torch.Tensor,
255
+ sampling_metadata: SamplingMetadata,
256
+ ) -> SamplerOutput:
257
+ return self.language_model.sample(logits, sampling_metadata)
258
+
259
+ def load_weights(self, weights: Iterable[Tuple[str,
260
+ torch.Tensor]]) -> Set[str]:
261
+ skip_prefixes = [
262
+ 'image_projection', "vision_tower", "image_proj_norm",
263
+ "image_pos_embed", "visual_temporal_embed"
264
+ ]
265
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
266
+ return loader.load_weights(weights)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/fuyu.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2023 HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """ PyTorch Fuyu model."""
19
+ import math
20
+ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
21
+ TypedDict)
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
26
+ FuyuProcessor)
27
+
28
+ from vllm.attention import AttentionMetadata
29
+ from vllm.config import VllmConfig
30
+ from vllm.model_executor.layers.linear import ColumnParallelLinear
31
+ from vllm.model_executor.layers.sampler import SamplerOutput
32
+ from vllm.model_executor.models.persimmon import PersimmonForCausalLM
33
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
34
+ from vllm.multimodal import MULTIMODAL_REGISTRY
35
+ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
36
+ NestedTensors)
37
+ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
38
+ MultiModalDataItems)
39
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
40
+ BaseProcessingInfo, PromptReplacement,
41
+ PromptReplacementDetails)
42
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
43
+ from vllm.sequence import IntermediateTensors
44
+
45
+ from .interfaces import SupportsMultiModal, SupportsPP
46
+ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
47
+ merge_multimodal_embeddings)
48
+
49
+ # Cannot find the following 2 numbers from hf config.
50
+ _IMAGE_TOKEN_ID = 71011
51
+ _NEWLINE_TOKEN_ID = 71019
52
+
53
+
54
+ class FuyuImagePatchInputs(TypedDict):
55
+ type: Literal["image_patches"]
56
+ flat_data: torch.Tensor
57
+ """
58
+ Shape:
59
+ `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
60
+ """
61
+
62
+ patches_per_image: List[int]
63
+ """
64
+ List of number of total patches for each image in the batch.
65
+ This is used to restore the first two dimensions of `flat_data`.
66
+ """
67
+
68
+
69
+ class FuyuProcessingInfo(BaseProcessingInfo):
70
+
71
+ def get_hf_config(self):
72
+ return self.ctx.get_hf_config(FuyuConfig)
73
+
74
+ def get_hf_processor(self):
75
+ return self.ctx.get_hf_processor(FuyuProcessor)
76
+
77
+ def get_image_processor(self) -> FuyuImageProcessor:
78
+ return self.get_hf_processor().image_processor
79
+
80
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
81
+ return {"image": 1}
82
+
83
+ def get_mm_max_tokens_per_item(
84
+ self,
85
+ seq_len: int,
86
+ mm_counts: Mapping[str, int],
87
+ ) -> Mapping[str, int]:
88
+ target_width, target_height = self.get_image_size_with_most_features()
89
+
90
+ max_ncols, max_nrows = self.get_image_feature_grid_size(
91
+ image_width=target_width,
92
+ image_height=target_height,
93
+ )
94
+ max_image_tokens = (max_ncols + 1) * max_nrows
95
+
96
+ return {"image": max_image_tokens}
97
+
98
+ def get_image_feature_grid_size(
99
+ self,
100
+ *,
101
+ image_width: int,
102
+ image_height: int,
103
+ ) -> tuple[int, int]:
104
+ image_processor = self.get_image_processor()
105
+ target_width = image_processor.size["width"]
106
+ target_height = image_processor.size["height"]
107
+
108
+ if not (image_width <= target_width and image_height <= target_height):
109
+ height_scale_factor = target_height / image_height
110
+ width_scale_factor = target_width / image_width
111
+ optimal_scale_factor = min(height_scale_factor, width_scale_factor)
112
+
113
+ image_height = int(image_height * optimal_scale_factor)
114
+ image_width = int(image_width * optimal_scale_factor)
115
+
116
+ ncols = math.ceil(image_width / 30)
117
+ nrows = math.ceil(image_height / 30)
118
+ return ncols, nrows
119
+
120
+ def get_image_size_with_most_features(self) -> ImageSize:
121
+ image_processor = self.get_image_processor()
122
+ return ImageSize(width=image_processor.size["width"],
123
+ height=image_processor.size["height"])
124
+
125
+
126
+ class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
127
+
128
+ def get_dummy_processor_inputs(
129
+ self,
130
+ seq_len: int,
131
+ mm_counts: Mapping[str, int],
132
+ ) -> ProcessorInputs:
133
+ target_width, target_height = \
134
+ self.info.get_image_size_with_most_features()
135
+ num_images = mm_counts.get("image", 0)
136
+
137
+ mm_data = {
138
+ "image":
139
+ self._get_dummy_images(width=target_width,
140
+ height=target_height,
141
+ num_images=num_images)
142
+ }
143
+
144
+ return ProcessorInputs(
145
+ prompt_text="",
146
+ mm_data=mm_data,
147
+ )
148
+
149
+
150
+ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
151
+
152
+ def _call_hf_processor(
153
+ self,
154
+ prompt: str,
155
+ mm_data: Mapping[str, object],
156
+ mm_kwargs: Mapping[str, object],
157
+ ) -> BatchFeature:
158
+ if not mm_data:
159
+ # Avoid warning from HF logger for text-only input
160
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
161
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
162
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
163
+
164
+ processed_outputs = super()._call_hf_processor(
165
+ prompt=prompt,
166
+ mm_data=mm_data,
167
+ mm_kwargs=mm_kwargs,
168
+ )
169
+
170
+ image_patches = processed_outputs.get("image_patches")
171
+ if image_patches is not None:
172
+ images = mm_data["images"]
173
+ assert isinstance(images, list)
174
+
175
+ # Original output: (1, num_images, Pn, Px * Py * C)
176
+ # New output: (num_images, Pn, Px * Py * C)
177
+ assert (isinstance(image_patches, list)
178
+ and len(image_patches) == 1)
179
+ assert (isinstance(image_patches[0], torch.Tensor)
180
+ and len(image_patches[0]) == len(images))
181
+
182
+ processed_outputs["image_patches"] = image_patches[0]
183
+
184
+ return processed_outputs
185
+
186
+ def _apply_hf_processor_tokens_only(
187
+ self,
188
+ prompt_tokens: list[int],
189
+ ) -> list[int]:
190
+ # HF processor adds boa_token_id
191
+ tokenizer = self.info.get_tokenizer()
192
+ vocab = tokenizer.get_vocab()
193
+
194
+ boa_token_id = vocab["<0x04>"]
195
+
196
+ return prompt_tokens + [boa_token_id]
197
+
198
+ def _get_mm_fields_config(
199
+ self,
200
+ hf_inputs: BatchFeature,
201
+ hf_processor_mm_kwargs: Mapping[str, object],
202
+ ) -> Mapping[str, MultiModalFieldConfig]:
203
+ return dict(image_patches=MultiModalFieldConfig.batched("image"))
204
+
205
+ def _get_prompt_replacements(
206
+ self,
207
+ mm_items: MultiModalDataItems,
208
+ hf_processor_mm_kwargs: Mapping[str, object],
209
+ out_mm_kwargs: MultiModalKwargs,
210
+ ) -> list[PromptReplacement]:
211
+ hf_config = self.info.get_hf_config()
212
+ bos_token_id = hf_config.bos_token_id
213
+ assert isinstance(bos_token_id, int)
214
+
215
+ tokenizer = self.info.get_tokenizer()
216
+ eot_token_id = tokenizer.bos_token_id
217
+ assert isinstance(eot_token_id, int)
218
+
219
+ def get_replacement_fuyu(item_idx: int):
220
+ images = mm_items.get_items("image", ImageProcessorItems)
221
+ image_size = images.get_image_size(item_idx)
222
+
223
+ ncols, nrows = self.info.get_image_feature_grid_size(
224
+ image_width=image_size.width,
225
+ image_height=image_size.height,
226
+ )
227
+ image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
228
+ [_NEWLINE_TOKEN_ID]) * nrows
229
+
230
+ return PromptReplacementDetails(
231
+ full=image_tokens + [bos_token_id],
232
+ features=image_tokens,
233
+ )
234
+
235
+ return [
236
+ PromptReplacement(
237
+ modality="image",
238
+ target=[eot_token_id],
239
+ replacement=get_replacement_fuyu,
240
+ )
241
+ ]
242
+
243
+
244
+ @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
245
+ info=FuyuProcessingInfo,
246
+ dummy_inputs=FuyuDummyInputsBuilder)
247
+ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
248
+
249
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
250
+ super().__init__()
251
+ config = vllm_config.model_config.hf_config
252
+ quant_config = vllm_config.quant_config
253
+ multimodal_config = vllm_config.model_config.multimodal_config
254
+ self.config = config
255
+ self.multimodal_config = multimodal_config
256
+
257
+ self.padding_idx = config.pad_token_id
258
+ self.vocab_size = config.text_config.vocab_size
259
+ self.image_token_id = _IMAGE_TOKEN_ID
260
+ self.image_feature_size = config.patch_size**2 * config.num_channels
261
+
262
+ self.vision_embed_tokens = ColumnParallelLinear(
263
+ self.image_feature_size,
264
+ config.hidden_size,
265
+ quant_config=quant_config,
266
+ gather_output=True,
267
+ )
268
+ self.language_model = PersimmonForCausalLM(
269
+ vllm_config=vllm_config.with_hf_config(config.text_config),
270
+ prefix=maybe_prefix(prefix, "language_model"),
271
+ )
272
+ self.make_empty_intermediate_tensors = (
273
+ self.language_model.make_empty_intermediate_tensors)
274
+
275
+ @property
276
+ def sampler(self):
277
+ return self.language_model.sampler
278
+
279
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
280
+
281
+ h = w = self.config.patch_size
282
+ num_channels = self.config.num_channels
283
+ expected_dims = num_channels * h * w
284
+
285
+ def _validate_shape(d: torch.Tensor):
286
+ actual_dims = d.size(-1)
287
+
288
+ if actual_dims != expected_dims:
289
+ expected_expr = str(expected_dims)
290
+ raise ValueError(
291
+ "The expected shape of pixel values per image per batch "
292
+ f" per patch is {expected_expr}. "
293
+ f"You supplied {tuple(d.shape)}.")
294
+
295
+ for d in data:
296
+ _validate_shape(d)
297
+
298
+ return data.to(self.vision_embed_tokens.weight.dtype)
299
+
300
+ def _parse_and_validate_image_input(
301
+ self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
302
+ image_patches = kwargs.pop("image_patches", None)
303
+ if image_patches is not None:
304
+ if not isinstance(image_patches, (torch.Tensor, list)):
305
+ raise ValueError("Incorrect type of image patches. "
306
+ f"Got type: {type(image_patches)}")
307
+
308
+ image_patches_flat = flatten_bn(image_patches)
309
+
310
+ return FuyuImagePatchInputs(
311
+ type="image_patches",
312
+ flat_data=self._validate_pixel_values(
313
+ flatten_bn(image_patches_flat, concat=True)),
314
+ patches_per_image=[x.size(0) for x in image_patches_flat],
315
+ )
316
+
317
+ return None
318
+
319
+ def _process_image_input(
320
+ self, image_input: FuyuImagePatchInputs) -> NestedTensors:
321
+ image_patches_flat = image_input["flat_data"]
322
+ patches_per_image = image_input["patches_per_image"]
323
+
324
+ assert self.vision_embed_tokens is not None
325
+ vision_embeddings_flat, _ = self.vision_embed_tokens(
326
+ image_patches_flat)
327
+ return vision_embeddings_flat.split(patches_per_image, dim=0)
328
+
329
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
330
+ image_input = self._parse_and_validate_image_input(**kwargs)
331
+ if image_input is None:
332
+ return None
333
+ vision_embeddings = self._process_image_input(image_input)
334
+ return vision_embeddings
335
+
336
+ def get_input_embeddings(
337
+ self,
338
+ input_ids: torch.Tensor,
339
+ multimodal_embeddings: Optional[NestedTensors] = None,
340
+ ) -> torch.Tensor:
341
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
342
+ if multimodal_embeddings is not None:
343
+ inputs_embeds = merge_multimodal_embeddings(
344
+ input_ids, inputs_embeds, multimodal_embeddings,
345
+ _IMAGE_TOKEN_ID)
346
+ return inputs_embeds
347
+
348
+ def forward(
349
+ self,
350
+ input_ids: torch.Tensor,
351
+ positions: torch.Tensor,
352
+ kv_caches: List[torch.Tensor],
353
+ attn_metadata: AttentionMetadata,
354
+ intermediate_tensors: Optional[IntermediateTensors] = None,
355
+ inputs_embeds: Optional[torch.Tensor] = None,
356
+ **kwargs: object,
357
+ ):
358
+ if intermediate_tensors is not None:
359
+ inputs_embeds = None
360
+
361
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
362
+ # condition is for v0 compatibility.
363
+ elif inputs_embeds is None:
364
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
365
+ inputs_embeds = self.get_input_embeddings(input_ids,
366
+ vision_embeddings)
367
+ input_ids = None
368
+
369
+ hidden_states = self.language_model(
370
+ input_ids=input_ids,
371
+ positions=positions,
372
+ kv_caches=kv_caches,
373
+ attn_metadata=attn_metadata,
374
+ intermediate_tensors=intermediate_tensors,
375
+ inputs_embeds=inputs_embeds,
376
+ )
377
+ return hidden_states
378
+
379
+ def compute_logits(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ sampling_metadata: SamplingMetadata,
383
+ ) -> Optional[torch.Tensor]:
384
+ logits = self.language_model.logits_processor(
385
+ self.language_model.lm_head, hidden_states, sampling_metadata)
386
+ return logits
387
+
388
+ def sample(
389
+ self,
390
+ logits: torch.Tensor,
391
+ sampling_metadata: SamplingMetadata,
392
+ ) -> Optional[SamplerOutput]:
393
+ next_tokens = self.language_model.sampler(logits, sampling_metadata)
394
+ return next_tokens
395
+
396
+ def load_weights(self, weights: Iterable[Tuple[str,
397
+ torch.Tensor]]) -> Set[str]:
398
+ loader = AutoWeightsLoader(self)
399
+ return loader.load_weights(weights)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Copyright 2023 The vLLM team.
4
+ # Copyright (c) Google Inc.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Inference-only Gemma model compatible with HuggingFace weights."""
18
+ from functools import cache
19
+ from typing import Iterable, List, Optional, Set, Tuple, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import GemmaConfig
24
+
25
+ from vllm.attention import Attention, AttentionMetadata
26
+ from vllm.compilation.decorators import support_torch_compile
27
+ from vllm.config import CacheConfig, VllmConfig
28
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
29
+ from vllm.logger import init_logger
30
+ from vllm.model_executor.layers.activation import GeluAndMul
31
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm
32
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ RowParallelLinear)
35
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
+ from vllm.model_executor.layers.quantization import QuantizationConfig
37
+ from vllm.model_executor.layers.rotary_embedding import get_rope
38
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
+ VocabParallelEmbedding)
41
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
43
+ from vllm.sequence import IntermediateTensors
44
+
45
+ from .interfaces import SupportsLoRA, SupportsPP
46
+ from .utils import (is_pp_missing_parameter,
47
+ make_empty_intermediate_tensors_factory, make_layers,
48
+ maybe_prefix)
49
+
50
+ logger = init_logger(__name__)
51
+
52
+
53
+ @cache
54
+ def _get_gemma_act_fn(
55
+ hidden_act: Optional[str],
56
+ hidden_activation: Optional[str],
57
+ ) -> nn.Module:
58
+ if hidden_activation is None:
59
+ if hidden_act is not None:
60
+ logger.warning(
61
+ "Gemma's activation function was incorrectly set to exact GeLU "
62
+ "in the config JSON file when it was initially released. "
63
+ "Changing the activation function to approximate GeLU "
64
+ "(`gelu_pytorch_tanh`). If you want to use the legacy "
65
+ "`%s`, edit the config JSON to set "
66
+ "`hidden_activation=%s` instead of `hidden_act`. "
67
+ "See https://github.com/huggingface/transformers/pull/29402 "
68
+ "for more details.", hidden_act, hidden_act)
69
+ return GeluAndMul(approximate="tanh")
70
+ elif hidden_activation == "gelu_pytorch_tanh":
71
+ return GeluAndMul(approximate="tanh")
72
+ elif hidden_activation == "gelu":
73
+ return GeluAndMul(approximate="none")
74
+ else:
75
+ raise ValueError(f"Activation function {hidden_act} is not "
76
+ "supported for Gemma models.")
77
+
78
+
79
+ class GemmaMLP(nn.Module):
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int,
84
+ intermediate_size: int,
85
+ hidden_act: Optional[str] = None,
86
+ hidden_activation: Optional[str] = None,
87
+ quant_config: Optional[QuantizationConfig] = None,
88
+ prefix: str = "",
89
+ ) -> None:
90
+ super().__init__()
91
+ self.gate_up_proj = MergedColumnParallelLinear(
92
+ hidden_size,
93
+ [intermediate_size] * 2,
94
+ bias=False,
95
+ quant_config=quant_config,
96
+ prefix=f"{prefix}.gate_up_proj",
97
+ )
98
+ self.down_proj = RowParallelLinear(
99
+ intermediate_size,
100
+ hidden_size,
101
+ bias=False,
102
+ quant_config=quant_config,
103
+ prefix=f"{prefix}.down_proj",
104
+ )
105
+ self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
106
+
107
+ def forward(self, x):
108
+ gate_up, _ = self.gate_up_proj(x)
109
+ x = self.act_fn(gate_up)
110
+ x, _ = self.down_proj(x)
111
+ return x
112
+
113
+
114
+ class GemmaAttention(nn.Module):
115
+
116
+ def __init__(
117
+ self,
118
+ hidden_size: int,
119
+ num_heads: int,
120
+ num_kv_heads: int,
121
+ head_dim: int,
122
+ max_position_embeddings: int = 8192,
123
+ rope_theta: float = 10000,
124
+ cache_config: Optional[CacheConfig] = None,
125
+ quant_config: Optional[QuantizationConfig] = None,
126
+ prefix: str = "",
127
+ ) -> None:
128
+ super().__init__()
129
+ self.hidden_size = hidden_size
130
+ tp_size = get_tensor_model_parallel_world_size()
131
+ self.total_num_heads = num_heads
132
+ assert self.total_num_heads % tp_size == 0
133
+ self.num_heads = self.total_num_heads // tp_size
134
+ self.total_num_kv_heads = num_kv_heads
135
+ if self.total_num_kv_heads >= tp_size:
136
+ # Number of KV heads is greater than TP size, so we partition
137
+ # the KV heads across multiple tensor parallel GPUs.
138
+ assert self.total_num_kv_heads % tp_size == 0
139
+ else:
140
+ # Number of KV heads is less than TP size, so we replicate
141
+ # the KV heads across multiple tensor parallel GPUs.
142
+ assert tp_size % self.total_num_kv_heads == 0
143
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
144
+ self.head_dim = head_dim
145
+ self.q_size = self.num_heads * self.head_dim
146
+ self.kv_size = self.num_kv_heads * self.head_dim
147
+ self.scaling = self.head_dim**-0.5
148
+ self.rope_theta = rope_theta
149
+
150
+ self.qkv_proj = QKVParallelLinear(
151
+ hidden_size,
152
+ self.head_dim,
153
+ self.total_num_heads,
154
+ self.total_num_kv_heads,
155
+ bias=False,
156
+ quant_config=quant_config,
157
+ prefix=f"{prefix}.qkv_proj",
158
+ )
159
+ self.o_proj = RowParallelLinear(
160
+ self.total_num_heads * self.head_dim,
161
+ hidden_size,
162
+ bias=False,
163
+ quant_config=quant_config,
164
+ prefix=f"{prefix}.o_proj",
165
+ )
166
+
167
+ self.rotary_emb = get_rope(
168
+ self.head_dim,
169
+ rotary_dim=self.head_dim,
170
+ max_position=max_position_embeddings,
171
+ base=self.rope_theta,
172
+ is_neox_style=True,
173
+ )
174
+ self.attn = Attention(self.num_heads,
175
+ self.head_dim,
176
+ self.scaling,
177
+ num_kv_heads=self.num_kv_heads,
178
+ cache_config=cache_config,
179
+ quant_config=quant_config,
180
+ prefix=f"{prefix}.attn")
181
+
182
+ def forward(
183
+ self,
184
+ positions: torch.Tensor,
185
+ hidden_states: torch.Tensor,
186
+ kv_cache: torch.Tensor,
187
+ attn_metadata: AttentionMetadata,
188
+ ) -> torch.Tensor:
189
+ qkv, _ = self.qkv_proj(hidden_states)
190
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
191
+ q, k = self.rotary_emb(positions, q, k)
192
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
193
+ output, _ = self.o_proj(attn_output)
194
+ return output
195
+
196
+
197
+ class GemmaDecoderLayer(nn.Module):
198
+
199
+ def __init__(
200
+ self,
201
+ config: GemmaConfig,
202
+ cache_config: Optional[CacheConfig] = None,
203
+ quant_config: Optional[QuantizationConfig] = None,
204
+ prefix: str = "",
205
+ ) -> None:
206
+ super().__init__()
207
+ self.hidden_size = config.hidden_size
208
+ self.self_attn = GemmaAttention(
209
+ hidden_size=self.hidden_size,
210
+ num_heads=config.num_attention_heads,
211
+ num_kv_heads=config.num_key_value_heads,
212
+ head_dim=config.head_dim,
213
+ max_position_embeddings=config.max_position_embeddings,
214
+ rope_theta=config.rope_theta,
215
+ cache_config=cache_config,
216
+ quant_config=quant_config,
217
+ prefix=f"{prefix}.self_attn",
218
+ )
219
+ self.mlp = GemmaMLP(
220
+ hidden_size=self.hidden_size,
221
+ intermediate_size=config.intermediate_size,
222
+ hidden_act=config.hidden_act,
223
+ hidden_activation=getattr(config, "hidden_activation", None),
224
+ quant_config=quant_config,
225
+ prefix=f"{prefix}.mlp",
226
+ )
227
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size,
228
+ eps=config.rms_norm_eps)
229
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
230
+ eps=config.rms_norm_eps)
231
+
232
+ def forward(
233
+ self,
234
+ positions: torch.Tensor,
235
+ hidden_states: torch.Tensor,
236
+ kv_cache: torch.Tensor,
237
+ attn_metadata: AttentionMetadata,
238
+ residual: Optional[torch.Tensor],
239
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
240
+ # Self Attention
241
+ if residual is None:
242
+ residual = hidden_states
243
+ hidden_states = self.input_layernorm(hidden_states)
244
+ else:
245
+ hidden_states, residual = self.input_layernorm(
246
+ hidden_states, residual)
247
+ hidden_states = self.self_attn(
248
+ positions=positions,
249
+ hidden_states=hidden_states,
250
+ kv_cache=kv_cache,
251
+ attn_metadata=attn_metadata,
252
+ )
253
+
254
+ # Fully Connected
255
+ hidden_states, residual = self.post_attention_layernorm(
256
+ hidden_states, residual)
257
+ hidden_states = self.mlp(hidden_states)
258
+ return hidden_states, residual
259
+
260
+
261
+ @support_torch_compile
262
+ class GemmaModel(nn.Module):
263
+
264
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
265
+ super().__init__()
266
+
267
+ config = vllm_config.model_config.hf_config
268
+ cache_config = vllm_config.cache_config
269
+ quant_config = vllm_config.quant_config
270
+
271
+ self.config = config
272
+
273
+ self.embed_tokens = VocabParallelEmbedding(
274
+ config.vocab_size,
275
+ config.hidden_size,
276
+ )
277
+ self.start_layer, self.end_layer, self.layers = make_layers(
278
+ config.num_hidden_layers,
279
+ lambda prefix: GemmaDecoderLayer(
280
+ config, cache_config, quant_config, prefix=prefix),
281
+ prefix=f"{prefix}.layers")
282
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
283
+
284
+ # Normalize the embedding by sqrt(hidden_size)
285
+ # The normalizer's data type should be downcasted to the model's
286
+ # data type such as bfloat16, not float32.
287
+ # See https://github.com/huggingface/transformers/pull/29402
288
+ normalizer = self.config.hidden_size**0.5
289
+ self.register_buffer("normalizer", torch.tensor(normalizer))
290
+ self.make_empty_intermediate_tensors = (
291
+ make_empty_intermediate_tensors_factory(
292
+ ["hidden_states", "residual"], config.hidden_size))
293
+
294
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
295
+ return self.embed_tokens(input_ids)
296
+
297
+ def forward(
298
+ self,
299
+ input_ids: torch.Tensor,
300
+ positions: torch.Tensor,
301
+ kv_caches: List[torch.Tensor],
302
+ attn_metadata: AttentionMetadata,
303
+ intermediate_tensors: Optional[IntermediateTensors],
304
+ inputs_embeds: Optional[torch.Tensor] = None,
305
+ ) -> Union[torch.Tensor, IntermediateTensors]:
306
+ if get_pp_group().is_first_rank:
307
+ if inputs_embeds is not None:
308
+ hidden_states = inputs_embeds
309
+ else:
310
+ hidden_states = self.get_input_embeddings(input_ids)
311
+ hidden_states *= self.normalizer
312
+ residual = None
313
+ else:
314
+ hidden_states = intermediate_tensors["hidden_states"]
315
+ residual = intermediate_tensors["residual"]
316
+ for i in range(self.start_layer, self.end_layer):
317
+ layer = self.layers[i]
318
+ hidden_states, residual = layer(
319
+ positions,
320
+ hidden_states,
321
+ kv_caches[i - self.start_layer],
322
+ attn_metadata,
323
+ residual,
324
+ )
325
+ if not get_pp_group().is_last_rank:
326
+ return IntermediateTensors({
327
+ "hidden_states": hidden_states,
328
+ "residual": residual
329
+ })
330
+ hidden_states, _ = self.norm(hidden_states, residual)
331
+ return hidden_states
332
+
333
+
334
+ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
335
+ packed_modules_mapping = {
336
+ "qkv_proj": [
337
+ "q_proj",
338
+ "k_proj",
339
+ "v_proj",
340
+ ],
341
+ "gate_up_proj": [
342
+ "gate_proj",
343
+ "up_proj",
344
+ ],
345
+ }
346
+
347
+ # LoRA specific attributes
348
+ supported_lora_modules = [
349
+ "qkv_proj",
350
+ "o_proj",
351
+ "gate_up_proj",
352
+ "down_proj",
353
+ ]
354
+
355
+ # Gemma does not apply LoRA to the embedding layer.
356
+ embedding_modules = {}
357
+ embedding_padding_modules = []
358
+
359
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
360
+ super().__init__()
361
+ config = vllm_config.model_config.hf_config
362
+ quant_config = vllm_config.quant_config
363
+ lora_config = vllm_config.lora_config
364
+
365
+ self.config = config
366
+ # currently all existing Gemma models have `tie_word_embeddings` enabled
367
+ assert config.tie_word_embeddings
368
+ self.lora_config = lora_config
369
+
370
+ self.quant_config = quant_config
371
+ self.model = GemmaModel(vllm_config=vllm_config,
372
+ prefix=maybe_prefix(prefix, "model"))
373
+ self.logits_processor = LogitsProcessor(config.vocab_size)
374
+ self.sampler = get_sampler()
375
+ self.make_empty_intermediate_tensors = (
376
+ self.model.make_empty_intermediate_tensors)
377
+
378
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
379
+ return self.model.get_input_embeddings(input_ids)
380
+
381
+ def forward(
382
+ self,
383
+ input_ids: torch.Tensor,
384
+ positions: torch.Tensor,
385
+ kv_caches: List[torch.Tensor],
386
+ attn_metadata: AttentionMetadata,
387
+ intermediate_tensors: Optional[IntermediateTensors] = None,
388
+ inputs_embeds: Optional[torch.Tensor] = None,
389
+ ) -> Union[torch.Tensor, IntermediateTensors]:
390
+ hidden_states = self.model(input_ids, positions, kv_caches,
391
+ attn_metadata, intermediate_tensors,
392
+ inputs_embeds)
393
+ return hidden_states
394
+
395
+ def compute_logits(
396
+ self,
397
+ hidden_states: torch.Tensor,
398
+ sampling_metadata: SamplingMetadata,
399
+ ) -> Optional[torch.Tensor]:
400
+ logits = self.logits_processor(self.model.embed_tokens, hidden_states,
401
+ sampling_metadata)
402
+ return logits
403
+
404
+ def sample(
405
+ self,
406
+ logits: torch.Tensor,
407
+ sampling_metadata: SamplingMetadata,
408
+ ) -> Optional[SamplerOutput]:
409
+ next_tokens = self.sampler(logits, sampling_metadata)
410
+ return next_tokens
411
+
412
+ def load_weights(self, weights: Iterable[Tuple[str,
413
+ torch.Tensor]]) -> Set[str]:
414
+ stacked_params_mapping = [
415
+ # (param_name, shard_name, shard_id)
416
+ ("qkv_proj", "q_proj", "q"),
417
+ ("qkv_proj", "k_proj", "k"),
418
+ ("qkv_proj", "v_proj", "v"),
419
+ ("gate_up_proj", "gate_proj", 0),
420
+ ("gate_up_proj", "up_proj", 1),
421
+ ]
422
+ params_dict = dict(self.named_parameters())
423
+ loaded_params: Set[str] = set()
424
+ for name, loaded_weight in weights:
425
+ for (param_name, shard_name, shard_id) in stacked_params_mapping:
426
+ if shard_name not in name:
427
+ continue
428
+ name = name.replace(shard_name, param_name)
429
+ # Skip loading extra bias for GPTQ models.
430
+ if name.endswith(".bias") and name not in params_dict:
431
+ continue
432
+ if is_pp_missing_parameter(name, self):
433
+ continue
434
+ param = params_dict[name]
435
+ weight_loader = param.weight_loader
436
+ weight_loader(param, loaded_weight, shard_id)
437
+ break
438
+ else:
439
+ # lm_head is not used in vllm as it is tied with embed_token.
440
+ # To prevent errors, skip loading lm_head.weight.
441
+ if "lm_head.weight" in name:
442
+ continue
443
+ # Skip loading extra bias for GPTQ models.
444
+ if name.endswith(".bias") and name not in params_dict:
445
+ continue
446
+ if is_pp_missing_parameter(name, self):
447
+ continue
448
+ param = params_dict[name]
449
+ weight_loader = getattr(param, "weight_loader",
450
+ default_weight_loader)
451
+ weight_loader(param, loaded_weight)
452
+ loaded_params.add(name)
453
+ unloaded_params = params_dict.keys() - loaded_params
454
+ if unloaded_params:
455
+ logger.warning(
456
+ "Some weights are not initialized from checkpoints: %s",
457
+ unloaded_params)
458
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/glm4_vision_encoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/THUDM/GLM-4
5
+ """Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
6
+ from argparse import Namespace
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import LayerNorm
12
+
13
+ from vllm.attention.layer import MultiHeadAttention
14
+ from vllm.distributed import get_tensor_model_parallel_world_size
15
+ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
16
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
17
+ MergedColumnParallelLinear,
18
+ QKVParallelLinear,
19
+ ReplicatedLinear,
20
+ RowParallelLinear)
21
+ from vllm.model_executor.layers.quantization.base_config import (
22
+ QuantizationConfig)
23
+
24
+
25
+ class PatchEmbedding(nn.Module):
26
+
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ self.proj = nn.Conv2d(config.in_channels,
30
+ config.hidden_size,
31
+ kernel_size=config.patch_size,
32
+ stride=config.patch_size)
33
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
34
+ self.position_embedding = nn.Embedding(config.num_positions,
35
+ config.hidden_size)
36
+
37
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ Parameters:
40
+ images : torch.Tensor
41
+ Input image tensor with shape (B, C, H, W)
42
+
43
+ Returns:
44
+ torch.Tensor
45
+ Transformed tensor with shape (B, L, D)
46
+ """
47
+ images = images.to(device=self.proj.weight.device,
48
+ dtype=self.proj.weight.dtype)
49
+ x = self.proj(images)
50
+ x = x.flatten(2).transpose(1, 2)
51
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
52
+ x = torch.cat((cls_token, x), dim=1)
53
+ x += self.position_embedding.weight.unsqueeze(0)
54
+ return x
55
+
56
+
57
+ class Attention(nn.Module):
58
+
59
+ def __init__(
60
+ self,
61
+ config,
62
+ quant_config: Optional[QuantizationConfig] = None,
63
+ prefix: str = '',
64
+ ):
65
+ super().__init__()
66
+ self.hidden_size = config.hidden_size
67
+ self.tp_size = get_tensor_model_parallel_world_size()
68
+ self.num_heads_per_rank = config.num_heads // self.tp_size
69
+ self.head_dim = config.hidden_size // config.num_heads
70
+ self.scale = self.head_dim**-0.5
71
+
72
+ self.query_key_value = QKVParallelLinear(
73
+ config.hidden_size,
74
+ self.head_dim,
75
+ config.num_heads,
76
+ quant_config=quant_config,
77
+ prefix=f"{prefix}.query_key_value",
78
+ )
79
+ self.dense = RowParallelLinear(
80
+ config.hidden_size,
81
+ config.hidden_size,
82
+ quant_config=quant_config,
83
+ prefix=f"{prefix}.dense",
84
+ )
85
+
86
+ self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
87
+ self.scale)
88
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
89
+
90
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
92
+ q, k, v = qkv.chunk(3, dim=-1)
93
+
94
+ out = self.attn(q, k, v)
95
+ output, _ = self.dense(out)
96
+ output = self.output_dropout(output)
97
+ return output
98
+
99
+
100
+ class MLP(nn.Module):
101
+
102
+ def __init__(
103
+ self,
104
+ config,
105
+ quant_config: Optional[QuantizationConfig] = None,
106
+ prefix: str = '',
107
+ ):
108
+ super().__init__()
109
+ self.config = config
110
+ self.activation_fn = get_act_fn(config.hidden_act)
111
+ self.fc1 = ColumnParallelLinear(
112
+ config.hidden_size,
113
+ config.intermediate_size,
114
+ quant_config=quant_config,
115
+ prefix=f"{prefix}.fc1",
116
+ )
117
+ self.fc2 = RowParallelLinear(
118
+ config.intermediate_size,
119
+ config.hidden_size,
120
+ quant_config=quant_config,
121
+ prefix=f"{prefix}.fc2",
122
+ )
123
+
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ x, _ = self.fc1(x)
126
+ x = self.activation_fn(x)
127
+ x, _ = self.fc2(x)
128
+ return x
129
+
130
+
131
+ class TransformerLayer(nn.Module):
132
+
133
+ def __init__(
134
+ self,
135
+ config,
136
+ quant_config: Optional[QuantizationConfig] = None,
137
+ prefix: str = '',
138
+ ):
139
+ super().__init__()
140
+ self.input_layernorm = LayerNorm(config.hidden_size,
141
+ eps=config.layer_norm_eps)
142
+ self.attention = Attention(config,
143
+ quant_config=quant_config,
144
+ prefix=f"{prefix}.attention")
145
+ self.mlp = MLP(config,
146
+ quant_config=quant_config,
147
+ prefix=f"{prefix}.mlp")
148
+ self.post_attention_layernorm = LayerNorm(config.hidden_size,
149
+ eps=config.layer_norm_eps)
150
+
151
+ def forward(self, hidden_states):
152
+ attention_input = hidden_states
153
+ attention_output = self.input_layernorm(
154
+ self.attention(attention_input))
155
+ hidden_states = attention_input + attention_output
156
+ mlp_input = hidden_states
157
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
158
+ output = mlp_input + mlp_output
159
+ return output
160
+
161
+
162
+ class Transformer(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ config,
167
+ quant_config: Optional[QuantizationConfig] = None,
168
+ prefix: str = '',
169
+ ):
170
+ super().__init__()
171
+ self.layers = nn.ModuleList([
172
+ TransformerLayer(config,
173
+ quant_config=quant_config,
174
+ prefix=f"{prefix}.layers.{layer_idx}")
175
+ for layer_idx in range(config.num_hidden_layers)
176
+ ])
177
+
178
+ def forward(self, hidden_states):
179
+ for layer_module in self.layers:
180
+ hidden_states = layer_module(hidden_states)
181
+ return hidden_states
182
+
183
+
184
+ class GLU(nn.Module):
185
+
186
+ def __init__(
187
+ self,
188
+ config,
189
+ in_features,
190
+ quant_config: Optional[QuantizationConfig] = None,
191
+ prefix: str = '',
192
+ ):
193
+ """
194
+ The original implementation is the same as:
195
+ ```python
196
+ self.dense_h_to_4h = ColumnParallelLinear(
197
+ config.hidden_size,
198
+ config.ffn_hidden_size,
199
+ bias=False,
200
+ quant_config=quant_config
201
+ )
202
+
203
+ self.gate_proj = ColumnParallelLinear(
204
+ config.hidden_size,
205
+ config.ffn_hidden_size,
206
+ bias=False,
207
+ quant_config=quant_config
208
+ )
209
+ ```
210
+ ```
211
+ gate_proj_output, _ = self.gate_proj(x)
212
+ dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
213
+ x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
214
+ ```
215
+
216
+ We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
217
+ ```
218
+ self.merged_proj = MergedColumnParallelLinear(
219
+ config.hidden_size,
220
+ [config.ffn_hidden_size] * 2,
221
+ bias=False,
222
+ quant_config=quant_config
223
+ )
224
+ ```
225
+ ```
226
+ x, _ = self.merged_proj(x)
227
+ ```
228
+ """
229
+ super().__init__()
230
+ self.linear_proj = ReplicatedLinear(in_features,
231
+ config.hidden_size,
232
+ bias=False,
233
+ quant_config=quant_config,
234
+ prefix=f"{prefix}.linear_proj")
235
+ self.norm1 = nn.LayerNorm(config.hidden_size)
236
+ self.act1 = nn.GELU()
237
+ self.act2 = SiluAndMul()
238
+
239
+ self.merged_proj = MergedColumnParallelLinear(
240
+ config.hidden_size, [config.ffn_hidden_size] * 2,
241
+ bias=False,
242
+ quant_config=quant_config,
243
+ prefix=f"{prefix}.merged_proj")
244
+
245
+ self.dense_4h_to_h = RowParallelLinear(
246
+ config.ffn_hidden_size,
247
+ config.hidden_size,
248
+ bias=False,
249
+ quant_config=quant_config,
250
+ prefix=f"{prefix}.dense_4h_to_h")
251
+
252
+ def forward(self, x):
253
+ x, _ = self.linear_proj(x)
254
+ x = self.act1(self.norm1(x))
255
+ x, _ = self.merged_proj(x)
256
+ x = self.act2(x)
257
+ x, _ = self.dense_4h_to_h(x)
258
+ return x
259
+
260
+
261
+ class EVA2CLIPModel(nn.Module):
262
+
263
+ def __init__(
264
+ self,
265
+ config,
266
+ quant_config: Optional[QuantizationConfig] = None,
267
+ prefix: str = '',
268
+ ):
269
+ super().__init__()
270
+ vision_config = Namespace(**config.vision_config)
271
+ self.patch_embedding = PatchEmbedding(vision_config)
272
+ self.transformer = Transformer(vision_config,
273
+ quant_config=quant_config,
274
+ prefix=f"{prefix}.transformer")
275
+ self.linear_proj = GLU(config,
276
+ in_features=config.hidden_size,
277
+ quant_config=quant_config,
278
+ prefix=f"{prefix}.linear_proj")
279
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
280
+ out_channels=config.hidden_size,
281
+ kernel_size=2,
282
+ stride=2)
283
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
284
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
285
+ self.scaling_factor = vision_config.scaling_factor
286
+
287
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
288
+ """
289
+ Parameters:
290
+ images : torch.Tensor
291
+ Input image tensor with shape (B, C, H, W)
292
+
293
+ Returns:
294
+ torch.Tensor
295
+ Transformed tensor with shape (B, L, D)
296
+ """
297
+ x = self.patch_embedding(images)
298
+ x = self.transformer(x)
299
+ x = x[:, 1:]
300
+
301
+ b, s, h = x.shape
302
+ grid_size = int(s**0.5)
303
+ x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
304
+ x = self.conv(x)
305
+
306
+ x = x.flatten(2).transpose(1, 2)
307
+ x = self.linear_proj(x)
308
+ boi = self.boi.expand(x.shape[0], -1, -1)
309
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
310
+ x = torch.cat((boi, x, eoi), dim=1)
311
+ x = x / self.scaling_factor
312
+ return x
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt2.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
7
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Inference-only GPT-2 model compatible with HuggingFace weights."""
21
+ from typing import Iterable, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from transformers import GPT2Config
26
+
27
+ from vllm.attention import Attention, AttentionMetadata
28
+ from vllm.compilation.decorators import support_torch_compile
29
+ from vllm.config import CacheConfig, VllmConfig
30
+ from vllm.distributed.parallel_state import (
31
+ get_pp_group, get_tensor_model_parallel_world_size)
32
+ from vllm.model_executor.layers.activation import get_act_fn
33
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
+ from vllm.model_executor.layers.quantization import QuantizationConfig
38
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead, VocabParallelEmbedding)
41
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
43
+ from vllm.sequence import IntermediateTensors
44
+
45
+ from .interfaces import SupportsPP
46
+ from .utils import (is_pp_missing_parameter,
47
+ make_empty_intermediate_tensors_factory, make_layers,
48
+ maybe_prefix)
49
+
50
+
51
+ class GPT2Attention(nn.Module):
52
+
53
+ def __init__(
54
+ self,
55
+ config: GPT2Config,
56
+ cache_config: Optional[CacheConfig] = None,
57
+ quant_config: Optional[QuantizationConfig] = None,
58
+ prefix: str = "",
59
+ ):
60
+ super().__init__()
61
+ self.hidden_size = config.hidden_size
62
+ total_num_heads = config.num_attention_heads
63
+ tensor_model_parallel_world_size = (
64
+ get_tensor_model_parallel_world_size())
65
+ assert total_num_heads % tensor_model_parallel_world_size == 0
66
+ self.num_heads = total_num_heads // tensor_model_parallel_world_size
67
+ self.head_dim = self.hidden_size // total_num_heads
68
+ self.scale = self.head_dim**-0.5
69
+
70
+ self.c_attn = QKVParallelLinear(
71
+ self.hidden_size,
72
+ self.head_dim,
73
+ total_num_heads,
74
+ bias=True,
75
+ quant_config=quant_config,
76
+ prefix=f"{prefix}.c_attn",
77
+ )
78
+ self.c_proj = RowParallelLinear(
79
+ self.hidden_size,
80
+ self.hidden_size,
81
+ bias=True,
82
+ quant_config=quant_config,
83
+ prefix=f"{prefix}.c_proj",
84
+ )
85
+ self.attn = Attention(self.num_heads,
86
+ self.head_dim,
87
+ scale=self.scale,
88
+ cache_config=cache_config,
89
+ quant_config=quant_config,
90
+ prefix=f"{prefix}.attn")
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ kv_cache: torch.Tensor,
96
+ attn_metadata: AttentionMetadata,
97
+ ) -> torch.Tensor:
98
+ qkv, _ = self.c_attn(hidden_states)
99
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
100
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
101
+ attn_output, _ = self.c_proj(attn_output)
102
+ return attn_output
103
+
104
+
105
+ class GPT2MLP(nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ intermediate_size: int,
110
+ config: GPT2Config,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ prefix: str = "",
113
+ ):
114
+ super().__init__()
115
+ hidden_size = config.hidden_size
116
+ self.c_fc = ColumnParallelLinear(
117
+ hidden_size,
118
+ intermediate_size,
119
+ bias=True,
120
+ quant_config=quant_config,
121
+ prefix=f"{prefix}.c_fc",
122
+ )
123
+ self.c_proj = RowParallelLinear(
124
+ intermediate_size,
125
+ hidden_size,
126
+ bias=True,
127
+ quant_config=quant_config,
128
+ prefix=f"{prefix}.c_proj",
129
+ )
130
+ self.act = get_act_fn(config.activation_function)
131
+
132
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
133
+ hidden_states, _ = self.c_fc(hidden_states)
134
+ hidden_states = self.act(hidden_states)
135
+ hidden_states, _ = self.c_proj(hidden_states)
136
+ return hidden_states
137
+
138
+
139
+ class GPT2Block(nn.Module):
140
+
141
+ def __init__(
142
+ self,
143
+ config: GPT2Config,
144
+ cache_config: Optional[CacheConfig] = None,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ prefix: str = "",
147
+ ):
148
+ super().__init__()
149
+ hidden_size = config.hidden_size
150
+ inner_dim = (config.n_inner if config.n_inner is not None else 4 *
151
+ hidden_size)
152
+
153
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
154
+ self.attn = GPT2Attention(config,
155
+ cache_config,
156
+ quant_config,
157
+ prefix=f"{prefix}.attn")
158
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
159
+ self.mlp = GPT2MLP(inner_dim,
160
+ config,
161
+ quant_config,
162
+ prefix=f"{prefix}.mlp")
163
+
164
+ def forward(
165
+ self,
166
+ hidden_states: torch.Tensor,
167
+ kv_cache: torch.Tensor,
168
+ attn_metadata: AttentionMetadata,
169
+ ) -> torch.Tensor:
170
+ residual = hidden_states
171
+ hidden_states = self.ln_1(hidden_states)
172
+ attn_output = self.attn(
173
+ hidden_states=hidden_states,
174
+ kv_cache=kv_cache,
175
+ attn_metadata=attn_metadata,
176
+ )
177
+ # residual connection
178
+ hidden_states = attn_output + residual
179
+
180
+ residual = hidden_states
181
+ hidden_states = self.ln_2(hidden_states)
182
+ feed_forward_hidden_states = self.mlp(hidden_states)
183
+ # residual connection
184
+ hidden_states = residual + feed_forward_hidden_states
185
+ return hidden_states
186
+
187
+
188
+ @support_torch_compile
189
+ class GPT2Model(nn.Module):
190
+
191
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
192
+ super().__init__()
193
+
194
+ config = vllm_config.model_config.hf_config
195
+ cache_config = vllm_config.cache_config
196
+ quant_config = vllm_config.quant_config
197
+
198
+ self.config = config
199
+ assert not config.add_cross_attention
200
+ assert not config.scale_attn_by_inverse_layer_idx
201
+ assert not config.reorder_and_upcast_attn
202
+ self.embed_dim = config.hidden_size
203
+ self.wte = VocabParallelEmbedding(config.vocab_size,
204
+ self.embed_dim,
205
+ quant_config=quant_config,
206
+ prefix=f"{prefix}.wte")
207
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
208
+ self.start_layer, self.end_layer, self.h = make_layers(
209
+ config.num_hidden_layers,
210
+ lambda prefix: GPT2Block(
211
+ config, cache_config, quant_config, prefix=prefix),
212
+ prefix=f"{prefix}.h")
213
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
214
+ self.make_empty_intermediate_tensors = (
215
+ make_empty_intermediate_tensors_factory(["hidden_states"],
216
+ config.n_embd))
217
+
218
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
219
+ return self.wte(input_ids)
220
+
221
+ def forward(
222
+ self,
223
+ input_ids: torch.Tensor,
224
+ position_ids: torch.Tensor,
225
+ kv_caches: List[torch.Tensor],
226
+ attn_metadata: AttentionMetadata,
227
+ intermediate_tensors: Optional[IntermediateTensors],
228
+ inputs_embeds: Optional[torch.Tensor],
229
+ ) -> Union[torch.Tensor, IntermediateTensors]:
230
+ if get_pp_group().is_first_rank:
231
+ if inputs_embeds is None:
232
+ inputs_embeds = self.get_input_embeddings(input_ids)
233
+ position_embeds = self.wpe(position_ids)
234
+ hidden_states = inputs_embeds + position_embeds
235
+ else:
236
+ assert intermediate_tensors is not None
237
+ hidden_states = intermediate_tensors["hidden_states"]
238
+
239
+ for i in range(self.start_layer, self.end_layer):
240
+ layer = self.h[i]
241
+ hidden_states = layer(hidden_states,
242
+ kv_caches[i - self.start_layer],
243
+ attn_metadata)
244
+
245
+ if not get_pp_group().is_last_rank:
246
+ return IntermediateTensors({"hidden_states": hidden_states})
247
+
248
+ hidden_states = self.ln_f(hidden_states)
249
+ return hidden_states
250
+
251
+
252
+ class GPT2LMHeadModel(nn.Module, SupportsPP):
253
+
254
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
255
+ super().__init__()
256
+ config = vllm_config.model_config.hf_config
257
+ quant_config = vllm_config.quant_config
258
+ self.config = config
259
+ self.quant_config = quant_config
260
+ self.transformer = GPT2Model(vllm_config=vllm_config,
261
+ prefix=maybe_prefix(
262
+ prefix, "transformer"))
263
+ self.lm_head = ParallelLMHead(self.config.vocab_size,
264
+ self.config.hidden_size,
265
+ quant_config=quant_config,
266
+ prefix=f"{prefix}.lm_head")
267
+ if self.config.tie_word_embeddings:
268
+ self.lm_head = self.lm_head.tie_weights(self.transformer.wte)
269
+
270
+ self.logits_processor = LogitsProcessor(config.vocab_size)
271
+ self.sampler = get_sampler()
272
+ self.make_empty_intermediate_tensors = (
273
+ self.transformer.make_empty_intermediate_tensors)
274
+
275
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
276
+ return self.transformer.get_input_embeddings(input_ids)
277
+
278
+ def forward(
279
+ self,
280
+ input_ids: torch.Tensor,
281
+ positions: torch.Tensor,
282
+ kv_caches: List[torch.Tensor],
283
+ attn_metadata: AttentionMetadata,
284
+ intermediate_tensors: Optional[IntermediateTensors] = None,
285
+ inputs_embeds: Optional[torch.Tensor] = None,
286
+ ) -> Union[torch.Tensor, IntermediateTensors]:
287
+ hidden_states = self.transformer(input_ids, positions, kv_caches,
288
+ attn_metadata, intermediate_tensors,
289
+ inputs_embeds)
290
+ return hidden_states
291
+
292
+ def compute_logits(
293
+ self,
294
+ hidden_states: torch.Tensor,
295
+ sampling_metadata: SamplingMetadata,
296
+ ) -> Optional[torch.Tensor]:
297
+ logits = self.logits_processor(self.lm_head, hidden_states,
298
+ sampling_metadata)
299
+ return logits
300
+
301
+ def sample(
302
+ self,
303
+ logits: torch.Tensor,
304
+ sampling_metadata: SamplingMetadata,
305
+ ) -> Optional[SamplerOutput]:
306
+ next_tokens = self.sampler(logits, sampling_metadata)
307
+ return next_tokens
308
+
309
+ def load_weights(self, weights: Iterable[Tuple[str,
310
+ torch.Tensor]]) -> Set[str]:
311
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
312
+ loaded_params: Set[str] = set()
313
+ for name, loaded_weight in weights:
314
+ if ".attn.bias" in name or ".attn.masked_bias" in name:
315
+ # Skip attention mask.
316
+ # NOTE: "c_attn.bias" should not be skipped.
317
+ continue
318
+ if not name.startswith("transformer.") and not name.startswith(
319
+ "lm_head"):
320
+ name = "transformer." + name
321
+
322
+ if is_pp_missing_parameter(name, self):
323
+ continue
324
+
325
+ param = params_dict[name]
326
+ # The HF's GPT-2 implementation uses Conv1D instead of Linear.
327
+ # Because of this, we need to transpose the weights.
328
+ # Note(zhuohan): the logic below might break quantized models.
329
+ for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
330
+ if conv1d_weight_name not in name:
331
+ continue
332
+ if not name.endswith(".weight"):
333
+ continue
334
+ loaded_weight = loaded_weight.t()
335
+ weight_loader = getattr(param, "weight_loader",
336
+ default_weight_loader)
337
+ weight_loader(param, loaded_weight)
338
+ loaded_params.add(name)
339
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_bigcode.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2023 CTranslate2, and Michael Feil
7
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
8
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """Inference-only GPTBigCode model compatible with HuggingFace weights."""
22
+ from typing import Iterable, List, Optional, Set, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import GPTBigCodeConfig
27
+
28
+ from vllm.attention import Attention, AttentionMetadata
29
+ from vllm.compilation.decorators import support_torch_compile
30
+ from vllm.config import CacheConfig, VllmConfig
31
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
+ from vllm.model_executor.layers.activation import get_act_fn
33
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
+ from vllm.model_executor.layers.quantization import QuantizationConfig
38
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead, VocabParallelEmbedding)
41
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
43
+ from vllm.sequence import IntermediateTensors
44
+
45
+ from .interfaces import SupportsLoRA, SupportsPP
46
+ from .utils import (is_pp_missing_parameter,
47
+ make_empty_intermediate_tensors_factory, make_layers)
48
+
49
+
50
+ class GPTBigCodeAttention(nn.Module):
51
+
52
+ def __init__(
53
+ self,
54
+ config: GPTBigCodeConfig,
55
+ cache_config: Optional[CacheConfig] = None,
56
+ quant_config: Optional[QuantizationConfig] = None,
57
+ prefix: str = "",
58
+ ):
59
+ super().__init__()
60
+ self.hidden_size = config.hidden_size
61
+ total_num_heads = config.num_attention_heads
62
+ self.tensor_model_parallel_world_size = (
63
+ get_tensor_model_parallel_world_size())
64
+ assert total_num_heads % self.tensor_model_parallel_world_size == 0
65
+ self.num_heads = (total_num_heads //
66
+ self.tensor_model_parallel_world_size)
67
+ self.head_dim = self.hidden_size // total_num_heads
68
+ self.scale = self.head_dim**-0.5
69
+
70
+ self.multi_query = config.multi_query
71
+ if self.multi_query:
72
+ total_num_kv_heads = 1
73
+ self.num_kv_heads = 1
74
+ else:
75
+ total_num_kv_heads = total_num_heads
76
+ self.num_kv_heads = self.num_heads
77
+ self.kv_dim = self.head_dim * self.num_kv_heads
78
+ self.c_attn = QKVParallelLinear(
79
+ self.hidden_size,
80
+ self.head_dim,
81
+ total_num_heads,
82
+ total_num_kv_heads,
83
+ bias=True,
84
+ quant_config=quant_config,
85
+ )
86
+
87
+ self.c_proj = RowParallelLinear(
88
+ self.hidden_size,
89
+ self.hidden_size,
90
+ bias=True,
91
+ quant_config=quant_config,
92
+ )
93
+ self.attn = Attention(self.num_heads,
94
+ self.head_dim,
95
+ scale=self.scale,
96
+ num_kv_heads=self.num_kv_heads,
97
+ cache_config=cache_config,
98
+ quant_config=quant_config,
99
+ prefix=f"{prefix}.attn")
100
+
101
+ def forward(
102
+ self,
103
+ hidden_states: torch.Tensor,
104
+ kv_cache: torch.Tensor,
105
+ attn_metadata: AttentionMetadata,
106
+ ) -> torch.Tensor:
107
+ qkv, _ = self.c_attn(hidden_states)
108
+ q, k, v = qkv.split(
109
+ [
110
+ self.hidden_size // self.tensor_model_parallel_world_size,
111
+ self.kv_dim, self.kv_dim
112
+ ],
113
+ dim=-1,
114
+ )
115
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
116
+ attn_output, _ = self.c_proj(attn_output)
117
+ return attn_output
118
+
119
+
120
+ class GPTBigMLP(nn.Module):
121
+
122
+ def __init__(
123
+ self,
124
+ intermediate_size: int,
125
+ config: GPTBigCodeConfig,
126
+ quant_config: Optional[QuantizationConfig] = None,
127
+ ):
128
+ super().__init__()
129
+ hidden_size = config.hidden_size
130
+ self.c_fc = ColumnParallelLinear(
131
+ hidden_size,
132
+ intermediate_size,
133
+ bias=True,
134
+ quant_config=quant_config,
135
+ )
136
+ self.c_proj = RowParallelLinear(
137
+ intermediate_size,
138
+ hidden_size,
139
+ bias=True,
140
+ quant_config=quant_config,
141
+ )
142
+ self.act = get_act_fn(config.activation_function)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+ hidden_states, _ = self.c_fc(hidden_states)
146
+ hidden_states = self.act(hidden_states)
147
+ hidden_states, _ = self.c_proj(hidden_states)
148
+ return hidden_states
149
+
150
+
151
+ class GPTBigCodeBlock(nn.Module):
152
+
153
+ def __init__(
154
+ self,
155
+ config: GPTBigCodeConfig,
156
+ cache_config: Optional[CacheConfig] = None,
157
+ quant_config: Optional[QuantizationConfig] = None,
158
+ prefix: str = "",
159
+ ):
160
+ super().__init__()
161
+ hidden_size = config.hidden_size
162
+ inner_dim = (config.n_inner if config.n_inner is not None else 4 *
163
+ hidden_size)
164
+
165
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
166
+ self.attn = GPTBigCodeAttention(config,
167
+ cache_config,
168
+ quant_config,
169
+ prefix=f"{prefix}.attn")
170
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
171
+ self.mlp = GPTBigMLP(inner_dim, config, quant_config)
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ kv_cache: torch.Tensor,
177
+ attn_metadata: AttentionMetadata,
178
+ ) -> torch.Tensor:
179
+ residual = hidden_states
180
+ hidden_states = self.ln_1(hidden_states)
181
+ attn_output = self.attn(
182
+ hidden_states=hidden_states,
183
+ kv_cache=kv_cache,
184
+ attn_metadata=attn_metadata,
185
+ )
186
+ # residual connection
187
+ hidden_states = attn_output + residual
188
+
189
+ residual = hidden_states
190
+ hidden_states = self.ln_2(hidden_states)
191
+ feed_forward_hidden_states = self.mlp(hidden_states)
192
+ # residual connection
193
+ hidden_states = residual + feed_forward_hidden_states
194
+ return hidden_states
195
+
196
+
197
+ @support_torch_compile
198
+ class GPTBigCodeModel(nn.Module):
199
+
200
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
201
+ super().__init__()
202
+
203
+ config = vllm_config.model_config.hf_config
204
+ cache_config = vllm_config.cache_config
205
+ quant_config = vllm_config.quant_config
206
+ lora_config = vllm_config.lora_config
207
+
208
+ self.config = config
209
+ assert not config.add_cross_attention
210
+
211
+ self.embed_dim = config.hidden_size
212
+ lora_vocab = (lora_config.lora_extra_vocab_size *
213
+ (lora_config.max_loras or 1)) if lora_config else 0
214
+ self.vocab_size = config.vocab_size + lora_vocab
215
+ self.wte = VocabParallelEmbedding(self.vocab_size,
216
+ self.embed_dim,
217
+ org_num_embeddings=config.vocab_size)
218
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
219
+ self.start_layer, self.end_layer, self.h = make_layers(
220
+ config.num_hidden_layers,
221
+ lambda prefix: GPTBigCodeBlock(
222
+ config, cache_config, quant_config, prefix=prefix),
223
+ prefix=f"{prefix}.h",
224
+ )
225
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
226
+ self.make_empty_intermediate_tensors = (
227
+ make_empty_intermediate_tensors_factory(["hidden_states"],
228
+ config.n_embd))
229
+
230
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
231
+ return self.wte(input_ids)
232
+
233
+ def forward(
234
+ self,
235
+ input_ids: torch.Tensor,
236
+ position_ids: torch.Tensor,
237
+ kv_caches: List[torch.Tensor],
238
+ attn_metadata: AttentionMetadata,
239
+ intermediate_tensors: Optional[IntermediateTensors],
240
+ inputs_embeds: Optional[torch.Tensor] = None,
241
+ ) -> Union[torch.Tensor, IntermediateTensors]:
242
+ if get_pp_group().is_first_rank:
243
+ if inputs_embeds is None:
244
+ inputs_embeds = self.get_input_embeddings(input_ids)
245
+ hidden_states = inputs_embeds + self.wpe(position_ids)
246
+ else:
247
+ hidden_states = intermediate_tensors["hidden_states"]
248
+
249
+ for i in range(self.start_layer, self.end_layer):
250
+ layer = self.h[i]
251
+ hidden_states = layer(hidden_states,
252
+ kv_caches[i - self.start_layer],
253
+ attn_metadata)
254
+
255
+ if not get_pp_group().is_last_rank:
256
+ return IntermediateTensors({"hidden_states": hidden_states})
257
+ hidden_states = self.ln_f(hidden_states)
258
+ return hidden_states
259
+
260
+
261
+ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
262
+ packed_modules_mapping = {"c_attn": ["c_attn"]}
263
+
264
+ supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
265
+
266
+ embedding_modules = {
267
+ "wte": "input_embeddings",
268
+ "lm_head": "output_embeddings",
269
+ }
270
+
271
+ embedding_padding_modules = []
272
+
273
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
274
+ super().__init__()
275
+ config = vllm_config.model_config.hf_config
276
+ quant_config = vllm_config.quant_config
277
+ lora_config = vllm_config.lora_config
278
+
279
+ self.config = config
280
+ self.lora_config = lora_config
281
+
282
+ self.quant_config = quant_config
283
+ self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
284
+ prefix=prefix)
285
+ if self.config.tie_word_embeddings:
286
+ self.lm_head = self.transformer.wte
287
+ else:
288
+ self.lm_head = ParallelLMHead(
289
+ self.transformer.vocab_size,
290
+ self.transformer.embed_dim,
291
+ org_num_embeddings=self.config.vocab_size)
292
+ self.unpadded_vocab_size = config.vocab_size
293
+ if lora_config:
294
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
295
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
296
+ config.vocab_size)
297
+ self.sampler = get_sampler()
298
+ self.make_empty_intermediate_tensors = (
299
+ self.transformer.make_empty_intermediate_tensors)
300
+
301
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
302
+ return self.transformer.get_input_embeddings(input_ids)
303
+
304
+ def forward(
305
+ self,
306
+ input_ids: torch.Tensor,
307
+ positions: torch.Tensor,
308
+ kv_caches: List[torch.Tensor],
309
+ attn_metadata: AttentionMetadata,
310
+ intermediate_tensors: Optional[IntermediateTensors] = None,
311
+ inputs_embeds: Optional[torch.Tensor] = None,
312
+ ) -> Union[torch.Tensor, IntermediateTensors]:
313
+ hidden_states = self.transformer(input_ids, positions, kv_caches,
314
+ attn_metadata, intermediate_tensors,
315
+ inputs_embeds)
316
+ return hidden_states
317
+
318
+ def compute_logits(
319
+ self,
320
+ hidden_states: torch.Tensor,
321
+ sampling_metadata: SamplingMetadata,
322
+ ) -> Optional[torch.Tensor]:
323
+ logits = self.logits_processor(self.lm_head, hidden_states,
324
+ sampling_metadata)
325
+ return logits
326
+
327
+ def sample(
328
+ self,
329
+ logits: torch.Tensor,
330
+ sampling_metadata: SamplingMetadata,
331
+ ) -> Optional[SamplerOutput]:
332
+ next_tokens = self.sampler(logits, sampling_metadata)
333
+ return next_tokens
334
+
335
+ def load_weights(self, weights: Iterable[Tuple[str,
336
+ torch.Tensor]]) -> Set[str]:
337
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
338
+ loaded_params: Set[str] = set()
339
+ for name, loaded_weight in weights:
340
+ if "lm_head.weight" in name:
341
+ continue
342
+ if ".attn.bias" in name:
343
+ # Skip attention mask.
344
+ # NOTE: "c_attn.bias" should not be skipped.
345
+ continue
346
+ if is_pp_missing_parameter(name, self):
347
+ continue
348
+ param = params_dict[name]
349
+ weight_loader = getattr(param, "weight_loader",
350
+ default_weight_loader)
351
+ # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
352
+ if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
353
+ weight_loader(param, loaded_weight, 'q')
354
+ weight_loader(param, loaded_weight, 'k')
355
+ weight_loader(param, loaded_weight, 'v')
356
+ else:
357
+ weight_loader(param, loaded_weight)
358
+ loaded_params.add(name)
359
+ return loaded_params
.venv/lib/python3.11/site-packages/vllm/model_executor/models/granitemoe.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only GraniteMoe model."""
25
+ from typing import Iterable, List, Optional, Set, Tuple
26
+
27
+ import torch
28
+ from torch import nn
29
+ from transformers.models.granitemoe import GraniteMoeConfig
30
+
31
+ from vllm.attention import Attention, AttentionMetadata
32
+ from vllm.compilation.decorators import support_torch_compile
33
+ from vllm.config import CacheConfig, VllmConfig
34
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
+ from vllm.model_executor.layers.fused_moe import FusedMoE
36
+ from vllm.model_executor.layers.layernorm import RMSNorm
37
+ from vllm.model_executor.layers.linear import (QKVParallelLinear,
38
+ ReplicatedLinear,
39
+ RowParallelLinear)
40
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
+ from vllm.model_executor.layers.quantization.base_config import (
42
+ QuantizationConfig)
43
+ from vllm.model_executor.layers.rotary_embedding import get_rope
44
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
45
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
46
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
47
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
48
+ from vllm.sequence import IntermediateTensors
49
+
50
+ from . import mixtral
51
+ from .interfaces import SupportsLoRA, SupportsPP
52
+ from .utils import make_layers, maybe_prefix
53
+
54
+
55
+ class GraniteMoeMoE(nn.Module):
56
+ """A tensor-parallel MoE implementation for GraniteMoe that shards each
57
+ expert across all ranks.
58
+ Each expert's weights are sharded across all ranks and a fused MoE
59
+ kernel is used for the forward pass, and finally we reduce the outputs
60
+ across ranks.
61
+ """
62
+
63
+ def __init__(self,
64
+ num_experts: int,
65
+ top_k: int,
66
+ hidden_size: int,
67
+ intermediate_size: int,
68
+ params_dtype: Optional[torch.dtype] = None,
69
+ quant_config: Optional[QuantizationConfig] = None,
70
+ tp_size: Optional[int] = None,
71
+ prefix: str = ""):
72
+ super().__init__()
73
+ self.hidden_size = hidden_size
74
+
75
+ # Gate always runs at half / full precision for now.
76
+ self.gate = ReplicatedLinear(hidden_size,
77
+ num_experts,
78
+ bias=False,
79
+ params_dtype=params_dtype,
80
+ quant_config=None,
81
+ prefix=f"{prefix}.gate")
82
+
83
+ self.experts = FusedMoE(num_experts=num_experts,
84
+ top_k=top_k,
85
+ hidden_size=hidden_size,
86
+ intermediate_size=intermediate_size,
87
+ params_dtype=params_dtype,
88
+ reduce_results=True,
89
+ renormalize=True,
90
+ quant_config=quant_config,
91
+ tp_size=tp_size,
92
+ prefix=f"{prefix}.experts")
93
+
94
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
95
+ # NOTE: hidden_states can have either 1D or 2D shape.
96
+ orig_shape = hidden_states.shape
97
+ hidden_states = hidden_states.view(-1, self.hidden_size)
98
+ # router_logits: (num_tokens, n_experts)
99
+ router_logits, _ = self.gate(hidden_states)
100
+ final_hidden_states = self.experts(hidden_states, router_logits)
101
+ return final_hidden_states.view(orig_shape)
102
+
103
+
104
+ class GraniteMoeAttention(nn.Module):
105
+
106
+ def __init__(
107
+ self,
108
+ hidden_size: int,
109
+ num_heads: int,
110
+ num_kv_heads: int,
111
+ max_position: int = 4096 * 32,
112
+ rope_theta: float = 10000,
113
+ cache_config: Optional[CacheConfig] = None,
114
+ quant_config: Optional[QuantizationConfig] = None,
115
+ attention_multiplier: Optional[float] = None,
116
+ prefix: str = "",
117
+ ) -> None:
118
+ super().__init__()
119
+ self.hidden_size = hidden_size
120
+ tp_size = get_tensor_model_parallel_world_size()
121
+ self.total_num_heads = num_heads
122
+ assert self.total_num_heads % tp_size == 0
123
+ self.num_heads = self.total_num_heads // tp_size
124
+ self.total_num_kv_heads = num_kv_heads
125
+ if self.total_num_kv_heads >= tp_size:
126
+ # Number of KV heads is greater than TP size, so we partition
127
+ # the KV heads across multiple tensor parallel GPUs.
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ # Number of KV heads is less than TP size, so we replicate
131
+ # the KV heads across multiple tensor parallel GPUs.
132
+ assert tp_size % self.total_num_kv_heads == 0
133
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134
+ self.head_dim = hidden_size // self.total_num_heads
135
+ self.q_size = self.num_heads * self.head_dim
136
+ self.kv_size = self.num_kv_heads * self.head_dim
137
+ self.scaling = (attention_multiplier if attention_multiplier
138
+ is not None else self.head_dim**-1)
139
+ self.rope_theta = rope_theta
140
+
141
+ self.qkv_proj = QKVParallelLinear(
142
+ hidden_size,
143
+ self.head_dim,
144
+ self.total_num_heads,
145
+ self.total_num_kv_heads,
146
+ bias=False,
147
+ quant_config=quant_config,
148
+ prefix=f"{prefix}.qkv_proj",
149
+ )
150
+ self.o_proj = RowParallelLinear(
151
+ self.total_num_heads * self.head_dim,
152
+ hidden_size,
153
+ bias=False,
154
+ quant_config=quant_config,
155
+ prefix=f"{prefix}.o_proj",
156
+ )
157
+ self.rotary_emb = get_rope(
158
+ self.head_dim,
159
+ rotary_dim=self.head_dim,
160
+ max_position=max_position,
161
+ base=int(self.rope_theta),
162
+ is_neox_style=True,
163
+ )
164
+ self.attn = Attention(self.num_heads,
165
+ self.head_dim,
166
+ self.scaling,
167
+ num_kv_heads=self.num_kv_heads,
168
+ cache_config=cache_config,
169
+ quant_config=quant_config,
170
+ prefix=f"{prefix}.attn")
171
+
172
+ def forward(
173
+ self,
174
+ positions: torch.Tensor,
175
+ hidden_states: torch.Tensor,
176
+ kv_cache: torch.Tensor,
177
+ attn_metadata: AttentionMetadata,
178
+ ) -> torch.Tensor:
179
+ qkv, _ = self.qkv_proj(hidden_states)
180
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
181
+ q, k = self.rotary_emb(positions, q, k)
182
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
183
+ output, _ = self.o_proj(attn_output)
184
+ return output
185
+
186
+
187
+ class GraniteMoeDecoderLayer(nn.Module):
188
+
189
+ def __init__(
190
+ self,
191
+ config: GraniteMoeConfig,
192
+ cache_config: Optional[CacheConfig] = None,
193
+ quant_config: Optional[QuantizationConfig] = None,
194
+ prefix: str = "",
195
+ ) -> None:
196
+ super().__init__()
197
+ self.hidden_size = config.hidden_size
198
+ # Requires transformers > 4.32.0
199
+ rope_theta = getattr(config, "rope_theta", 10000)
200
+ self.self_attn = GraniteMoeAttention(
201
+ hidden_size=self.hidden_size,
202
+ num_heads=config.num_attention_heads,
203
+ max_position=config.max_position_embeddings,
204
+ num_kv_heads=config.num_key_value_heads,
205
+ rope_theta=rope_theta,
206
+ cache_config=cache_config,
207
+ quant_config=quant_config,
208
+ prefix=f"{prefix}.self_attn",
209
+ attention_multiplier=config.attention_multiplier)
210
+ self.block_sparse_moe = GraniteMoeMoE(
211
+ num_experts=config.num_local_experts,
212
+ top_k=config.num_experts_per_tok,
213
+ hidden_size=config.hidden_size,
214
+ intermediate_size=config.intermediate_size,
215
+ quant_config=quant_config,
216
+ prefix=f"{prefix}.block_sparse_moe")
217
+
218
+ self.input_layernorm = RMSNorm(config.hidden_size,
219
+ eps=config.rms_norm_eps)
220
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
221
+ eps=config.rms_norm_eps)
222
+
223
+ self.residual_multiplier = config.residual_multiplier
224
+
225
+ def forward(
226
+ self,
227
+ positions: torch.Tensor,
228
+ hidden_states: torch.Tensor,
229
+ kv_cache: torch.Tensor,
230
+ attn_metadata: AttentionMetadata,
231
+ ) -> torch.Tensor:
232
+ # Self Attention
233
+ residual = hidden_states
234
+ hidden_states = self.input_layernorm(hidden_states)
235
+ hidden_states = self.self_attn(
236
+ positions=positions,
237
+ hidden_states=hidden_states,
238
+ kv_cache=kv_cache,
239
+ attn_metadata=attn_metadata,
240
+ )
241
+ hidden_states = residual + hidden_states * self.residual_multiplier
242
+ residual = hidden_states
243
+ hidden_states = self.post_attention_layernorm(hidden_states)
244
+ hidden_states = self.block_sparse_moe(hidden_states)
245
+ hidden_states = residual + hidden_states * self.residual_multiplier
246
+
247
+ return hidden_states
248
+
249
+
250
+ @support_torch_compile
251
+ class GraniteMoeModel(nn.Module):
252
+
253
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
254
+ super().__init__()
255
+
256
+ config = vllm_config.model_config.hf_config
257
+ cache_config = vllm_config.cache_config
258
+ quant_config = vllm_config.quant_config
259
+ lora_config = vllm_config.lora_config
260
+
261
+ self.padding_idx = config.pad_token_id
262
+ lora_vocab = (lora_config.lora_extra_vocab_size *
263
+ (lora_config.max_loras or 1)) if lora_config else 0
264
+ self.vocab_size = config.vocab_size + lora_vocab
265
+ self.org_vocab_size = config.vocab_size
266
+
267
+ self.embed_tokens = VocabParallelEmbedding(
268
+ self.vocab_size,
269
+ config.hidden_size,
270
+ org_num_embeddings=config.vocab_size,
271
+ )
272
+ self.embedding_multiplier = config.embedding_multiplier
273
+
274
+ self.start_layer, self.end_layer, self.layers = make_layers(
275
+ config.num_hidden_layers,
276
+ lambda prefix: GraniteMoeDecoderLayer(
277
+ config, cache_config, quant_config=quant_config, prefix=prefix
278
+ ),
279
+ prefix=f"{prefix}.layers")
280
+
281
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
+
283
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
284
+ return self.embed_tokens(input_ids)
285
+
286
+ def forward(
287
+ self,
288
+ input_ids: torch.Tensor,
289
+ positions: torch.Tensor,
290
+ kv_caches: List[torch.Tensor],
291
+ attn_metadata: AttentionMetadata,
292
+ intermediate_tensors: Optional[IntermediateTensors],
293
+ inputs_embeds: Optional[torch.Tensor] = None,
294
+ ) -> torch.Tensor:
295
+ if get_pp_group().is_first_rank:
296
+ if inputs_embeds is not None:
297
+ hidden_states = inputs_embeds
298
+ else:
299
+ hidden_states = self.get_input_embeddings(input_ids)
300
+ hidden_states *= self.embedding_multiplier
301
+ residual = None
302
+ else:
303
+ assert intermediate_tensors is not None
304
+ hidden_states = intermediate_tensors["hidden_states"]
305
+ residual = intermediate_tensors["residual"]
306
+ for i in range(self.start_layer, self.end_layer):
307
+ layer = self.layers[i]
308
+ hidden_states = layer(positions, hidden_states,
309
+ kv_caches[i - self.start_layer],
310
+ attn_metadata)
311
+ if not get_pp_group().is_last_rank:
312
+ return IntermediateTensors({
313
+ "hidden_states": hidden_states,
314
+ "residual": residual
315
+ })
316
+ hidden_states = self.norm(hidden_states)
317
+ return hidden_states
318
+
319
+
320
+ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
321
+ fall_back_to_pt_during_load = False
322
+
323
+ packed_modules_mapping = {
324
+ "qkv_proj": [
325
+ "q_proj",
326
+ "k_proj",
327
+ "v_proj",
328
+ ],
329
+ }
330
+
331
+ # LoRA specific attributes
332
+ supported_lora_modules = [
333
+ "qkv_proj",
334
+ "o_proj",
335
+ "embed_tokens",
336
+ "lm_head",
337
+ "layer",
338
+ ]
339
+ embedding_modules = {
340
+ "embed_tokens": "input_embeddings",
341
+ "lm_head": "output_embeddings",
342
+ }
343
+ embedding_padding_modules = ["lm_head"]
344
+
345
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
346
+ super().__init__()
347
+ config = vllm_config.model_config.hf_config
348
+ quant_config = vllm_config.quant_config
349
+ lora_config = vllm_config.lora_config
350
+
351
+ self.config = config
352
+ self.lora_config = lora_config
353
+ self.quant_config = quant_config # Required by MixtralForCausalLM
354
+
355
+ self.model = GraniteMoeModel(vllm_config=vllm_config,
356
+ prefix=maybe_prefix(prefix, "model"))
357
+ self.unpadded_vocab_size = config.vocab_size
358
+ if lora_config:
359
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
360
+ self.lm_head = ParallelLMHead(
361
+ self.unpadded_vocab_size,
362
+ config.hidden_size,
363
+ org_num_embeddings=config.vocab_size,
364
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
365
+ # We need bigger padding if using lora for kernel
366
+ # compatibility
367
+ if not lora_config else lora_config.lora_vocab_padding_size,
368
+ quant_config=quant_config,
369
+ )
370
+ if config.tie_word_embeddings:
371
+ self.lm_head.weight = self.model.embed_tokens.weight
372
+
373
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
374
+ config.vocab_size,
375
+ scale=1 /
376
+ self.config.logits_scaling)
377
+
378
+ self.sampler = get_sampler()
379
+
380
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
381
+ return self.model.get_input_embeddings(input_ids)
382
+
383
+ def forward(
384
+ self,
385
+ input_ids: torch.Tensor,
386
+ positions: torch.Tensor,
387
+ kv_caches: List[torch.Tensor],
388
+ attn_metadata: AttentionMetadata,
389
+ intermediate_tensors: Optional[IntermediateTensors] = None,
390
+ inputs_embeds: Optional[torch.Tensor] = None,
391
+ ) -> torch.Tensor:
392
+ hidden_states = self.model(input_ids, positions, kv_caches,
393
+ attn_metadata, intermediate_tensors,
394
+ inputs_embeds)
395
+ return hidden_states
396
+
397
+ def compute_logits(
398
+ self, hidden_states: torch.Tensor,
399
+ sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
400
+ logits = self.logits_processor(self.lm_head, hidden_states,
401
+ sampling_metadata)
402
+ return logits
403
+
404
+ def make_empty_intermediate_tensors(
405
+ self, batch_size: int, dtype: torch.dtype,
406
+ device: torch.device) -> IntermediateTensors:
407
+ return IntermediateTensors({
408
+ "hidden_states":
409
+ torch.zeros((batch_size, self.config.hidden_size),
410
+ dtype=dtype,
411
+ device=device),
412
+ "residual":
413
+ torch.zeros((batch_size, self.config.hidden_size),
414
+ dtype=dtype,
415
+ device=device),
416
+ })
417
+
418
+ def sample(
419
+ self,
420
+ logits: Optional[torch.Tensor],
421
+ sampling_metadata: SamplingMetadata,
422
+ ) -> Optional[SamplerOutput]:
423
+ next_tokens = self.sampler(logits, sampling_metadata)
424
+ return next_tokens
425
+
426
+ def load_weights(self, weights: Iterable[Tuple[str,
427
+ torch.Tensor]]) -> Set[str]:
428
+ new_weights = {}
429
+ for n, p in weights:
430
+ if n.endswith('.block_sparse_moe.input_linear.weight'):
431
+ for e in range(p.size(0)):
432
+ w1_name = n.replace(
433
+ '.block_sparse_moe.input_linear.weight',
434
+ f".block_sparse_moe.experts.{e}.w1.weight")
435
+ w3_name = n.replace(
436
+ '.block_sparse_moe.input_linear.weight',
437
+ f".block_sparse_moe.experts.{e}.w3.weight")
438
+ w1_param, w3_param = p[e].chunk(2, dim=0)
439
+ assert w1_name not in new_weights
440
+ assert w3_name not in new_weights
441
+ new_weights[w1_name] = w1_param
442
+ new_weights[w3_name] = w3_param
443
+ elif n.endswith('.block_sparse_moe.output_linear.weight'):
444
+ for e in range(p.size(0)):
445
+ w2_name = n.replace(
446
+ '.block_sparse_moe.output_linear.weight',
447
+ f".block_sparse_moe.experts.{e}.w2.weight")
448
+ w2_param = p[e]
449
+ assert w2_name not in new_weights
450
+ new_weights[w2_name] = w2_param
451
+ elif n.endswith('.block_sparse_moe.router.layer.weight'):
452
+ gate_name = n.replace('.block_sparse_moe.router.layer.weight',
453
+ ".block_sparse_moe.gate.weight")
454
+ assert gate_name not in new_weights
455
+ new_weights[gate_name] = p
456
+ elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
457
+ pass
458
+ else:
459
+ new_weights[n] = p
460
+ return mixtral.MixtralForCausalLM.load_weights(self,
461
+ new_weights.items())
.venv/lib/python3.11/site-packages/vllm/model_executor/models/h2ovl.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
4
+ # https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
5
+ # --------------------------------------------------------
6
+ # H2OVL-Mississippi
7
+ # Copyright (c) 2024 H2O.AI
8
+ # Licensed under Apache 2.0 License [see LICENSE for details]
9
+ # --------------------------------------------------------
10
+ from typing import Mapping, Optional
11
+
12
+ import torch
13
+ from PIL import Image
14
+ from transformers import PretrainedConfig
15
+
16
+ from vllm.logger import init_logger
17
+ from vllm.model_executor.layers.quantization import QuantizationConfig
18
+ from vllm.multimodal import MULTIMODAL_REGISTRY
19
+ from vllm.multimodal.inputs import MultiModalKwargs
20
+ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
21
+ MultiModalDataItems)
22
+ from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
23
+ PromptReplacementDetails)
24
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder
25
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
26
+
27
+ from .intern_vit import InternVisionModel
28
+ from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
29
+ BaseInternVLProcessingInfo, BaseInternVLProcessor,
30
+ InternVLChatModel, InternVLDummyInputsBuilder,
31
+ InternVLMultiModalProcessor, build_transform,
32
+ find_closest_aspect_ratio, get_internvl_target_ratios)
33
+
34
+ logger = init_logger(__name__)
35
+
36
+
37
+ def resolve_h2ovl_min_max_num(
38
+ *,
39
+ min_dynamic_patch: int,
40
+ max_dynamic_patch: int,
41
+ dynamic_image_size: bool,
42
+ use_thumbnail: bool,
43
+ ) -> tuple[int, int]:
44
+ max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
45
+
46
+ if use_thumbnail and max_dynamic_patch != 1:
47
+ max_dynamic_patch += 1
48
+
49
+ return min_dynamic_patch, max_dynamic_patch
50
+
51
+
52
+ def get_h2ovl_target_ratios(
53
+ min_num: int,
54
+ max_num: int,
55
+ *,
56
+ prior_aspect_ratio: Optional[tuple[int, int]],
57
+ ) -> list[tuple[int, int]]:
58
+ target_ratios = get_internvl_target_ratios(min_num, max_num)
59
+
60
+ # if prior_aspect_ratio is provided, filter the target ratios
61
+ if prior_aspect_ratio is not None:
62
+ target_ratios = [
63
+ ratio for ratio in target_ratios if prior_aspect_ratio[0] %
64
+ ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
65
+ ]
66
+
67
+ return target_ratios
68
+
69
+
70
+ # modified to include blocks generated in second pass
71
+ def calculate_h2ovl_targets(
72
+ *,
73
+ orig_width: int,
74
+ orig_height: int,
75
+ target_ratios: list[tuple[int, int]],
76
+ image_size: int,
77
+ use_thumbnail: bool,
78
+ ) -> tuple[int, int, int, tuple[int, int]]:
79
+ aspect_ratio = orig_width / orig_height
80
+
81
+ # find the closest aspect ratio to the target
82
+ target_aspect_ratio = find_closest_aspect_ratio(
83
+ aspect_ratio,
84
+ target_ratios,
85
+ width=orig_width,
86
+ height=orig_height,
87
+ image_size=image_size,
88
+ )
89
+
90
+ # calculate the target width and height
91
+ target_width = image_size * target_aspect_ratio[0]
92
+ target_height = image_size * target_aspect_ratio[1]
93
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
94
+
95
+ # add thumbnail image if num_blocks != 1
96
+ if use_thumbnail and blocks != 1:
97
+ blocks += 1
98
+
99
+ return blocks, target_width, target_height, target_aspect_ratio
100
+
101
+
102
+ # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
103
+ # refactored to handle prior_aspect_ratio
104
+ def dynamic_preprocess_h2ovl(
105
+ image: Image.Image,
106
+ *,
107
+ target_ratios: list[tuple[int, int]],
108
+ image_size: int,
109
+ use_thumbnail: bool,
110
+ ) -> tuple[list[Image.Image], tuple[int, int]]:
111
+ orig_width, orig_height = image.size
112
+
113
+ # calculate the number of blocks without thumbnail
114
+ (
115
+ blocks,
116
+ target_width,
117
+ target_height,
118
+ target_aspect_ratio,
119
+ ) = calculate_h2ovl_targets(
120
+ orig_width=orig_width,
121
+ orig_height=orig_height,
122
+ target_ratios=target_ratios,
123
+ image_size=image_size,
124
+ use_thumbnail=False,
125
+ )
126
+
127
+ # resize the image
128
+ resized_img = image.resize((target_width, target_height))
129
+ processed_images = []
130
+ for i in range(blocks):
131
+ box = (
132
+ (i % (target_width // image_size)) * image_size,
133
+ (i // (target_width // image_size)) * image_size,
134
+ ((i % (target_width // image_size)) + 1) * image_size,
135
+ ((i // (target_width // image_size)) + 1) * image_size,
136
+ )
137
+ # split the image
138
+ split_img = resized_img.crop(box)
139
+ processed_images.append(split_img)
140
+
141
+ assert len(processed_images) == blocks
142
+
143
+ if use_thumbnail and len(processed_images) != 1:
144
+ thumbnail_img = image.resize((image_size, image_size))
145
+ processed_images.append(thumbnail_img)
146
+
147
+ return processed_images, target_aspect_ratio
148
+
149
+
150
+ def _preprocess_image(
151
+ image: Image.Image,
152
+ *,
153
+ input_size: int,
154
+ min_num: int,
155
+ max_num: int,
156
+ use_thumbnail: bool,
157
+ prior_aspect_ratio: Optional[tuple[int, int]],
158
+ ) -> tuple[torch.Tensor, tuple[int, int]]:
159
+ target_ratios = get_h2ovl_target_ratios(
160
+ min_num,
161
+ max_num,
162
+ prior_aspect_ratio=prior_aspect_ratio,
163
+ )
164
+
165
+ transform = build_transform(input_size=input_size)
166
+ images, target_aspect_ratio = dynamic_preprocess_h2ovl(
167
+ image,
168
+ image_size=input_size,
169
+ use_thumbnail=use_thumbnail,
170
+ target_ratios=target_ratios,
171
+ )
172
+
173
+ pixel_values = torch.stack([transform(image) for image in images])
174
+ return pixel_values, target_aspect_ratio
175
+
176
+
177
+ # refactored to use the _preprocess_image function
178
+ def image_to_pixel_values_h2ovl(
179
+ image: Image.Image,
180
+ *,
181
+ input_size: int,
182
+ min_num: int,
183
+ max_num: int,
184
+ use_thumbnail: bool,
185
+ use_msac: bool,
186
+ ) -> torch.Tensor:
187
+ # when MSAC is turned on, we need to process the image twice
188
+ if use_msac:
189
+ # first pass
190
+ pixel_values1, aspect_ratio1 = _preprocess_image(
191
+ image,
192
+ input_size=input_size,
193
+ min_num=min_num,
194
+ max_num=max_num,
195
+ use_thumbnail=True,
196
+ prior_aspect_ratio=None,
197
+ )
198
+ # second pass
199
+ pixel_values2, _ = _preprocess_image(
200
+ image,
201
+ input_size=input_size,
202
+ min_num=3, # Hardcoded value
203
+ max_num=max_num,
204
+ use_thumbnail=True,
205
+ prior_aspect_ratio=aspect_ratio1,
206
+ )
207
+ # combine pixel values
208
+ pixel_values = torch.cat(
209
+ [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
210
+
211
+ else:
212
+ pixel_values, _ = _preprocess_image(
213
+ image,
214
+ input_size=input_size,
215
+ min_num=min_num,
216
+ max_num=max_num,
217
+ use_thumbnail=use_thumbnail,
218
+ prior_aspect_ratio=None,
219
+ )
220
+
221
+ return pixel_values
222
+
223
+
224
+ class H2OVLProcessor(BaseInternVLProcessor):
225
+
226
+ def __init__(
227
+ self,
228
+ config: PretrainedConfig,
229
+ tokenizer: AnyTokenizer,
230
+ *,
231
+ max_dynamic_patch: Optional[int] = None,
232
+ dynamic_image_size: Optional[bool] = None,
233
+ use_msac: Optional[bool] = None,
234
+ ) -> None:
235
+ super().__init__(
236
+ config,
237
+ tokenizer,
238
+ max_dynamic_patch=max_dynamic_patch,
239
+ dynamic_image_size=dynamic_image_size,
240
+ )
241
+
242
+ if use_msac is None:
243
+ use_msac = config.use_msac
244
+ assert isinstance(use_msac, bool)
245
+
246
+ self.use_msac = use_msac
247
+
248
+ @property
249
+ def image_token_id(self) -> int:
250
+ return self.tokenizer.get_vocab()[IMG_CONTEXT]
251
+
252
+ def get_image_repl_features(
253
+ self,
254
+ feature_size: int,
255
+ num_patches: Optional[int],
256
+ ) -> str:
257
+ return IMG_CONTEXT * feature_size
258
+
259
+ def get_image_repl_full(
260
+ self,
261
+ feature_size: int,
262
+ num_patches: Optional[int],
263
+ ) -> str:
264
+ features = self.get_image_repl_features(feature_size, num_patches)
265
+ return IMG_START + features + IMG_END
266
+
267
+ def resolve_min_max_num(
268
+ self,
269
+ *,
270
+ max_dynamic_patch: Optional[int] = None,
271
+ dynamic_image_size: Optional[bool] = None,
272
+ use_thumbnail: Optional[bool] = None,
273
+ ) -> tuple[int, int]:
274
+ min_dynamic_patch = self.min_dynamic_patch
275
+ max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
276
+ is None else max_dynamic_patch)
277
+ dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
278
+ is None else dynamic_image_size)
279
+ use_thumbnail = (self.use_thumbnail
280
+ if use_thumbnail is None else use_thumbnail)
281
+
282
+ return resolve_h2ovl_min_max_num(
283
+ min_dynamic_patch=min_dynamic_patch,
284
+ max_dynamic_patch=max_dynamic_patch,
285
+ dynamic_image_size=dynamic_image_size,
286
+ use_thumbnail=use_thumbnail,
287
+ )
288
+
289
+ def resolve_target_ratios(
290
+ self,
291
+ *,
292
+ max_dynamic_patch: Optional[int] = None,
293
+ dynamic_image_size: Optional[bool] = None,
294
+ use_thumbnail: Optional[bool] = None,
295
+ prior_aspect_ratio: Optional[tuple[int, int]] = None,
296
+ ) -> list[tuple[int, int]]:
297
+ min_num, max_num = self.resolve_min_max_num(
298
+ max_dynamic_patch=max_dynamic_patch,
299
+ dynamic_image_size=dynamic_image_size,
300
+ use_thumbnail=use_thumbnail,
301
+ )
302
+ if prior_aspect_ratio: # hardcoded value for second pass of use_msac
303
+ min_num = 3
304
+
305
+ return get_h2ovl_target_ratios(
306
+ min_num,
307
+ max_num,
308
+ prior_aspect_ratio=prior_aspect_ratio,
309
+ )
310
+
311
+ def get_num_image_tokens(
312
+ self,
313
+ *,
314
+ image_width: int,
315
+ image_height: int,
316
+ use_msac: Optional[bool] = None,
317
+ ) -> int:
318
+ use_msac = (self.use_msac if use_msac is None else use_msac)
319
+
320
+ use_thumbnail = self.use_thumbnail
321
+
322
+ if use_msac:
323
+ target_ratios_1 = self.resolve_target_ratios(
324
+ use_thumbnail=False, # Applied in calculate_targets
325
+ )
326
+ num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
327
+ orig_width=image_width,
328
+ orig_height=image_height,
329
+ image_size=self.image_size,
330
+ target_ratios=target_ratios_1,
331
+ use_thumbnail=True,
332
+ )
333
+
334
+ target_ratios_2 = self.resolve_target_ratios(
335
+ use_thumbnail=False, # Applied in calculate_targets
336
+ prior_aspect_ratio=aspect_ratio_1,
337
+ )
338
+ num_patches_2, _, _, _ = calculate_h2ovl_targets(
339
+ orig_width=image_width,
340
+ orig_height=image_height,
341
+ image_size=self.image_size,
342
+ target_ratios=target_ratios_2,
343
+ use_thumbnail=True,
344
+ )
345
+
346
+ num_patches = num_patches_1 + num_patches_2 - 1
347
+ else:
348
+ target_ratios = self.resolve_target_ratios(
349
+ use_thumbnail=False, # Applied in calculate_targets
350
+ )
351
+ num_patches, _, _, _ = calculate_h2ovl_targets(
352
+ orig_width=image_width,
353
+ orig_height=image_height,
354
+ image_size=self.image_size,
355
+ target_ratios=target_ratios,
356
+ use_thumbnail=use_thumbnail,
357
+ )
358
+
359
+ return num_patches * self.num_image_token
360
+
361
+ def _images_to_pixel_values_lst(
362
+ self,
363
+ images: list[Image.Image],
364
+ max_dynamic_patch: Optional[int] = None,
365
+ dynamic_image_size: Optional[bool] = None,
366
+ ) -> list[torch.Tensor]:
367
+ use_msac = self.use_msac if len(images) == 1 else False
368
+
369
+ min_num, max_num = self.resolve_min_max_num(
370
+ max_dynamic_patch=max_dynamic_patch,
371
+ dynamic_image_size=dynamic_image_size,
372
+ use_thumbnail=False, # Applied in image_to_pixel_values
373
+ )
374
+
375
+ return [
376
+ image_to_pixel_values_h2ovl(
377
+ image,
378
+ input_size=self.image_size,
379
+ min_num=min_num,
380
+ max_num=max_num,
381
+ use_thumbnail=self.use_thumbnail,
382
+ use_msac=use_msac,
383
+ ) for image in images
384
+ ]
385
+
386
+
387
+ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
388
+
389
+ def get_hf_processor(
390
+ self,
391
+ *,
392
+ max_dynamic_patch: Optional[int] = None,
393
+ dynamic_image_size: Optional[bool] = None,
394
+ ) -> H2OVLProcessor:
395
+ return H2OVLProcessor(
396
+ self.get_hf_config(),
397
+ self.get_tokenizer(),
398
+ max_dynamic_patch=max_dynamic_patch,
399
+ dynamic_image_size=dynamic_image_size,
400
+ )
401
+
402
+ def get_mm_max_tokens_per_item(
403
+ self,
404
+ seq_len: int,
405
+ mm_counts: Mapping[str, int],
406
+ ) -> Mapping[str, int]:
407
+ max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
408
+ if mm_counts.get("image", 0) <= 1:
409
+ max_tokens_per_image = max_tokens_one_image
410
+ else:
411
+ max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
412
+
413
+ return {"image": max_tokens_per_image}
414
+
415
+ def get_num_image_tokens(
416
+ self,
417
+ *,
418
+ image_width: int,
419
+ image_height: int,
420
+ processor: Optional[H2OVLProcessor],
421
+ use_msac: Optional[bool] = None,
422
+ ) -> int:
423
+ if processor is None:
424
+ processor = self.get_hf_processor()
425
+
426
+ return processor.get_num_image_tokens(
427
+ image_width=image_width,
428
+ image_height=image_height,
429
+ use_msac=use_msac,
430
+ )
431
+
432
+ def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
433
+ target_width, target_height = self.get_image_size_with_most_features()
434
+
435
+ return self.get_num_image_tokens(
436
+ image_width=target_width,
437
+ image_height=target_height,
438
+ processor=None,
439
+ use_msac=use_msac,
440
+ )
441
+
442
+
443
+ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
444
+ ):
445
+
446
+ def __init__(self,
447
+ info: H2OVLProcessingInfo,
448
+ dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
449
+ *,
450
+ cache: Optional[ProcessingCache] = None,
451
+ enable_sanity_checks: bool = True) -> None:
452
+ super().__init__(
453
+ info,
454
+ dummy_inputs,
455
+ cache=cache,
456
+ enable_sanity_checks=enable_sanity_checks,
457
+ )
458
+
459
+ if self.cache is not None:
460
+ # The processor output depends on the number of images passed,
461
+ # making it incompatible with processing cache which is supposed
462
+ # to be invariant of how many images are passed per prompt
463
+ self.cache = None
464
+ logger.warning_once(
465
+ f"{type(self).__name__} does not support processing cache.")
466
+
467
+ def _get_prompt_replacements(
468
+ self,
469
+ mm_items: MultiModalDataItems,
470
+ hf_processor_mm_kwargs: Mapping[str, object],
471
+ out_mm_kwargs: MultiModalKwargs,
472
+ ) -> list[PromptReplacement]:
473
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
474
+
475
+ if "image_num_patches" in out_mm_kwargs:
476
+ image_num_patches = out_mm_kwargs["image_num_patches"]
477
+ assert isinstance(image_num_patches, torch.Tensor)
478
+ image_num_patches = image_num_patches.tolist()
479
+ elif "image_embeds" in out_mm_kwargs:
480
+ # TODO: Use image size information in dictionary embedding inputs
481
+ # to compute num_patches (similar to Qwen2-VL)
482
+ image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
483
+ else:
484
+ image_num_patches = []
485
+
486
+ num_images = len(image_num_patches)
487
+
488
+ def get_replacement_internvl(item_idx: int):
489
+ images = mm_items.get_items(
490
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
491
+
492
+ if isinstance(images, ImageEmbeddingItems):
493
+ feature_size = images.get_feature_size(item_idx)
494
+ else:
495
+ image_size = images.get_image_size(item_idx)
496
+ feature_size = self.info.get_num_image_tokens(
497
+ image_width=image_size.width,
498
+ image_height=image_size.height,
499
+ processor=hf_processor,
500
+ use_msac=None if num_images == 1 else False,
501
+ )
502
+
503
+ num_patches = image_num_patches[item_idx]
504
+ if num_patches is not None:
505
+ assert isinstance(num_patches, int)
506
+
507
+ return PromptReplacementDetails(
508
+ full=hf_processor.get_image_repl_full(feature_size,
509
+ num_patches),
510
+ features=hf_processor.get_image_repl_features(
511
+ feature_size, num_patches),
512
+ )
513
+
514
+ return [
515
+ PromptReplacement(
516
+ modality="image",
517
+ target="<image>",
518
+ replacement=get_replacement_internvl,
519
+ )
520
+ ]
521
+
522
+
523
+ @MULTIMODAL_REGISTRY.register_processor(
524
+ H2OVLMultiModalProcessor,
525
+ info=H2OVLProcessingInfo,
526
+ dummy_inputs=InternVLDummyInputsBuilder)
527
+ class H2OVLChatModel(InternVLChatModel):
528
+
529
+ def _init_vision_model(
530
+ self,
531
+ config: PretrainedConfig,
532
+ quant_config: Optional[QuantizationConfig],
533
+ *,
534
+ is_mono: bool,
535
+ prefix: str,
536
+ ):
537
+ if not is_mono:
538
+ vision_feature_layer = config.select_layer
539
+ if vision_feature_layer < 0:
540
+ num_hidden_layers = (config.vision_config.num_hidden_layers +
541
+ vision_feature_layer + 1)
542
+ else:
543
+ num_hidden_layers = vision_feature_layer + 1
544
+
545
+ return InternVisionModel(
546
+ config.vision_config,
547
+ quant_config=quant_config,
548
+ num_hidden_layers_override=num_hidden_layers,
549
+ prefix=prefix,
550
+ )
551
+ else:
552
+ msg = "Monolith mode is not applicable to H2OVL"
553
+ raise NotImplementedError(msg)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/idefics3.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Inference-only Idefics3 model compatible with HuggingFace weights."""
17
+
18
+ import math
19
+ from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
20
+ Tuple, TypedDict, Union)
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
26
+ Idefics3Processor)
27
+
28
+ from vllm.attention import AttentionMetadata
29
+ from vllm.config import VllmConfig
30
+ from vllm.logger import init_logger
31
+ from vllm.model_executor.layers.linear import ReplicatedLinear
32
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
+ from vllm.model_executor.layers.quantization import QuantizationConfig
34
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
35
+ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
36
+ from vllm.model_executor.models.module_mapping import MultiModelKeys
37
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
38
+ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
39
+ from vllm.multimodal.inputs import NestedTensors
40
+ from vllm.multimodal.parse import ImageProcessorItems
41
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
42
+ BaseProcessingInfo,
43
+ MultiModalDataItems,
44
+ MultiModalFieldConfig,
45
+ PromptReplacement)
46
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
47
+ from vllm.sequence import IntermediateTensors
48
+
49
+ # yapf: disable
50
+ from .idefics2_vision_model import (
51
+ Idefics2VisionTransformer as Idefics3VisionTransformer)
52
+ # yapf: enable
53
+ from .interfaces import SupportsLoRA, SupportsMultiModal
54
+ from .llama import LlamaModel
55
+ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
56
+ merge_multimodal_embeddings)
57
+
58
+ logger = init_logger(__name__)
59
+
60
+
61
+ class Idefics3ImagePixelInputs(TypedDict):
62
+ type: Literal["pixel_values"]
63
+ data: torch.Tensor
64
+ """
65
+ Shape: `(batch_size * num_images * num_patches,
66
+ num_channels, height, width)`
67
+ """
68
+ pixel_attention_mask: Optional[torch.BoolTensor]
69
+
70
+
71
+ class Idefics3ImageEmbeddingInputs(TypedDict):
72
+ type: Literal["image_embeds"]
73
+ data: torch.Tensor
74
+ """
75
+ Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
76
+ `hidden_size` must match the hidden size of language model backbone.
77
+ """
78
+
79
+
80
+ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
81
+
82
+
83
+ class Idefics3ProcessingInfo(BaseProcessingInfo):
84
+
85
+ def get_hf_processor(
86
+ self,
87
+ *,
88
+ size: Optional[Dict[str, int]] = None) -> Idefics3Processor:
89
+ if size is not None:
90
+ return self.ctx.get_hf_processor(Idefics3Processor, size=size)
91
+
92
+ return self.ctx.get_hf_processor(Idefics3Processor)
93
+
94
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
95
+ return {"image": None}
96
+
97
+ def get_mm_max_tokens_per_item(
98
+ self,
99
+ seq_len: int,
100
+ mm_counts: Mapping[str, int],
101
+ ) -> Mapping[str, int]:
102
+ hf_processor = self.get_hf_processor()
103
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
104
+ grid_w, grid_h = self._get_image_feature_grid_size(
105
+ image_width=image_processor.size['longest_edge'],
106
+ image_height=image_processor.size['longest_edge'],
107
+ )
108
+ num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
109
+ # Calculate Non-image-token length
110
+ # NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
111
+ # but not for Idefic3, so we need to tokenize them to get actual length.
112
+ tokenizer = self.get_tokenizer()
113
+ tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
114
+ glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
115
+ # linebreak and <fake_token_around_image> always cost 1 token
116
+ fake_token_len = lb_len = 1
117
+ non_image_token = (grid_w * grid_h) * (
118
+ tile_token_len + fake_token_len) + glob_token_len + (
119
+ grid_h + 1) * lb_len + fake_token_len
120
+ return {"image": num_image_token + non_image_token}
121
+
122
+ def _resize_output_size(self,
123
+ *,
124
+ height: int,
125
+ width: int,
126
+ max_len: Optional[int] = None,
127
+ min_len: Optional[int] = 1,
128
+ max_size: Optional[int] = None) -> tuple[int, int]:
129
+ # Set default value for max_len if not provided
130
+ max_len = max(height, width) if max_len is None else max_len
131
+ aspect_ratio = width / height
132
+
133
+ # Handle the maximum size constraint
134
+ if max_size is not None:
135
+ max_len = min(max_len, max_size)
136
+
137
+ # Adjust dimensions according to the aspect ratio
138
+ if width >= height:
139
+ width = max_len
140
+ height = int(width / aspect_ratio)
141
+ else:
142
+ height = max_len
143
+ width = int(height * aspect_ratio)
144
+
145
+ # Ensure both width and height are even (if needed)
146
+ height += height % 2
147
+ width += width % 2
148
+
149
+ # Ensure dimensions are not smaller than the minimum length
150
+ height = max(height, min_len)
151
+ width = max(width, min_len)
152
+
153
+ return height, width
154
+
155
+ def _get_resize_output_image_size(
156
+ self,
157
+ *,
158
+ image_width: int,
159
+ image_height: int,
160
+ resolution_max_side: int,
161
+ ) -> tuple[int, int]:
162
+ hf_processor = self.get_hf_processor()
163
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
164
+ max_image_size = image_processor.size['longest_edge']
165
+ if resolution_max_side > max_image_size:
166
+ raise ValueError(
167
+ "`resolution_max_side` cannot be larger than `max_image_size`")
168
+
169
+ height, width = image_height, image_width
170
+
171
+ # Find the output size, when rescaling the longest edge to max_len and
172
+ # preserving the aspect ratio
173
+ height, width = self._resize_output_size(height=height,
174
+ width=width,
175
+ max_len=resolution_max_side)
176
+ return height, width
177
+
178
+ def _get_image_feature_grid_size(
179
+ self,
180
+ *,
181
+ image_width: int,
182
+ image_height: int,
183
+ size: Optional[dict[str, object]] = None,
184
+ ) -> tuple[int, int]:
185
+ hf_processor = self.get_hf_processor(size=size)
186
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
187
+ max_image_size = image_processor.max_image_size['longest_edge']
188
+ size = image_processor.size['longest_edge']
189
+ assert size % max_image_size == 0, (
190
+ "`longest_edge` in image_processor's `size` must be divisible by "
191
+ "`longest_edge` in `max_image_size`, this may be caused by "
192
+ "incorrect mm_kwargs override.")
193
+
194
+ resized_height, resized_width = self._get_resize_output_image_size(
195
+ image_width=image_width,
196
+ image_height=image_height,
197
+ resolution_max_side=size,
198
+ )
199
+ if resized_height > max_image_size or resized_width > max_image_size:
200
+ grid_h = math.ceil(resized_height / max_image_size)
201
+ grid_w = math.ceil(resized_width / max_image_size)
202
+ else:
203
+ grid_h = grid_w = 0
204
+ return grid_w, grid_h
205
+
206
+
207
+ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
208
+ ):
209
+
210
+ def get_dummy_processor_inputs(
211
+ self,
212
+ seq_len: int,
213
+ mm_counts: Mapping[str, int],
214
+ ) -> ProcessorInputs:
215
+ num_images = mm_counts.get("image", 0)
216
+ hf_processor = self.info.get_hf_processor()
217
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
218
+ longest_edge = image_processor.max_image_size['longest_edge']
219
+ image_token: str = hf_processor.image_token.content
220
+
221
+ mm_data = {
222
+ "image":
223
+ self._get_dummy_images(width=longest_edge,
224
+ height=longest_edge,
225
+ num_images=num_images)
226
+ }
227
+
228
+ return ProcessorInputs(
229
+ prompt_text=image_token * num_images,
230
+ mm_data=mm_data,
231
+ )
232
+
233
+
234
+ class Idefics3MultimodalProcessor(
235
+ BaseMultiModalProcessor[Idefics3ProcessingInfo]):
236
+
237
+ def _call_hf_processor(
238
+ self,
239
+ prompt: str,
240
+ mm_data: Mapping[str, object],
241
+ mm_kwargs: Mapping[str, object],
242
+ ) -> BatchFeature:
243
+ if mm_data:
244
+ processed_outputs = super()._call_hf_processor(
245
+ prompt, mm_data, mm_kwargs)
246
+ image_grids = [
247
+ self.info._get_image_feature_grid_size(
248
+ image_width=img.width,
249
+ image_height=img.height,
250
+ **mm_kwargs,
251
+ ) for img in mm_data["images"]
252
+ ]
253
+ image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
254
+ for key in ("pixel_values", "pixel_attention_mask"):
255
+ data = processed_outputs.pop(key)
256
+ data = data.flatten(0, 1).split(image_patches)
257
+ processed_outputs[key] = data
258
+ else:
259
+ tokenizer = self.info.get_tokenizer()
260
+ processed_outputs = tokenizer(prompt,
261
+ add_special_tokens=True,
262
+ return_tensors="pt")
263
+ return processed_outputs
264
+
265
+ def _get_mm_fields_config(
266
+ self,
267
+ hf_inputs: BatchFeature,
268
+ hf_processor_mm_kwargs: Mapping[str, object],
269
+ ) -> Mapping[str, MultiModalFieldConfig]:
270
+ return dict(
271
+ pixel_values=MultiModalFieldConfig.batched("image"),
272
+ pixel_attention_mask=MultiModalFieldConfig.batched("image"),
273
+ image_embeds=MultiModalFieldConfig.batched("image"),
274
+ )
275
+
276
+ def _get_prompt_replacements(
277
+ self,
278
+ mm_items: MultiModalDataItems,
279
+ hf_processor_mm_kwargs: Mapping[str, object],
280
+ out_mm_kwargs: MultiModalKwargs,
281
+ ) -> list[PromptReplacement]:
282
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
283
+
284
+ image_token = hf_processor.image_token.content
285
+ fake_image_token = hf_processor.fake_image_token.content
286
+ global_img_token = hf_processor.global_image_tag
287
+ image_seq_len = hf_processor.image_seq_len
288
+ grid_placeholder = "<row_{n_h}_col_{n_w}>"
289
+
290
+ p_img = image_token * image_seq_len
291
+ global_img_placeholder = fake_image_token + global_img_token + p_img
292
+ tile_img_placeholder = fake_image_token + grid_placeholder + p_img
293
+
294
+ def get_replacement_idefics3(item_idx: int) -> str:
295
+ images = mm_items.get_items("image", ImageProcessorItems)
296
+
297
+ image_size = images.get_image_size(item_idx)
298
+ grid_w, grid_h = self.info._get_image_feature_grid_size(
299
+ image_width=image_size.width,
300
+ image_height=image_size.height,
301
+ **hf_processor_mm_kwargs,
302
+ )
303
+ if grid_w == 0 and grid_h == 0:
304
+ image_placeholder = global_img_placeholder
305
+ else:
306
+ tiles_placeholder = list[str]()
307
+ for i in range(grid_h):
308
+ for j in range(grid_w):
309
+ placeholder_per_tile = tile_img_placeholder.format(
310
+ n_h=i + 1, n_w=j + 1)
311
+ tiles_placeholder.append(placeholder_per_tile)
312
+ # Add line break if it is the last tile in the row
313
+ if j == grid_w - 1:
314
+ tiles_placeholder.append("\n")
315
+
316
+ image_placeholder = "".join(
317
+ [*tiles_placeholder, "\n", global_img_placeholder])
318
+ return image_placeholder + fake_image_token
319
+
320
+ return [
321
+ PromptReplacement(
322
+ modality="image",
323
+ target=image_token,
324
+ replacement=get_replacement_idefics3,
325
+ )
326
+ ]
327
+
328
+
329
+ class Idefics3SimpleMLP(nn.Module):
330
+
331
+ def __init__(
332
+ self,
333
+ config: Idefics3Config,
334
+ quant_config: Optional[QuantizationConfig] = None,
335
+ prefix: str = "",
336
+ ):
337
+ super().__init__()
338
+ input_size = config.vision_config.hidden_size * (config.scale_factor**
339
+ 2)
340
+ output_size = config.text_config.hidden_size
341
+ self.proj = ReplicatedLinear(
342
+ input_size,
343
+ output_size,
344
+ bias=False,
345
+ quant_config=quant_config,
346
+ prefix=maybe_prefix(prefix, "proj"),
347
+ )
348
+
349
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
350
+ out, _ = self.proj(x)
351
+ return out
352
+
353
+
354
+ class Idefics3Connector(nn.Module):
355
+
356
+ def __init__(
357
+ self,
358
+ config: Idefics3Config,
359
+ quant_config: Optional[QuantizationConfig] = None,
360
+ prefix: str = "",
361
+ ):
362
+ super().__init__()
363
+ self.scale_factor = config.scale_factor
364
+ self.modality_projection = Idefics3SimpleMLP(
365
+ config,
366
+ quant_config,
367
+ prefix=maybe_prefix(prefix, "modality_projection"),
368
+ )
369
+
370
+ def pixel_shuffle(self,
371
+ x: torch.Tensor,
372
+ scale_factor: int = 2) -> torch.Tensor:
373
+ bsz, seq, embed_dim = x.size()
374
+ height = width = int(seq**0.5)
375
+ x = x.view(bsz, height, width, embed_dim)
376
+ x = x.view(bsz, height, int(width / scale_factor),
377
+ embed_dim * scale_factor)
378
+ x = x.permute(0, 2, 1, 3)
379
+ x = x.reshape(
380
+ bsz,
381
+ int(width / scale_factor),
382
+ int(height / scale_factor),
383
+ embed_dim * (scale_factor**2),
384
+ )
385
+ x = x.permute(0, 2, 1, 3)
386
+ x = x.reshape(bsz, int(seq / (scale_factor**2)),
387
+ embed_dim * (scale_factor**2))
388
+ return x
389
+
390
+ def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
391
+ image_hidden_states = self.pixel_shuffle(image_hidden_states,
392
+ self.scale_factor)
393
+ image_hidden_states = self.modality_projection(image_hidden_states)
394
+ return image_hidden_states
395
+
396
+
397
+ class Idefics3Model(nn.Module):
398
+
399
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
400
+ super().__init__()
401
+
402
+ config: Idefics3Config = vllm_config.model_config.hf_config
403
+ quant_config = vllm_config.quant_config
404
+
405
+ self.config = config
406
+ self.padding_idx = self.config.text_config.pad_token_id
407
+ self.vocab_size = self.config.text_config.vocab_size
408
+ self.vision_model = Idefics3VisionTransformer(
409
+ config.vision_config,
410
+ quant_config=quant_config,
411
+ prefix=maybe_prefix(prefix, "vision_model"))
412
+ self.connector = Idefics3Connector(
413
+ config,
414
+ quant_config,
415
+ prefix=maybe_prefix(prefix, "connector"),
416
+ )
417
+ self.text_model = LlamaModel(
418
+ vllm_config=vllm_config.with_hf_config(config.text_config),
419
+ prefix=maybe_prefix(prefix, "text_model"),
420
+ )
421
+
422
+ self.image_seq_len = int(
423
+ ((config.vision_config.image_size //
424
+ config.vision_config.patch_size)**2) / (config.scale_factor**2))
425
+ self.image_token_id = self.config.image_token_id
426
+
427
+ def _validate_pixel_values(
428
+ self, data: Union[torch.Tensor, List[torch.Tensor]]
429
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
430
+
431
+ h = w = self.config.vision_config.image_size
432
+ expected_dims = (3, h, w)
433
+
434
+ def _validate_shape(d: torch.Tensor):
435
+ actual_dims = tuple(d.shape[1:])
436
+
437
+ if actual_dims != expected_dims:
438
+ expected_expr = ("num_patches", *map(str, expected_dims))
439
+ raise ValueError(
440
+ "The expected shape of pixel values per image per batch "
441
+ f"is {expected_expr}. You supplied {tuple(d.shape)}.")
442
+
443
+ for d in data:
444
+ _validate_shape(d)
445
+
446
+ return data
447
+
448
+ def _parse_and_validate_image_input(
449
+ self, **kwargs: object) -> Optional[ImageInputs]:
450
+ pixel_values = kwargs.pop("pixel_values", None)
451
+ image_embeds = kwargs.pop("image_embeds", None)
452
+ pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
453
+
454
+ if pixel_values is None and image_embeds is None:
455
+ return None
456
+
457
+ if image_embeds is not None:
458
+ if not isinstance(image_embeds, (torch.Tensor, list)):
459
+ raise ValueError("Incorrect type of image embeddings. "
460
+ f"Got type: {type(image_embeds)}")
461
+
462
+ return Idefics3ImageEmbeddingInputs(
463
+ type="image_embeds",
464
+ data=flatten_bn(image_embeds, concat=True),
465
+ )
466
+
467
+ if pixel_values is not None:
468
+ if not isinstance(pixel_values, (torch.Tensor, list)):
469
+ raise ValueError("Incorrect type of pixel values. "
470
+ f"Got type: {type(pixel_values)}")
471
+
472
+ if isinstance(pixel_values, list):
473
+ pixel_values = torch.cat(pixel_values, dim=1)
474
+ pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
475
+ else:
476
+ pixel_values = flatten_bn(pixel_values)
477
+ pixel_attention_mask = flatten_bn(pixel_attention_mask)
478
+
479
+ return Idefics3ImagePixelInputs(
480
+ type="pixel_values",
481
+ data=self._validate_pixel_values(pixel_values),
482
+ pixel_attention_mask=pixel_attention_mask)
483
+
484
+ raise AssertionError("This line should be unreachable.")
485
+
486
+ def _image_pixels_to_features(
487
+ self,
488
+ pixel_values: torch.Tensor,
489
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
490
+ ) -> NestedTensors:
491
+ # NOTE: we skip the step to select the vision feature layer since
492
+ # this is already done inside the vision tower
493
+ num_patches = [x.size(0) for x in pixel_values]
494
+ pixel_values = pixel_values.to(
495
+ dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
496
+ ) # fp16 compatibility
497
+
498
+ # Remove padding images - padding images are full 0.
499
+ nb_values_per_image = pixel_values.shape[1:].numel()
500
+ real_images_inds = (pixel_values == 0.0).sum(
501
+ dim=(-1, -2, -3)) != nb_values_per_image
502
+ pixel_values = pixel_values[real_images_inds].contiguous()
503
+
504
+ # Handle the vision attention mask
505
+ if pixel_attention_mask is None:
506
+ pixel_attention_mask = torch.ones(
507
+ size=(pixel_values.size(0), pixel_values.size(2),
508
+ pixel_values.size(3)),
509
+ dtype=torch.bool,
510
+ device=pixel_values.device,
511
+ )
512
+ else:
513
+ # Remove padding images from the mask
514
+ pixel_attention_mask = pixel_attention_mask[
515
+ real_images_inds].contiguous()
516
+
517
+ patch_size = self.config.vision_config.patch_size
518
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1,
519
+ size=patch_size,
520
+ step=patch_size)
521
+ patches_subgrid = patches_subgrid.unfold(dimension=2,
522
+ size=patch_size,
523
+ step=patch_size)
524
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
525
+
526
+ # Get sequence from the vision encoder
527
+ image_hidden_states = self.vision_model(
528
+ pixel_values=pixel_values,
529
+ patch_attention_mask=patch_attention_mask,
530
+ )
531
+
532
+ return image_hidden_states.split(num_patches)
533
+
534
+ def _process_image_pixels(
535
+ self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
536
+ assert self.vision_model is not None
537
+
538
+ pixel_values = inputs["data"]
539
+ pixel_attention_mask = inputs["pixel_attention_mask"]
540
+
541
+ return self._image_pixels_to_features(pixel_values,
542
+ pixel_attention_mask)
543
+
544
+ def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
545
+ if image_input["type"] == "image_embeds":
546
+ return image_input["data"]
547
+
548
+ assert self.vision_model is not None
549
+ image_features = self._process_image_pixels(image_input)
550
+ num_patches = [x.size(0) for x in image_features]
551
+ image_features = torch.cat(image_features)
552
+ return self.connector(image_features).split(num_patches)
553
+
554
+ def get_input_embeddings(
555
+ self,
556
+ input_ids: torch.Tensor,
557
+ ) -> torch.Tensor:
558
+ return self.text_model.get_input_embeddings(input_ids)
559
+
560
+ def forward(
561
+ self,
562
+ input_ids: torch.Tensor,
563
+ positions: torch.Tensor,
564
+ kv_caches: List[torch.Tensor],
565
+ attn_metadata: AttentionMetadata,
566
+ intermediate_tensors: Optional[IntermediateTensors] = None,
567
+ inputs_embeds: Optional[torch.Tensor] = None,
568
+ ) -> Union[torch.Tensor, IntermediateTensors]:
569
+
570
+ hidden_states = self.text_model(
571
+ input_ids,
572
+ positions,
573
+ kv_caches,
574
+ attn_metadata,
575
+ intermediate_tensors,
576
+ inputs_embeds=inputs_embeds,
577
+ )
578
+ return hidden_states
579
+
580
+
581
+ @MULTIMODAL_REGISTRY.register_processor(
582
+ Idefics3MultimodalProcessor,
583
+ info=Idefics3ProcessingInfo,
584
+ dummy_inputs=Idefics3DummyInputsBuilder)
585
+ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
586
+ SupportsLoRA):
587
+ packed_modules_mapping = {
588
+ "qkv_proj": [
589
+ "q_proj",
590
+ "k_proj",
591
+ "v_proj",
592
+ ],
593
+ "gate_up_proj": [
594
+ "gate_proj",
595
+ "up_proj",
596
+ ],
597
+ }
598
+ # LoRA specific attributes
599
+ supported_lora_modules = [
600
+ # vision_model
601
+ "fc1",
602
+ "fc2",
603
+ "out_proj",
604
+ # text_model
605
+ "qkv_proj", # same name with vision encoder
606
+ "o_proj",
607
+ "gate_up_proj",
608
+ "down_proj",
609
+ ]
610
+
611
+ embedding_modules = {}
612
+ embedding_padding_modules = []
613
+
614
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
615
+ super().__init__()
616
+
617
+ config = vllm_config.model_config.hf_config
618
+ quant_config = vllm_config.quant_config
619
+ multimodal_config = vllm_config.model_config.multimodal_config
620
+
621
+ self.config = config
622
+ self.multimodal_config = multimodal_config
623
+
624
+ self.model = Idefics3Model(vllm_config=vllm_config,
625
+ prefix=maybe_prefix(prefix, "model"))
626
+ self.image_token_id = self.config.image_token_id
627
+
628
+ self.lm_head = ParallelLMHead(
629
+ config.text_config.vocab_size,
630
+ config.text_config.hidden_size,
631
+ quant_config=quant_config,
632
+ )
633
+ if self.config.text_config.tie_word_embeddings:
634
+ self.lm_head.weight = self.model.text_model.wte.weight
635
+ self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
636
+ self.sampler = get_sampler()
637
+
638
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
639
+ image_input = self.model._parse_and_validate_image_input(**kwargs)
640
+ if image_input is None:
641
+ return None
642
+ vision_embeddings = self.model._process_image_input(image_input)
643
+ return vision_embeddings
644
+
645
+ def get_input_embeddings(
646
+ self,
647
+ input_ids: torch.Tensor,
648
+ multimodal_embeddings: Optional[NestedTensors] = None,
649
+ ) -> torch.Tensor:
650
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
651
+ if multimodal_embeddings is not None:
652
+ inputs_embeds = merge_multimodal_embeddings(
653
+ input_ids, inputs_embeds, multimodal_embeddings,
654
+ self.config.image_token_id)
655
+ return inputs_embeds
656
+
657
+ def forward(
658
+ self,
659
+ input_ids: torch.Tensor,
660
+ positions: torch.Tensor,
661
+ kv_caches: List[torch.Tensor],
662
+ attn_metadata: AttentionMetadata,
663
+ intermediate_tensors: Optional[IntermediateTensors] = None,
664
+ inputs_embeds: Optional[torch.Tensor] = None,
665
+ **kwargs: object,
666
+ ) -> Union[torch.Tensor, IntermediateTensors]:
667
+ if intermediate_tensors is not None:
668
+ inputs_embeds = None
669
+
670
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
671
+ # condition is for v0 compatibility.
672
+ elif inputs_embeds is None:
673
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
674
+ inputs_embeds = self.get_input_embeddings(input_ids,
675
+ vision_embeddings)
676
+ input_ids = None
677
+
678
+ hidden_states = self.model.text_model(input_ids,
679
+ positions,
680
+ kv_caches,
681
+ attn_metadata,
682
+ intermediate_tensors,
683
+ inputs_embeds=inputs_embeds)
684
+
685
+ return hidden_states
686
+
687
+ def compute_logits(self, hidden_states: torch.Tensor,
688
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
689
+ logits = self.logits_processor(self.lm_head, hidden_states,
690
+ sampling_metadata)
691
+ return logits
692
+
693
+ def sample(
694
+ self,
695
+ logits: torch.Tensor,
696
+ sampling_metadata: SamplingMetadata,
697
+ ) -> Optional[SamplerOutput]:
698
+ next_tokens = self.sampler(logits, sampling_metadata)
699
+ return next_tokens
700
+
701
+ def load_weights(self, weights: Iterable[Tuple[str,
702
+ torch.Tensor]]) -> Set[str]:
703
+ loader = AutoWeightsLoader(self)
704
+ return loader.load_weights(weights)
705
+
706
+ def get_mm_mapping(self) -> MultiModelKeys:
707
+ """
708
+ Get the module prefix in multimodal models
709
+ """
710
+ return MultiModelKeys.from_string_field(
711
+ language_model="model.text_model",
712
+ connector="model.connector",
713
+ tower_model="model.vision_model")
.venv/lib/python3.11/site-packages/vllm/model_executor/models/internlm2.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import PretrainedConfig
9
+
10
+ from vllm.attention import Attention, AttentionMetadata
11
+ from vllm.compilation.decorators import support_torch_compile
12
+ from vllm.config import CacheConfig, VllmConfig
13
+ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ split_tensor_along_last_dim,
16
+ tensor_model_parallel_all_gather)
17
+ from vllm.model_executor.layers.activation import SiluAndMul
18
+ from vllm.model_executor.layers.layernorm import RMSNorm
19
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ RowParallelLinear)
22
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
+ from vllm.model_executor.layers.pooler import Pooler, PoolingType
24
+ from vllm.model_executor.layers.quantization import QuantizationConfig
25
+ from vllm.model_executor.layers.rotary_embedding import get_rope
26
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
28
+ ParallelLMHead, VocabParallelEmbedding)
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+ from vllm.model_executor.pooling_metadata import PoolingMetadata
31
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
32
+ from vllm.sequence import IntermediateTensors, PoolerOutput
33
+
34
+ from .interfaces import SupportsLoRA, SupportsPP
35
+ from .utils import (is_pp_missing_parameter,
36
+ make_empty_intermediate_tensors_factory, make_layers,
37
+ maybe_prefix)
38
+
39
+
40
+ class InternLM2MLP(nn.Module):
41
+
42
+ def __init__(
43
+ self,
44
+ hidden_size: int,
45
+ intermediate_size: int,
46
+ hidden_act: str,
47
+ quant_config: Optional[QuantizationConfig] = None,
48
+ prefix: str = "",
49
+ ) -> None:
50
+ super().__init__()
51
+ self.gate_up_proj = MergedColumnParallelLinear(
52
+ hidden_size,
53
+ [intermediate_size] * 2,
54
+ bias=False,
55
+ quant_config=quant_config,
56
+ prefix=f"{prefix}.gate_up_proj",
57
+ )
58
+ self.w2 = RowParallelLinear(
59
+ intermediate_size,
60
+ hidden_size,
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=f"{prefix}.w2",
64
+ )
65
+ if hidden_act != "silu":
66
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
67
+ "Only silu is supported for now.")
68
+ self.act_fn = SiluAndMul()
69
+
70
+ def forward(self, x):
71
+ gate_up, _ = self.gate_up_proj(x)
72
+ x = self.act_fn(gate_up)
73
+ x, _ = self.w2(x)
74
+ return x
75
+
76
+
77
+ class InternLM2Attention(nn.Module):
78
+
79
+ def __init__(
80
+ self,
81
+ hidden_size: int,
82
+ num_heads: int,
83
+ num_kv_heads: int,
84
+ rope_theta: float = 10000,
85
+ rope_scaling: Optional[Dict[str, Any]] = None,
86
+ max_position_embeddings: int = 8192,
87
+ cache_config: Optional[CacheConfig] = None,
88
+ quant_config: Optional[QuantizationConfig] = None,
89
+ prefix: str = "",
90
+ ) -> None:
91
+ super().__init__()
92
+ self.hidden_size = hidden_size
93
+ self.tp_size = get_tensor_model_parallel_world_size()
94
+ self.tp_rank = get_tensor_model_parallel_rank()
95
+ self.total_num_heads = num_heads
96
+ assert self.total_num_heads % self.tp_size == 0
97
+ self.num_heads = self.total_num_heads // self.tp_size
98
+ self.total_num_kv_heads = num_kv_heads
99
+ if self.total_num_kv_heads >= self.tp_size:
100
+ # Number of KV heads is greater than TP size, so we partition
101
+ # the KV heads across multiple tensor parallel GPUs.
102
+ assert self.total_num_kv_heads % self.tp_size == 0
103
+ else:
104
+ # Number of KV heads is less than TP size, so we replicate
105
+ # the KV heads across multiple tensor parallel GPUs.
106
+ assert self.tp_size % self.total_num_kv_heads == 0
107
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
108
+ self.head_dim = hidden_size // self.total_num_heads
109
+ self.q_size = self.num_heads * self.head_dim
110
+ self.kv_size = self.num_kv_heads * self.head_dim
111
+ self.key_value_groups = int(self.num_heads / self.num_kv_heads)
112
+ self.scaling = self.head_dim**-0.5
113
+ self.rope_theta = rope_theta
114
+ self.max_position_embeddings = max_position_embeddings
115
+
116
+ self.wqkv = QKVParallelLinear(
117
+ hidden_size,
118
+ self.head_dim,
119
+ self.total_num_heads,
120
+ self.total_num_kv_heads,
121
+ bias=False,
122
+ quant_config=quant_config,
123
+ prefix=f"{prefix}.wqkv",
124
+ )
125
+ self.wo = RowParallelLinear(
126
+ self.total_num_heads * self.head_dim,
127
+ hidden_size,
128
+ bias=False,
129
+ quant_config=quant_config,
130
+ prefix=f"{prefix}.wo",
131
+ )
132
+
133
+ self.rotary_emb = get_rope(
134
+ self.head_dim,
135
+ rotary_dim=self.head_dim,
136
+ max_position=max_position_embeddings,
137
+ base=rope_theta,
138
+ rope_scaling=rope_scaling,
139
+ )
140
+ self.attn = Attention(
141
+ self.num_heads,
142
+ self.head_dim,
143
+ self.scaling,
144
+ num_kv_heads=self.num_kv_heads,
145
+ cache_config=cache_config,
146
+ quant_config=quant_config,
147
+ prefix=f"{prefix}.attn",
148
+ )
149
+
150
+ def split_qkv(self, qkv: torch.Tensor):
151
+ seq_len = qkv.shape[0]
152
+ if self.tp_size > 1:
153
+ qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
154
+ qkv = tensor_model_parallel_all_gather(qkv)
155
+ qkv = torch.split(qkv, qkv_map, dim=-1)
156
+ qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
157
+ qkv = torch.cat(qkv, dim=-1)
158
+
159
+ qkv = qkv.view(seq_len, self.total_num_kv_heads,
160
+ self.key_value_groups + 2, self.head_dim)
161
+ q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
162
+ q = q.reshape(seq_len, self.q_size * self.tp_size)
163
+ k = k.reshape(seq_len, self.kv_size * self.tp_size)
164
+ v = v.reshape(seq_len, self.kv_size * self.tp_size)
165
+
166
+ if self.tp_size > 1:
167
+ splitter = partial(split_tensor_along_last_dim,
168
+ num_partitions=self.tp_size)
169
+ q = splitter(q)[self.tp_rank]
170
+ k = splitter(k)[self.tp_rank]
171
+ v = splitter(v)[self.tp_rank]
172
+ return q, k, v
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ kv_cache: torch.Tensor,
179
+ attn_metadata: AttentionMetadata,
180
+ ) -> torch.Tensor:
181
+ qkv, _ = self.wqkv(hidden_states)
182
+ q, k, v = self.split_qkv(qkv)
183
+ q, k = self.rotary_emb(positions, q, k)
184
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
185
+ output, _ = self.wo(attn_output)
186
+ return output
187
+
188
+
189
+ class InternLMDecoderLayer(nn.Module):
190
+
191
+ def __init__(
192
+ self,
193
+ config: PretrainedConfig,
194
+ cache_config: Optional[CacheConfig] = None,
195
+ quant_config: Optional[QuantizationConfig] = None,
196
+ prefix: str = "",
197
+ ) -> None:
198
+ super().__init__()
199
+ self.hidden_size = config.hidden_size
200
+ rope_theta = getattr(config, "rope_theta", 10000)
201
+ rope_scaling = getattr(config, "rope_scaling", None)
202
+ max_position_embeddings = getattr(config, "max_position_embeddings",
203
+ 8192)
204
+ self.attention = InternLM2Attention(
205
+ hidden_size=self.hidden_size,
206
+ num_heads=config.num_attention_heads,
207
+ num_kv_heads=config.num_key_value_heads,
208
+ rope_theta=rope_theta,
209
+ rope_scaling=rope_scaling,
210
+ max_position_embeddings=max_position_embeddings,
211
+ cache_config=cache_config,
212
+ quant_config=quant_config,
213
+ prefix=f"{prefix}.attention",
214
+ )
215
+ self.feed_forward = InternLM2MLP(
216
+ hidden_size=self.hidden_size,
217
+ intermediate_size=config.intermediate_size,
218
+ hidden_act=config.hidden_act,
219
+ quant_config=quant_config,
220
+ prefix=f"{prefix}.feed_forward",
221
+ )
222
+ self.attention_norm = RMSNorm(config.hidden_size,
223
+ eps=config.rms_norm_eps)
224
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
225
+
226
+ def forward(
227
+ self,
228
+ positions: torch.Tensor,
229
+ hidden_states: torch.Tensor,
230
+ kv_cache: torch.Tensor,
231
+ attn_metadata: AttentionMetadata,
232
+ residual: Optional[torch.Tensor],
233
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
234
+ # Self Attention
235
+ if residual is None:
236
+ residual = hidden_states
237
+ hidden_states = self.attention_norm(hidden_states)
238
+ else:
239
+ hidden_states, residual = self.attention_norm(
240
+ hidden_states, residual)
241
+ hidden_states = self.attention(
242
+ positions=positions,
243
+ hidden_states=hidden_states,
244
+ kv_cache=kv_cache,
245
+ attn_metadata=attn_metadata,
246
+ )
247
+
248
+ # Fully Connected
249
+ hidden_states, residual = self.ffn_norm(hidden_states, residual)
250
+ hidden_states = self.feed_forward(hidden_states)
251
+ return hidden_states, residual
252
+
253
+
254
+ @support_torch_compile
255
+ class InternLM2Model(nn.Module):
256
+
257
+ def __init__(
258
+ self,
259
+ *,
260
+ vllm_config: VllmConfig,
261
+ prefix: str = "",
262
+ layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
263
+ super().__init__()
264
+
265
+ config = vllm_config.model_config.hf_config
266
+ cache_config = vllm_config.cache_config
267
+ quant_config = vllm_config.quant_config
268
+
269
+ self.config = config
270
+ self.padding_idx = config.pad_token_id
271
+ self.vocab_size = config.vocab_size
272
+ self.tok_embeddings = VocabParallelEmbedding(
273
+ config.vocab_size,
274
+ config.hidden_size,
275
+ )
276
+ self.start_layer, self.end_layer, self.layers = make_layers(
277
+ config.num_hidden_layers,
278
+ lambda prefix: layer_type(
279
+ config, cache_config, quant_config, prefix=prefix),
280
+ prefix=f"{prefix}.layers")
281
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
+ self.make_empty_intermediate_tensors = (
283
+ make_empty_intermediate_tensors_factory(
284
+ ["hidden_states", "residual"], config.hidden_size))
285
+
286
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
287
+ return self.tok_embeddings(input_ids)
288
+
289
+ def forward(
290
+ self,
291
+ input_ids: torch.Tensor,
292
+ positions: torch.Tensor,
293
+ kv_caches: List[torch.Tensor],
294
+ attn_metadata: AttentionMetadata,
295
+ intermediate_tensors: Optional[IntermediateTensors] = None,
296
+ inputs_embeds: Optional[torch.Tensor] = None,
297
+ ) -> Union[torch.Tensor, IntermediateTensors]:
298
+ if get_pp_group().is_first_rank:
299
+ if inputs_embeds is not None:
300
+ hidden_states = inputs_embeds
301
+ else:
302
+ hidden_states = self.get_input_embeddings(input_ids)
303
+ residual = None
304
+ else:
305
+ assert intermediate_tensors is not None
306
+ hidden_states = intermediate_tensors["hidden_states"]
307
+ residual = intermediate_tensors["residual"]
308
+ for i in range(self.start_layer, self.end_layer):
309
+ layer = self.layers[i]
310
+ hidden_states, residual = layer(
311
+ positions,
312
+ hidden_states,
313
+ kv_caches[i - self.start_layer],
314
+ attn_metadata,
315
+ residual,
316
+ )
317
+ if not get_pp_group().is_last_rank:
318
+ return IntermediateTensors({
319
+ "hidden_states": hidden_states,
320
+ "residual": residual
321
+ })
322
+ hidden_states, _ = self.norm(hidden_states, residual)
323
+ return hidden_states
324
+
325
+
326
+ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
327
+ packed_modules_mapping = {
328
+ "wqkv": ["wqkv"],
329
+ "gate_up_proj": ["w1", "w3"],
330
+ }
331
+
332
+ # LoRA specific attributes
333
+ supported_lora_modules = [
334
+ "wqkv",
335
+ "wo",
336
+ "gate_up_proj",
337
+ "w2",
338
+ ]
339
+ embedding_modules = {}
340
+ embedding_padding_modules = []
341
+
342
+ def __init__(self,
343
+ *,
344
+ vllm_config: VllmConfig,
345
+ prefix: str = "",
346
+ model_type: Type[InternLM2Model] = InternLM2Model):
347
+ super().__init__()
348
+ config = vllm_config.model_config.hf_config
349
+ quant_config = vllm_config.quant_config
350
+ lora_config = vllm_config.lora_config
351
+
352
+ self.config = config
353
+ self.quant_config = quant_config
354
+ self.lora_config = lora_config
355
+
356
+ self.model = model_type(vllm_config=vllm_config,
357
+ prefix=maybe_prefix(prefix, "model"))
358
+ self.output = ParallelLMHead(config.vocab_size,
359
+ config.hidden_size,
360
+ quant_config=quant_config,
361
+ prefix=maybe_prefix(prefix, "output"))
362
+ if self.config.tie_word_embeddings:
363
+ self.output.weight = self.model.tok_embeddings.weight
364
+ self.logits_processor = LogitsProcessor(config.vocab_size)
365
+ self.sampler = get_sampler()
366
+ self.make_empty_intermediate_tensors = (
367
+ self.model.make_empty_intermediate_tensors)
368
+
369
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
370
+ return self.model.get_input_embeddings(input_ids)
371
+
372
+ def forward(
373
+ self,
374
+ input_ids: torch.Tensor,
375
+ positions: torch.Tensor,
376
+ kv_caches: List[torch.Tensor],
377
+ attn_metadata: AttentionMetadata,
378
+ intermediate_tensors: Optional[IntermediateTensors],
379
+ inputs_embeds: Optional[torch.Tensor] = None,
380
+ ) -> torch.Tensor:
381
+ hidden_states = self.model(input_ids, positions, kv_caches,
382
+ attn_metadata, intermediate_tensors,
383
+ inputs_embeds)
384
+ return hidden_states
385
+
386
+ def compute_logits(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ sampling_metadata: SamplingMetadata,
390
+ ) -> Optional[torch.Tensor]:
391
+ logits = self.logits_processor(self.output, hidden_states,
392
+ sampling_metadata)
393
+ return logits
394
+
395
+ def sample(
396
+ self,
397
+ logits: torch.Tensor,
398
+ sampling_metadata: SamplingMetadata,
399
+ ) -> Optional[SamplerOutput]:
400
+ next_tokens = self.sampler(logits, sampling_metadata)
401
+ return next_tokens
402
+
403
+ def load_weights(self, weights: Iterable[Tuple[str,
404
+ torch.Tensor]]) -> Set[str]:
405
+ stacked_params_mapping = [
406
+ # (param_name, shard_name, shard_id)
407
+ ("gate_up_proj", "w1", 0),
408
+ ("gate_up_proj", "w3", 1),
409
+ ]
410
+ params_dict = dict(self.named_parameters())
411
+ loaded_params: Set[str] = set()
412
+ for name, loaded_weight in weights:
413
+ if "rotary_emb.inv_freq" in name:
414
+ continue
415
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
416
+ if weight_name not in name:
417
+ continue
418
+ name = name.replace(weight_name, param_name)
419
+ # Skip loading extra bias for GPTQ models.
420
+ if name.endswith(".bias") and name not in params_dict:
421
+ continue
422
+ if is_pp_missing_parameter(name, self):
423
+ continue
424
+ param = params_dict[name]
425
+ weight_loader = param.weight_loader
426
+ weight_loader(param, loaded_weight, shard_id)
427
+ break
428
+ else:
429
+ # Skip loading extra bias for GPTQ models.
430
+ if name.endswith(".bias") and name not in params_dict:
431
+ continue
432
+ if is_pp_missing_parameter(name, self):
433
+ continue
434
+ param = params_dict[name]
435
+ weight_loader = getattr(param, "weight_loader",
436
+ default_weight_loader)
437
+ weight_loader(param, loaded_weight)
438
+ loaded_params.add(name)
439
+ return loaded_params
440
+
441
+
442
+ class InternLM2ForRewardModel(InternLM2ForCausalLM):
443
+
444
+ def __init__(
445
+ self,
446
+ *,
447
+ vllm_config: VllmConfig,
448
+ prefix: str = "",
449
+ model_type: Type[InternLM2Model] = InternLM2Model,
450
+ ):
451
+ super().__init__(vllm_config=vllm_config,
452
+ prefix=prefix,
453
+ model_type=model_type)
454
+
455
+ for attr in ("output", "logits_processor", "sampler"):
456
+ delattr(self, attr)
457
+
458
+ config = vllm_config.model_config.hf_config
459
+ self.v_head = RowParallelLinear(
460
+ config.hidden_size,
461
+ 1,
462
+ bias=False,
463
+ input_is_parallel=False,
464
+ prefix=maybe_prefix(prefix, "v_head"),
465
+ )
466
+
467
+ pooler_config = vllm_config.model_config.pooler_config
468
+ self._pooler = Pooler.from_config_with_defaults(
469
+ pooler_config,
470
+ pooling_type=PoolingType.ALL,
471
+ normalize=False,
472
+ softmax=False,
473
+ )
474
+
475
+ def forward(
476
+ self,
477
+ input_ids: torch.Tensor,
478
+ positions: torch.Tensor,
479
+ kv_caches: List[torch.Tensor],
480
+ attn_metadata: AttentionMetadata,
481
+ intermediate_tensors: Optional[IntermediateTensors] = None,
482
+ inputs_embeds: Optional[torch.Tensor] = None,
483
+ ) -> Union[torch.Tensor, IntermediateTensors]:
484
+ hidden_states = self.model(input_ids, positions, kv_caches,
485
+ attn_metadata, intermediate_tensors,
486
+ inputs_embeds)
487
+ logits, _ = self.v_head(hidden_states)
488
+ return logits
489
+
490
+ def pooler(
491
+ self,
492
+ hidden_states: torch.Tensor,
493
+ pooling_metadata: PoolingMetadata,
494
+ ) -> Optional[PoolerOutput]:
495
+ return self._pooler(hidden_states, pooling_metadata)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/internvl.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
4
+ # --------------------------------------------------------
5
+ # InternVL
6
+ # Copyright (c) 2023 OpenGVLab
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # --------------------------------------------------------
9
+ from abc import ABC, abstractmethod
10
+ from functools import cached_property
11
+ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
12
+ TypedDict, TypeVar, Union)
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torchvision.transforms as T
17
+ from PIL import Image
18
+ from transformers import BatchFeature, PretrainedConfig, TensorType
19
+
20
+ from vllm.attention import AttentionMetadata
21
+ from vllm.config import VllmConfig
22
+ from vllm.model_executor.layers.quantization import QuantizationConfig
23
+ from vllm.model_executor.layers.quantization.awq import AWQConfig
24
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
25
+ from vllm.model_executor.models.intern_vit import (InternVisionModel,
26
+ InternVisionPatchModel)
27
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
28
+ from vllm.multimodal import MULTIMODAL_REGISTRY
29
+ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
30
+ NestedTensors)
31
+ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
+ ImageSize, MultiModalDataItems)
33
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
+ BaseProcessingInfo, PromptReplacement,
35
+ PromptReplacementDetails)
36
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
+ from vllm.sequence import IntermediateTensors
38
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
39
+
40
+ from .interfaces import SupportsMultiModal, SupportsPP
41
+ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
42
+ maybe_prefix, merge_multimodal_embeddings)
43
+
44
+ IMG_START = '<img>'
45
+ IMG_END = '</img>'
46
+ IMG_CONTEXT = '<IMG_CONTEXT>'
47
+
48
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
49
+ IMAGENET_STD = (0.229, 0.224, 0.225)
50
+
51
+
52
+ class InternVLImagePixelInputs(TypedDict):
53
+ type: Literal["pixel_values"]
54
+ data: torch.Tensor
55
+ """
56
+ Shape:
57
+ `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
58
+ """
59
+ patches_per_image: List[int]
60
+ """
61
+ List of number of total patches for each image in the batch.
62
+ """
63
+
64
+
65
+ class InternVLImageEmbeddingInputs(TypedDict):
66
+ type: Literal["image_embeds"]
67
+ data: NestedTensors
68
+ """
69
+ A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
70
+ or a list of tensors of shape `(total_image_feature_size, hidden_size)`
71
+
72
+ `hidden_size` must match the hidden size of language model backbone.
73
+ """
74
+
75
+
76
+ InternVLImageInputs = Union[InternVLImagePixelInputs,
77
+ InternVLImageEmbeddingInputs]
78
+
79
+
80
+ # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
81
+ def build_transform(input_size: int):
82
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
83
+ return T.Compose([
84
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
85
+ T.Resize((input_size, input_size),
86
+ interpolation=T.InterpolationMode.BICUBIC),
87
+ T.ToTensor(),
88
+ T.Normalize(mean=MEAN, std=STD)
89
+ ])
90
+
91
+
92
+ # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
93
+ def find_closest_aspect_ratio(
94
+ aspect_ratio: float,
95
+ target_ratios: list[tuple[int, int]],
96
+ *,
97
+ width: int,
98
+ height: int,
99
+ image_size: int,
100
+ ) -> tuple[int, int]:
101
+ best_ratio_diff = float('inf')
102
+ best_ratio = (1, 1)
103
+ area = width * height
104
+ for ratio in target_ratios:
105
+ target_aspect_ratio = ratio[0] / ratio[1]
106
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
107
+ if ratio_diff < best_ratio_diff:
108
+ best_ratio_diff = ratio_diff
109
+ best_ratio = ratio
110
+ elif ratio_diff == best_ratio_diff:
111
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
112
+ best_ratio = ratio
113
+ return best_ratio
114
+
115
+
116
+ def resolve_internvl_min_max_num(
117
+ *,
118
+ min_dynamic_patch: int,
119
+ max_dynamic_patch: int,
120
+ dynamic_image_size: bool,
121
+ use_thumbnail: bool,
122
+ ) -> tuple[int, int]:
123
+ max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
124
+
125
+ if use_thumbnail and max_dynamic_patch != 1:
126
+ max_dynamic_patch += 1
127
+
128
+ return min_dynamic_patch, max_dynamic_patch
129
+
130
+
131
+ def get_internvl_target_ratios(
132
+ min_num: int,
133
+ max_num: int,
134
+ ) -> list[tuple[int, int]]:
135
+ target_ratios = {(i, j)
136
+ for n in range(min_num, max_num + 1)
137
+ for i in range(1, n + 1)
138
+ for j in range(1, n + 1) if min_num <= i * j <= max_num}
139
+ return sorted(target_ratios, key=lambda x: x[0] * x[1])
140
+
141
+
142
+ def calculate_internvl_targets(
143
+ *,
144
+ orig_width: int,
145
+ orig_height: int,
146
+ target_ratios: list[tuple[int, int]],
147
+ image_size: int,
148
+ use_thumbnail: bool,
149
+ ) -> tuple[int, int, int]:
150
+ aspect_ratio = orig_width / orig_height
151
+
152
+ # find the closest aspect ratio to the target
153
+ target_aspect_ratio = find_closest_aspect_ratio(
154
+ aspect_ratio,
155
+ target_ratios,
156
+ width=orig_width,
157
+ height=orig_height,
158
+ image_size=image_size,
159
+ )
160
+
161
+ # calculate the target width and height
162
+ target_width = image_size * target_aspect_ratio[0]
163
+ target_height = image_size * target_aspect_ratio[1]
164
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
165
+
166
+ # add thumbnail image if num_blocks != 1
167
+ if use_thumbnail and blocks != 1:
168
+ blocks += 1
169
+
170
+ return blocks, target_width, target_height
171
+
172
+
173
+ # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
174
+ def dynamic_preprocess_internvl(
175
+ image: Image.Image,
176
+ *,
177
+ target_ratios: list[tuple[int, int]],
178
+ image_size: int,
179
+ use_thumbnail: bool,
180
+ ) -> list[Image.Image]:
181
+ orig_width, orig_height = image.size
182
+
183
+ # calculate the number of blocks without thumbnail
184
+ blocks, target_width, target_height = calculate_internvl_targets(
185
+ orig_width=orig_width,
186
+ orig_height=orig_height,
187
+ target_ratios=target_ratios,
188
+ image_size=image_size,
189
+ use_thumbnail=False,
190
+ )
191
+
192
+ # resize the image
193
+ resized_img = image.resize((target_width, target_height))
194
+ processed_images = []
195
+ for i in range(blocks):
196
+ box = ((i % (target_width // image_size)) * image_size,
197
+ (i // (target_width // image_size)) * image_size,
198
+ ((i % (target_width // image_size)) + 1) * image_size,
199
+ ((i // (target_width // image_size)) + 1) * image_size)
200
+ # split the image
201
+ split_img = resized_img.crop(box)
202
+ processed_images.append(split_img)
203
+
204
+ assert len(processed_images) == blocks
205
+
206
+ if use_thumbnail and len(processed_images) != 1:
207
+ thumbnail_img = image.resize((image_size, image_size))
208
+ processed_images.append(thumbnail_img)
209
+
210
+ return processed_images
211
+
212
+
213
+ # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
214
+ def image_to_pixel_values_internvl(
215
+ image: Image.Image,
216
+ *,
217
+ input_size: int,
218
+ min_num: int,
219
+ max_num: int,
220
+ use_thumbnail: bool,
221
+ ) -> torch.Tensor:
222
+ target_ratios = get_internvl_target_ratios(min_num, max_num)
223
+
224
+ transform = build_transform(input_size=input_size)
225
+ images = dynamic_preprocess_internvl(
226
+ image,
227
+ target_ratios=target_ratios,
228
+ image_size=input_size,
229
+ use_thumbnail=use_thumbnail,
230
+ )
231
+
232
+ pixel_values = torch.stack([transform(image) for image in images])
233
+ return pixel_values
234
+
235
+
236
+ class BaseInternVLProcessor(ABC):
237
+ """
238
+ This model doesn't define its own HF processor,
239
+ so we implement our own one here.
240
+
241
+ The code to insert image tokens is based on:
242
+ https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ config: PretrainedConfig,
248
+ tokenizer: AnyTokenizer,
249
+ *,
250
+ max_dynamic_patch: Optional[int] = None,
251
+ dynamic_image_size: Optional[bool] = None,
252
+ ) -> None:
253
+ super().__init__()
254
+
255
+ self.config = config
256
+ self.tokenizer = tokenizer
257
+
258
+ image_size: int = config.vision_config.image_size
259
+ patch_size: int = config.vision_config.patch_size
260
+
261
+ if dynamic_image_size is None:
262
+ dynamic_image_size = config.dynamic_image_size
263
+ assert isinstance(dynamic_image_size, bool)
264
+
265
+ if max_dynamic_patch is None:
266
+ max_dynamic_patch = config.max_dynamic_patch
267
+ assert isinstance(max_dynamic_patch, int)
268
+
269
+ self.num_image_token = int(
270
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
271
+ self.image_size = image_size
272
+ self.min_dynamic_patch: int = config.min_dynamic_patch
273
+ self.max_dynamic_patch = max_dynamic_patch
274
+ self.dynamic_image_size = dynamic_image_size
275
+ self.use_thumbnail: bool = config.use_thumbnail
276
+
277
+ @property
278
+ @abstractmethod
279
+ def image_token_id(self) -> int:
280
+ raise NotImplementedError
281
+
282
+ @abstractmethod
283
+ def get_image_repl_features(
284
+ self,
285
+ feature_size: int,
286
+ num_patches: Optional[int],
287
+ ) -> str:
288
+ raise NotImplementedError
289
+
290
+ @abstractmethod
291
+ def get_image_repl_full(
292
+ self,
293
+ feature_size: int,
294
+ num_patches: Optional[int],
295
+ ) -> str:
296
+ raise NotImplementedError
297
+
298
+ def resolve_min_max_num(
299
+ self,
300
+ *,
301
+ max_dynamic_patch: Optional[int] = None,
302
+ dynamic_image_size: Optional[bool] = None,
303
+ use_thumbnail: Optional[bool] = None,
304
+ ) -> tuple[int, int]:
305
+ min_dynamic_patch = self.min_dynamic_patch
306
+ max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
307
+ is None else max_dynamic_patch)
308
+ dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
309
+ is None else dynamic_image_size)
310
+ use_thumbnail = (self.use_thumbnail
311
+ if use_thumbnail is None else use_thumbnail)
312
+
313
+ return resolve_internvl_min_max_num(
314
+ min_dynamic_patch=min_dynamic_patch,
315
+ max_dynamic_patch=max_dynamic_patch,
316
+ dynamic_image_size=dynamic_image_size,
317
+ use_thumbnail=use_thumbnail,
318
+ )
319
+
320
+ def resolve_target_ratios(
321
+ self,
322
+ *,
323
+ max_dynamic_patch: Optional[int] = None,
324
+ dynamic_image_size: Optional[bool] = None,
325
+ use_thumbnail: Optional[bool] = None,
326
+ ) -> list[tuple[int, int]]:
327
+ min_num, max_num = self.resolve_min_max_num(
328
+ max_dynamic_patch=max_dynamic_patch,
329
+ dynamic_image_size=dynamic_image_size,
330
+ use_thumbnail=use_thumbnail,
331
+ )
332
+
333
+ return get_internvl_target_ratios(min_num, max_num)
334
+
335
+ def get_num_image_tokens(
336
+ self,
337
+ *,
338
+ image_width: int,
339
+ image_height: int,
340
+ ) -> int:
341
+ target_ratios = self.resolve_target_ratios(
342
+ use_thumbnail=False, # Applied in calculate_targets
343
+ )
344
+
345
+ num_patches, _, _ = calculate_internvl_targets(
346
+ orig_width=image_width,
347
+ orig_height=image_height,
348
+ image_size=self.image_size,
349
+ target_ratios=target_ratios,
350
+ use_thumbnail=self.use_thumbnail,
351
+ )
352
+
353
+ return num_patches * self.num_image_token
354
+
355
+ def _images_to_pixel_values_lst(
356
+ self,
357
+ images: list[Image.Image],
358
+ max_dynamic_patch: Optional[int] = None,
359
+ dynamic_image_size: Optional[bool] = None,
360
+ ) -> list[torch.Tensor]:
361
+ min_num, max_num = self.resolve_min_max_num(
362
+ max_dynamic_patch=max_dynamic_patch,
363
+ dynamic_image_size=dynamic_image_size,
364
+ use_thumbnail=False, # Applied in image_to_pixel_values
365
+ )
366
+
367
+ return [
368
+ image_to_pixel_values_internvl(
369
+ image,
370
+ input_size=self.image_size,
371
+ min_num=min_num,
372
+ max_num=max_num,
373
+ use_thumbnail=self.use_thumbnail,
374
+ ) for image in images
375
+ ]
376
+
377
+ def __call__(
378
+ self,
379
+ text: Optional[Union[str, list[str]]] = None,
380
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
381
+ max_dynamic_patch: Optional[int] = None,
382
+ dynamic_image_size: Optional[bool] = None,
383
+ return_tensors: Optional[Union[str, TensorType]] = None,
384
+ ) -> BatchFeature:
385
+ if text is None:
386
+ text = []
387
+ if not isinstance(text, list):
388
+ text = [text]
389
+ if images is None:
390
+ images = []
391
+ if not isinstance(images, list):
392
+ images = [images]
393
+
394
+ if len(images) == 0:
395
+ image_inputs = {}
396
+ else:
397
+ pixel_values_lst = self._images_to_pixel_values_lst(
398
+ images,
399
+ max_dynamic_patch=max_dynamic_patch,
400
+ dynamic_image_size=dynamic_image_size,
401
+ )
402
+ image_inputs = {
403
+ "pixel_values_flat": torch.cat(pixel_values_lst),
404
+ "image_num_patches": list(map(len, pixel_values_lst)),
405
+ }
406
+
407
+ for pixel_values in pixel_values_lst:
408
+ num_patches = pixel_values.shape[0]
409
+ feature_size = num_patches * self.num_image_token
410
+
411
+ image_repl = self.get_image_repl_full(feature_size,
412
+ num_patches)
413
+ text = [t.replace('<image>', image_repl, 1) for t in text]
414
+
415
+ text_inputs = self.tokenizer(text)
416
+
417
+ return BatchFeature(
418
+ {
419
+ **text_inputs,
420
+ **image_inputs,
421
+ },
422
+ tensor_type=return_tensors,
423
+ )
424
+
425
+
426
+ class InternVLProcessor(BaseInternVLProcessor):
427
+
428
+ @property
429
+ def image_token_id(self) -> int:
430
+ return self.tokenizer.get_vocab()[IMG_CONTEXT]
431
+
432
+ def get_image_repl_features(
433
+ self,
434
+ feature_size: int,
435
+ num_patches: Optional[int],
436
+ ) -> str:
437
+ return IMG_CONTEXT * feature_size
438
+
439
+ def get_image_repl_full(
440
+ self,
441
+ feature_size: int,
442
+ num_patches: Optional[int],
443
+ ) -> str:
444
+ features = self.get_image_repl_features(feature_size, num_patches)
445
+ return IMG_START + features + IMG_END
446
+
447
+
448
+ class BaseInternVLProcessingInfo(BaseProcessingInfo):
449
+
450
+ @abstractmethod
451
+ def get_hf_processor(
452
+ self,
453
+ *,
454
+ max_dynamic_patch: Optional[int] = None,
455
+ dynamic_image_size: Optional[bool] = None,
456
+ ) -> BaseInternVLProcessor:
457
+ raise NotImplementedError
458
+
459
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
460
+ return {"image": None}
461
+
462
+ def get_mm_max_tokens_per_item(
463
+ self,
464
+ seq_len: int,
465
+ mm_counts: Mapping[str, int],
466
+ ) -> Mapping[str, int]:
467
+ return {"image": self.get_max_image_tokens()}
468
+
469
+ def get_num_image_tokens(
470
+ self,
471
+ *,
472
+ image_width: int,
473
+ image_height: int,
474
+ processor: Optional[BaseInternVLProcessor],
475
+ ) -> int:
476
+ if processor is None:
477
+ processor = self.get_hf_processor()
478
+
479
+ return processor.get_num_image_tokens(
480
+ image_width=image_width,
481
+ image_height=image_height,
482
+ )
483
+
484
+ def get_max_image_tokens(self) -> int:
485
+ target_width, target_height = self.get_image_size_with_most_features()
486
+
487
+ return self.get_num_image_tokens(
488
+ image_width=target_width,
489
+ image_height=target_height,
490
+ processor=None,
491
+ )
492
+
493
+ def get_image_size_with_most_features(self) -> ImageSize:
494
+ processor = self.get_hf_processor()
495
+
496
+ base_size = processor.image_size
497
+ target_ratios = processor.resolve_target_ratios()
498
+
499
+ largest_feature_size, largest_feature_pinpoint = 0, None
500
+ for wr, hr in target_ratios:
501
+ width, height = base_size * wr, base_size * hr
502
+
503
+ feat_size = self.get_num_image_tokens(
504
+ image_width=width,
505
+ image_height=height,
506
+ processor=processor,
507
+ )
508
+ if feat_size > largest_feature_size:
509
+ largest_feature_size = feat_size
510
+ largest_feature_pinpoint = ImageSize(width=width,
511
+ height=height)
512
+
513
+ if largest_feature_size == 0 or largest_feature_pinpoint is None:
514
+ raise ValueError("Cannot have a largest feature size of 0!")
515
+
516
+ return largest_feature_pinpoint
517
+
518
+
519
+ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
520
+
521
+
522
+ class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
523
+
524
+ def get_dummy_processor_inputs(
525
+ self,
526
+ seq_len: int,
527
+ mm_counts: Mapping[str, int],
528
+ ) -> ProcessorInputs:
529
+ target_width, target_height = \
530
+ self.info.get_image_size_with_most_features()
531
+ num_images = mm_counts.get("image", 0)
532
+
533
+ mm_data = {
534
+ "image":
535
+ self._get_dummy_images(width=target_width,
536
+ height=target_height,
537
+ num_images=num_images)
538
+ }
539
+
540
+ return ProcessorInputs(
541
+ prompt_text="<image>" * num_images,
542
+ mm_data=mm_data,
543
+ )
544
+
545
+
546
+ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
547
+
548
+ def _call_hf_processor(
549
+ self,
550
+ prompt: str,
551
+ mm_data: Mapping[str, object],
552
+ mm_kwargs: Mapping[str, object],
553
+ ) -> BatchFeature:
554
+ processed_outputs = super()._call_hf_processor(
555
+ prompt=prompt,
556
+ mm_data=mm_data,
557
+ mm_kwargs=mm_kwargs,
558
+ )
559
+
560
+ image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
561
+ image_data = mm_data.get("images", [])
562
+ assert isinstance(image_data, list)
563
+
564
+ # Since there may be extra tokens in the feature placeholders,
565
+ # we need to pass the image token ID to the model to select the
566
+ # tokens to merge from the vision encoder outputs
567
+ processed_outputs["image_token_id"] = torch.tensor(image_token_id)
568
+
569
+ return processed_outputs
570
+
571
+ def _get_mm_fields_config(
572
+ self,
573
+ hf_inputs: BatchFeature,
574
+ hf_processor_mm_kwargs: Mapping[str, object],
575
+ ) -> Mapping[str, MultiModalFieldConfig]:
576
+ image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
577
+ num_images = len(image_num_patches)
578
+
579
+ return dict(
580
+ pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
581
+ "image", image_num_patches),
582
+ image_num_patches=MultiModalFieldConfig.batched("image"),
583
+ image_embeds=MultiModalFieldConfig.batched("image"),
584
+ image_token_id=MultiModalFieldConfig.shared("image", num_images),
585
+ )
586
+
587
+ def _get_prompt_replacements(
588
+ self,
589
+ mm_items: MultiModalDataItems,
590
+ hf_processor_mm_kwargs: Mapping[str, object],
591
+ out_mm_kwargs: MultiModalKwargs,
592
+ ) -> list[PromptReplacement]:
593
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
594
+
595
+ if "image_num_patches" in out_mm_kwargs:
596
+ image_num_patches = out_mm_kwargs["image_num_patches"]
597
+ assert isinstance(image_num_patches, torch.Tensor)
598
+ image_num_patches = image_num_patches.tolist()
599
+ elif "image_embeds" in out_mm_kwargs:
600
+ # TODO: Use image size information in dictionary embedding inputs
601
+ # to compute num_patches (similar to Qwen2-VL)
602
+ image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
603
+ else:
604
+ image_num_patches = []
605
+
606
+ def get_replacement_internvl(item_idx: int):
607
+ images = mm_items.get_items(
608
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
609
+
610
+ if isinstance(images, ImageEmbeddingItems):
611
+ feature_size = images.get_feature_size(item_idx)
612
+ else:
613
+ image_size = images.get_image_size(item_idx)
614
+ feature_size = self.info.get_num_image_tokens(
615
+ image_width=image_size.width,
616
+ image_height=image_size.height,
617
+ processor=hf_processor,
618
+ )
619
+
620
+ num_patches = image_num_patches[item_idx]
621
+ if num_patches is not None:
622
+ assert isinstance(num_patches, int)
623
+
624
+ return PromptReplacementDetails(
625
+ full=hf_processor.get_image_repl_full(feature_size,
626
+ num_patches),
627
+ features=hf_processor.get_image_repl_features(
628
+ feature_size, num_patches),
629
+ )
630
+
631
+ return [
632
+ PromptReplacement(
633
+ modality="image",
634
+ target="<image>",
635
+ replacement=get_replacement_internvl,
636
+ )
637
+ ]
638
+
639
+
640
+ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
641
+
642
+ def get_hf_processor(
643
+ self,
644
+ *,
645
+ max_dynamic_patch: Optional[int] = None,
646
+ dynamic_image_size: Optional[bool] = None,
647
+ ) -> InternVLProcessor:
648
+ return InternVLProcessor(
649
+ self.get_hf_config(),
650
+ self.get_tokenizer(),
651
+ max_dynamic_patch=max_dynamic_patch,
652
+ dynamic_image_size=dynamic_image_size,
653
+ )
654
+
655
+
656
+ @MULTIMODAL_REGISTRY.register_processor(
657
+ InternVLMultiModalProcessor,
658
+ info=InternVLProcessingInfo,
659
+ dummy_inputs=InternVLDummyInputsBuilder)
660
+ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
661
+
662
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
663
+ super().__init__()
664
+
665
+ config = vllm_config.model_config.hf_config
666
+ quant_config = vllm_config.quant_config
667
+ multimodal_config = vllm_config.model_config.multimodal_config
668
+
669
+ self.config = config
670
+ self.multimodal_config = multimodal_config
671
+ self._patch_quant_config(config, quant_config)
672
+
673
+ image_size = config.force_image_size or config.vision_config.image_size
674
+ patch_size = config.vision_config.patch_size
675
+ self.patch_size = patch_size
676
+ self.num_image_token = int(
677
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
678
+ self.downsample_ratio = config.downsample_ratio
679
+ self.ps_version = config.ps_version
680
+
681
+ self.llm_arch_name = config.text_config.architectures[0]
682
+ self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
683
+ self.vision_model = self._init_vision_model(
684
+ config,
685
+ quant_config=quant_config,
686
+ is_mono=self.is_mono,
687
+ prefix=maybe_prefix(prefix, "vision_model"),
688
+ )
689
+
690
+ self.language_model = init_vllm_registered_model(
691
+ vllm_config=vllm_config,
692
+ hf_config=config.text_config,
693
+ prefix=maybe_prefix(prefix, "language_model"),
694
+ )
695
+
696
+ self.mlp1 = self._init_mlp1(config)
697
+
698
+ self.img_context_token_id = None
699
+ self.visual_token_mask = None
700
+ self.make_empty_intermediate_tensors = (
701
+ self.language_model.make_empty_intermediate_tensors)
702
+
703
+ def _patch_quant_config(self, config: PretrainedConfig,
704
+ quant_config: QuantizationConfig):
705
+ # the awq models from OpenGVLab missing `modules_to_not_convert`
706
+ # patch the quant_config to add `modules_to_not_convert` back
707
+ if isinstance(quant_config, AWQConfig):
708
+ text_config = config.text_config
709
+ llm_quant_config = getattr(text_config, "quantization_config",
710
+ None)
711
+ if (not quant_config.modules_to_not_convert) and \
712
+ (llm_quant_config is not None):
713
+ quant_config.modules_to_not_convert.append("vision_model")
714
+
715
+ @cached_property
716
+ def sampler(self):
717
+ if hasattr(self.language_model, "sampler"):
718
+ return self.language_model.sampler
719
+
720
+ return get_sampler()
721
+
722
+ def _init_vision_model(
723
+ self,
724
+ config: PretrainedConfig,
725
+ quant_config: Optional[QuantizationConfig],
726
+ *,
727
+ is_mono: bool,
728
+ prefix: str,
729
+ ):
730
+ if not is_mono:
731
+ vision_feature_layer = config.select_layer
732
+ if vision_feature_layer < 0:
733
+ num_hidden_layers = config.vision_config.num_hidden_layers \
734
+ + vision_feature_layer + 1
735
+ else:
736
+ num_hidden_layers = vision_feature_layer + 1
737
+
738
+ return InternVisionModel(
739
+ config.vision_config,
740
+ quant_config=quant_config,
741
+ num_hidden_layers_override=num_hidden_layers,
742
+ prefix=prefix,
743
+ )
744
+ else:
745
+ return InternVisionPatchModel(config.vision_config)
746
+
747
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
748
+ vit_hidden_size = config.vision_config.hidden_size
749
+ llm_hidden_size = config.text_config.hidden_size
750
+
751
+ return nn.Sequential(
752
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
753
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
754
+ llm_hidden_size),
755
+ nn.GELU(),
756
+ nn.Linear(llm_hidden_size, llm_hidden_size),
757
+ )
758
+
759
+ def pixel_shuffle(self, x, scale_factor=0.5):
760
+ n, w, h, c = x.size()
761
+ # N, W, H, C --> N, W, H * scale, C // scale
762
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
763
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
764
+ x = x.permute(0, 2, 1, 3).contiguous()
765
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
766
+ int(c / (scale_factor * scale_factor)))
767
+ if self.ps_version == 'v1':
768
+ pass
769
+ else:
770
+ x = x.permute(0, 2, 1, 3).contiguous()
771
+ return x
772
+
773
+ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
774
+ vit_embeds = self.vision_model(pixel_values=pixel_values)
775
+ vit_embeds = vit_embeds[:, 1:, :]
776
+
777
+ h = w = int(vit_embeds.shape[1]**0.5)
778
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
779
+ vit_embeds = self.pixel_shuffle(vit_embeds,
780
+ scale_factor=self.downsample_ratio)
781
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
782
+ vit_embeds.shape[-1])
783
+ vit_embeds = self.mlp1(vit_embeds)
784
+ return vit_embeds
785
+
786
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
787
+
788
+ h = w = self.config.vision_config.image_size
789
+ expected_dims = (3, h, w)
790
+
791
+ def _validate_shape(d: torch.Tensor):
792
+ actual_dims = tuple(d.shape)
793
+
794
+ if actual_dims != expected_dims:
795
+ expected_expr = str(expected_dims)
796
+ raise ValueError(
797
+ "The expected shape of pixel values per image per batch "
798
+ f" per patch is {expected_expr}. "
799
+ f"You supplied {tuple(d.shape)}.")
800
+
801
+ for d in data:
802
+ _validate_shape(d)
803
+
804
+ return data
805
+
806
+ def _parse_and_validate_image_input(
807
+ self, **kwargs: object) -> Optional[InternVLImageInputs]:
808
+ pixel_values_flat = kwargs.pop("pixel_values_flat", None)
809
+ image_num_patches = kwargs.pop("image_num_patches", None)
810
+ image_embeds = kwargs.pop("image_embeds", None)
811
+
812
+ if pixel_values_flat is None and image_embeds is None:
813
+ return None
814
+
815
+ if image_embeds is not None:
816
+ if not isinstance(image_embeds, torch.Tensor):
817
+ raise ValueError("Incorrect type of image embeddings. "
818
+ f"Got type: {type(image_embeds)}")
819
+
820
+ return InternVLImageEmbeddingInputs(
821
+ type="image_embeds",
822
+ data=flatten_bn(image_embeds),
823
+ )
824
+
825
+ image_token_id = kwargs["image_token_id"]
826
+ assert isinstance(image_token_id, torch.Tensor)
827
+ self.img_context_token_id = image_token_id.flatten().unique().item()
828
+
829
+ if pixel_values_flat is not None:
830
+ if not isinstance(pixel_values_flat, (torch.Tensor, list)):
831
+ raise ValueError("Incorrect type of pixel values. "
832
+ f"Got type: {type(pixel_values_flat)}")
833
+
834
+ assert isinstance(image_num_patches, (torch.Tensor, list))
835
+
836
+ return InternVLImagePixelInputs(
837
+ type="pixel_values",
838
+ data=self._validate_pixel_values(
839
+ flatten_bn(pixel_values_flat, concat=True)),
840
+ patches_per_image=flatten_bn(image_num_patches,
841
+ concat=True).tolist())
842
+
843
+ raise AssertionError("This line should be unreachable.")
844
+
845
+ def _process_image_input(
846
+ self,
847
+ image_input: InternVLImageInputs,
848
+ ) -> tuple[torch.Tensor, ...]:
849
+ if image_input["type"] == "image_embeds":
850
+ return image_input["data"]
851
+
852
+ assert self.vision_model is not None
853
+
854
+ image_embeds = self.extract_feature(image_input["data"])
855
+
856
+ patches_per_image = image_input["patches_per_image"]
857
+
858
+ # Only one image in the current batch
859
+ if len(patches_per_image) == 1:
860
+ image_embeds = image_embeds.view(
861
+ -1, self.config.text_config.hidden_size).unsqueeze(0)
862
+ return image_embeds
863
+
864
+ # NOTE: Image embeddings are split into separate tensors for each image
865
+ # by the size of each embedding.
866
+ feature_size = image_embeds.shape[1]
867
+ image_embeds = image_embeds.view(-1,
868
+ self.config.text_config.hidden_size)
869
+ image_feature_sizes = [
870
+ num_patches * feature_size for num_patches in patches_per_image
871
+ ]
872
+ image_embeds = image_embeds.split(image_feature_sizes)
873
+ return image_embeds
874
+
875
+ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
876
+ if self.is_mono:
877
+ self.visual_token_mask = (
878
+ input_ids == self.img_context_token_id).reshape(-1, 1)
879
+ else:
880
+ self.visual_token_mask = None
881
+
882
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
883
+ image_input = self._parse_and_validate_image_input(**kwargs)
884
+ if image_input is None:
885
+ return None
886
+ vision_embeddings = self._process_image_input(image_input)
887
+ return vision_embeddings
888
+
889
+ def get_input_embeddings(
890
+ self,
891
+ input_ids: torch.Tensor,
892
+ multimodal_embeddings: Optional[NestedTensors] = None,
893
+ ) -> torch.Tensor:
894
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
895
+ if multimodal_embeddings is not None:
896
+ assert self.img_context_token_id is not None
897
+ self._set_visual_token_mask(input_ids)
898
+ inputs_embeds = merge_multimodal_embeddings(
899
+ input_ids, inputs_embeds, multimodal_embeddings,
900
+ self.img_context_token_id)
901
+ return inputs_embeds
902
+
903
+ def forward(
904
+ self,
905
+ input_ids: torch.Tensor,
906
+ positions: torch.Tensor,
907
+ kv_caches: List[torch.Tensor],
908
+ attn_metadata: AttentionMetadata,
909
+ intermediate_tensors: Optional[IntermediateTensors] = None,
910
+ inputs_embeds: Optional[torch.Tensor] = None,
911
+ **kwargs: object,
912
+ ) -> Union[SamplerOutput, IntermediateTensors]:
913
+
914
+ if intermediate_tensors is not None:
915
+ input_ids = None
916
+ inputs_embeds = None
917
+
918
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
919
+ # condition is for v0 compatibility.
920
+ elif inputs_embeds is None:
921
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
922
+ inputs_embeds = self.get_input_embeddings(input_ids,
923
+ vision_embeddings)
924
+ input_ids = None
925
+
926
+ forward_kwargs = {
927
+ "input_ids": input_ids,
928
+ "positions": positions,
929
+ "kv_caches": kv_caches,
930
+ "attn_metadata": attn_metadata,
931
+ "intermediate_tensors": intermediate_tensors,
932
+ "inputs_embeds": inputs_embeds,
933
+ }
934
+
935
+ # Only required if the model is mono-architecture
936
+ if self.visual_token_mask is not None:
937
+ forward_kwargs.update(
938
+ {"visual_token_mask": self.visual_token_mask})
939
+ self.visual_token_mask = None
940
+
941
+ hidden_states = self.language_model.model(**forward_kwargs)
942
+ return hidden_states
943
+
944
+ def compute_logits(
945
+ self,
946
+ hidden_states: torch.Tensor,
947
+ sampling_metadata: SamplingMetadata,
948
+ ) -> Optional[torch.Tensor]:
949
+ return self.language_model.compute_logits(hidden_states,
950
+ sampling_metadata)
951
+
952
+ def sample(
953
+ self,
954
+ logits: torch.Tensor,
955
+ sampling_metadata: SamplingMetadata,
956
+ ) -> Optional[SamplerOutput]:
957
+ return self.language_model.sample(logits, sampling_metadata)
958
+
959
+ def load_weights(self, weights: Iterable[Tuple[str,
960
+ torch.Tensor]]) -> Set[str]:
961
+ loader = AutoWeightsLoader(self)
962
+ return loader.load_weights(weights)
.venv/lib/python3.11/site-packages/vllm/model_executor/models/jamba.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Inference-only Jamba model."""
3
+ from typing import Iterable, List, Optional, Set, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import JambaConfig
8
+
9
+ from vllm.attention.backends.abstract import AttentionMetadata
10
+ from vllm.attention.layer import Attention
11
+ from vllm.config import CacheConfig, VllmConfig
12
+ from vllm.distributed import get_tensor_model_parallel_world_size
13
+ from vllm.distributed.parallel_state import get_pp_group
14
+ from vllm.model_executor.layers.fused_moe import FusedMoE
15
+ from vllm.model_executor.layers.layernorm import RMSNorm
16
+ from vllm.model_executor.layers.linear import (QKVParallelLinear,
17
+ ReplicatedLinear,
18
+ RowParallelLinear)
19
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
+ from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
21
+ from vllm.model_executor.layers.pooler import Pooler, PoolingType
22
+ from vllm.model_executor.layers.quantization import QuantizationConfig
23
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
25
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
26
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
+ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
28
+ MambaCacheParams)
29
+ from vllm.model_executor.pooling_metadata import PoolingMetadata
30
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
31
+ from vllm.sequence import IntermediateTensors, PoolerOutput
32
+ from vllm.utils import LayerBlockType
33
+
34
+ from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
35
+ from .utils import (is_pp_missing_parameter,
36
+ make_empty_intermediate_tensors_factory, make_layers,
37
+ maybe_prefix)
38
+
39
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
40
+
41
+
42
+ class JambaMoE(nn.Module):
43
+
44
+ def __init__(self,
45
+ config: JambaConfig,
46
+ num_experts: Optional[int] = None,
47
+ top_k: Optional[int] = None,
48
+ params_dtype: Optional[torch.dtype] = None,
49
+ tp_size: Optional[int] = None,
50
+ quant_config: Optional[QuantizationConfig] = None):
51
+ super().__init__()
52
+ self.num_total_experts = num_experts or config.num_experts
53
+ self.top_k = top_k or config.num_experts_per_tok
54
+ self.hidden_size = config.hidden_size
55
+ self.intermediate_size = config.intermediate_size
56
+
57
+ if self.num_total_experts > 1:
58
+ self.router = ReplicatedLinear(self.hidden_size,
59
+ self.num_total_experts,
60
+ bias=False,
61
+ quant_config=None,
62
+ params_dtype=params_dtype)
63
+
64
+ self.experts = FusedMoE(self.num_total_experts,
65
+ self.top_k,
66
+ self.hidden_size,
67
+ self.intermediate_size,
68
+ tp_size=tp_size,
69
+ params_dtype=params_dtype,
70
+ reduce_results=True,
71
+ renormalize=False,
72
+ use_grouped_topk=False,
73
+ quant_config=quant_config)
74
+
75
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
76
+ orig_shape = hidden_states.shape
77
+ hidden_states = hidden_states.view(-1, self.hidden_size)
78
+ # router_logits: (batch * sequence_length, n_experts)
79
+ if self.num_total_experts > 1:
80
+ router_logits, _ = self.router(hidden_states)
81
+ else:
82
+ router_logits = torch.ones((hidden_states.shape[0], 1),
83
+ device=hidden_states.device,
84
+ dtype=hidden_states.dtype)
85
+ hidden_states = self.experts(hidden_states, router_logits)
86
+ return hidden_states.view(orig_shape)
87
+
88
+
89
+ class JambaMLP(JambaMoE):
90
+
91
+ def __init__(self,
92
+ config: JambaConfig,
93
+ params_dtype: Optional[torch.dtype] = None,
94
+ tp_size: Optional[int] = None,
95
+ quant_config: Optional[QuantizationConfig] = None):
96
+ super().__init__(config,
97
+ num_experts=1,
98
+ top_k=1,
99
+ params_dtype=params_dtype,
100
+ tp_size=tp_size,
101
+ quant_config=quant_config)
102
+
103
+
104
+ class JambaMambaDecoderLayer(nn.Module):
105
+
106
+ def __init__(self,
107
+ config: JambaConfig,
108
+ layer_idx: int,
109
+ cache_config: Optional[CacheConfig] = None,
110
+ quant_config: Optional[QuantizationConfig] = None,
111
+ is_lora_enabled: Optional[bool] = False,
112
+ **kwargs) -> None:
113
+ super().__init__()
114
+ self.config = config
115
+ self.is_lora_enabled = is_lora_enabled
116
+ self.mamba = MambaMixer(hidden_size= config.hidden_size,
117
+ ssm_state_size = config.mamba_d_state,
118
+ conv_kernel_size = config.mamba_d_conv,
119
+ intermediate_size = config.mamba_expand *\
120
+ config.hidden_size,
121
+ time_step_rank = config.mamba_dt_rank,
122
+ use_conv_bias = config.mamba_conv_bias,
123
+ use_bias = config.mamba_proj_bias,
124
+ use_rms_norm=True,
125
+ rms_norm_eps=config.rms_norm_eps,
126
+ activation=config.hidden_act,
127
+ is_lora_enabled = self.is_lora_enabled
128
+ )
129
+
130
+ num_experts = config.layers_num_experts[layer_idx]
131
+ ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
132
+ self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
133
+ self.input_layernorm = RMSNorm(config.hidden_size,
134
+ eps=config.rms_norm_eps)
135
+ self.pre_ff_layernorm = RMSNorm(config.hidden_size,
136
+ eps=config.rms_norm_eps)
137
+
138
+ def forward(
139
+ self,
140
+ hidden_states: torch.Tensor,
141
+ attn_metadata: AttentionMetadata,
142
+ residual: Optional[torch.Tensor],
143
+ mamba_cache_params: MambaCacheParams,
144
+ **kwargs,
145
+ ):
146
+ if residual is None:
147
+ residual = hidden_states
148
+ hidden_states = self.input_layernorm(hidden_states)
149
+ else:
150
+ hidden_states, residual = self.input_layernorm(
151
+ hidden_states, residual)
152
+
153
+ hidden_states = self.mamba(hidden_states, attn_metadata,
154
+ mamba_cache_params)
155
+ # Fully Connected
156
+ hidden_states, residual = self.pre_ff_layernorm(
157
+ hidden_states, residual)
158
+ hidden_states = self.feed_forward(hidden_states)
159
+ return hidden_states, residual
160
+
161
+
162
+ class JambaAttentionDecoderLayer(nn.Module):
163
+
164
+ def __init__(self,
165
+ config: JambaConfig,
166
+ layer_idx: int,
167
+ cache_config: Optional[CacheConfig] = None,
168
+ quant_config: Optional[QuantizationConfig] = None,
169
+ prefix: str = "",
170
+ **kwargs) -> None:
171
+ super().__init__()
172
+ self.hidden_size = config.hidden_size
173
+ tp_size = get_tensor_model_parallel_world_size()
174
+ self.total_num_heads = config.num_attention_heads
175
+ assert self.total_num_heads % tp_size == 0
176
+ self.num_heads = self.total_num_heads // tp_size
177
+ self.total_num_kv_heads = config.num_key_value_heads
178
+ if self.total_num_kv_heads >= tp_size:
179
+ # Number of KV heads is greater than TP size, so we partition
180
+ # the KV heads across multiple tensor parallel GPUs.
181
+ assert self.total_num_kv_heads % tp_size == 0
182
+ else:
183
+ # Number of KV heads is less than TP size, so we replicate
184
+ # the KV heads across multiple tensor parallel GPUs.
185
+ assert tp_size % self.total_num_kv_heads == 0
186
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
187
+ self.head_dim = config.hidden_size // self.total_num_heads
188
+ self.q_size = self.num_heads * self.head_dim
189
+ self.kv_size = self.num_kv_heads * self.head_dim
190
+ self.scaling = self.head_dim**-0.5
191
+
192
+ self.qkv_proj = QKVParallelLinear(
193
+ config.hidden_size,
194
+ self.head_dim,
195
+ self.total_num_heads,
196
+ self.total_num_kv_heads,
197
+ bias=False,
198
+ quant_config=quant_config,
199
+ )
200
+ self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
201
+ config.hidden_size,
202
+ bias=False,
203
+ quant_config=quant_config)
204
+
205
+ self.attn = Attention(
206
+ self.num_heads,
207
+ self.head_dim,
208
+ self.scaling,
209
+ num_kv_heads=self.num_kv_heads,
210
+ cache_config=cache_config,
211
+ prefix=f"{prefix}.attn",
212
+ )
213
+
214
+ num_experts = config.layers_num_experts[layer_idx]
215
+ ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
216
+ self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
217
+ self.input_layernorm = RMSNorm(config.hidden_size,
218
+ eps=config.rms_norm_eps)
219
+ self.pre_ff_layernorm = RMSNorm(config.hidden_size,
220
+ eps=config.rms_norm_eps)
221
+
222
+ def self_attention(
223
+ self,
224
+ positions: torch.Tensor,
225
+ hidden_states: torch.Tensor,
226
+ kv_cache: torch.Tensor,
227
+ attn_metadata: AttentionMetadata,
228
+ **kwargs,
229
+ ) -> torch.Tensor:
230
+ qkv, _ = self.qkv_proj(hidden_states)
231
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
232
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
233
+ output, _ = self.o_proj(attn_output)
234
+ return output
235
+
236
+ def forward(
237
+ self,
238
+ positions: torch.Tensor,
239
+ hidden_states: torch.Tensor,
240
+ kv_cache: torch.Tensor,
241
+ attn_metadata: AttentionMetadata,
242
+ residual: Optional[torch.Tensor],
243
+ **kwargs,
244
+ ):
245
+ if residual is None:
246
+ residual = hidden_states
247
+ hidden_states = self.input_layernorm(hidden_states)
248
+ else:
249
+ hidden_states, residual = self.input_layernorm(
250
+ hidden_states, residual)
251
+
252
+ hidden_states = self.self_attention(
253
+ positions=positions,
254
+ hidden_states=hidden_states,
255
+ kv_cache=kv_cache,
256
+ attn_metadata=attn_metadata,
257
+ )
258
+ # Fully Connected
259
+ hidden_states, residual = self.pre_ff_layernorm(
260
+ hidden_states, residual)
261
+ hidden_states = self.feed_forward(hidden_states)
262
+ return hidden_states, residual
263
+
264
+
265
+ ALL_DECODER_LAYER_TYPES = {
266
+ "attention": JambaAttentionDecoderLayer,
267
+ "mamba": JambaMambaDecoderLayer
268
+ }
269
+
270
+
271
+ class JambaModel(nn.Module):
272
+
273
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
274
+ super().__init__()
275
+
276
+ config = vllm_config.model_config.hf_config
277
+ cache_config = vllm_config.cache_config
278
+ quant_config = vllm_config.quant_config
279
+ lora_config = vllm_config.lora_config
280
+
281
+ self.config = config
282
+ self.padding_idx = config.pad_token_id
283
+ lora_vocab = ((lora_config.lora_extra_vocab_size *
284
+ (lora_config.max_loras or 1)) if lora_config else 0)
285
+ self.vocab_size = config.vocab_size + lora_vocab
286
+ self.org_vocab_size = config.vocab_size
287
+
288
+ self.embed_tokens = VocabParallelEmbedding(
289
+ self.vocab_size,
290
+ config.hidden_size,
291
+ org_num_embeddings=config.vocab_size,
292
+ )
293
+
294
+ extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
295
+
296
+ def get_layer(prefix: str):
297
+ layer_idx = int(prefix.rsplit(".", 1)[1])
298
+ layer_class = ALL_DECODER_LAYER_TYPES[
299
+ config.layers_block_type[layer_idx]]
300
+ return layer_class(config,
301
+ layer_idx,
302
+ cache_config,
303
+ quant_config=quant_config,
304
+ prefix=prefix,
305
+ **extra_kwargs)
306
+
307
+ self.start_layer, self.end_layer, self.layers = make_layers(
308
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
309
+ self.make_empty_intermediate_tensors = (
310
+ make_empty_intermediate_tensors_factory(
311
+ ["hidden_states", "residual"], config.hidden_size))
312
+
313
+ self.final_layernorm = RMSNorm(config.hidden_size,
314
+ eps=config.rms_norm_eps)
315
+
316
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
317
+ return self.embed_tokens(input_ids)
318
+
319
+ def forward(
320
+ self,
321
+ input_ids: torch.Tensor,
322
+ positions: torch.Tensor,
323
+ kv_caches: List[torch.Tensor],
324
+ attn_metadata: AttentionMetadata,
325
+ mamba_cache_params: MambaCacheParams,
326
+ intermediate_tensors: Optional[IntermediateTensors] = None,
327
+ inputs_embeds: Optional[torch.Tensor] = None,
328
+ ) -> torch.Tensor:
329
+ if get_pp_group().is_first_rank:
330
+ if inputs_embeds is not None:
331
+ hidden_states = inputs_embeds
332
+ else:
333
+ hidden_states = self.get_input_embeddings(input_ids)
334
+ residual = None
335
+ else:
336
+ assert intermediate_tensors is not None
337
+ hidden_states = intermediate_tensors["hidden_states"]
338
+ residual = intermediate_tensors["residual"]
339
+
340
+ kv_cache_index = 0
341
+ mamba_cache_index = 0
342
+ for i in range(self.start_layer, self.end_layer):
343
+ layer = self.layers[i]
344
+ kv_cache = None
345
+ layer_mamba_cache_params = None
346
+ if isinstance(layer, JambaAttentionDecoderLayer):
347
+ kv_cache = kv_caches[kv_cache_index]
348
+ kv_cache_index += 1
349
+ if isinstance(layer, JambaMambaDecoderLayer):
350
+ current_state_layer = mamba_cache_index
351
+ layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
352
+ current_state_layer)
353
+ mamba_cache_index += 1
354
+
355
+ hidden_states, residual = layer(
356
+ positions=positions,
357
+ hidden_states=hidden_states,
358
+ kv_cache=kv_cache,
359
+ attn_metadata=attn_metadata,
360
+ residual=residual,
361
+ mamba_cache_params=layer_mamba_cache_params)
362
+ if not get_pp_group().is_last_rank:
363
+ return IntermediateTensors({
364
+ "hidden_states": hidden_states,
365
+ "residual": residual
366
+ })
367
+ hidden_states, _ = self.final_layernorm(hidden_states, residual)
368
+ return hidden_states
369
+
370
+
371
+ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
372
+ IsHybrid):
373
+ packed_modules_mapping = {
374
+ "qkv_proj": [
375
+ "q_proj",
376
+ "k_proj",
377
+ "v_proj",
378
+ ],
379
+ "in_proj": ["in_proj"],
380
+ }
381
+
382
+ # LoRA specific attributes
383
+ supported_lora_modules = [
384
+ "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
385
+ "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
386
+ ]
387
+ embedding_modules = {
388
+ "embed_tokens": "input_embeddings",
389
+ "lm_head": "output_embeddings",
390
+ }
391
+ embedding_padding_modules = ["lm_head"]
392
+
393
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
394
+ config = vllm_config.model_config.hf_config
395
+ cache_config = vllm_config.cache_config
396
+ lora_config = vllm_config.lora_config
397
+ scheduler_config = vllm_config.scheduler_config
398
+ assert not cache_config.enable_prefix_caching, \
399
+ "Jamba currently does not support prefix caching"
400
+
401
+ super().__init__()
402
+ self.config = config
403
+ self.vllm_config = vllm_config
404
+ self.model_config = vllm_config.model_config
405
+ self.scheduler_config = scheduler_config
406
+ self.model = JambaModel(vllm_config=vllm_config,
407
+ prefix=maybe_prefix(prefix, "model"))
408
+ self.unpadded_vocab_size = config.vocab_size
409
+ if lora_config:
410
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
411
+ self.lm_head = ParallelLMHead(
412
+ self.unpadded_vocab_size,
413
+ config.hidden_size,
414
+ org_num_embeddings=config.vocab_size,
415
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
416
+ # We need bigger padding if using lora for kernel
417
+ # compatibility
418
+ if not lora_config else lora_config.lora_vocab_padding_size,
419
+ )
420
+ # Used to track and store by the Mamba cache between steps.
421
+ self.mamba_cache: Optional[MambaCacheManager] = None
422
+
423
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
424
+ config.vocab_size)
425
+ self.sampler = get_sampler()
426
+
427
+ self.make_empty_intermediate_tensors = (
428
+ self.model.make_empty_intermediate_tensors)
429
+ if self.scheduler_config is not None and \
430
+ not self.model_config.enforce_eager:
431
+ if self.scheduler_config.max_num_seqs > \
432
+ vllm_config.compilation_config.max_capture_size:
433
+ self.max_batch_size = \
434
+ vllm_config.compilation_config.max_capture_size
435
+ else:
436
+ self.max_batch_size = vllm_config.pad_for_cudagraph(
437
+ self.scheduler_config.max_num_seqs)
438
+ else:
439
+ self.max_batch_size = 8192 + 2
440
+
441
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
442
+ return self.model.get_input_embeddings(input_ids)
443
+
444
+ def forward(self,
445
+ input_ids: torch.Tensor,
446
+ positions: torch.Tensor,
447
+ kv_caches: List[KVCache],
448
+ attn_metadata: AttentionMetadata,
449
+ intermediate_tensors: Optional[IntermediateTensors] = None,
450
+ inputs_embeds: Optional[torch.Tensor] = None,
451
+ **kwargs):
452
+ if self.mamba_cache is None:
453
+ num_mamba_layers = self.model_config.get_num_layers_by_block_type(
454
+ self.vllm_config.parallel_config, LayerBlockType.mamba)
455
+ self.mamba_cache = MambaCacheManager(
456
+ self.lm_head.weight.dtype, num_mamba_layers,
457
+ self.max_batch_size, *self._get_mamba_cache_shape())
458
+ (
459
+ mamba_cache_tensors,
460
+ state_indices_tensor,
461
+ ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
462
+ **kwargs)
463
+ mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
464
+ mamba_cache_tensors[1],
465
+ state_indices_tensor)
466
+ hidden_states = self.model(input_ids, positions, kv_caches,
467
+ attn_metadata, mamba_cache_params,
468
+ intermediate_tensors, inputs_embeds)
469
+ return hidden_states
470
+
471
+ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
472
+ return self.mamba_cache.copy_inputs_before_cuda_graphs(
473
+ input_buffers, **kwargs)
474
+
475
+ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
476
+ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
477
+
478
+ def _get_mamba_cache_shape(
479
+ self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
480
+ world_size = get_tensor_model_parallel_world_size()
481
+ hidden_size = self.config.hidden_size
482
+ conv_state_shape = (
483
+ self.config.mamba_expand * hidden_size // world_size,
484
+ self.config.mamba_d_conv - 1,
485
+ )
486
+ temporal_state_shape = (
487
+ self.config.mamba_expand * hidden_size // world_size,
488
+ self.config.mamba_d_state,
489
+ )
490
+ return conv_state_shape, temporal_state_shape
491
+
492
+ def compute_logits(
493
+ self,
494
+ hidden_states: torch.Tensor,
495
+ sampling_metadata: SamplingMetadata,
496
+ ) -> Optional[torch.Tensor]:
497
+ logits = self.logits_processor(self.lm_head, hidden_states,
498
+ sampling_metadata)
499
+ return logits
500
+
501
+ def sample(
502
+ self,
503
+ logits: Optional[torch.Tensor],
504
+ sampling_metadata: SamplingMetadata,
505
+ ) -> Optional[SamplerOutput]:
506
+ next_tokens = self.sampler(logits, sampling_metadata)
507
+ return next_tokens
508
+
509
+ def load_weights(self, weights: Iterable[Tuple[str,
510
+ torch.Tensor]]) -> Set[str]:
511
+ stacked_params_mapping = [
512
+ # (param_name, shard_name, shard_id)
513
+ ("qkv_proj", "q_proj", "q"),
514
+ ("qkv_proj", "k_proj", "k"),
515
+ ("qkv_proj", "v_proj", "v"),
516
+ ]
517
+
518
+ # Params for weights, fp8 weight scales, fp8 activation scales
519
+ # (param_name, weight_name, expert_id, shard_id)
520
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
521
+ ckpt_gate_proj_name="gate_proj",
522
+ ckpt_down_proj_name="down_proj",
523
+ ckpt_up_proj_name="up_proj",
524
+ num_experts=self.config.num_experts)
525
+
526
+ params_dict = dict(self.named_parameters())
527
+ loaded_params: Set[str] = set()
528
+ for name, loaded_weight in weights:
529
+ if "rotary_emb.inv_freq" in name:
530
+ continue
531
+
532
+ if "A_log" in name:
533
+ name = name.replace("A_log", "A")
534
+
535
+ if ".self_attn." in name:
536
+ name = name.replace(".self_attn", "")
537
+
538
+ if "feed_forward" in name and not _is_moe_layer(name):
539
+ ## map MLP layers to expert with ID=0
540
+ name = name.replace("feed_forward", "feed_forward.experts.0")
541
+
542
+ for param_name, weight_name, shard_id in stacked_params_mapping:
543
+ if weight_name not in name:
544
+ continue
545
+ if 'experts' in name:
546
+ continue
547
+ name = name.replace(weight_name, param_name)
548
+ # Skip loading extra bias for GPTQ models.
549
+
550
+ if name.endswith(".bias") and name not in params_dict:
551
+ continue
552
+ # Skip layers on other devices.
553
+ if is_pp_missing_parameter(name, self):
554
+ continue
555
+ param = params_dict[name]
556
+ weight_loader = param.weight_loader
557
+ weight_loader(param, loaded_weight, shard_id)
558
+ break
559
+ else:
560
+ for (
561
+ param_name,
562
+ weight_name,
563
+ expert_id,
564
+ shard_id,
565
+ ) in expert_params_mapping:
566
+ if weight_name not in name:
567
+ continue
568
+
569
+ if is_pp_missing_parameter(name, self):
570
+ continue
571
+ name = name.replace(weight_name, param_name)
572
+ param = params_dict[name]
573
+ weight_loader = param.weight_loader
574
+ weight_loader(param,
575
+ loaded_weight,
576
+ name,
577
+ shard_id=shard_id,
578
+ expert_id=expert_id)
579
+ break
580
+ else:
581
+ # Skip loading extra bias for GPTQ models.
582
+ if name.endswith(".bias") and name not in params_dict:
583
+ continue
584
+ if is_pp_missing_parameter(name, self):
585
+ continue
586
+
587
+ param = params_dict[name]
588
+ weight_loader = getattr(param, "weight_loader",
589
+ default_weight_loader)
590
+ weight_loader(param, loaded_weight)
591
+ loaded_params.add(name)
592
+ return loaded_params
593
+
594
+
595
+ def _is_moe_layer(name: str):
596
+ return any(
597
+ [experts_name in name for experts_name in [
598
+ "experts",
599
+ "router",
600
+ ]])
601
+
602
+
603
+ class JambaForSequenceClassification(JambaForCausalLM):
604
+
605
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
606
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
607
+ config = vllm_config.model_config.hf_config
608
+ num_labels: int = config.num_labels
609
+ score_bias: bool = getattr(config, 'score_bias', False)
610
+ self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
611
+
612
+ pooler_config = vllm_config.model_config.pooler_config
613
+ self._pooler = Pooler.from_config_with_defaults(
614
+ pooler_config,
615
+ pooling_type=PoolingType.LAST,
616
+ normalize=False,
617
+ softmax=False)
618
+
619
+ def pooler(
620
+ self,
621
+ hidden_states: torch.Tensor,
622
+ pooling_metadata: PoolingMetadata,
623
+ ) -> Optional[PoolerOutput]:
624
+ hidden_states = hidden_states.float()
625
+ logits = self.score(hidden_states)
626
+ return self._pooler(logits, pooling_metadata)
627
+
628
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
629
+ # TODO: The reward weights themselves have float32 accuracy data, we
630
+ # would like to load them in fp32 to get that extra precision.
631
+ super().load_weights(weights)
632
+ self.score = self.score.float()
.venv/lib/python3.11/site-packages/vllm/model_executor/models/llama.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
25
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
26
+
27
+ import torch
28
+ from torch import nn
29
+ from transformers import LlamaConfig
30
+
31
+ from vllm.attention import Attention, AttentionMetadata
32
+ from vllm.compilation.decorators import support_torch_compile
33
+ from vllm.config import CacheConfig, VllmConfig
34
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
+ from vllm.model_executor.layers.activation import SiluAndMul
36
+ from vllm.model_executor.layers.layernorm import RMSNorm
37
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
38
+ QKVParallelLinear,
39
+ RowParallelLinear)
40
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
+ from vllm.model_executor.layers.quantization import QuantizationConfig
42
+ from vllm.model_executor.layers.rotary_embedding import get_rope
43
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
45
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
+ from vllm.model_executor.model_loader.weight_utils import (
47
+ default_weight_loader, maybe_remap_kv_scale_name)
48
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
49
+ from vllm.sequence import IntermediateTensors
50
+
51
+ from .interfaces import SupportsLoRA, SupportsPP
52
+ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
53
+ is_pp_missing_parameter,
54
+ make_empty_intermediate_tensors_factory, make_layers,
55
+ maybe_prefix)
56
+
57
+
58
+ class LlamaMLP(nn.Module):
59
+
60
+ def __init__(
61
+ self,
62
+ hidden_size: int,
63
+ intermediate_size: int,
64
+ hidden_act: str,
65
+ quant_config: Optional[QuantizationConfig] = None,
66
+ bias: bool = False,
67
+ prefix: str = "",
68
+ ) -> None:
69
+ super().__init__()
70
+ self.gate_up_proj = MergedColumnParallelLinear(
71
+ input_size=hidden_size,
72
+ output_sizes=[intermediate_size] * 2,
73
+ bias=bias,
74
+ quant_config=quant_config,
75
+ prefix=f"{prefix}.gate_up_proj",
76
+ )
77
+ self.down_proj = RowParallelLinear(
78
+ input_size=intermediate_size,
79
+ output_size=hidden_size,
80
+ bias=bias,
81
+ quant_config=quant_config,
82
+ prefix=f"{prefix}.down_proj",
83
+ )
84
+ if hidden_act != "silu":
85
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
86
+ "Only silu is supported for now.")
87
+ self.act_fn = SiluAndMul()
88
+
89
+ def forward(self, x):
90
+ x, _ = self.gate_up_proj(x)
91
+ x = self.act_fn(x)
92
+ x, _ = self.down_proj(x)
93
+ return x
94
+
95
+
96
+ class LlamaAttention(nn.Module):
97
+
98
+ def __init__(self,
99
+ config: LlamaConfig,
100
+ hidden_size: int,
101
+ num_heads: int,
102
+ num_kv_heads: int,
103
+ rope_theta: float = 10000,
104
+ rope_scaling: Optional[Dict[str, Any]] = None,
105
+ max_position_embeddings: int = 8192,
106
+ quant_config: Optional[QuantizationConfig] = None,
107
+ bias: bool = False,
108
+ bias_o_proj: bool = False,
109
+ cache_config: Optional[CacheConfig] = None,
110
+ prefix: str = "") -> None:
111
+ super().__init__()
112
+ layer_idx = extract_layer_index(prefix)
113
+ self.hidden_size = hidden_size
114
+ tp_size = get_tensor_model_parallel_world_size()
115
+ self.total_num_heads = num_heads
116
+ assert self.total_num_heads % tp_size == 0
117
+ self.num_heads = self.total_num_heads // tp_size
118
+ self.total_num_kv_heads = num_kv_heads
119
+ if self.total_num_kv_heads >= tp_size:
120
+ # Number of KV heads is greater than TP size, so we partition
121
+ # the KV heads across multiple tensor parallel GPUs.
122
+ assert self.total_num_kv_heads % tp_size == 0
123
+ else:
124
+ # Number of KV heads is less than TP size, so we replicate
125
+ # the KV heads across multiple tensor parallel GPUs.
126
+ assert tp_size % self.total_num_kv_heads == 0
127
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
128
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
129
+ self.head_dim = getattr(config, "head_dim",
130
+ self.hidden_size // self.total_num_heads)
131
+ self.q_size = self.num_heads * self.head_dim
132
+ self.kv_size = self.num_kv_heads * self.head_dim
133
+ self.scaling = self.head_dim**-0.5
134
+ self.rope_theta = rope_theta
135
+ self.max_position_embeddings = max_position_embeddings
136
+
137
+ self.qkv_proj = QKVParallelLinear(
138
+ hidden_size=hidden_size,
139
+ head_size=self.head_dim,
140
+ total_num_heads=self.total_num_heads,
141
+ total_num_kv_heads=self.total_num_kv_heads,
142
+ bias=bias,
143
+ quant_config=quant_config,
144
+ prefix=f"{prefix}.qkv_proj",
145
+ )
146
+
147
+ self.o_proj = RowParallelLinear(
148
+ input_size=self.total_num_heads * self.head_dim,
149
+ output_size=hidden_size,
150
+ bias=bias_o_proj,
151
+ quant_config=quant_config,
152
+ prefix=f"{prefix}.o_proj",
153
+ )
154
+
155
+ is_neox_style = True
156
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
157
+ if is_gguf and config.model_type == "llama":
158
+ is_neox_style = False
159
+
160
+ self.rotary_emb = get_rope(
161
+ self.head_dim,
162
+ rotary_dim=self.head_dim,
163
+ max_position=max_position_embeddings,
164
+ base=rope_theta,
165
+ rope_scaling=rope_scaling,
166
+ is_neox_style=is_neox_style,
167
+ )
168
+
169
+ if hasattr(config, "interleaved_sliding_window"):
170
+ interleaved_sliding_window = config.interleaved_sliding_window
171
+ if isinstance(interleaved_sliding_window, int):
172
+ sliding_window = interleaved_sliding_window
173
+ elif isinstance(interleaved_sliding_window, list):
174
+ sw_idx = layer_idx % len(interleaved_sliding_window)
175
+ sliding_window = interleaved_sliding_window[sw_idx]
176
+ else:
177
+ raise ValueError(
178
+ f"{type(interleaved_sliding_window)} is not supported.")
179
+ else:
180
+ sliding_window = None
181
+
182
+ self.attn = Attention(
183
+ self.num_heads,
184
+ self.head_dim,
185
+ self.scaling,
186
+ num_kv_heads=self.num_kv_heads,
187
+ cache_config=cache_config,
188
+ quant_config=quant_config,
189
+ per_layer_sliding_window=sliding_window,
190
+ prefix=f"{prefix}.attn",
191
+ )
192
+
193
+ def forward(
194
+ self,
195
+ positions: torch.Tensor,
196
+ hidden_states: torch.Tensor,
197
+ kv_cache: torch.Tensor,
198
+ attn_metadata: AttentionMetadata,
199
+ ) -> torch.Tensor:
200
+ qkv, _ = self.qkv_proj(hidden_states)
201
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
202
+ q, k = self.rotary_emb(positions, q, k)
203
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
204
+ output, _ = self.o_proj(attn_output)
205
+ return output
206
+
207
+
208
+ class LlamaDecoderLayer(nn.Module):
209
+
210
+ def __init__(
211
+ self,
212
+ config: LlamaConfig,
213
+ cache_config: Optional[CacheConfig] = None,
214
+ quant_config: Optional[QuantizationConfig] = None,
215
+ prefix: str = "",
216
+ ) -> None:
217
+ super().__init__()
218
+ self.hidden_size = config.hidden_size
219
+ rope_theta = getattr(config, "rope_theta", 10000)
220
+ rope_scaling = getattr(config, "rope_scaling", None)
221
+ if rope_scaling is not None and getattr(
222
+ config, "original_max_position_embeddings", None):
223
+ rope_scaling["original_max_position_embeddings"] = (
224
+ config.original_max_position_embeddings)
225
+ max_position_embeddings = getattr(config, "max_position_embeddings",
226
+ 8192)
227
+ # Support abacusai/Smaug-72B-v0.1 with attention_bias
228
+ # Support internlm/internlm-7b with bias
229
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
230
+ config, "bias", False)
231
+ bias_o_proj = attention_bias
232
+ # support internlm/internlm3-8b with qkv_bias
233
+ if hasattr(config, 'qkv_bias'):
234
+ attention_bias = config.qkv_bias
235
+
236
+ self.self_attn = LlamaAttention(
237
+ config=config,
238
+ hidden_size=self.hidden_size,
239
+ num_heads=config.num_attention_heads,
240
+ num_kv_heads=getattr(config, "num_key_value_heads",
241
+ config.num_attention_heads),
242
+ rope_theta=rope_theta,
243
+ rope_scaling=rope_scaling,
244
+ max_position_embeddings=max_position_embeddings,
245
+ quant_config=quant_config,
246
+ bias=attention_bias,
247
+ bias_o_proj=bias_o_proj,
248
+ cache_config=cache_config,
249
+ prefix=f"{prefix}.self_attn",
250
+ )
251
+ self.mlp = LlamaMLP(
252
+ hidden_size=self.hidden_size,
253
+ intermediate_size=config.intermediate_size,
254
+ hidden_act=config.hidden_act,
255
+ quant_config=quant_config,
256
+ bias=getattr(config, "mlp_bias", False),
257
+ prefix=f"{prefix}.mlp",
258
+ )
259
+ self.input_layernorm = RMSNorm(config.hidden_size,
260
+ eps=config.rms_norm_eps)
261
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
262
+ eps=config.rms_norm_eps)
263
+
264
+ def forward(
265
+ self,
266
+ positions: torch.Tensor,
267
+ hidden_states: torch.Tensor,
268
+ kv_cache: torch.Tensor,
269
+ attn_metadata: AttentionMetadata,
270
+ residual: Optional[torch.Tensor],
271
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
272
+ # Self Attention
273
+ if residual is None:
274
+ residual = hidden_states
275
+ hidden_states = self.input_layernorm(hidden_states)
276
+ else:
277
+ hidden_states, residual = self.input_layernorm(
278
+ hidden_states, residual)
279
+ hidden_states = self.self_attn(positions=positions,
280
+ hidden_states=hidden_states,
281
+ kv_cache=kv_cache,
282
+ attn_metadata=attn_metadata)
283
+
284
+ # Fully Connected
285
+ hidden_states, residual = self.post_attention_layernorm(
286
+ hidden_states, residual)
287
+ hidden_states = self.mlp(hidden_states)
288
+ return hidden_states, residual
289
+
290
+
291
+ @support_torch_compile
292
+ class LlamaModel(nn.Module):
293
+
294
+ def __init__(self,
295
+ *,
296
+ vllm_config: VllmConfig,
297
+ prefix: str = "",
298
+ layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
299
+ super().__init__()
300
+
301
+ config = vllm_config.model_config.hf_config
302
+ cache_config = vllm_config.cache_config
303
+ quant_config = vllm_config.quant_config
304
+ lora_config = vllm_config.lora_config
305
+
306
+ self.config = config
307
+ self.quant_config = quant_config
308
+ self.padding_idx = config.pad_token_id
309
+ lora_vocab = (lora_config.lora_extra_vocab_size *
310
+ (lora_config.max_loras or 1)) if lora_config else 0
311
+ self.vocab_size = config.vocab_size + lora_vocab
312
+ self.org_vocab_size = config.vocab_size
313
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
314
+ and get_pp_group().is_last_rank):
315
+ self.embed_tokens = VocabParallelEmbedding(
316
+ self.vocab_size,
317
+ config.hidden_size,
318
+ org_num_embeddings=config.vocab_size,
319
+ quant_config=quant_config,
320
+ )
321
+ else:
322
+ self.embed_tokens = PPMissingLayer()
323
+ self.start_layer, self.end_layer, self.layers = make_layers(
324
+ config.num_hidden_layers,
325
+ lambda prefix: layer_type(config=config,
326
+ cache_config=cache_config,
327
+ quant_config=quant_config,
328
+ prefix=prefix),
329
+ prefix=f"{prefix}.layers",
330
+ )
331
+ if get_pp_group().is_last_rank:
332
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
333
+ else:
334
+ self.norm = PPMissingLayer()
335
+
336
+ self.make_empty_intermediate_tensors = (
337
+ make_empty_intermediate_tensors_factory(
338
+ ["hidden_states", "residual"], config.hidden_size))
339
+
340
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
341
+ return self.embed_tokens(input_ids)
342
+
343
+ def forward(
344
+ self,
345
+ input_ids: Optional[torch.Tensor],
346
+ positions: torch.Tensor,
347
+ kv_caches: List[torch.Tensor],
348
+ attn_metadata: AttentionMetadata,
349
+ intermediate_tensors: Optional[IntermediateTensors],
350
+ inputs_embeds: Optional[torch.Tensor] = None,
351
+ ) -> Union[torch.Tensor, IntermediateTensors]:
352
+ if get_pp_group().is_first_rank:
353
+ if inputs_embeds is not None:
354
+ hidden_states = inputs_embeds
355
+ else:
356
+ hidden_states = self.get_input_embeddings(input_ids)
357
+ residual = None
358
+ else:
359
+ assert intermediate_tensors is not None
360
+ hidden_states = intermediate_tensors["hidden_states"]
361
+ residual = intermediate_tensors["residual"]
362
+
363
+ for i in range(self.start_layer, self.end_layer):
364
+ layer = self.layers[i]
365
+ hidden_states, residual = layer(positions, hidden_states,
366
+ kv_caches[i - self.start_layer],
367
+ attn_metadata, residual)
368
+
369
+ if not get_pp_group().is_last_rank:
370
+ return IntermediateTensors({
371
+ "hidden_states": hidden_states,
372
+ "residual": residual
373
+ })
374
+
375
+ hidden_states, _ = self.norm(hidden_states, residual)
376
+ return hidden_states
377
+
378
+ def load_weights(self, weights: Iterable[Tuple[str,
379
+ torch.Tensor]]) -> Set[str]:
380
+ stacked_params_mapping = [
381
+ # (param_name, shard_name, shard_id)
382
+ (".qkv_proj", ".q_proj", "q"),
383
+ (".qkv_proj", ".k_proj", "k"),
384
+ (".qkv_proj", ".v_proj", "v"),
385
+ (".gate_up_proj", ".gate_proj", 0),
386
+ (".gate_up_proj", ".up_proj", 1),
387
+ ]
388
+ params_dict = dict(self.named_parameters())
389
+ loaded_params: Set[str] = set()
390
+ for name, loaded_weight in weights:
391
+ if "rotary_emb.inv_freq" in name:
392
+ continue
393
+ if ("rotary_emb.cos_cached" in name
394
+ or "rotary_emb.sin_cached" in name):
395
+ # Models trained using ColossalAI may include these tensors in
396
+ # the checkpoint. Skip them.
397
+ continue
398
+ if (self.quant_config is not None and
399
+ (scale_name := self.quant_config.get_cache_scale(name))):
400
+ # Loading kv cache quantization scales
401
+ param = params_dict[scale_name]
402
+ weight_loader = getattr(param, "weight_loader",
403
+ default_weight_loader)
404
+ loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
405
+ loaded_weight[0])
406
+ weight_loader(param, loaded_weight)
407
+ loaded_params.add(scale_name)
408
+ continue
409
+ if "scale" in name:
410
+ # Remapping the name of FP8 kv-scale.
411
+ name = maybe_remap_kv_scale_name(name, params_dict)
412
+ if name is None:
413
+ continue
414
+ for param_name, weight_name, shard_id in stacked_params_mapping:
415
+ if weight_name not in name:
416
+ continue
417
+ name = name.replace(weight_name, param_name)
418
+ # Skip loading extra bias for GPTQ models.
419
+ if name.endswith(".bias") and name not in params_dict:
420
+ continue
421
+
422
+ if is_pp_missing_parameter(name, self):
423
+ continue
424
+
425
+ param = params_dict[name]
426
+ weight_loader = param.weight_loader
427
+ weight_loader(param, loaded_weight, shard_id)
428
+ break
429
+ else:
430
+ # Skip loading extra bias for GPTQ models.
431
+ if name.endswith(".bias") and name not in params_dict:
432
+ continue
433
+
434
+ if is_pp_missing_parameter(name, self):
435
+ continue
436
+
437
+ param = params_dict[name]
438
+ weight_loader = getattr(param, "weight_loader",
439
+ default_weight_loader)
440
+ weight_loader(param, loaded_weight)
441
+ loaded_params.add(name)
442
+ return loaded_params
443
+
444
+
445
+ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
446
+ packed_modules_mapping = {
447
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
448
+ "gate_up_proj": ["gate_proj", "up_proj"]
449
+ }
450
+
451
+ # LoRA specific attributes
452
+ supported_lora_modules = [
453
+ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
454
+ "lm_head"
455
+ ]
456
+ embedding_modules = {
457
+ "embed_tokens": "input_embeddings",
458
+ "lm_head": "output_embeddings"
459
+ }
460
+ embedding_padding_modules = ["lm_head"]
461
+
462
+ # Mistral/Llama models can also be loaded with --load-format mistral
463
+ # from consolidated.safetensors checkpoints
464
+ mistral_mapping = {
465
+ "layers": "model.layers",
466
+ "attention": "self_attn",
467
+ "wq": "q_proj",
468
+ "wk": "k_proj",
469
+ "wv": "v_proj",
470
+ "wo": "o_proj",
471
+ "attention_norm": "input_layernorm",
472
+ "feed_forward": "mlp",
473
+ "w1": "gate_proj",
474
+ "w2": "down_proj",
475
+ "w3": "up_proj",
476
+ "ffn_norm": "post_attention_layernorm",
477
+ "tok_embeddings": "model.embed_tokens",
478
+ "output": "lm_head",
479
+ "norm": "model.norm"
480
+ }
481
+
482
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
483
+ super().__init__()
484
+ config = vllm_config.model_config.hf_config
485
+ quant_config = vllm_config.quant_config
486
+ lora_config = vllm_config.lora_config
487
+ self.config = config
488
+ self.lora_config = lora_config
489
+
490
+ self.model = self._init_model(vllm_config=vllm_config,
491
+ prefix=maybe_prefix(prefix, "model"))
492
+
493
+ if get_pp_group().is_last_rank:
494
+ self.unpadded_vocab_size = config.vocab_size
495
+ if lora_config:
496
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
497
+ self.lm_head = ParallelLMHead(
498
+ self.unpadded_vocab_size,
499
+ config.hidden_size,
500
+ org_num_embeddings=config.vocab_size,
501
+ padding_size=(
502
+ DEFAULT_VOCAB_PADDING_SIZE
503
+ # We need bigger padding if using lora for kernel
504
+ # compatibility
505
+ if not lora_config else
506
+ lora_config.lora_vocab_padding_size),
507
+ quant_config=quant_config,
508
+ prefix=maybe_prefix(prefix, "lm_head"),
509
+ )
510
+ if config.tie_word_embeddings:
511
+ self.lm_head = self.lm_head.tie_weights(
512
+ self.model.embed_tokens)
513
+
514
+ logit_scale = getattr(config, "logit_scale", 1.0)
515
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
516
+ config.vocab_size,
517
+ logit_scale)
518
+ else:
519
+ self.lm_head = PPMissingLayer()
520
+
521
+ self.sampler = get_sampler()
522
+
523
+ self.make_empty_intermediate_tensors = (
524
+ self.model.make_empty_intermediate_tensors)
525
+
526
+ def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
527
+ return LlamaModel(vllm_config=vllm_config, prefix=prefix)
528
+
529
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
530
+ return self.model.get_input_embeddings(input_ids)
531
+
532
+ def forward(
533
+ self,
534
+ input_ids: torch.Tensor,
535
+ positions: torch.Tensor,
536
+ kv_caches: List[torch.Tensor],
537
+ attn_metadata: AttentionMetadata,
538
+ intermediate_tensors: Optional[IntermediateTensors] = None,
539
+ inputs_embeds: Optional[torch.Tensor] = None,
540
+ ) -> Union[torch.Tensor, IntermediateTensors]:
541
+ model_output = self.model(input_ids, positions, kv_caches,
542
+ attn_metadata, intermediate_tensors,
543
+ inputs_embeds)
544
+ return model_output
545
+
546
+ def compute_logits(
547
+ self,
548
+ hidden_states: torch.Tensor,
549
+ sampling_metadata: SamplingMetadata,
550
+ ) -> Optional[torch.Tensor]:
551
+ logits = self.logits_processor(self.lm_head, hidden_states,
552
+ sampling_metadata)
553
+ return logits
554
+
555
+ def sample(self, logits: torch.Tensor,
556
+ sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
557
+ next_tokens = self.sampler(logits, sampling_metadata)
558
+ return next_tokens
559
+
560
+ def load_weights(self, weights: Iterable[Tuple[str,
561
+ torch.Tensor]]) -> Set[str]:
562
+ loader = AutoWeightsLoader(
563
+ self,
564
+ skip_prefixes=(["lm_head."]
565
+ if self.config.tie_word_embeddings else None),
566
+ )
567
+ return loader.load_weights(
568
+ self.maybe_remap_mistral(name, loaded_weight)
569
+ for name, loaded_weight in weights)
570
+
571
+ # This function is used to remap the mistral format as
572
+ # used by Mistral and Llama <=2
573
+ def maybe_remap_mistral(
574
+ self,
575
+ name: str,
576
+ loaded_weight: torch.Tensor,
577
+ ) -> Tuple[str, torch.Tensor]:
578
+
579
+ def permute(w: torch.Tensor, n_heads: int):
580
+ attn_in = self.config.head_dim * n_heads
581
+ attn_out = self.config.hidden_size
582
+
583
+ return w.view(n_heads, attn_in // n_heads // 2, 2,
584
+ attn_out).transpose(1, 2).reshape(attn_in, attn_out)
585
+
586
+ mapping = self.mistral_mapping
587
+ modules = name.split(".")
588
+
589
+ # rotary embeds should be sliced
590
+ if "wk" in modules:
591
+ loaded_weight = permute(loaded_weight,
592
+ self.config.num_key_value_heads)
593
+ elif "wq" in modules:
594
+ loaded_weight = permute(loaded_weight,
595
+ self.config.num_attention_heads)
596
+
597
+ for item in modules:
598
+ if item in mapping and mapping[item] not in name:
599
+ name = name.replace(item, mapping[item])
600
+
601
+ return name, loaded_weight
.venv/lib/python3.11/site-packages/vllm/model_executor/models/llava.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import abstractmethod
4
+ from functools import cached_property
5
+ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
6
+ Protocol, Set, Tuple, TypedDict, TypeVar, Union)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from packaging.version import Version
11
+ from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
12
+ PixtralVisionConfig, PretrainedConfig,
13
+ SiglipVisionConfig)
14
+ from transformers import __version__ as TRANSFORMERS_VERSION
15
+ from transformers.models.llava import LlavaProcessor
16
+ from transformers.models.pixtral import PixtralProcessor
17
+
18
+ from vllm.attention import AttentionMetadata
19
+ from vllm.config import VllmConfig
20
+ from vllm.inputs import InputProcessingContext
21
+ from vllm.model_executor.layers.activation import get_act_fn
22
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
23
+ RowParallelLinear)
24
+ from vllm.model_executor.layers.quantization import QuantizationConfig
25
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
26
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
27
+ from vllm.multimodal import MULTIMODAL_REGISTRY
28
+ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
+ MultiModalInputs, MultiModalKwargs,
30
+ NestedTensors)
31
+ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
+ ImageSize, MultiModalDataItems)
33
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
+ BaseProcessingInfo, ProcessingCache,
35
+ PromptReplacement)
36
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
+ from vllm.sequence import IntermediateTensors
38
+
39
+ from .clip import CLIPVisionModel
40
+ from .interfaces import SupportsMultiModal, SupportsPP
41
+ from .pixtral import (PixtralHFVisionModel,
42
+ get_pixtral_hf_image_feature_grid_size)
43
+ from .siglip import SiglipVisionModel
44
+ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
+ maybe_prefix, merge_multimodal_embeddings)
46
+ from .vision import get_vision_encoder_info
47
+
48
+
49
+ class LlavaImagePixelInputs(TypedDict):
50
+ type: Literal["pixel_values"]
51
+ data: Union[torch.Tensor, List[torch.Tensor]]
52
+ """
53
+ Shape: `(batch_size * num_images, num_channels, height, width)`
54
+
55
+ Note that `height` or `width` may be different per batch and image,
56
+ in which case the data is passed as a list instead of a batched tensor.
57
+ """
58
+
59
+
60
+ class LlavaImageEmbeddingInputs(TypedDict):
61
+ type: Literal["image_embeds"]
62
+ data: torch.Tensor
63
+ """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
64
+
65
+ `hidden_size` must match the hidden size of language model backbone.
66
+ """
67
+
68
+
69
+ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
70
+
71
+
72
+ class LlavaMultiModalProjector(nn.Module):
73
+
74
+ def __init__(self,
75
+ vision_hidden_size: int,
76
+ text_hidden_size: int,
77
+ projector_hidden_act: str,
78
+ multimodal_projector_bias: bool,
79
+ quant_config: Optional[QuantizationConfig] = None,
80
+ prefix: str = ""):
81
+ super().__init__()
82
+
83
+ self.linear_1 = ColumnParallelLinear(vision_hidden_size,
84
+ text_hidden_size,
85
+ bias=multimodal_projector_bias,
86
+ quant_config=quant_config,
87
+ prefix=f"{prefix}.linear_1")
88
+ self.act = get_act_fn(projector_hidden_act)
89
+ self.linear_2 = RowParallelLinear(text_hidden_size,
90
+ text_hidden_size,
91
+ bias=multimodal_projector_bias,
92
+ quant_config=quant_config,
93
+ prefix=f"{prefix}.linear_2")
94
+
95
+ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
96
+ hidden_states, _ = self.linear_1(image_features)
97
+ hidden_states = self.act(hidden_states)
98
+ hidden_states, _ = self.linear_2(hidden_states)
99
+ return hidden_states
100
+
101
+
102
+ class LlavaLikeConfig(Protocol):
103
+ vision_config: Final[PretrainedConfig]
104
+ image_token_index: Final[int]
105
+ vision_feature_select_strategy: Final[str]
106
+ vision_feature_layer: Final[Union[int, list[int]]]
107
+
108
+
109
+ class LlavaLikeProcessor(Protocol):
110
+ image_token: Final[str]
111
+
112
+
113
+ class BaseLlavaProcessingInfo(BaseProcessingInfo):
114
+
115
+ def get_hf_config(self) -> LlavaLikeConfig:
116
+ return self.ctx.get_hf_config(LlavaConfig)
117
+
118
+ def get_vision_encoder_info(self):
119
+ return get_vision_encoder_info(self.get_hf_config())
120
+
121
+ @abstractmethod
122
+ def get_hf_processor(self) -> LlavaLikeProcessor:
123
+ raise NotImplementedError
124
+
125
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
126
+ return {"image": None}
127
+
128
+ def get_mm_max_tokens_per_item(
129
+ self,
130
+ seq_len: int,
131
+ mm_counts: Mapping[str, int],
132
+ ) -> Mapping[str, int]:
133
+ return {"image": self.get_max_image_tokens()}
134
+
135
+ def _apply_feature_select_strategy(
136
+ self,
137
+ strategy: str,
138
+ encoder_num_image_tokens: int,
139
+ ) -> int:
140
+ if strategy == "default":
141
+ return encoder_num_image_tokens - 1
142
+ if strategy == "full":
143
+ return encoder_num_image_tokens
144
+
145
+ msg = f"Unexpected feature select strategy: {strategy!r}"
146
+ raise NotImplementedError(msg)
147
+
148
+ def get_num_image_tokens(
149
+ self,
150
+ *,
151
+ image_width: int,
152
+ image_height: int,
153
+ ) -> int:
154
+ hf_config = self.get_hf_config()
155
+ vision_encoder_info = self.get_vision_encoder_info()
156
+
157
+ return self._apply_feature_select_strategy(
158
+ hf_config.vision_feature_select_strategy,
159
+ vision_encoder_info.get_num_image_tokens(
160
+ image_width=image_width,
161
+ image_height=image_height,
162
+ ),
163
+ )
164
+
165
+ def get_image_size_with_most_features(self) -> ImageSize:
166
+ vision_encoder_info = self.get_vision_encoder_info()
167
+ width = height = vision_encoder_info.get_image_size()
168
+ return ImageSize(width=width, height=height)
169
+
170
+ def get_max_image_tokens(self) -> int:
171
+ target_width, target_height = self.get_image_size_with_most_features()
172
+
173
+ return self.get_num_image_tokens(
174
+ image_width=target_width,
175
+ image_height=target_height,
176
+ )
177
+
178
+
179
+ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
180
+
181
+
182
+ class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
183
+
184
+ def get_dummy_processor_inputs(
185
+ self,
186
+ seq_len: int,
187
+ mm_counts: Mapping[str, int],
188
+ ) -> ProcessorInputs:
189
+ num_images = mm_counts.get("image", 0)
190
+
191
+ processor = self.info.get_hf_processor()
192
+ image_token = processor.image_token
193
+ target_width, target_height = \
194
+ self.info.get_image_size_with_most_features()
195
+
196
+ mm_data = {
197
+ "image":
198
+ self._get_dummy_images(width=target_width,
199
+ height=target_height,
200
+ num_images=num_images)
201
+ }
202
+
203
+ return ProcessorInputs(
204
+ prompt_text=image_token * num_images,
205
+ mm_data=mm_data,
206
+ )
207
+
208
+
209
+ class LlavaProcessingInfo(BaseLlavaProcessingInfo):
210
+
211
+ def get_hf_processor(self):
212
+ return self.ctx.get_hf_processor(LlavaProcessor)
213
+
214
+
215
+ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
216
+
217
+ # Copied from BaseMultiModalProcessor
218
+ @abstractmethod
219
+ def _get_mm_fields_config(
220
+ self,
221
+ hf_inputs: BatchFeature,
222
+ hf_processor_mm_kwargs: Mapping[str, object],
223
+ ) -> Mapping[str, MultiModalFieldConfig]:
224
+ raise NotImplementedError
225
+
226
+ def _get_prompt_replacements(
227
+ self,
228
+ mm_items: MultiModalDataItems,
229
+ hf_processor_mm_kwargs: Mapping[str, object],
230
+ out_mm_kwargs: MultiModalKwargs,
231
+ ) -> list[PromptReplacement]:
232
+ hf_config = self.info.get_hf_config()
233
+ image_token_id = hf_config.image_token_index
234
+
235
+ def get_replacement(item_idx: int):
236
+ images = mm_items.get_items(
237
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
238
+
239
+ if isinstance(images, ImageEmbeddingItems):
240
+ num_image_tokens = images.get_feature_size(item_idx)
241
+ else:
242
+ image_size = images.get_image_size(item_idx)
243
+ num_image_tokens = self.info.get_num_image_tokens(
244
+ image_width=image_size.width,
245
+ image_height=image_size.height,
246
+ )
247
+
248
+ return [image_token_id] * num_image_tokens
249
+
250
+ return [
251
+ PromptReplacement(
252
+ modality="image",
253
+ target=[image_token_id],
254
+ replacement=get_replacement,
255
+ ),
256
+ ]
257
+
258
+
259
+ class LlavaMultiModalProcessor(
260
+ BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
261
+
262
+ def _get_mm_fields_config(
263
+ self,
264
+ hf_inputs: BatchFeature,
265
+ hf_processor_mm_kwargs: Mapping[str, object],
266
+ ) -> Mapping[str, MultiModalFieldConfig]:
267
+ return dict(
268
+ pixel_values=MultiModalFieldConfig.batched("image"),
269
+ image_embeds=MultiModalFieldConfig.batched("image"),
270
+ )
271
+
272
+
273
+ class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
274
+
275
+ def get_hf_processor(self):
276
+ return self.ctx.get_hf_processor(PixtralProcessor)
277
+
278
+
279
+ class PixtralHFMultiModalProcessor(
280
+ BaseMultiModalProcessor[PixtralHFProcessingInfo]):
281
+
282
+ def _call_hf_processor(
283
+ self,
284
+ prompt: str,
285
+ mm_data: Mapping[str, object],
286
+ mm_kwargs: Mapping[str, object],
287
+ ) -> BatchFeature:
288
+ processed_outputs = super()._call_hf_processor(
289
+ prompt=prompt,
290
+ mm_data=mm_data,
291
+ mm_kwargs=mm_kwargs,
292
+ )
293
+
294
+ pixel_values = processed_outputs.get("pixel_values")
295
+ if pixel_values is not None:
296
+ # Before/after https://github.com/huggingface/transformers/pull/35122
297
+ if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
298
+ images = mm_data["images"]
299
+ assert isinstance(images, list)
300
+
301
+ # Original output: (1, num_images, C, H, W)
302
+ # New output: (num_images, C, H, W)
303
+ assert (isinstance(pixel_values, list)
304
+ and len(pixel_values) == 1)
305
+ assert (isinstance(pixel_values[0], list)
306
+ and len(pixel_values[0]) == len(images))
307
+
308
+ processed_outputs["pixel_values"] = pixel_values[0]
309
+ else:
310
+ # Avoid padding since we need the output for each image to be
311
+ # independent of other images for the cache to work correctly
312
+ image_sizes = processed_outputs["image_sizes"]
313
+ assert len(pixel_values) == len(image_sizes)
314
+
315
+ processed_outputs["pixel_values"] = [
316
+ p[:, :h, :w]
317
+ for p, (h, w) in zip(pixel_values, image_sizes)
318
+ ]
319
+
320
+ return processed_outputs
321
+
322
+ def _get_mm_fields_config(
323
+ self,
324
+ hf_inputs: BatchFeature,
325
+ hf_processor_mm_kwargs: Mapping[str, object],
326
+ ) -> Mapping[str, MultiModalFieldConfig]:
327
+ return dict(
328
+ pixel_values=MultiModalFieldConfig.batched("image"),
329
+ image_embeds=MultiModalFieldConfig.batched("image"),
330
+ )
331
+
332
+ def _get_prompt_replacements(
333
+ self,
334
+ mm_items: MultiModalDataItems,
335
+ hf_processor_mm_kwargs: Mapping[str, object],
336
+ out_mm_kwargs: MultiModalKwargs,
337
+ ) -> list[PromptReplacement]:
338
+ processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
339
+ hf_config = self.info.get_hf_config()
340
+ tokenizer = self.info.get_tokenizer()
341
+ vocab = tokenizer.get_vocab()
342
+
343
+ image_break_id = vocab[processor.image_break_token]
344
+ image_token_id = hf_config.image_token_index
345
+ image_end_id = vocab[processor.image_end_token]
346
+
347
+ vision_config = hf_config.vision_config
348
+ assert isinstance(vision_config, PixtralVisionConfig)
349
+
350
+ def get_replacement(item_idx: int):
351
+ images = mm_items.get_items("image", ImageProcessorItems)
352
+ image_size = images.get_image_size(item_idx)
353
+
354
+ ncols, nrows = get_pixtral_hf_image_feature_grid_size(
355
+ vision_config,
356
+ image_width=image_size.width,
357
+ image_height=image_size.height,
358
+ )
359
+
360
+ tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
361
+ tokens[-1] = image_end_id
362
+
363
+ return tokens
364
+
365
+ return [
366
+ PromptReplacement(
367
+ modality="image",
368
+ target=[image_token_id],
369
+ replacement=get_replacement,
370
+ ),
371
+ ]
372
+
373
+
374
+ def _build_llava_or_pixtral_hf_info(
375
+ ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
376
+ hf_config = ctx.get_hf_config(LlavaConfig)
377
+
378
+ if isinstance(hf_config.vision_config, PixtralVisionConfig):
379
+ return PixtralHFProcessingInfo(ctx)
380
+
381
+ return LlavaProcessingInfo(ctx)
382
+
383
+
384
+ def _build_llava_or_pixtral_hf_processor(
385
+ info: _I,
386
+ dummy_inputs: BaseDummyInputsBuilder[_I],
387
+ *,
388
+ cache: Optional[ProcessingCache] = None,
389
+ enable_sanity_checks: bool = True,
390
+ ) -> BaseMultiModalProcessor:
391
+ if isinstance(info, PixtralHFProcessingInfo):
392
+ return PixtralHFMultiModalProcessor(
393
+ info,
394
+ dummy_inputs, # type: ignore
395
+ cache=cache,
396
+ enable_sanity_checks=enable_sanity_checks,
397
+ )
398
+
399
+ if isinstance(info, LlavaProcessingInfo):
400
+ return LlavaMultiModalProcessor(
401
+ info,
402
+ dummy_inputs, # type: ignore
403
+ cache=cache,
404
+ enable_sanity_checks=enable_sanity_checks,
405
+ )
406
+
407
+ raise NotImplementedError(type(info))
408
+
409
+
410
+ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
411
+ """Determine the number of hidden layers to initialize up to in the
412
+ visual encoder.
413
+
414
+ Args:
415
+ hf_config: Model config with vision feature layer(s).
416
+ """
417
+ feature_layers = hf_config.vision_feature_layer
418
+ num_hidden_layers = hf_config.vision_config.num_hidden_layers
419
+ # If we have one feature layer, initialize up to that layer
420
+ if isinstance(feature_layers, int):
421
+ return _get_layer_index(feature_layers, num_hidden_layers)
422
+ # If we have multiple feature layers, initialize up to the deepest one
423
+ elif isinstance(feature_layers, (list, tuple)):
424
+ return max(
425
+ _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
426
+ raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
427
+ " is not supported")
428
+
429
+
430
+ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
431
+ """Given an signed vision feature layer, get the number of hidden layers
432
+ needed to leverage it.
433
+
434
+ Args:
435
+ feature_layer_index: Index of a required layer in the visual encoder.
436
+ num_hidden_layers: The total number of hidden layers in the visual
437
+ encoder.
438
+ """
439
+ if feature_layer_index < 0:
440
+ return num_hidden_layers + feature_layer_index + 1
441
+ return feature_layer_index + 1
442
+
443
+
444
+ def init_vision_tower_for_llava(
445
+ hf_config: LlavaLikeConfig,
446
+ quant_config: Optional[QuantizationConfig],
447
+ *,
448
+ require_post_norm: Optional[bool] = None,
449
+ prefix: str = "",
450
+ ):
451
+ vision_config = hf_config.vision_config
452
+
453
+ # Initialize the vision tower only up to the deepest required feature layer
454
+ num_hidden_layers = _get_num_hidden_layers(hf_config)
455
+
456
+ if isinstance(vision_config, CLIPVisionConfig):
457
+ return CLIPVisionModel(
458
+ vision_config,
459
+ quant_config=quant_config,
460
+ num_hidden_layers_override=num_hidden_layers,
461
+ require_post_norm=require_post_norm,
462
+ prefix=prefix,
463
+ )
464
+ elif isinstance(vision_config, SiglipVisionConfig):
465
+ return SiglipVisionModel(
466
+ vision_config,
467
+ quant_config=quant_config,
468
+ num_hidden_layers_override=num_hidden_layers,
469
+ require_post_norm=require_post_norm,
470
+ prefix=prefix,
471
+ )
472
+ elif isinstance(vision_config, PixtralVisionConfig):
473
+ return PixtralHFVisionModel(
474
+ vision_config,
475
+ quant_config=quant_config,
476
+ num_hidden_layers_override=num_hidden_layers,
477
+ require_post_norm=require_post_norm,
478
+ prefix=prefix,
479
+ )
480
+
481
+ msg = f"Unsupported vision config: {type(vision_config)}"
482
+ raise NotImplementedError(msg)
483
+
484
+
485
+ @MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
486
+ info=_build_llava_or_pixtral_hf_info,
487
+ dummy_inputs=LlavaDummyInputsBuilder)
488
+ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
489
+
490
+ packed_modules_mapping = {
491
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
492
+ "gate_up_proj": ["gate_proj", "up_proj"]
493
+ }
494
+
495
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
496
+ super().__init__()
497
+
498
+ config = vllm_config.model_config.hf_config
499
+ quant_config = vllm_config.quant_config
500
+ multimodal_config = vllm_config.model_config.multimodal_config
501
+
502
+ self.config = config
503
+ self.multimodal_config = multimodal_config
504
+
505
+ # NOTE: These are special cases for Pixtral-12B in the HF-format
506
+ # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
507
+ if (config.text_config.architectures is None
508
+ and config.text_config.model_type == "mistral"):
509
+ config.text_config.architectures = ["MistralForCausalLM"]
510
+ if (config.projector_hidden_act is None
511
+ and config.vision_config.hidden_act == "gelu"):
512
+ config.projector_hidden_act = "gelu"
513
+
514
+ # TODO: Optionally initializes this for supporting embeddings.
515
+ self.vision_tower = init_vision_tower_for_llava(
516
+ config,
517
+ quant_config,
518
+ require_post_norm=False,
519
+ prefix=maybe_prefix(prefix, "vision_tower"))
520
+ self.multi_modal_projector = LlavaMultiModalProjector(
521
+ vision_hidden_size=config.vision_config.hidden_size,
522
+ text_hidden_size=config.text_config.hidden_size,
523
+ projector_hidden_act=config.projector_hidden_act,
524
+ multimodal_projector_bias=config.multimodal_projector_bias,
525
+ quant_config=quant_config,
526
+ prefix=maybe_prefix(prefix, "multi_modal_projector"))
527
+
528
+ self.language_model = init_vllm_registered_model(
529
+ vllm_config=vllm_config,
530
+ hf_config=config.text_config,
531
+ prefix=maybe_prefix(prefix, "language_model"),
532
+ )
533
+
534
+ self.make_empty_intermediate_tensors = (
535
+ self.language_model.make_empty_intermediate_tensors)
536
+
537
+ @cached_property
538
+ def sampler(self):
539
+ if hasattr(self.language_model, "sampler"):
540
+ return self.language_model.sampler
541
+
542
+ return get_sampler()
543
+
544
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
545
+ h = w = self.config.vision_config.image_size
546
+ expected_dims = (3, h, w)
547
+ actual_dims = tuple(data.shape[1:])
548
+
549
+ if actual_dims != expected_dims:
550
+ expected_expr = ("batch_size", *map(str, expected_dims))
551
+ raise ValueError(
552
+ f"The expected shape of pixel values is {expected_expr}. "
553
+ f"You supplied {tuple(data.shape)}.")
554
+
555
+ return data
556
+
557
+ def _parse_and_validate_image_input(
558
+ self, **kwargs: object) -> Optional[LlavaImageInputs]:
559
+ pixel_values = kwargs.pop("pixel_values", None)
560
+ image_embeds = kwargs.pop("image_embeds", None)
561
+
562
+ if pixel_values is None and image_embeds is None:
563
+ return None
564
+
565
+ if pixel_values is not None:
566
+ if not isinstance(pixel_values, (torch.Tensor, list)):
567
+ raise ValueError("Incorrect type of pixel values. "
568
+ f"Got type: {type(pixel_values)}")
569
+
570
+ if self.config.vision_config.model_type == "pixtral":
571
+ return LlavaImagePixelInputs(
572
+ type="pixel_values",
573
+ data=flatten_bn(pixel_values),
574
+ )
575
+
576
+ return LlavaImagePixelInputs(
577
+ type="pixel_values",
578
+ data=self._validate_pixel_values(
579
+ flatten_bn(pixel_values, concat=True)),
580
+ )
581
+
582
+ if image_embeds is not None:
583
+ if not isinstance(image_embeds, (torch.Tensor, list)):
584
+ raise ValueError("Incorrect type of image embeddings. "
585
+ f"Got type: {type(image_embeds)}")
586
+
587
+ return LlavaImageEmbeddingInputs(
588
+ type="image_embeds",
589
+ data=flatten_bn(image_embeds, concat=True),
590
+ )
591
+
592
+ raise AssertionError("This line should be unreachable.")
593
+
594
+ def _select_image_features(self, image_features: torch.Tensor, *,
595
+ strategy: str) -> torch.Tensor:
596
+ # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
597
+ if strategy == "default":
598
+ return image_features[:, 1:]
599
+ elif strategy == "full":
600
+ return image_features
601
+
602
+ raise ValueError(f"Unexpected select feature strategy: {strategy}")
603
+
604
+ def _image_pixels_to_features(
605
+ self,
606
+ vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
607
+ PixtralHFVisionModel],
608
+ pixel_values: torch.Tensor,
609
+ ) -> torch.Tensor:
610
+
611
+ # NOTE: we skip the step to select the vision feature layer since
612
+ # this is already done inside the vision tower
613
+ image_features = vision_tower(pixel_values)
614
+
615
+ return self._select_image_features(
616
+ image_features,
617
+ strategy=self.config.vision_feature_select_strategy,
618
+ )
619
+
620
+ def _process_image_pixels(self,
621
+ inputs: LlavaImagePixelInputs) -> torch.Tensor:
622
+ assert self.vision_tower is not None
623
+
624
+ pixel_values = inputs["data"]
625
+
626
+ return self._image_pixels_to_features(self.vision_tower, pixel_values)
627
+
628
+ def _process_image_input(self,
629
+ image_input: LlavaImageInputs) -> torch.Tensor:
630
+
631
+ if image_input["type"] == "image_embeds":
632
+ return image_input["data"]
633
+
634
+ assert self.vision_tower is not None
635
+ image_features = self._process_image_pixels(image_input)
636
+ return self.multi_modal_projector(image_features)
637
+
638
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
639
+ image_input = self._parse_and_validate_image_input(**kwargs)
640
+ if image_input is None:
641
+ return None
642
+ vision_embeddings = self._process_image_input(image_input)
643
+ return vision_embeddings
644
+
645
+ def get_input_embeddings(
646
+ self,
647
+ input_ids: torch.Tensor,
648
+ multimodal_embeddings: Optional[NestedTensors] = None,
649
+ ) -> torch.Tensor:
650
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
651
+ if multimodal_embeddings is not None:
652
+ inputs_embeds = merge_multimodal_embeddings(
653
+ input_ids, inputs_embeds, multimodal_embeddings,
654
+ self.config.image_token_index)
655
+ return inputs_embeds
656
+
657
+ def forward(
658
+ self,
659
+ input_ids: torch.Tensor,
660
+ positions: torch.Tensor,
661
+ kv_caches: List[torch.Tensor],
662
+ attn_metadata: AttentionMetadata,
663
+ intermediate_tensors: Optional[IntermediateTensors] = None,
664
+ inputs_embeds: Optional[torch.Tensor] = None,
665
+ **kwargs: object,
666
+ ) -> Union[torch.Tensor, IntermediateTensors]:
667
+ """Run forward pass for LLaVA-1.5.
668
+
669
+ One key thing to understand is the `input_ids` already accounts for the
670
+ positions of the to-be-inserted image embeddings.
671
+
672
+ Concretely, consider a text prompt:
673
+ `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
674
+
675
+ Tokenizer outputs:
676
+ `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
677
+ 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
678
+
679
+ To reserve space in KV cache, we have to insert placeholder tokens
680
+ before they are inputted to the model, so the input processor prepends
681
+ additional image tokens (denoted as `32000`), resulting in:
682
+ `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
683
+ 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
684
+ 29901]`.
685
+
686
+ We insert 575 tokens so that including the original image token in the
687
+ input, there are a total of 576 (24 * 24) image tokens, which
688
+ corresponds to the number of image tokens inputted to the language
689
+ model, i.e. the number of image tokens outputted by the visual encoder.
690
+
691
+ This way, the `positions` and `attn_metadata` are consistent
692
+ with the `input_ids`.
693
+
694
+ Args:
695
+ input_ids: Flattened (concatenated) input_ids corresponding to a
696
+ batch.
697
+ pixel_values: The pixels in each input image.
698
+
699
+ See also:
700
+ :class:`LlavaImageInputs`
701
+ """
702
+ if intermediate_tensors is not None:
703
+ inputs_embeds = None
704
+
705
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
706
+ # condition is for v0 compatibility.
707
+ elif inputs_embeds is None:
708
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
709
+ inputs_embeds = self.get_input_embeddings(input_ids,
710
+ vision_embeddings)
711
+ input_ids = None
712
+
713
+ hidden_states = self.language_model.model(input_ids,
714
+ positions,
715
+ kv_caches,
716
+ attn_metadata,
717
+ intermediate_tensors,
718
+ inputs_embeds=inputs_embeds)
719
+
720
+ return hidden_states
721
+
722
+ def compute_logits(
723
+ self,
724
+ hidden_states: torch.Tensor,
725
+ sampling_metadata: SamplingMetadata,
726
+ ) -> Optional[torch.Tensor]:
727
+ return self.language_model.compute_logits(hidden_states,
728
+ sampling_metadata)
729
+
730
+ def sample(
731
+ self,
732
+ logits: torch.Tensor,
733
+ sampling_metadata: SamplingMetadata,
734
+ ) -> Optional[SamplerOutput]:
735
+ return self.language_model.sample(logits, sampling_metadata)
736
+
737
+ def load_weights(self, weights: Iterable[Tuple[str,
738
+ torch.Tensor]]) -> Set[str]:
739
+ loader = AutoWeightsLoader(self)
740
+ return loader.load_weights(weights)
741
+
742
+
743
+ class MantisProcessingInfo(LlavaProcessingInfo):
744
+
745
+ def get_hf_processor(self):
746
+ hf_config = self.get_hf_config()
747
+ vision_info = self.get_vision_encoder_info()
748
+
749
+ if Version(TRANSFORMERS_VERSION) < Version("4.48"):
750
+ # BUG: num_additional_image_tokens = 0 but treated as 1,
751
+ # so we set vision_feature_select_strategy to None to offset this
752
+ vision_feature_select_strategy = None
753
+ else:
754
+ # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
755
+ vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
756
+
757
+ return self.ctx.get_hf_processor(
758
+ LlavaProcessor,
759
+ patch_size=vision_info.get_patch_size(),
760
+ vision_feature_select_strategy=vision_feature_select_strategy,
761
+ )
762
+
763
+
764
+ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
765
+
766
+ def apply(
767
+ self,
768
+ prompt: Union[str, list[int]],
769
+ mm_data: MultiModalDataDict,
770
+ hf_processor_mm_kwargs: Mapping[str, object],
771
+ ) -> MultiModalInputs:
772
+ hf_config = self.info.get_hf_config()
773
+ image_token_id = hf_config.image_token_index
774
+
775
+ # Assume that it doesn't depend on the image size
776
+ num_image_tokens = self.info.get_num_image_tokens(
777
+ image_width=-1,
778
+ image_height=-1,
779
+ )
780
+
781
+ result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
782
+
783
+ mm_items = self._to_mm_items(mm_data)
784
+ mm_item_counts = mm_items.get_all_counts()
785
+ mm_kwargs = result["mm_kwargs"]
786
+
787
+ # We reimplement the functionality of MLlavaProcessor from
788
+ # https://github.com/TIGER-AI-Lab/Mantis.git
789
+ def get_replacement_mantis(item_idx: int):
790
+ return "".join([
791
+ f"(image {item_idx+1}: <Image>", # 7 tokens
792
+ "<image>" * num_image_tokens,
793
+ "</Image>)", # 3 tokens
794
+ ])
795
+
796
+ mantis_mm_repls = self._bind_and_group_repls([
797
+ PromptReplacement(
798
+ modality="image",
799
+ target=[image_token_id] * num_image_tokens,
800
+ replacement=get_replacement_mantis,
801
+ )
802
+ ])
803
+
804
+ prompt_ids, prompt, _ = self._apply_prompt_replacements(
805
+ result["prompt_token_ids"],
806
+ mantis_mm_repls,
807
+ mm_item_counts,
808
+ )
809
+
810
+ unbound_orig_repls = self._get_prompt_replacements(
811
+ mm_items,
812
+ hf_processor_mm_kwargs,
813
+ mm_kwargs,
814
+ )
815
+ orig_repls = self._bind_and_group_repls(unbound_orig_repls)
816
+
817
+ mm_placeholders = self._find_mm_placeholders(
818
+ orig_repls,
819
+ prompt_ids,
820
+ mm_item_counts,
821
+ )
822
+
823
+ self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
824
+
825
+ mm_placeholder_ranges = {
826
+ modality: [item.to_range() for item in placeholders]
827
+ for modality, placeholders in mm_placeholders.items()
828
+ }
829
+
830
+ return MultiModalInputs(
831
+ type="multimodal",
832
+ prompt=prompt,
833
+ prompt_token_ids=prompt_ids,
834
+ mm_kwargs=mm_kwargs,
835
+ mm_placeholders=mm_placeholder_ranges,
836
+ )
837
+
838
+
839
+ # To use this model, please use
840
+ # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
841
+ @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
842
+ info=MantisProcessingInfo,
843
+ dummy_inputs=LlavaDummyInputsBuilder)
844
+ class MantisForConditionalGeneration(LlavaForConditionalGeneration):
845
+ pass