Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/custom_op.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/parameter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/pooling_metadata.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/sampling_metadata.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__init__.py +141 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/xgrammar_decoding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/outlines_logits_processors.py +229 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/utils.py +237 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__init__.py +20 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/loader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/openvino.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/loader.py +1441 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/neuron.py +212 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/utils.py +162 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/arctic.py +582 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/bart.py +1000 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/bert.py +534 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/blip2.py +736 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/bloom.py +385 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/chameleon.py +1161 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/chatglm.py +801 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek.py +503 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/eagle.py +214 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/falcon.py +529 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/florence2.py +266 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/fuyu.py +399 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma.py +458 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/glm4_vision_encoder.py +312 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt2.py +339 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_bigcode.py +359 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/granitemoe.py +461 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/h2ovl.py +553 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/idefics3.py +713 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/internlm2.py +495 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/internvl.py +962 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/jamba.py +632 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/llama.py +601 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/llava.py +845 -0
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (633 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/custom_op.cpython-311.pyc
ADDED
|
Binary file (6.88 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/parameter.cpython-311.pyc
ADDED
|
Binary file (21.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/pooling_metadata.cpython-311.pyc
ADDED
|
Binary file (3.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/sampling_metadata.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__init__.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
from vllm.logger import init_logger
|
| 8 |
+
from vllm.model_executor.guided_decoding.utils import (
|
| 9 |
+
convert_lark_to_gbnf, grammar_is_likely_lark,
|
| 10 |
+
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
| 11 |
+
from vllm.platforms import CpuArchEnum
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from transformers import PreTrainedTokenizer
|
| 15 |
+
|
| 16 |
+
from vllm.config import ModelConfig
|
| 17 |
+
from vllm.logits_process import LogitsProcessor
|
| 18 |
+
from vllm.sampling_params import GuidedDecodingParams
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def maybe_backend_fallback(
|
| 24 |
+
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
| 25 |
+
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
| 26 |
+
if guided_params.backend == "lm-format-enforcer":
|
| 27 |
+
if guided_params.grammar is not None:
|
| 28 |
+
logger.warning(
|
| 29 |
+
"lm-format-enforcer does not support grammar guided decoding. "
|
| 30 |
+
"Falling back to use xgrammar instead.")
|
| 31 |
+
guided_params.backend = "xgrammar"
|
| 32 |
+
|
| 33 |
+
# lm-format-enforcer doesn't support some JSON schema features
|
| 34 |
+
elif (guided_params.json is not None
|
| 35 |
+
and has_lmf_unsupported_json_features(guided_params.json)):
|
| 36 |
+
logger.warning(
|
| 37 |
+
"lm-format-enforcer does not support advanced JSON schema "
|
| 38 |
+
"features like patterns or numeric ranges. "
|
| 39 |
+
"Falling back to use outlines instead.")
|
| 40 |
+
guided_params.backend = "outlines"
|
| 41 |
+
|
| 42 |
+
if guided_params.backend == "xgrammar":
|
| 43 |
+
# xgrammar only has x86 wheels for linux, fallback to outlines
|
| 44 |
+
from vllm.platforms import current_platform
|
| 45 |
+
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
| 46 |
+
logger.warning("xgrammar is only supported on x86 CPUs. "
|
| 47 |
+
"Falling back to use outlines instead.")
|
| 48 |
+
guided_params.backend = "outlines"
|
| 49 |
+
|
| 50 |
+
# xgrammar doesn't support regex or choice, fallback to outlines
|
| 51 |
+
if guided_params.regex is not None or guided_params.choice is not None:
|
| 52 |
+
logger.warning(
|
| 53 |
+
"xgrammar only supports json or grammar guided decoding. "
|
| 54 |
+
"Falling back to use outlines instead.")
|
| 55 |
+
guided_params.backend = "outlines"
|
| 56 |
+
|
| 57 |
+
# xgrammar doesn't support some JSON schema features
|
| 58 |
+
elif (guided_params.json is not None
|
| 59 |
+
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
| 60 |
+
logger.warning(
|
| 61 |
+
"xgrammar does not support advanced JSON schema features like "
|
| 62 |
+
"patterns or numeric ranges. "
|
| 63 |
+
"Falling back to use outlines instead.")
|
| 64 |
+
guided_params.backend = "outlines"
|
| 65 |
+
|
| 66 |
+
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
| 67 |
+
# We must check if the grammar is likely Lark and if that
|
| 68 |
+
# grammar is convertible to GBNF
|
| 69 |
+
elif (guided_params.grammar is not None
|
| 70 |
+
and grammar_is_likely_lark(guided_params.grammar)):
|
| 71 |
+
try:
|
| 72 |
+
convert_lark_to_gbnf(guided_params.grammar)
|
| 73 |
+
except Exception:
|
| 74 |
+
logger.warning(
|
| 75 |
+
"xgrammar does not support Lark grammars and the "
|
| 76 |
+
"grammar failed to convert to GBNF. "
|
| 77 |
+
"Falling back to use outlines instead.")
|
| 78 |
+
guided_params.backend = "outlines"
|
| 79 |
+
|
| 80 |
+
if (guided_params.backend == "outlines"
|
| 81 |
+
and guided_params.json_object is not None):
|
| 82 |
+
# outlines doesn't support json_object, fallback to xgrammar
|
| 83 |
+
logger.warning("outlines does not support json_object. "
|
| 84 |
+
"Falling back to use xgrammar instead.")
|
| 85 |
+
guided_params.backend = "xgrammar"
|
| 86 |
+
|
| 87 |
+
return guided_params
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
async def get_guided_decoding_logits_processor(
|
| 91 |
+
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
| 92 |
+
model_config: ModelConfig) -> LogitsProcessor | None:
|
| 93 |
+
guided_params = maybe_backend_fallback(guided_params)
|
| 94 |
+
# CFG grammar not supported by LMFE, so we use outlines instead
|
| 95 |
+
if guided_params.backend == 'outlines':
|
| 96 |
+
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
| 97 |
+
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
| 98 |
+
get_outlines_guided_decoding_logits_processor)
|
| 99 |
+
return await get_outlines_guided_decoding_logits_processor(
|
| 100 |
+
guided_params, tokenizer)
|
| 101 |
+
if guided_params.backend == 'lm-format-enforcer':
|
| 102 |
+
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
| 103 |
+
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
| 104 |
+
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
| 105 |
+
guided_params, tokenizer)
|
| 106 |
+
if guided_params.backend == 'xgrammar':
|
| 107 |
+
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
| 108 |
+
get_local_xgrammar_guided_decoding_logits_processor)
|
| 109 |
+
return get_local_xgrammar_guided_decoding_logits_processor(
|
| 110 |
+
guided_params, tokenizer, model_config)
|
| 111 |
+
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
| 114 |
+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_local_guided_decoding_logits_processor(
|
| 118 |
+
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
| 119 |
+
model_config: ModelConfig) -> LogitsProcessor | None:
|
| 120 |
+
guided_params = maybe_backend_fallback(guided_params)
|
| 121 |
+
# CFG grammar not supported by LMFE, so we use outlines instead
|
| 122 |
+
if guided_params.backend == 'outlines':
|
| 123 |
+
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
| 124 |
+
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
| 125 |
+
get_local_outlines_guided_decoding_logits_processor)
|
| 126 |
+
return get_local_outlines_guided_decoding_logits_processor(
|
| 127 |
+
guided_params, tokenizer)
|
| 128 |
+
if guided_params.backend == 'lm-format-enforcer':
|
| 129 |
+
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
| 130 |
+
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
| 131 |
+
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
| 132 |
+
guided_params, tokenizer)
|
| 133 |
+
if guided_params.backend == 'xgrammar':
|
| 134 |
+
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
| 135 |
+
get_local_xgrammar_guided_decoding_logits_processor)
|
| 136 |
+
return get_local_xgrammar_guided_decoding_logits_processor(
|
| 137 |
+
guided_params, tokenizer, model_config)
|
| 138 |
+
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
| 141 |
+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (5.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-311.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-311.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-311.pyc
ADDED
|
Binary file (5.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/__pycache__/xgrammar_decoding.cpython-311.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/outlines_logits_processors.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2024- the Outlines developers
|
| 4 |
+
# This file is adapted from
|
| 5 |
+
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
import copy
|
| 19 |
+
import json
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from functools import lru_cache
|
| 22 |
+
from typing import Callable, DefaultDict, Dict, List, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from outlines import grammars
|
| 27 |
+
from outlines.caching import cache
|
| 28 |
+
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
|
| 29 |
+
RegexGuide, Write)
|
| 30 |
+
from outlines.fsm.parsing import PartialLark
|
| 31 |
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
| 32 |
+
from pydantic import BaseModel
|
| 33 |
+
from transformers import PreTrainedTokenizerBase
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BaseLogitsProcessor:
|
| 37 |
+
|
| 38 |
+
def __init__(self, guide: Guide):
|
| 39 |
+
self._guide: Guide = guide
|
| 40 |
+
# CFGState is used for the FSM state for CFGGuide
|
| 41 |
+
self._fsm_state: DefaultDict[int, Union[int,
|
| 42 |
+
CFGState]] = defaultdict(int)
|
| 43 |
+
|
| 44 |
+
def __call__(self, input_ids: List[int],
|
| 45 |
+
scores: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""Use the FSM to bias the logits before sampling the next token."""
|
| 47 |
+
seq_id = hash(tuple(input_ids))
|
| 48 |
+
|
| 49 |
+
if len(input_ids) > 0:
|
| 50 |
+
last_token = input_ids[-1]
|
| 51 |
+
last_seq_id = hash(tuple(input_ids[:-1]))
|
| 52 |
+
self._fsm_state[seq_id] = self._guide.get_next_state(
|
| 53 |
+
state=self._fsm_state[last_seq_id], token_id=last_token)
|
| 54 |
+
else:
|
| 55 |
+
# Note: this is a hack.
|
| 56 |
+
# Lark pickling does not work properly (silent failure),
|
| 57 |
+
# which breaks the RPC (which uses python pickleing).
|
| 58 |
+
# We need to find a better solution.
|
| 59 |
+
# On the first time this is called, we simply re-create
|
| 60 |
+
# the Lark object.
|
| 61 |
+
if isinstance(self._guide, CFGGuide):
|
| 62 |
+
self._guide.parser = PartialLark(
|
| 63 |
+
self._guide.cfg_string,
|
| 64 |
+
parser="lalr",
|
| 65 |
+
import_paths=[grammars.GRAMMAR_PATH],
|
| 66 |
+
)
|
| 67 |
+
self._fsm_state[seq_id] = CFGState(
|
| 68 |
+
parser_state=self._guide.parser.parse(""), prev_token=None)
|
| 69 |
+
|
| 70 |
+
instruction = self._guide.get_next_instruction(
|
| 71 |
+
state=self._fsm_state[seq_id])
|
| 72 |
+
|
| 73 |
+
if type(instruction) == Generate: # noqa: E721
|
| 74 |
+
allowed_tokens = instruction.tokens
|
| 75 |
+
elif type(instruction) == Write: # noqa: E721
|
| 76 |
+
# TODO: support fast forward tokens
|
| 77 |
+
allowed_tokens = [instruction.tokens[0]]
|
| 78 |
+
else:
|
| 79 |
+
raise TypeError(
|
| 80 |
+
f"Unsupported instruction type {type(instruction)}")
|
| 81 |
+
|
| 82 |
+
mask = torch.full((scores.shape[-1], ),
|
| 83 |
+
-torch.inf,
|
| 84 |
+
device=scores.device)
|
| 85 |
+
# The tokenizer may support more token ids than the model can generate,
|
| 86 |
+
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
|
| 87 |
+
# but scores.shape == torch.Size([128256])
|
| 88 |
+
# Using NumPy is faster for filtering token ids
|
| 89 |
+
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
|
| 90 |
+
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
| 91 |
+
allowed_tokens = allowed_tokens.masked_select(
|
| 92 |
+
allowed_tokens < scores.shape[-1])
|
| 93 |
+
mask.index_fill_(0, allowed_tokens, 0)
|
| 94 |
+
scores.add_(mask)
|
| 95 |
+
return scores
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RegexLogitsProcessor(BaseLogitsProcessor):
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
@cache()
|
| 102 |
+
def _get_guide(cls, regex_string: str,
|
| 103 |
+
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
| 104 |
+
tokenizer = _adapt_tokenizer(tokenizer)
|
| 105 |
+
return RegexGuide.from_regex(regex_string, tokenizer)
|
| 106 |
+
|
| 107 |
+
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
| 108 |
+
"""Compile the FSM that drives the regex-structured generation.
|
| 109 |
+
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
regex_string
|
| 113 |
+
A string that represents a regular expression
|
| 114 |
+
tokenizer
|
| 115 |
+
The model's tokenizer
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
super().__init__(
|
| 119 |
+
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class JSONLogitsProcessor(RegexLogitsProcessor):
|
| 123 |
+
|
| 124 |
+
def __init__(self, schema: Union[str, Dict, BaseModel],
|
| 125 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 126 |
+
whitespace_pattern: Union[str, None]):
|
| 127 |
+
"""Compile the FSM that drives the JSON-guided generation.
|
| 128 |
+
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
schema
|
| 132 |
+
A JSON schema that encodes the structure we want the model to
|
| 133 |
+
generate
|
| 134 |
+
tokenizer
|
| 135 |
+
The model's tokenizer
|
| 136 |
+
whitespace_pattern
|
| 137 |
+
Pattern to use for JSON syntactic whitespace (doesn't impact
|
| 138 |
+
string literals)
|
| 139 |
+
Example: allow only a single space or newline with
|
| 140 |
+
`whitespace_pattern=r"[\n ]?"`
|
| 141 |
+
"""
|
| 142 |
+
if isinstance(schema, type(BaseModel)):
|
| 143 |
+
schema_str = json.dumps(schema.model_json_schema())
|
| 144 |
+
elif isinstance(schema, Dict):
|
| 145 |
+
schema_str = json.dumps(schema)
|
| 146 |
+
elif isinstance(schema, str):
|
| 147 |
+
schema_str = schema
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"Cannot parse schema {schema}. The schema must be either "
|
| 151 |
+
f"a Pydantic object, a dictionary or a string that contains "
|
| 152 |
+
f"the JSON Schema specification")
|
| 153 |
+
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
| 154 |
+
super().__init__(regex_string, tokenizer)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class CFGLogitsProcessor(BaseLogitsProcessor):
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
@cache()
|
| 161 |
+
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
| 162 |
+
tokenizer = _adapt_tokenizer(tokenizer)
|
| 163 |
+
return CFGGuide(cfg, tokenizer)
|
| 164 |
+
|
| 165 |
+
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
| 166 |
+
"""Compile the FSM that drives the context free grammar generation.
|
| 167 |
+
|
| 168 |
+
Parameters
|
| 169 |
+
----------
|
| 170 |
+
cfg
|
| 171 |
+
A string that represents a context-free grammar
|
| 172 |
+
tokenizer
|
| 173 |
+
The model's tokenizer
|
| 174 |
+
|
| 175 |
+
"""
|
| 176 |
+
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
|
| 177 |
+
self._guide = self._guide.copy()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@lru_cache(maxsize=32)
|
| 181 |
+
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
| 182 |
+
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
| 183 |
+
|
| 184 |
+
The API of Outlines tokenizers is slightly different to that of
|
| 185 |
+
`transformers`. The decoder of outlines, returns a list whereas
|
| 186 |
+
the decode of vLLM returns an str. To sync the vLLM decoder with
|
| 187 |
+
outlines internal api, the decoder should be adapted. In addition
|
| 188 |
+
we need to handle the missing spaces to Llama's tokenizer to be
|
| 189 |
+
able to compile FSMs for this model.
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
if getattr(tokenizer, "_outlines_adapted", False):
|
| 193 |
+
return tokenizer
|
| 194 |
+
|
| 195 |
+
tokenizer = copy.deepcopy(tokenizer)
|
| 196 |
+
|
| 197 |
+
tokenizer.vocabulary = tokenizer.get_vocab()
|
| 198 |
+
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
| 199 |
+
|
| 200 |
+
def convert_token_to_string(token: str) -> str:
|
| 201 |
+
from transformers.file_utils import SPIECE_UNDERLINE
|
| 202 |
+
|
| 203 |
+
string = tokenizer.convert_tokens_to_string([token])
|
| 204 |
+
|
| 205 |
+
# A hack to handle missing spaces to HF's Llama tokenizers
|
| 206 |
+
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
|
| 207 |
+
or token == "<0x20>"):
|
| 208 |
+
return " " + string
|
| 209 |
+
|
| 210 |
+
return string
|
| 211 |
+
|
| 212 |
+
def change_decoder(
|
| 213 |
+
decoder: Callable[[List[int]],
|
| 214 |
+
str]) -> Callable[[List[int]], List[str]]:
|
| 215 |
+
"""Sync vLLM's decoder with the outlines by returning list."""
|
| 216 |
+
|
| 217 |
+
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
| 218 |
+
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
| 219 |
+
and isinstance(inp_tokens[0], list)):
|
| 220 |
+
inp_tokens = inp_tokens[0]
|
| 221 |
+
return [decoder(inp_tokens)]
|
| 222 |
+
|
| 223 |
+
return new_decoder
|
| 224 |
+
|
| 225 |
+
tokenizer.convert_token_to_string = convert_token_to_string
|
| 226 |
+
tokenizer.decode = change_decoder(tokenizer.decode)
|
| 227 |
+
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
| 228 |
+
|
| 229 |
+
return tokenizer
|
.venv/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/utils.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
| 7 |
+
"""Check if JSON schema contains features unsupported by xgrammar."""
|
| 8 |
+
|
| 9 |
+
def check_object(obj: dict) -> bool:
|
| 10 |
+
if not isinstance(obj, dict):
|
| 11 |
+
return False
|
| 12 |
+
|
| 13 |
+
# Check for pattern restrictions
|
| 14 |
+
if "pattern" in obj:
|
| 15 |
+
return True
|
| 16 |
+
|
| 17 |
+
# Check for numeric ranges
|
| 18 |
+
if obj.get("type") in ("integer", "number") and any(
|
| 19 |
+
key in obj for key in [
|
| 20 |
+
"minimum", "maximum", "exclusiveMinimum",
|
| 21 |
+
"exclusiveMaximum", "multipleOf"
|
| 22 |
+
]):
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
# Check for array unsupported keywords
|
| 26 |
+
if obj.get("type") == "array" and any(key in obj for key in [
|
| 27 |
+
"uniqueItems", "contains", "minContains", "maxContains",
|
| 28 |
+
"minItems", "maxItems"
|
| 29 |
+
]):
|
| 30 |
+
return True
|
| 31 |
+
|
| 32 |
+
# Recursively check all nested objects and arrays
|
| 33 |
+
for value in obj.values():
|
| 34 |
+
if isinstance(value, dict):
|
| 35 |
+
if check_object(value):
|
| 36 |
+
return True
|
| 37 |
+
elif isinstance(value, list):
|
| 38 |
+
for item in value:
|
| 39 |
+
if isinstance(item, dict) and check_object(item):
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
return check_object(schema)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def has_lmf_unsupported_json_features(schema: dict) -> bool:
|
| 48 |
+
"""
|
| 49 |
+
Check if JSON schema contains features unsupported
|
| 50 |
+
by lm_format_enforcer.
|
| 51 |
+
|
| 52 |
+
Known issues:
|
| 53 |
+
- Regex patterns:
|
| 54 |
+
"grade": {
|
| 55 |
+
"type": "string",
|
| 56 |
+
"pattern": "^[A-D]$" # Regex pattern
|
| 57 |
+
},
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def check_object(obj: dict) -> bool:
|
| 61 |
+
if not isinstance(obj, dict):
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
# Check for pattern restrictions
|
| 65 |
+
if "pattern" in obj:
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
# Recursively check all nested objects and arrays
|
| 69 |
+
for value in obj.values():
|
| 70 |
+
if isinstance(value, dict):
|
| 71 |
+
if check_object(value):
|
| 72 |
+
return True
|
| 73 |
+
elif isinstance(value, list):
|
| 74 |
+
for item in value:
|
| 75 |
+
if isinstance(item, dict) and check_object(item):
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
return check_object(schema)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
| 84 |
+
"""
|
| 85 |
+
Check if grammar appears to use Lark syntax.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
grammar_str: Input grammar string
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
bool: True if grammar appears to be in Lark format, False otherwise
|
| 92 |
+
|
| 93 |
+
Examples:
|
| 94 |
+
>>> grammar_is_likely_lark("rule: 'abc'")
|
| 95 |
+
True
|
| 96 |
+
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
| 97 |
+
False
|
| 98 |
+
"""
|
| 99 |
+
if not grammar_str or not isinstance(grammar_str, str):
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
for line in grammar_str.split('\n'):
|
| 103 |
+
# Remove both comment styles
|
| 104 |
+
line = re.sub(r'(#|//).*$', '', line).strip()
|
| 105 |
+
if not line:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# Look for GBNF rule definition
|
| 109 |
+
if '::=' in line:
|
| 110 |
+
return False
|
| 111 |
+
|
| 112 |
+
return True
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def convert_lark_to_gbnf(grammar_str: str) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Convert a Lark grammar string to GBNF format.
|
| 118 |
+
|
| 119 |
+
GBNF reference:
|
| 120 |
+
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
| 121 |
+
Lark grammar reference:
|
| 122 |
+
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
grammar_str: Input grammar in Lark format
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
str: Converted grammar in GBNF format
|
| 129 |
+
|
| 130 |
+
Examples:
|
| 131 |
+
>>> print(convert_lark_to_gbnf("rule: 'hello'"))
|
| 132 |
+
root ::= rule
|
| 133 |
+
rule ::= "hello"
|
| 134 |
+
"""
|
| 135 |
+
if not isinstance(grammar_str, str):
|
| 136 |
+
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
| 137 |
+
if not grammar_str.strip():
|
| 138 |
+
raise ValueError("Grammar string cannot be empty")
|
| 139 |
+
|
| 140 |
+
defined_rules = set()
|
| 141 |
+
referenced_rules = set()
|
| 142 |
+
output_lines = []
|
| 143 |
+
|
| 144 |
+
def clean_line(line: str) -> str:
|
| 145 |
+
"""Remove comments and whitespace from line."""
|
| 146 |
+
return re.sub(r'(#|//).*$', '', line).strip()
|
| 147 |
+
|
| 148 |
+
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
| 149 |
+
"""Validate quote matching in text."""
|
| 150 |
+
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
f"Mismatched quotes in {rule_name} on line {line_num}")
|
| 153 |
+
|
| 154 |
+
def extract_references(text: str) -> set:
|
| 155 |
+
"""Extract rule references from text."""
|
| 156 |
+
# Remove quoted strings and special characters
|
| 157 |
+
text = re.sub(r'"[^"]*"', '', text)
|
| 158 |
+
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
| 159 |
+
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
| 160 |
+
|
| 161 |
+
# First pass: Find root rule and validate rule definitions
|
| 162 |
+
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
| 163 |
+
first_rule = None
|
| 164 |
+
|
| 165 |
+
for line_num, line in enumerate(lines, 1):
|
| 166 |
+
if not line or line.startswith('|'):
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
if ':' in line:
|
| 170 |
+
try:
|
| 171 |
+
name = line.split(':', 1)[0].strip().strip('?')
|
| 172 |
+
defined_rules.add(name)
|
| 173 |
+
if first_rule is None:
|
| 174 |
+
first_rule = name
|
| 175 |
+
if name == 'start':
|
| 176 |
+
first_rule = 'start'
|
| 177 |
+
except IndexError as e:
|
| 178 |
+
raise ValueError(f"Invalid rule format on line {line_num}. "
|
| 179 |
+
"Expected 'rule_name: definition'") from e
|
| 180 |
+
|
| 181 |
+
if not defined_rules:
|
| 182 |
+
raise ValueError("No valid rules found in grammar")
|
| 183 |
+
|
| 184 |
+
# Add root rule
|
| 185 |
+
output_lines.append(f"root ::= {first_rule}")
|
| 186 |
+
|
| 187 |
+
# Second pass: Process rule definitions and alternatives
|
| 188 |
+
current_rule = None
|
| 189 |
+
current_definition = []
|
| 190 |
+
|
| 191 |
+
for line_num, line in enumerate(lines, 1):
|
| 192 |
+
if not line:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
if ':' in line and not line.startswith('|'):
|
| 197 |
+
# Save previous rule if exists
|
| 198 |
+
if current_rule:
|
| 199 |
+
output_lines.append(
|
| 200 |
+
f"{current_rule} ::= {' | '.join(current_definition)}")
|
| 201 |
+
|
| 202 |
+
# Process new rule
|
| 203 |
+
name, definition = line.split(':', 1)
|
| 204 |
+
current_rule = name.strip().strip('?')
|
| 205 |
+
|
| 206 |
+
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
| 207 |
+
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
| 208 |
+
referenced_rules.update(extract_references(definition))
|
| 209 |
+
current_definition = [definition.strip()]
|
| 210 |
+
|
| 211 |
+
elif line.startswith('|'):
|
| 212 |
+
if not current_rule:
|
| 213 |
+
raise ValueError(f"Alternative '|' on line {line_num} "
|
| 214 |
+
"without a preceding rule definition")
|
| 215 |
+
|
| 216 |
+
alt_def = line[1:].strip()
|
| 217 |
+
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
| 218 |
+
line_num)
|
| 219 |
+
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
| 220 |
+
referenced_rules.update(extract_references(alt_def))
|
| 221 |
+
current_definition.append(alt_def)
|
| 222 |
+
|
| 223 |
+
except ValueError as e:
|
| 224 |
+
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
| 225 |
+
|
| 226 |
+
# Add final rule if exists
|
| 227 |
+
if current_rule:
|
| 228 |
+
output_lines.append(
|
| 229 |
+
f"{current_rule} ::= {' | '.join(current_definition)}")
|
| 230 |
+
|
| 231 |
+
# Validate all rules are defined
|
| 232 |
+
undefined_rules = referenced_rules - defined_rules - {'root'}
|
| 233 |
+
if undefined_rules:
|
| 234 |
+
raise ValueError("Referenced rules are not defined: "
|
| 235 |
+
f"{', '.join(sorted(undefined_rules))}")
|
| 236 |
+
|
| 237 |
+
return '\n'.join(output_lines)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from vllm.config import VllmConfig
|
| 6 |
+
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
| 7 |
+
get_model_loader)
|
| 8 |
+
from vllm.model_executor.model_loader.utils import (
|
| 9 |
+
get_architecture_class_name, get_model_architecture)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
|
| 13 |
+
loader = get_model_loader(vllm_config.load_config)
|
| 14 |
+
return loader.load_model(vllm_config=vllm_config)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"get_model", "get_model_loader", "BaseModelLoader",
|
| 19 |
+
"get_architecture_class_name", "get_model_architecture"
|
| 20 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/loader.cpython-311.pyc
ADDED
|
Binary file (73.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/openvino.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-311.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (8.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-311.pyc
ADDED
|
Binary file (35.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/loader.py
ADDED
|
@@ -0,0 +1,1441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# ruff: noqa: SIM117
|
| 4 |
+
import collections
|
| 5 |
+
import copy
|
| 6 |
+
import dataclasses
|
| 7 |
+
import fnmatch
|
| 8 |
+
import glob
|
| 9 |
+
import inspect
|
| 10 |
+
import itertools
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import warnings
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from contextlib import contextmanager
|
| 16 |
+
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
|
| 17 |
+
Tuple, cast)
|
| 18 |
+
|
| 19 |
+
import gguf
|
| 20 |
+
import huggingface_hub
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from huggingface_hub import HfApi
|
| 24 |
+
from torch import nn
|
| 25 |
+
from transformers import AutoModelForCausalLM
|
| 26 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
| 27 |
+
|
| 28 |
+
from vllm.attention import Attention
|
| 29 |
+
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
|
| 30 |
+
VllmConfig, set_current_vllm_config)
|
| 31 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 32 |
+
get_tensor_model_parallel_world_size)
|
| 33 |
+
from vllm.envs import VLLM_USE_MODELSCOPE
|
| 34 |
+
from vllm.logger import init_logger
|
| 35 |
+
from vllm.model_executor.layers.linear import (LinearBase,
|
| 36 |
+
MergedColumnParallelLinear,
|
| 37 |
+
QKVParallelLinear,
|
| 38 |
+
ReplicatedLinear,
|
| 39 |
+
RowParallelLinear)
|
| 40 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 41 |
+
QuantizeMethodBase)
|
| 42 |
+
from vllm.model_executor.model_loader.tensorizer import (
|
| 43 |
+
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
|
| 44 |
+
serialize_vllm_model, tensorizer_weights_iterator)
|
| 45 |
+
from vllm.model_executor.model_loader.utils import (ParamMapping,
|
| 46 |
+
configure_quant_config,
|
| 47 |
+
get_model_architecture,
|
| 48 |
+
set_default_torch_dtype)
|
| 49 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 50 |
+
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
| 51 |
+
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
| 52 |
+
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
|
| 53 |
+
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
| 54 |
+
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
| 55 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 56 |
+
from vllm.platforms import current_platform
|
| 57 |
+
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
| 58 |
+
from vllm.transformers_utils.utils import is_s3
|
| 59 |
+
from vllm.utils import is_pin_memory_available
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@contextmanager
|
| 63 |
+
def device_loading_context(module: torch.nn.Module,
|
| 64 |
+
target_device: torch.device):
|
| 65 |
+
if target_device.type == "cpu":
|
| 66 |
+
# If target is CPU, no need to move anything
|
| 67 |
+
yield module
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
original_device_states: Dict[str, torch.device] = {}
|
| 71 |
+
|
| 72 |
+
# Store original device states and move parameters to GPU if they're on CPU
|
| 73 |
+
for name, p in module.named_parameters():
|
| 74 |
+
if p.device.type == "cpu":
|
| 75 |
+
original_device_states[name] = p.device
|
| 76 |
+
p.data = p.data.to(target_device)
|
| 77 |
+
# Parameters already on target device are not touched
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
yield module
|
| 81 |
+
|
| 82 |
+
finally:
|
| 83 |
+
# Restore parameters to their original devices, ignoring new parameters
|
| 84 |
+
pin_memory = is_pin_memory_available()
|
| 85 |
+
for name, p in module.named_parameters():
|
| 86 |
+
if name in original_device_states:
|
| 87 |
+
original_device: torch.device = original_device_states[name]
|
| 88 |
+
if original_device.type == "cpu":
|
| 89 |
+
# `torch.empty_like` does not support `pin_memory` argument
|
| 90 |
+
cpu_data = torch.empty_strided(
|
| 91 |
+
size=p.data.size(),
|
| 92 |
+
stride=p.data.stride(),
|
| 93 |
+
dtype=p.data.dtype,
|
| 94 |
+
layout=p.data.layout,
|
| 95 |
+
device="cpu",
|
| 96 |
+
pin_memory=pin_memory,
|
| 97 |
+
)
|
| 98 |
+
cpu_data.copy_(p.data)
|
| 99 |
+
p.data = cpu_data
|
| 100 |
+
else:
|
| 101 |
+
p.data = p.data.to(original_device)
|
| 102 |
+
# New parameters or parameters already on target device are untouched
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
logger = init_logger(__name__)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _initialize_model(
|
| 109 |
+
vllm_config: VllmConfig,
|
| 110 |
+
*,
|
| 111 |
+
prefix: str = "",
|
| 112 |
+
) -> nn.Module:
|
| 113 |
+
"""Initialize a model with the given configurations."""
|
| 114 |
+
model_config = vllm_config.model_config
|
| 115 |
+
model_class, _ = get_model_architecture(model_config)
|
| 116 |
+
|
| 117 |
+
if vllm_config.quant_config is not None:
|
| 118 |
+
configure_quant_config(vllm_config.quant_config, model_class)
|
| 119 |
+
|
| 120 |
+
signatures = inspect.signature(model_class.__init__)
|
| 121 |
+
all_params = [param.name for param in signatures.parameters.values()]
|
| 122 |
+
if "vllm_config" in all_params and "prefix" in all_params:
|
| 123 |
+
# new-style model class
|
| 124 |
+
with set_current_vllm_config(vllm_config, check_compile=True):
|
| 125 |
+
return model_class(vllm_config=vllm_config, prefix=prefix)
|
| 126 |
+
|
| 127 |
+
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
| 128 |
+
"input arguments. Possibly you have an old-style model class"
|
| 129 |
+
" registered from out of tree and it is used for new vLLM version. "
|
| 130 |
+
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
| 131 |
+
"for the design and update the model class accordingly.")
|
| 132 |
+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
| 133 |
+
|
| 134 |
+
logger.warning(
|
| 135 |
+
"Trying to guess the arguments for old-style model class %s",
|
| 136 |
+
model_class,
|
| 137 |
+
)
|
| 138 |
+
# try to be compatible with old-style model class
|
| 139 |
+
kwargs = {}
|
| 140 |
+
if "prefix" in all_params:
|
| 141 |
+
kwargs["prefix"] = prefix
|
| 142 |
+
if "config" in all_params:
|
| 143 |
+
kwargs["config"] = model_config.hf_config
|
| 144 |
+
if "cache_config" in all_params:
|
| 145 |
+
kwargs["cache_config"] = vllm_config.cache_config
|
| 146 |
+
if "quant_config" in all_params:
|
| 147 |
+
kwargs["quant_config"] = vllm_config.quant_config
|
| 148 |
+
if "lora_config" in all_params:
|
| 149 |
+
kwargs["lora_config"] = vllm_config.lora_config
|
| 150 |
+
if "scheduler_config" in all_params:
|
| 151 |
+
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
| 152 |
+
with set_current_vllm_config(vllm_config, check_compile=True):
|
| 153 |
+
return model_class(**kwargs)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BaseModelLoader(ABC):
|
| 157 |
+
"""Base class for model loaders."""
|
| 158 |
+
|
| 159 |
+
def __init__(self, load_config: LoadConfig):
|
| 160 |
+
self.load_config = load_config
|
| 161 |
+
|
| 162 |
+
@abstractmethod
|
| 163 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 164 |
+
"""Download a model so that it can be immediately loaded."""
|
| 165 |
+
raise NotImplementedError
|
| 166 |
+
|
| 167 |
+
@abstractmethod
|
| 168 |
+
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
|
| 169 |
+
"""Load a model with the given configurations."""
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class DefaultModelLoader(BaseModelLoader):
|
| 174 |
+
"""Model loader that can load different file types from disk."""
|
| 175 |
+
|
| 176 |
+
@dataclasses.dataclass
|
| 177 |
+
class Source:
|
| 178 |
+
"""A source for weights."""
|
| 179 |
+
|
| 180 |
+
model_or_path: str
|
| 181 |
+
"""The model ID or path."""
|
| 182 |
+
|
| 183 |
+
revision: Optional[str]
|
| 184 |
+
"""The optional model revision."""
|
| 185 |
+
|
| 186 |
+
prefix: str = ""
|
| 187 |
+
"""A prefix to prepend to all weights."""
|
| 188 |
+
|
| 189 |
+
fall_back_to_pt: bool = True
|
| 190 |
+
"""Whether .pt weights can be used."""
|
| 191 |
+
|
| 192 |
+
allow_patterns_overrides: Optional[list[str]] = None
|
| 193 |
+
"""If defined, weights will load exclusively using these patterns."""
|
| 194 |
+
|
| 195 |
+
def __init__(self, load_config: LoadConfig):
|
| 196 |
+
super().__init__(load_config)
|
| 197 |
+
if load_config.model_loader_extra_config:
|
| 198 |
+
raise ValueError(f"Model loader extra config is not supported for "
|
| 199 |
+
f"load format {load_config.load_format}")
|
| 200 |
+
|
| 201 |
+
def _maybe_download_from_modelscope(
|
| 202 |
+
self, model: str, revision: Optional[str]) -> Optional[str]:
|
| 203 |
+
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
| 204 |
+
|
| 205 |
+
Returns the path to the downloaded model, or None if the model is not
|
| 206 |
+
downloaded from ModelScope."""
|
| 207 |
+
if VLLM_USE_MODELSCOPE:
|
| 208 |
+
# download model from ModelScope hub,
|
| 209 |
+
# lazy import so that modelscope is not required for normal use.
|
| 210 |
+
# pylint: disable=C.
|
| 211 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
| 212 |
+
|
| 213 |
+
if not os.path.exists(model):
|
| 214 |
+
model_path = snapshot_download(
|
| 215 |
+
model_id=model,
|
| 216 |
+
cache_dir=self.load_config.download_dir,
|
| 217 |
+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
| 218 |
+
revision=revision,
|
| 219 |
+
ignore_file_pattern=self.load_config.ignore_patterns,
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
model_path = model
|
| 223 |
+
return model_path
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
def _prepare_weights(
|
| 227 |
+
self,
|
| 228 |
+
model_name_or_path: str,
|
| 229 |
+
revision: Optional[str],
|
| 230 |
+
fall_back_to_pt: bool,
|
| 231 |
+
allow_patterns_overrides: Optional[list[str]],
|
| 232 |
+
) -> Tuple[str, List[str], bool]:
|
| 233 |
+
"""Prepare weights for the model.
|
| 234 |
+
|
| 235 |
+
If the model is not local, it will be downloaded."""
|
| 236 |
+
model_name_or_path = (self._maybe_download_from_modelscope(
|
| 237 |
+
model_name_or_path, revision) or model_name_or_path)
|
| 238 |
+
|
| 239 |
+
is_local = os.path.isdir(model_name_or_path)
|
| 240 |
+
load_format = self.load_config.load_format
|
| 241 |
+
use_safetensors = False
|
| 242 |
+
index_file = SAFE_WEIGHTS_INDEX_NAME
|
| 243 |
+
# Some quantized models use .pt files for storing the weights.
|
| 244 |
+
if load_format == LoadFormat.AUTO:
|
| 245 |
+
allow_patterns = ["*.safetensors", "*.bin"]
|
| 246 |
+
elif load_format == LoadFormat.SAFETENSORS:
|
| 247 |
+
use_safetensors = True
|
| 248 |
+
allow_patterns = ["*.safetensors"]
|
| 249 |
+
elif load_format == LoadFormat.MISTRAL:
|
| 250 |
+
use_safetensors = True
|
| 251 |
+
allow_patterns = ["consolidated*.safetensors"]
|
| 252 |
+
index_file = "consolidated.safetensors.index.json"
|
| 253 |
+
elif load_format == LoadFormat.PT:
|
| 254 |
+
allow_patterns = ["*.pt"]
|
| 255 |
+
elif load_format == LoadFormat.NPCACHE:
|
| 256 |
+
allow_patterns = ["*.bin"]
|
| 257 |
+
else:
|
| 258 |
+
raise ValueError(f"Unknown load_format: {load_format}")
|
| 259 |
+
|
| 260 |
+
if fall_back_to_pt:
|
| 261 |
+
allow_patterns += ["*.pt"]
|
| 262 |
+
|
| 263 |
+
if allow_patterns_overrides is not None:
|
| 264 |
+
allow_patterns = allow_patterns_overrides
|
| 265 |
+
|
| 266 |
+
if not is_local:
|
| 267 |
+
hf_folder = download_weights_from_hf(
|
| 268 |
+
model_name_or_path,
|
| 269 |
+
self.load_config.download_dir,
|
| 270 |
+
allow_patterns,
|
| 271 |
+
revision,
|
| 272 |
+
ignore_patterns=self.load_config.ignore_patterns,
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
hf_folder = model_name_or_path
|
| 276 |
+
|
| 277 |
+
hf_weights_files: List[str] = []
|
| 278 |
+
for pattern in allow_patterns:
|
| 279 |
+
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
| 280 |
+
if len(hf_weights_files) > 0:
|
| 281 |
+
if pattern == "*.safetensors":
|
| 282 |
+
use_safetensors = True
|
| 283 |
+
break
|
| 284 |
+
|
| 285 |
+
if use_safetensors:
|
| 286 |
+
# For models like Mistral-7B-Instruct-v0.3
|
| 287 |
+
# there are both sharded safetensors files and a consolidated
|
| 288 |
+
# safetensors file. Using both breaks.
|
| 289 |
+
# Here, we download the `model.safetensors.index.json` and filter
|
| 290 |
+
# any files not found in the index.
|
| 291 |
+
if not is_local:
|
| 292 |
+
download_safetensors_index_file_from_hf(
|
| 293 |
+
model_name_or_path,
|
| 294 |
+
index_file,
|
| 295 |
+
self.load_config.download_dir,
|
| 296 |
+
revision,
|
| 297 |
+
)
|
| 298 |
+
hf_weights_files = filter_duplicate_safetensors_files(
|
| 299 |
+
hf_weights_files, hf_folder, index_file)
|
| 300 |
+
else:
|
| 301 |
+
hf_weights_files = filter_files_not_needed_for_inference(
|
| 302 |
+
hf_weights_files)
|
| 303 |
+
|
| 304 |
+
if len(hf_weights_files) == 0:
|
| 305 |
+
raise RuntimeError(
|
| 306 |
+
f"Cannot find any model weights with `{model_name_or_path}`")
|
| 307 |
+
|
| 308 |
+
return hf_folder, hf_weights_files, use_safetensors
|
| 309 |
+
|
| 310 |
+
def _get_weights_iterator(
|
| 311 |
+
self, source: "Source"
|
| 312 |
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
| 313 |
+
"""Get an iterator for the model weights based on the load format."""
|
| 314 |
+
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
| 315 |
+
source.model_or_path, source.revision, source.fall_back_to_pt,
|
| 316 |
+
source.allow_patterns_overrides)
|
| 317 |
+
if self.load_config.load_format == LoadFormat.NPCACHE:
|
| 318 |
+
# Currently np_cache only support *.bin checkpoints
|
| 319 |
+
assert use_safetensors is False
|
| 320 |
+
weights_iterator = np_cache_weights_iterator(
|
| 321 |
+
source.model_or_path,
|
| 322 |
+
self.load_config.download_dir,
|
| 323 |
+
hf_folder,
|
| 324 |
+
hf_weights_files,
|
| 325 |
+
)
|
| 326 |
+
elif use_safetensors:
|
| 327 |
+
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
| 328 |
+
else:
|
| 329 |
+
weights_iterator = pt_weights_iterator(hf_weights_files)
|
| 330 |
+
|
| 331 |
+
if current_platform.is_tpu():
|
| 332 |
+
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
| 333 |
+
# not too many ops are accumulated in the XLA program.
|
| 334 |
+
import torch_xla.core.xla_model as xm
|
| 335 |
+
|
| 336 |
+
def _xla_weights_iterator(iterator: Generator):
|
| 337 |
+
for weights in iterator:
|
| 338 |
+
yield weights
|
| 339 |
+
xm.mark_step()
|
| 340 |
+
|
| 341 |
+
weights_iterator = _xla_weights_iterator(weights_iterator)
|
| 342 |
+
|
| 343 |
+
# Apply the prefix.
|
| 344 |
+
return ((source.prefix + name, tensor)
|
| 345 |
+
for (name, tensor) in weights_iterator)
|
| 346 |
+
|
| 347 |
+
def _get_all_weights(
|
| 348 |
+
self,
|
| 349 |
+
model_config: ModelConfig,
|
| 350 |
+
model: nn.Module,
|
| 351 |
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
| 352 |
+
primary_weights = DefaultModelLoader.Source(
|
| 353 |
+
model_config.model,
|
| 354 |
+
model_config.revision,
|
| 355 |
+
prefix="",
|
| 356 |
+
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
| 357 |
+
True),
|
| 358 |
+
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
|
| 359 |
+
None),
|
| 360 |
+
)
|
| 361 |
+
yield from self._get_weights_iterator(primary_weights)
|
| 362 |
+
|
| 363 |
+
secondary_weights = cast(
|
| 364 |
+
Iterable[DefaultModelLoader.Source],
|
| 365 |
+
getattr(model, "secondary_weights", ()),
|
| 366 |
+
)
|
| 367 |
+
for source in secondary_weights:
|
| 368 |
+
yield from self._get_weights_iterator(source)
|
| 369 |
+
|
| 370 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 371 |
+
self._prepare_weights(model_config.model,
|
| 372 |
+
model_config.revision,
|
| 373 |
+
fall_back_to_pt=True,
|
| 374 |
+
allow_patterns_overrides=None)
|
| 375 |
+
|
| 376 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 377 |
+
device_config = vllm_config.device_config
|
| 378 |
+
model_config = vllm_config.model_config
|
| 379 |
+
|
| 380 |
+
target_device = torch.device(device_config.device)
|
| 381 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 382 |
+
with target_device:
|
| 383 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 384 |
+
|
| 385 |
+
weights_to_load = {name for name, _ in model.named_parameters()}
|
| 386 |
+
loaded_weights = model.load_weights(
|
| 387 |
+
self._get_all_weights(model_config, model))
|
| 388 |
+
# We only enable strict check for non-quantized models
|
| 389 |
+
# that have loaded weights tracking currently.
|
| 390 |
+
if model_config.quantization is None and loaded_weights is not None:
|
| 391 |
+
weights_not_loaded = weights_to_load - loaded_weights
|
| 392 |
+
if weights_not_loaded:
|
| 393 |
+
raise ValueError(
|
| 394 |
+
"Following weights were not initialized from "
|
| 395 |
+
f"checkpoint: {weights_not_loaded}")
|
| 396 |
+
|
| 397 |
+
for _, module in model.named_modules():
|
| 398 |
+
quant_method = getattr(module, "quant_method", None)
|
| 399 |
+
if isinstance(quant_method, QuantizeMethodBase):
|
| 400 |
+
# When quant methods need to process weights after loading
|
| 401 |
+
# (for repacking, quantizing, etc), they expect parameters
|
| 402 |
+
# to be on the global target device. This scope is for the
|
| 403 |
+
# case where cpu offloading is used, where we will move the
|
| 404 |
+
# parameters onto device for processing and back off after.
|
| 405 |
+
with device_loading_context(module, target_device):
|
| 406 |
+
quant_method.process_weights_after_loading(module)
|
| 407 |
+
if isinstance(module, Attention) and \
|
| 408 |
+
hasattr(module, "process_weights_after_loading"):
|
| 409 |
+
# When attention modules need to process weights after
|
| 410 |
+
# currently only used by MLA
|
| 411 |
+
# TODO(lucas): see if there is a way to unify the signatures
|
| 412 |
+
# of process_weights_after_loading
|
| 413 |
+
module.process_weights_after_loading(model_config.dtype)
|
| 414 |
+
return model.eval()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class DummyModelLoader(BaseModelLoader):
|
| 418 |
+
"""Model loader that will set model weights to random values."""
|
| 419 |
+
|
| 420 |
+
def __init__(self, load_config: LoadConfig):
|
| 421 |
+
super().__init__(load_config)
|
| 422 |
+
if load_config.model_loader_extra_config:
|
| 423 |
+
raise ValueError(f"Model loader extra config is not supported for "
|
| 424 |
+
f"load format {load_config.load_format}")
|
| 425 |
+
|
| 426 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 427 |
+
pass # Nothing to download
|
| 428 |
+
|
| 429 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 430 |
+
device_config = vllm_config.device_config
|
| 431 |
+
model_config = vllm_config.model_config
|
| 432 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 433 |
+
with torch.device(device_config.device):
|
| 434 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 435 |
+
# NOTE(woosuk): For accurate performance evaluation, we assign
|
| 436 |
+
# random values to the weights.
|
| 437 |
+
initialize_dummy_weights(model)
|
| 438 |
+
|
| 439 |
+
for _, module in model.named_modules():
|
| 440 |
+
quant_method = getattr(module, "quant_method", None)
|
| 441 |
+
if quant_method is not None:
|
| 442 |
+
# When quant methods need to process weights after loading
|
| 443 |
+
# (for repacking, quantizing, etc), they expect parameters
|
| 444 |
+
# to be on the global target device. This scope is for the
|
| 445 |
+
# case where cpu offloading is used, where we will move the
|
| 446 |
+
# parameters onto device for processing and back off after.
|
| 447 |
+
with device_loading_context(
|
| 448 |
+
module, torch.device(device_config.device)):
|
| 449 |
+
quant_method.process_weights_after_loading(module)
|
| 450 |
+
if isinstance(module, Attention) and \
|
| 451 |
+
hasattr(module, "process_weights_after_loading"):
|
| 452 |
+
# When attention modules need to process weights after
|
| 453 |
+
# currently only used by MLA
|
| 454 |
+
module.process_weights_after_loading(model_config.dtype)
|
| 455 |
+
return model.eval()
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class TensorizerLoader(BaseModelLoader):
|
| 459 |
+
"""Model loader using CoreWeave's tensorizer library."""
|
| 460 |
+
|
| 461 |
+
def __init__(self, load_config: LoadConfig):
|
| 462 |
+
super().__init__(load_config)
|
| 463 |
+
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
| 464 |
+
self.tensorizer_config = load_config.model_loader_extra_config
|
| 465 |
+
else:
|
| 466 |
+
self.tensorizer_config = TensorizerConfig(
|
| 467 |
+
**load_config.model_loader_extra_config)
|
| 468 |
+
|
| 469 |
+
def _verify_config(self, model_config: ModelConfig,
|
| 470 |
+
parallel_config: ParallelConfig):
|
| 471 |
+
self.tensorizer_config.verify_with_model_config(model_config)
|
| 472 |
+
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
| 473 |
+
|
| 474 |
+
def _get_weights_iterator(
|
| 475 |
+
self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
| 476 |
+
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
| 477 |
+
return tensorizer_weights_iterator(tensorizer_args)
|
| 478 |
+
|
| 479 |
+
def _load_model_serialized_cpu(
|
| 480 |
+
self,
|
| 481 |
+
vllm_config: VllmConfig,
|
| 482 |
+
) -> nn.Module:
|
| 483 |
+
"""Load a serialized model with tensorizer to the CPU.
|
| 484 |
+
|
| 485 |
+
This is only necessary when the model isn't vLLM-tensorized (see
|
| 486 |
+
examples/other/tensorize_vllm_model.py) This should still
|
| 487 |
+
be faster than default HuggingFace loading, but will be slower than
|
| 488 |
+
loading a vLLM-tensorized model.
|
| 489 |
+
"""
|
| 490 |
+
device_config = vllm_config.device_config
|
| 491 |
+
model_config = vllm_config.model_config
|
| 492 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 493 |
+
with torch.device(device_config.device):
|
| 494 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 495 |
+
|
| 496 |
+
model.load_weights(self._get_weights_iterator())
|
| 497 |
+
return model.eval()
|
| 498 |
+
|
| 499 |
+
def _load_model_serialized(
|
| 500 |
+
self,
|
| 501 |
+
vllm_config: VllmConfig,
|
| 502 |
+
) -> nn.Module:
|
| 503 |
+
"""Load a serialized model with tensorizer.
|
| 504 |
+
|
| 505 |
+
Expects a vLLM-tensorized model. See the
|
| 506 |
+
examples/other/tensorize_vllm_model.py example script
|
| 507 |
+
for serializing vLLM models."""
|
| 508 |
+
|
| 509 |
+
device_config = vllm_config.device_config
|
| 510 |
+
model_config = vllm_config.model_config
|
| 511 |
+
|
| 512 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 513 |
+
with torch.device(device_config.device):
|
| 514 |
+
model_class = get_model_architecture(model_config)[0]
|
| 515 |
+
|
| 516 |
+
tensorizer_config = copy.copy(self.tensorizer_config)
|
| 517 |
+
tensorizer_config.model_class = model_class
|
| 518 |
+
tensorizer_config.hf_config = model_config.hf_config
|
| 519 |
+
tensorizer_config.dtype = model_config.dtype
|
| 520 |
+
|
| 521 |
+
model = load_with_tensorizer(tensorizer_config,
|
| 522 |
+
vllm_config=vllm_config)
|
| 523 |
+
return model.eval()
|
| 524 |
+
|
| 525 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 526 |
+
self.tensorizer_config.verify_with_model_config(model_config)
|
| 527 |
+
|
| 528 |
+
with self.tensorizer_config.open_stream():
|
| 529 |
+
pass
|
| 530 |
+
|
| 531 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 532 |
+
model_config = vllm_config.model_config
|
| 533 |
+
parallel_config = vllm_config.parallel_config
|
| 534 |
+
self._verify_config(model_config, parallel_config)
|
| 535 |
+
|
| 536 |
+
if parallel_config.tensor_parallel_size > 1:
|
| 537 |
+
from vllm.distributed import get_tensor_model_parallel_rank
|
| 538 |
+
|
| 539 |
+
self.tensorizer_config.tensorizer_uri = (
|
| 540 |
+
self.tensorizer_config.tensorizer_uri %
|
| 541 |
+
get_tensor_model_parallel_rank())
|
| 542 |
+
|
| 543 |
+
if is_vllm_tensorized(self.tensorizer_config):
|
| 544 |
+
return self._load_model_serialized(vllm_config=vllm_config)
|
| 545 |
+
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
| 546 |
+
|
| 547 |
+
@staticmethod
|
| 548 |
+
def save_model(
|
| 549 |
+
model: torch.nn.Module,
|
| 550 |
+
tensorizer_config: TensorizerConfig,
|
| 551 |
+
) -> None:
|
| 552 |
+
serialize_vllm_model(
|
| 553 |
+
model=model,
|
| 554 |
+
tensorizer_config=tensorizer_config,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class ShardedStateLoader(BaseModelLoader):
|
| 559 |
+
"""
|
| 560 |
+
Model loader that directly loads each worker's model state dict, which
|
| 561 |
+
enables a fast load path for large tensor-parallel models where each worker
|
| 562 |
+
only needs to read its own shard rather than the entire checkpoint. See
|
| 563 |
+
`examples/offline_inference/save_sharded_state.py` for creating a sharded
|
| 564 |
+
checkpoint.
|
| 565 |
+
"""
|
| 566 |
+
|
| 567 |
+
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
| 568 |
+
|
| 569 |
+
def __init__(self, load_config: LoadConfig):
|
| 570 |
+
super().__init__(load_config)
|
| 571 |
+
extra_config = ({} if load_config.model_loader_extra_config is None
|
| 572 |
+
else load_config.model_loader_extra_config.copy())
|
| 573 |
+
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
| 574 |
+
if extra_config:
|
| 575 |
+
raise ValueError(f"Unexpected extra config keys for load format "
|
| 576 |
+
f"{load_config.load_format}: "
|
| 577 |
+
f"{load_config.model_loader_extra_config.keys()}")
|
| 578 |
+
|
| 579 |
+
@staticmethod
|
| 580 |
+
def _filter_subtensors(
|
| 581 |
+
tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
|
| 582 |
+
"""
|
| 583 |
+
Filter out all tensors that share the same memory or a subset of the
|
| 584 |
+
memory of another tensor.
|
| 585 |
+
"""
|
| 586 |
+
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
|
| 587 |
+
collections.defaultdict(list))
|
| 588 |
+
for key, tensor in tensors.items():
|
| 589 |
+
if tensor.numel():
|
| 590 |
+
ptr = tensor.untyped_storage().data_ptr()
|
| 591 |
+
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
| 592 |
+
|
| 593 |
+
def get_end_ptr(tensor: torch.Tensor) -> int:
|
| 594 |
+
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
| 595 |
+
|
| 596 |
+
result: Dict[str, torch.Tensor] = {}
|
| 597 |
+
for group in same_storage_groups.values():
|
| 598 |
+
for k, t in group:
|
| 599 |
+
a, b = t.data_ptr(), get_end_ptr(t)
|
| 600 |
+
for k2, t2 in group:
|
| 601 |
+
if not t2.is_contiguous():
|
| 602 |
+
continue
|
| 603 |
+
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
| 604 |
+
if a < a2 or b2 < b:
|
| 605 |
+
continue
|
| 606 |
+
if a2 < a or b < b2 or not t.is_contiguous():
|
| 607 |
+
break # t2 covers strictly more memory than t.
|
| 608 |
+
if k2 < k:
|
| 609 |
+
# Same tensors, keep the one with the smaller key.
|
| 610 |
+
break
|
| 611 |
+
else:
|
| 612 |
+
result[k] = t
|
| 613 |
+
return result
|
| 614 |
+
|
| 615 |
+
def _prepare_weights(self, model_name_or_path: str,
|
| 616 |
+
revision: Optional[str]):
|
| 617 |
+
if os.path.isdir(model_name_or_path):
|
| 618 |
+
return model_name_or_path
|
| 619 |
+
else:
|
| 620 |
+
allow_patterns = ["*.safetensors"]
|
| 621 |
+
return download_weights_from_hf(
|
| 622 |
+
model_name_or_path,
|
| 623 |
+
self.load_config.download_dir,
|
| 624 |
+
allow_patterns,
|
| 625 |
+
revision,
|
| 626 |
+
ignore_patterns=self.load_config.ignore_patterns,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 630 |
+
self._prepare_weights(model_config.model, model_config.revision)
|
| 631 |
+
|
| 632 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 633 |
+
device_config = vllm_config.device_config
|
| 634 |
+
model_config = vllm_config.model_config
|
| 635 |
+
from safetensors.torch import safe_open
|
| 636 |
+
|
| 637 |
+
from vllm.distributed import get_tensor_model_parallel_rank
|
| 638 |
+
|
| 639 |
+
local_model_path = self._prepare_weights(model_config.model,
|
| 640 |
+
model_config.revision)
|
| 641 |
+
|
| 642 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 643 |
+
with torch.device(device_config.device):
|
| 644 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 645 |
+
for _, module in model.named_modules():
|
| 646 |
+
quant_method = getattr(module, "quant_method", None)
|
| 647 |
+
if quant_method is not None:
|
| 648 |
+
quant_method.process_weights_after_loading(module)
|
| 649 |
+
if isinstance(module, Attention) and \
|
| 650 |
+
hasattr(module, "process_weights_after_loading"):
|
| 651 |
+
# When attention modules need to process weights after
|
| 652 |
+
# currently only used by MLA
|
| 653 |
+
module.process_weights_after_loading(
|
| 654 |
+
model_config.dtype)
|
| 655 |
+
rank = get_tensor_model_parallel_rank()
|
| 656 |
+
pattern = os.path.join(
|
| 657 |
+
local_model_path,
|
| 658 |
+
self.pattern.format(rank=rank, part="*"),
|
| 659 |
+
)
|
| 660 |
+
filepaths = glob.glob(pattern)
|
| 661 |
+
if not filepaths:
|
| 662 |
+
# TODO: support un-sharded checkpoints too
|
| 663 |
+
raise ValueError(
|
| 664 |
+
f"Could not find checkpoint files '{pattern}', only "
|
| 665 |
+
f"pre-sharded checkpoints are currently supported!")
|
| 666 |
+
state_dict = self._filter_subtensors(model.state_dict())
|
| 667 |
+
for path in filepaths:
|
| 668 |
+
with safe_open(path, framework="pt") as f:
|
| 669 |
+
for key in f.keys(): # noqa: SIM118
|
| 670 |
+
tensor = f.get_tensor(key)
|
| 671 |
+
# If loading with LoRA enabled, additional padding may
|
| 672 |
+
# be added to certain parameters. We only load into a
|
| 673 |
+
# narrowed view of the parameter data.
|
| 674 |
+
param_data = state_dict[key].data
|
| 675 |
+
param_shape = state_dict[key].shape
|
| 676 |
+
for dim, size in enumerate(tensor.shape):
|
| 677 |
+
if size < param_shape[dim]:
|
| 678 |
+
param_data = param_data.narrow(dim, 0, size)
|
| 679 |
+
if tensor.shape != param_shape:
|
| 680 |
+
logger.warning(
|
| 681 |
+
"loading tensor of shape %s into "
|
| 682 |
+
"parameter '%s' of shape %s",
|
| 683 |
+
tensor.shape,
|
| 684 |
+
key,
|
| 685 |
+
param_shape,
|
| 686 |
+
)
|
| 687 |
+
param_data.copy_(tensor)
|
| 688 |
+
state_dict.pop(key)
|
| 689 |
+
if state_dict:
|
| 690 |
+
raise ValueError(
|
| 691 |
+
f"Missing keys {tuple(state_dict)} in loaded state!")
|
| 692 |
+
return model.eval()
|
| 693 |
+
|
| 694 |
+
@staticmethod
|
| 695 |
+
def save_model(
|
| 696 |
+
model: torch.nn.Module,
|
| 697 |
+
path: str,
|
| 698 |
+
pattern: Optional[str] = None,
|
| 699 |
+
max_size: Optional[int] = None,
|
| 700 |
+
) -> None:
|
| 701 |
+
from safetensors.torch import save_file
|
| 702 |
+
|
| 703 |
+
from vllm.distributed import get_tensor_model_parallel_rank
|
| 704 |
+
|
| 705 |
+
if pattern is None:
|
| 706 |
+
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
| 707 |
+
rank = get_tensor_model_parallel_rank()
|
| 708 |
+
part_idx = 0
|
| 709 |
+
total_size = 0
|
| 710 |
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
| 711 |
+
state_dict_part: Dict[str, torch.Tensor] = {}
|
| 712 |
+
for key, tensor in state_dict.items():
|
| 713 |
+
param_size = tensor.nelement() * tensor.element_size()
|
| 714 |
+
if max_size is not None and total_size + param_size > max_size:
|
| 715 |
+
filename = pattern.format(rank=rank, part=part_idx)
|
| 716 |
+
save_file(
|
| 717 |
+
state_dict_part,
|
| 718 |
+
os.path.join(path, filename),
|
| 719 |
+
)
|
| 720 |
+
part_idx += 1
|
| 721 |
+
total_size = 0
|
| 722 |
+
state_dict_part = {}
|
| 723 |
+
state_dict_part[key] = tensor
|
| 724 |
+
total_size += param_size
|
| 725 |
+
if len(state_dict_part) > 0:
|
| 726 |
+
filename = pattern.format(rank=rank, part=part_idx)
|
| 727 |
+
save_file(
|
| 728 |
+
state_dict_part,
|
| 729 |
+
os.path.join(path, filename),
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
class BitsAndBytesModelLoader(BaseModelLoader):
|
| 734 |
+
"""Model loader to load model weights with BitAndBytes quantization."""
|
| 735 |
+
|
| 736 |
+
possible_config_file_names = ["adapter_config.json"]
|
| 737 |
+
|
| 738 |
+
def __init__(self, load_config: LoadConfig):
|
| 739 |
+
super().__init__(load_config)
|
| 740 |
+
|
| 741 |
+
# Save the module names without sharding.
|
| 742 |
+
self.unsharded_weights_modules: List[str] = []
|
| 743 |
+
# Save the module names that are sharded by column.
|
| 744 |
+
self.column_sharded_weights_modules: List[str] = []
|
| 745 |
+
# Store all module names (from transformers) that support
|
| 746 |
+
# BNB quantization.
|
| 747 |
+
self.target_modules: List[str] = []
|
| 748 |
+
# mapping weight names from transformers to vllm.
|
| 749 |
+
self.weight_mapper: Callable = lambda name: name
|
| 750 |
+
|
| 751 |
+
def _get_weight_files(
|
| 752 |
+
self,
|
| 753 |
+
model_name_or_path: str,
|
| 754 |
+
allowed_patterns: List[str],
|
| 755 |
+
revision: Optional[str] = None,
|
| 756 |
+
) -> Tuple[List[str], str]:
|
| 757 |
+
"""Retrieve weight files. Download the files if necessary.
|
| 758 |
+
|
| 759 |
+
Return the weight files and the file pattern."""
|
| 760 |
+
is_local = os.path.isdir(model_name_or_path)
|
| 761 |
+
|
| 762 |
+
if is_local:
|
| 763 |
+
for pattern in allowed_patterns:
|
| 764 |
+
weight_files = glob.glob(
|
| 765 |
+
os.path.join(model_name_or_path, pattern))
|
| 766 |
+
if weight_files:
|
| 767 |
+
return weight_files, pattern
|
| 768 |
+
else:
|
| 769 |
+
hf_api = HfApi()
|
| 770 |
+
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
| 771 |
+
for pattern in allowed_patterns:
|
| 772 |
+
matching_files = fnmatch.filter(repo_files, pattern)
|
| 773 |
+
if matching_files:
|
| 774 |
+
hf_folder = download_weights_from_hf(
|
| 775 |
+
model_name_or_path,
|
| 776 |
+
self.load_config.download_dir,
|
| 777 |
+
[pattern],
|
| 778 |
+
revision,
|
| 779 |
+
ignore_patterns=self.load_config.ignore_patterns,
|
| 780 |
+
)
|
| 781 |
+
return glob.glob(os.path.join(hf_folder, pattern)), pattern
|
| 782 |
+
|
| 783 |
+
raise RuntimeError(
|
| 784 |
+
f"No model weights found in: `{model_name_or_path}`")
|
| 785 |
+
|
| 786 |
+
def _prepare_weights(self, model_name_or_path: str,
|
| 787 |
+
revision: Optional[str]) -> Tuple[List[str], bool]:
|
| 788 |
+
"""Prepare weight files for the model."""
|
| 789 |
+
|
| 790 |
+
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
| 791 |
+
|
| 792 |
+
hf_weights_files, matched_pattern = self._get_weight_files(
|
| 793 |
+
model_name_or_path, allowed_patterns, revision)
|
| 794 |
+
|
| 795 |
+
if matched_pattern != "*.safetensors":
|
| 796 |
+
hf_weights_files = filter_files_not_needed_for_inference(
|
| 797 |
+
hf_weights_files)
|
| 798 |
+
|
| 799 |
+
if len(hf_weights_files) == 0:
|
| 800 |
+
raise RuntimeError(
|
| 801 |
+
f"Cannot find any model weights with `{model_name_or_path}`")
|
| 802 |
+
|
| 803 |
+
return hf_weights_files, matched_pattern == "*.safetensors"
|
| 804 |
+
|
| 805 |
+
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
| 806 |
+
if use_safetensors:
|
| 807 |
+
iterator = safetensors_weights_iterator(hf_weights_files)
|
| 808 |
+
else:
|
| 809 |
+
iterator = pt_weights_iterator(hf_weights_files)
|
| 810 |
+
for org_name, param in iterator:
|
| 811 |
+
# mapping weight names from transformers to vllm while preserving
|
| 812 |
+
# original names.
|
| 813 |
+
mapped_name = self.weight_mapper(org_name)
|
| 814 |
+
yield org_name, mapped_name, param
|
| 815 |
+
|
| 816 |
+
def _get_quantized_weights_iterator(
|
| 817 |
+
self,
|
| 818 |
+
model_name_or_path: str,
|
| 819 |
+
revision: Optional[str],
|
| 820 |
+
pre_quant: bool,
|
| 821 |
+
load_8bit: bool,
|
| 822 |
+
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
|
| 823 |
+
Any]]:
|
| 824 |
+
"""Get an iterator to the model weights with bitsandbytes quantization,
|
| 825 |
+
as well as the quantization state dictionary."""
|
| 826 |
+
|
| 827 |
+
# only load the bitsandbytes module when needed
|
| 828 |
+
try:
|
| 829 |
+
import bitsandbytes
|
| 830 |
+
|
| 831 |
+
if bitsandbytes.__version__ < "0.45.0":
|
| 832 |
+
raise ImportError("bitsandbytes version is wrong. Please "
|
| 833 |
+
"install bitsandbytes>=0.45.0.")
|
| 834 |
+
except ImportError as err:
|
| 835 |
+
raise ImportError("Please install bitsandbytes>=0.45.0 via "
|
| 836 |
+
"`pip install bitsandbytes>=0.45.0` to use "
|
| 837 |
+
"bitsandbytes quantizer.") from err
|
| 838 |
+
|
| 839 |
+
hf_weights_files, use_safetensors = self._prepare_weights(
|
| 840 |
+
model_name_or_path, revision)
|
| 841 |
+
|
| 842 |
+
quant_state_dict: Dict[str, Any] = {}
|
| 843 |
+
|
| 844 |
+
if pre_quant:
|
| 845 |
+
if load_8bit:
|
| 846 |
+
return self._quantized_8bit_generator(
|
| 847 |
+
hf_weights_files, use_safetensors,
|
| 848 |
+
quant_state_dict), quant_state_dict
|
| 849 |
+
else:
|
| 850 |
+
return self._quantized_4bit_generator(
|
| 851 |
+
hf_weights_files, use_safetensors,
|
| 852 |
+
quant_state_dict), quant_state_dict
|
| 853 |
+
|
| 854 |
+
return self._unquantized_generator(hf_weights_files, use_safetensors,
|
| 855 |
+
quant_state_dict), quant_state_dict
|
| 856 |
+
|
| 857 |
+
def _is_8bit_weight_name(self, weight_name: str):
|
| 858 |
+
quantized_suffix = {".scb", ".weight_format"}
|
| 859 |
+
return any(weight_name.lower().endswith(suffix)
|
| 860 |
+
for suffix in quantized_suffix)
|
| 861 |
+
|
| 862 |
+
def _is_4bit_weight_name(self, weight_name: str):
|
| 863 |
+
quantized_suffix = {
|
| 864 |
+
"absmax",
|
| 865 |
+
"quant_map",
|
| 866 |
+
"nested_absmax",
|
| 867 |
+
"nested_quant_map",
|
| 868 |
+
"bitsandbytes",
|
| 869 |
+
}
|
| 870 |
+
suffix = weight_name.split(".")[-1]
|
| 871 |
+
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
| 872 |
+
|
| 873 |
+
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
| 874 |
+
quant_state_dict) -> Generator:
|
| 875 |
+
for (
|
| 876 |
+
org_weight_name,
|
| 877 |
+
mapped_weight_name,
|
| 878 |
+
weight_tensor,
|
| 879 |
+
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
| 880 |
+
if not mapped_weight_name.lower().endswith(".scb"):
|
| 881 |
+
continue
|
| 882 |
+
|
| 883 |
+
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
| 884 |
+
quant_state_dict[weight_key] = weight_tensor
|
| 885 |
+
|
| 886 |
+
for (
|
| 887 |
+
org_weight_name,
|
| 888 |
+
mapped_weight_name,
|
| 889 |
+
weight_tensor,
|
| 890 |
+
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
| 891 |
+
if self._is_8bit_weight_name(mapped_weight_name):
|
| 892 |
+
continue
|
| 893 |
+
|
| 894 |
+
if mapped_weight_name in quant_state_dict:
|
| 895 |
+
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
| 896 |
+
yield org_weight_name, weight_tensor
|
| 897 |
+
else:
|
| 898 |
+
yield org_weight_name, weight_tensor
|
| 899 |
+
|
| 900 |
+
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
| 901 |
+
quant_state_dict) -> Generator:
|
| 902 |
+
from bitsandbytes.functional import QuantState
|
| 903 |
+
|
| 904 |
+
# First iterate over all quant state weights
|
| 905 |
+
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
| 906 |
+
use_safetensors)
|
| 907 |
+
temp_state_dict = {}
|
| 908 |
+
for (
|
| 909 |
+
org_weight_name,
|
| 910 |
+
mapped_weight_name,
|
| 911 |
+
weight_tensor,
|
| 912 |
+
) in weight_iterator:
|
| 913 |
+
if not self._is_4bit_weight_name(mapped_weight_name):
|
| 914 |
+
continue
|
| 915 |
+
# bitsandbytes library requires
|
| 916 |
+
# weight.quant_state.bitsandbytes__* in CPU
|
| 917 |
+
if "quant_state.bitsandbytes" in mapped_weight_name:
|
| 918 |
+
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
| 919 |
+
else:
|
| 920 |
+
temp_state_dict[mapped_weight_name] = weight_tensor
|
| 921 |
+
|
| 922 |
+
# Closure to parse quant_state for each prequant weight
|
| 923 |
+
def _parse_quant_state(param_name: str,
|
| 924 |
+
temp_state_dict: Dict) -> QuantState:
|
| 925 |
+
quant_state = {}
|
| 926 |
+
for k in temp_state_dict:
|
| 927 |
+
if param_name + "." in k:
|
| 928 |
+
quant_state[k] = temp_state_dict[k]
|
| 929 |
+
|
| 930 |
+
return QuantState.from_dict(quant_state, device="cuda")
|
| 931 |
+
|
| 932 |
+
# Second iterate over all prequant and normal weights
|
| 933 |
+
# pre quantized weights would have a quant_state
|
| 934 |
+
for (
|
| 935 |
+
org_weight_name,
|
| 936 |
+
mapped_weight_name,
|
| 937 |
+
weight_tensor,
|
| 938 |
+
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
| 939 |
+
if self._is_4bit_weight_name(mapped_weight_name):
|
| 940 |
+
continue
|
| 941 |
+
|
| 942 |
+
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
|
| 943 |
+
in temp_state_dict) or (
|
| 944 |
+
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
|
| 945 |
+
in temp_state_dict):
|
| 946 |
+
quant_state = _parse_quant_state(mapped_weight_name,
|
| 947 |
+
temp_state_dict)
|
| 948 |
+
quant_state_dict[mapped_weight_name] = quant_state
|
| 949 |
+
yield org_weight_name, weight_tensor
|
| 950 |
+
else:
|
| 951 |
+
yield org_weight_name, weight_tensor
|
| 952 |
+
|
| 953 |
+
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
| 954 |
+
quant_state_dict) -> Generator:
|
| 955 |
+
from bitsandbytes.functional import quantize_4bit
|
| 956 |
+
|
| 957 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 958 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 959 |
+
|
| 960 |
+
for (
|
| 961 |
+
org_weight_name,
|
| 962 |
+
mapped_weight_name,
|
| 963 |
+
weight_tensor,
|
| 964 |
+
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
| 965 |
+
if any(target_module in mapped_weight_name
|
| 966 |
+
for target_module in self.target_modules
|
| 967 |
+
) and mapped_weight_name.endswith(".weight"):
|
| 968 |
+
# Without sharding
|
| 969 |
+
if any(
|
| 970 |
+
mapped_weight_name.startswith(module)
|
| 971 |
+
for module in self.unsharded_weights_modules):
|
| 972 |
+
weight_sub_tensor = weight_tensor
|
| 973 |
+
# Shard by column
|
| 974 |
+
elif any(
|
| 975 |
+
mapped_weight_name.startswith(module)
|
| 976 |
+
for module in self.column_sharded_weights_modules):
|
| 977 |
+
total_size = weight_tensor.size(-1)
|
| 978 |
+
start_index = total_size // tp_size * tp_rank
|
| 979 |
+
end_index = total_size // tp_size * (tp_rank + 1)
|
| 980 |
+
weight_sub_tensor = weight_tensor[...,
|
| 981 |
+
start_index:end_index]
|
| 982 |
+
# Weights have fused on disk. In this case, we assume that the
|
| 983 |
+
# weight and module use same name.
|
| 984 |
+
elif any(
|
| 985 |
+
mapped_weight_name.startswith(module)
|
| 986 |
+
for module in self.maybe_fused_weights_modules):
|
| 987 |
+
# special case for fused weights
|
| 988 |
+
# get the size of each shard weight tensor
|
| 989 |
+
total_shard_sizes = next(
|
| 990 |
+
(sizes for module, sizes in
|
| 991 |
+
self.maybe_fused_weights_modules.items()
|
| 992 |
+
if mapped_weight_name.startswith(module)))
|
| 993 |
+
total_size = weight_tensor.size(0)
|
| 994 |
+
assert total_size == sum(total_shard_sizes)
|
| 995 |
+
# get the start/end index of each shard weight tensor
|
| 996 |
+
total_start_index = list(
|
| 997 |
+
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
| 998 |
+
shard_weights_index = [(
|
| 999 |
+
idx + size // tp_size * tp_rank,
|
| 1000 |
+
idx + size // tp_size * (tp_rank + 1),
|
| 1001 |
+
) for idx, size in zip(total_start_index,
|
| 1002 |
+
total_shard_sizes)]
|
| 1003 |
+
# slice and reorder the weight tensor
|
| 1004 |
+
weight_tensor = [
|
| 1005 |
+
weight_tensor[start_index:end_index, ...]
|
| 1006 |
+
for start_index, end_index in shard_weights_index
|
| 1007 |
+
]
|
| 1008 |
+
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
| 1009 |
+
# Shard by row
|
| 1010 |
+
else:
|
| 1011 |
+
total_size = weight_tensor.size(0)
|
| 1012 |
+
start_index = total_size // tp_size * tp_rank
|
| 1013 |
+
end_index = total_size // tp_size * (tp_rank + 1)
|
| 1014 |
+
weight_sub_tensor = weight_tensor[start_index:end_index,
|
| 1015 |
+
...]
|
| 1016 |
+
|
| 1017 |
+
# bitsandbytes requires data in GPU
|
| 1018 |
+
if weight_sub_tensor.is_cuda:
|
| 1019 |
+
loaded_weight = weight_sub_tensor
|
| 1020 |
+
else:
|
| 1021 |
+
loaded_weight = weight_sub_tensor.cuda()
|
| 1022 |
+
|
| 1023 |
+
# remove the following after the issue is fixed:
|
| 1024 |
+
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
| 1025 |
+
if loaded_weight.is_contiguous() is False:
|
| 1026 |
+
loaded_weight = loaded_weight.contiguous()
|
| 1027 |
+
|
| 1028 |
+
with set_default_torch_dtype(torch.float32):
|
| 1029 |
+
processed_weight, quant_state = quantize_4bit(
|
| 1030 |
+
loaded_weight,
|
| 1031 |
+
compress_statistics=True,
|
| 1032 |
+
quant_type="nf4",
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
quant_state_dict[mapped_weight_name] = quant_state
|
| 1036 |
+
else:
|
| 1037 |
+
processed_weight = weight_tensor
|
| 1038 |
+
yield org_weight_name, processed_weight
|
| 1039 |
+
|
| 1040 |
+
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
| 1041 |
+
|
| 1042 |
+
for name, module in model.named_modules():
|
| 1043 |
+
if isinstance(module, (LinearBase, )):
|
| 1044 |
+
if modules_info := self.modules_mapping.get_sub_modules(name):
|
| 1045 |
+
# Map vllm's names to transformers's names.
|
| 1046 |
+
rep_name, sub_modules = modules_info
|
| 1047 |
+
for sub_name in sub_modules:
|
| 1048 |
+
self.target_modules.append(
|
| 1049 |
+
name.replace(rep_name, sub_name))
|
| 1050 |
+
# Add original module name even if the module has stacked map,
|
| 1051 |
+
# in case model has a mixture of disk-merged and disk-splitted
|
| 1052 |
+
# weights with same last name.
|
| 1053 |
+
self.target_modules.append(name)
|
| 1054 |
+
|
| 1055 |
+
assert (self.target_modules
|
| 1056 |
+
), "vllm currently does not support BNB quantization for"
|
| 1057 |
+
f" {type(model).__name__}"
|
| 1058 |
+
|
| 1059 |
+
def _load_weights(self, model_config: ModelConfig,
|
| 1060 |
+
model: nn.Module) -> None:
|
| 1061 |
+
if not hasattr(model, "load_weights"):
|
| 1062 |
+
raise AttributeError(
|
| 1063 |
+
"The required method 'load_weights' is not defined in class"
|
| 1064 |
+
f" {type(model).__name__}.")
|
| 1065 |
+
|
| 1066 |
+
if not hasattr(model, "packed_modules_mapping"):
|
| 1067 |
+
raise AttributeError(
|
| 1068 |
+
f"Model {type(model).__name__} does not support BitsAndBytes "
|
| 1069 |
+
"quantization yet. No 'packed_modules_mapping' found.")
|
| 1070 |
+
|
| 1071 |
+
self.modules_mapping = ParamMapping(
|
| 1072 |
+
copy.deepcopy(model.packed_modules_mapping))
|
| 1073 |
+
|
| 1074 |
+
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
| 1075 |
+
# to ensure correct loading of weights.
|
| 1076 |
+
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
| 1077 |
+
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
| 1078 |
+
|
| 1079 |
+
# Modules whose weights might have fused on disk
|
| 1080 |
+
# we need their output_sizes to make shard in flight correctly with TP
|
| 1081 |
+
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
| 1082 |
+
self._get_bnb_target_modules(model)
|
| 1083 |
+
for name, module in model.named_modules():
|
| 1084 |
+
# Some modules like `ReplicatedLinear` should not have their weights
|
| 1085 |
+
# sharded. The reason for implementing it this way is to avoid new
|
| 1086 |
+
# static variable in the model implementation.
|
| 1087 |
+
if isinstance(module, (ReplicatedLinear, )):
|
| 1088 |
+
self.unsharded_weights_modules.append(name)
|
| 1089 |
+
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
| 1090 |
+
# fused weights on disk. We need to use the output sizes of these
|
| 1091 |
+
# modules to shard the weights correctly.
|
| 1092 |
+
elif isinstance(module,
|
| 1093 |
+
(QKVParallelLinear, MergedColumnParallelLinear)):
|
| 1094 |
+
self.maybe_fused_weights_modules[name] = module.output_sizes
|
| 1095 |
+
# In TP, these weights are partitioned along the column
|
| 1096 |
+
# dimension (dim=-1)
|
| 1097 |
+
elif isinstance(module, (RowParallelLinear, )):
|
| 1098 |
+
self.column_sharded_weights_modules.append(name)
|
| 1099 |
+
|
| 1100 |
+
self.model_type = type(model).__name__
|
| 1101 |
+
|
| 1102 |
+
logger.info("Loading weights with BitsAndBytes quantization. "
|
| 1103 |
+
" May take a while ...")
|
| 1104 |
+
|
| 1105 |
+
quant_config = getattr(model_config.hf_config, "quantization_config",
|
| 1106 |
+
None)
|
| 1107 |
+
|
| 1108 |
+
pre_quant = False
|
| 1109 |
+
if quant_config is not None:
|
| 1110 |
+
quant_method = quant_config.get("quant_method")
|
| 1111 |
+
if quant_method == "bitsandbytes":
|
| 1112 |
+
pre_quant = True
|
| 1113 |
+
else:
|
| 1114 |
+
raise ValueError(
|
| 1115 |
+
f"BitsAndBytes loader does not support {quant_method} "
|
| 1116 |
+
"quantization")
|
| 1117 |
+
|
| 1118 |
+
# The quant_states in pre_quantized models cannot work with a split
|
| 1119 |
+
# weight tensor. So TP does not work with pre_quantized bnb models.
|
| 1120 |
+
if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
| 1121 |
+
raise ValueError(
|
| 1122 |
+
"Prequant BitsAndBytes models with tensor parallelism is not "
|
| 1123 |
+
"supported. Please try with pipeline parallelism.")
|
| 1124 |
+
|
| 1125 |
+
load_8bit = False
|
| 1126 |
+
if pre_quant:
|
| 1127 |
+
load_8bit = quant_config.get("load_in_8bit", False)
|
| 1128 |
+
|
| 1129 |
+
qweight_iterator, quant_state_dict = (
|
| 1130 |
+
self._get_quantized_weights_iterator(model_config.model,
|
| 1131 |
+
model_config.revision,
|
| 1132 |
+
pre_quant, load_8bit))
|
| 1133 |
+
|
| 1134 |
+
weights_to_load = {name for name, _ in model.named_parameters()}
|
| 1135 |
+
loaded_weights = model.load_weights(qweight_iterator)
|
| 1136 |
+
# Some models may have weights loading tracker unimplemented.
|
| 1137 |
+
if loaded_weights is not None:
|
| 1138 |
+
weights_not_loaded = weights_to_load - loaded_weights
|
| 1139 |
+
if weights_not_loaded:
|
| 1140 |
+
raise ValueError("Following weights were not initialized from "
|
| 1141 |
+
f"checkpoint: {weights_not_loaded}")
|
| 1142 |
+
|
| 1143 |
+
torch.cuda.empty_cache()
|
| 1144 |
+
|
| 1145 |
+
param_dict = dict(model.named_parameters())
|
| 1146 |
+
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
| 1147 |
+
# TODO: Change this lazy import to normal import
|
| 1148 |
+
# after the checks are updated to run on a new version
|
| 1149 |
+
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
| 1150 |
+
|
| 1151 |
+
for quant_param_name in quant_state_dict:
|
| 1152 |
+
if is_pp_missing_parameter(quant_param_name, model):
|
| 1153 |
+
continue
|
| 1154 |
+
|
| 1155 |
+
non_stacked_param_name = quant_param_name
|
| 1156 |
+
|
| 1157 |
+
shard_index = 0
|
| 1158 |
+
for shard_name, (
|
| 1159 |
+
weight_name,
|
| 1160 |
+
index,
|
| 1161 |
+
) in self.modules_mapping.inverse_packed_mapping.items():
|
| 1162 |
+
# Some models, such as MiniCPM V2.5/2.6, contain both
|
| 1163 |
+
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
| 1164 |
+
# from being incorrectly identified as being present in
|
| 1165 |
+
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
| 1166 |
+
shard_pos = quant_param_name.find(shard_name)
|
| 1167 |
+
can_correct_rename = (shard_pos
|
| 1168 |
+
> 0) and (quant_param_name[shard_pos - 1]
|
| 1169 |
+
== ".")
|
| 1170 |
+
# If the quant_param_name is packed, it won't occur in the
|
| 1171 |
+
# param_dict before renaming.
|
| 1172 |
+
new_quant_param_name = quant_param_name.replace(
|
| 1173 |
+
shard_name, weight_name)
|
| 1174 |
+
need_rename = (quant_param_name not in param_dict) \
|
| 1175 |
+
and (new_quant_param_name in param_dict)
|
| 1176 |
+
if can_correct_rename and need_rename:
|
| 1177 |
+
shard_index = index
|
| 1178 |
+
quant_param_name = new_quant_param_name
|
| 1179 |
+
break
|
| 1180 |
+
|
| 1181 |
+
# Models like Clip/Siglip may skip some layers in initialization,
|
| 1182 |
+
# causing unused quant_param_name in state_dict.
|
| 1183 |
+
if quant_param_name not in param_dict:
|
| 1184 |
+
continue
|
| 1185 |
+
|
| 1186 |
+
if quant_param_name not in stacked_quant_state_dict:
|
| 1187 |
+
stacked_quant_state_dict[quant_param_name] = {}
|
| 1188 |
+
|
| 1189 |
+
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
| 1190 |
+
quant_state_dict[non_stacked_param_name])
|
| 1191 |
+
|
| 1192 |
+
# save quant_states and offsets as the attributes of the parameters
|
| 1193 |
+
for param_name, param in param_dict.items():
|
| 1194 |
+
if param_name in stacked_quant_state_dict:
|
| 1195 |
+
quant_states = stacked_quant_state_dict[param_name]
|
| 1196 |
+
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
| 1197 |
+
|
| 1198 |
+
pack_ratio = getattr(param, "pack_factor", -1)
|
| 1199 |
+
if pack_ratio == -1:
|
| 1200 |
+
raise ValueError(
|
| 1201 |
+
f"pack_factor not set for parameter {param_name}.")
|
| 1202 |
+
|
| 1203 |
+
num_elements = [0] * len(quant_states)
|
| 1204 |
+
for seq, quant_state in quant_states.items():
|
| 1205 |
+
num_elements[seq] = (math.prod(quant_state.shape) //
|
| 1206 |
+
pack_ratio)
|
| 1207 |
+
|
| 1208 |
+
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
| 1209 |
+
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
| 1210 |
+
|
| 1211 |
+
if load_8bit:
|
| 1212 |
+
set_weight_attrs(
|
| 1213 |
+
param, {"matmul_state": [None] * len(quant_states)})
|
| 1214 |
+
|
| 1215 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 1216 |
+
self._prepare_weights(model_config.model, model_config.revision)
|
| 1217 |
+
|
| 1218 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 1219 |
+
device_config = vllm_config.device_config
|
| 1220 |
+
model_config = vllm_config.model_config
|
| 1221 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 1222 |
+
with torch.device(device_config.device):
|
| 1223 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 1224 |
+
|
| 1225 |
+
self._load_weights(model_config, model)
|
| 1226 |
+
|
| 1227 |
+
return model.eval()
|
| 1228 |
+
|
| 1229 |
+
|
| 1230 |
+
class GGUFModelLoader(BaseModelLoader):
|
| 1231 |
+
"""
|
| 1232 |
+
Model loader that can load GGUF files. This is useful for loading models
|
| 1233 |
+
that are quantized with GGUF and saved in the GGUF format. This loader
|
| 1234 |
+
supports loading both full models and sharded models.
|
| 1235 |
+
"""
|
| 1236 |
+
|
| 1237 |
+
def __init__(self, load_config: LoadConfig):
|
| 1238 |
+
super().__init__(load_config)
|
| 1239 |
+
if load_config.model_loader_extra_config:
|
| 1240 |
+
raise ValueError(f"Model loader extra config is not supported for "
|
| 1241 |
+
f"load format {load_config.load_format}")
|
| 1242 |
+
|
| 1243 |
+
def _prepare_weights(self, model_name_or_path: str):
|
| 1244 |
+
if os.path.isfile(model_name_or_path):
|
| 1245 |
+
return model_name_or_path
|
| 1246 |
+
else:
|
| 1247 |
+
raise ValueError(f"{model_name_or_path} is not a file.")
|
| 1248 |
+
|
| 1249 |
+
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
| 1250 |
+
"""
|
| 1251 |
+
GGUF uses this naming convention for their tensors from HF checkpoint:
|
| 1252 |
+
`blk.N.BB.weight` and `blk.N.BB.bias`
|
| 1253 |
+
where N signifies the block number of a layer, and BB signifies the
|
| 1254 |
+
attention/mlp layer components.
|
| 1255 |
+
See "Standardized tensor names" in
|
| 1256 |
+
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
| 1257 |
+
"""
|
| 1258 |
+
config = model_config.hf_config
|
| 1259 |
+
model_type = config.model_type
|
| 1260 |
+
# hack: ggufs have a different name than transformers
|
| 1261 |
+
if model_type == "cohere":
|
| 1262 |
+
model_type = "command-r"
|
| 1263 |
+
arch = None
|
| 1264 |
+
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
| 1265 |
+
if value == model_type:
|
| 1266 |
+
arch = key
|
| 1267 |
+
break
|
| 1268 |
+
if arch is None:
|
| 1269 |
+
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
| 1270 |
+
num_layers = config.num_hidden_layers
|
| 1271 |
+
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
| 1272 |
+
with torch.device("meta"):
|
| 1273 |
+
dummy_model = AutoModelForCausalLM.from_config(config)
|
| 1274 |
+
state_dict = dummy_model.state_dict()
|
| 1275 |
+
|
| 1276 |
+
gguf_to_hf_name_map = {}
|
| 1277 |
+
for hf_name in state_dict:
|
| 1278 |
+
name, suffix = hf_name.rsplit(".", 1)
|
| 1279 |
+
gguf_name = name_map.get_name(name)
|
| 1280 |
+
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
|
| 1281 |
+
return gguf_to_hf_name_map
|
| 1282 |
+
|
| 1283 |
+
def _get_weights_iterator(
|
| 1284 |
+
self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
|
| 1285 |
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
| 1286 |
+
return gguf_quant_weights_iterator(model_name_or_path,
|
| 1287 |
+
gguf_to_hf_name_map)
|
| 1288 |
+
|
| 1289 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 1290 |
+
self._prepare_weights(model_config.model)
|
| 1291 |
+
|
| 1292 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 1293 |
+
device_config = vllm_config.device_config
|
| 1294 |
+
model_config = vllm_config.model_config
|
| 1295 |
+
local_model_path = self._prepare_weights(model_config.model)
|
| 1296 |
+
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
| 1297 |
+
# we can only know if tie word embeddings after mapping weights
|
| 1298 |
+
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
| 1299 |
+
local_model_path, gguf_weights_map):
|
| 1300 |
+
model_config.hf_config.update({"tie_word_embeddings": True})
|
| 1301 |
+
|
| 1302 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 1303 |
+
with torch.device(device_config.device):
|
| 1304 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 1305 |
+
model.load_weights(
|
| 1306 |
+
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
| 1307 |
+
return model
|
| 1308 |
+
|
| 1309 |
+
|
| 1310 |
+
class RunaiModelStreamerLoader(BaseModelLoader):
|
| 1311 |
+
"""
|
| 1312 |
+
Model loader that can load safetensors
|
| 1313 |
+
files from local FS or S3 bucket.
|
| 1314 |
+
"""
|
| 1315 |
+
|
| 1316 |
+
def __init__(self, load_config: LoadConfig):
|
| 1317 |
+
super().__init__(load_config)
|
| 1318 |
+
if load_config.model_loader_extra_config:
|
| 1319 |
+
extra_config = load_config.model_loader_extra_config
|
| 1320 |
+
|
| 1321 |
+
if ("concurrency" in extra_config
|
| 1322 |
+
and isinstance(extra_config.get("concurrency"), int)):
|
| 1323 |
+
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
| 1324 |
+
extra_config.get("concurrency"))
|
| 1325 |
+
|
| 1326 |
+
if ("memory_limit" in extra_config
|
| 1327 |
+
and isinstance(extra_config.get("memory_limit"), int)):
|
| 1328 |
+
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
| 1329 |
+
extra_config.get("memory_limit"))
|
| 1330 |
+
|
| 1331 |
+
runai_streamer_s3_endpoint = os.getenv(
|
| 1332 |
+
'RUNAI_STREAMER_S3_ENDPOINT')
|
| 1333 |
+
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
|
| 1334 |
+
if (runai_streamer_s3_endpoint is None
|
| 1335 |
+
and aws_endpoint_url is not None):
|
| 1336 |
+
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
| 1337 |
+
|
| 1338 |
+
def _prepare_weights(self, model_name_or_path: str,
|
| 1339 |
+
revision: Optional[str]) -> List[str]:
|
| 1340 |
+
"""Prepare weights for the model.
|
| 1341 |
+
|
| 1342 |
+
If the model is not local, it will be downloaded."""
|
| 1343 |
+
is_s3_path = is_s3(model_name_or_path)
|
| 1344 |
+
is_local = os.path.isdir(model_name_or_path)
|
| 1345 |
+
safetensors_pattern = "*.safetensors"
|
| 1346 |
+
index_file = SAFE_WEIGHTS_INDEX_NAME
|
| 1347 |
+
|
| 1348 |
+
hf_folder = (model_name_or_path if
|
| 1349 |
+
(is_local or is_s3_path) else download_weights_from_hf(
|
| 1350 |
+
model_name_or_path,
|
| 1351 |
+
self.load_config.download_dir,
|
| 1352 |
+
[safetensors_pattern],
|
| 1353 |
+
revision,
|
| 1354 |
+
ignore_patterns=self.load_config.ignore_patterns,
|
| 1355 |
+
))
|
| 1356 |
+
|
| 1357 |
+
if is_s3_path:
|
| 1358 |
+
hf_weights_files = s3_glob(path=hf_folder,
|
| 1359 |
+
allow_pattern=[safetensors_pattern])
|
| 1360 |
+
else:
|
| 1361 |
+
hf_weights_files = glob.glob(
|
| 1362 |
+
os.path.join(hf_folder, safetensors_pattern))
|
| 1363 |
+
|
| 1364 |
+
if not is_local and not is_s3_path:
|
| 1365 |
+
download_safetensors_index_file_from_hf(
|
| 1366 |
+
model_name_or_path, index_file, self.load_config.download_dir,
|
| 1367 |
+
revision)
|
| 1368 |
+
|
| 1369 |
+
if not hf_weights_files:
|
| 1370 |
+
raise RuntimeError(
|
| 1371 |
+
f"Cannot find any safetensors model weights with "
|
| 1372 |
+
f"`{model_name_or_path}`")
|
| 1373 |
+
|
| 1374 |
+
return hf_weights_files
|
| 1375 |
+
|
| 1376 |
+
def _get_weights_iterator(
|
| 1377 |
+
self, model_or_path: str,
|
| 1378 |
+
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
| 1379 |
+
"""Get an iterator for the model weights based on the load format."""
|
| 1380 |
+
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
| 1381 |
+
return runai_safetensors_weights_iterator(hf_weights_files)
|
| 1382 |
+
|
| 1383 |
+
def download_model(self, model_config: ModelConfig) -> None:
|
| 1384 |
+
"""Download model if necessary"""
|
| 1385 |
+
self._prepare_weights(model_config.model, model_config.revision)
|
| 1386 |
+
|
| 1387 |
+
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
| 1388 |
+
"""Perform streaming of the model to destination"""
|
| 1389 |
+
device_config = vllm_config.device_config
|
| 1390 |
+
model_config = vllm_config.model_config
|
| 1391 |
+
|
| 1392 |
+
target_device = torch.device(device_config.device)
|
| 1393 |
+
with set_default_torch_dtype(model_config.dtype):
|
| 1394 |
+
with target_device:
|
| 1395 |
+
model = _initialize_model(vllm_config=vllm_config)
|
| 1396 |
+
|
| 1397 |
+
model_weights = model_config.model
|
| 1398 |
+
if hasattr(model_config, "model_weights"):
|
| 1399 |
+
model_weights = model_config.model_weights
|
| 1400 |
+
model.load_weights(
|
| 1401 |
+
self._get_weights_iterator(model_weights,
|
| 1402 |
+
model_config.revision))
|
| 1403 |
+
|
| 1404 |
+
for _, module in model.named_modules():
|
| 1405 |
+
quant_method = getattr(module, "quant_method", None)
|
| 1406 |
+
if quant_method is not None:
|
| 1407 |
+
with device_loading_context(module, target_device):
|
| 1408 |
+
quant_method.process_weights_after_loading(module)
|
| 1409 |
+
if isinstance(module, Attention) and \
|
| 1410 |
+
hasattr(module, "process_weights_after_loading"):
|
| 1411 |
+
# When attention modules need to process weights after
|
| 1412 |
+
# currently only used by MLA
|
| 1413 |
+
module.process_weights_after_loading(model_config.dtype)
|
| 1414 |
+
return model.eval()
|
| 1415 |
+
|
| 1416 |
+
|
| 1417 |
+
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
| 1418 |
+
"""Get a model loader based on the load format."""
|
| 1419 |
+
|
| 1420 |
+
if isinstance(load_config.load_format, type):
|
| 1421 |
+
return load_config.load_format(load_config)
|
| 1422 |
+
|
| 1423 |
+
if load_config.load_format == LoadFormat.DUMMY:
|
| 1424 |
+
return DummyModelLoader(load_config)
|
| 1425 |
+
|
| 1426 |
+
if load_config.load_format == LoadFormat.TENSORIZER:
|
| 1427 |
+
return TensorizerLoader(load_config)
|
| 1428 |
+
|
| 1429 |
+
if load_config.load_format == LoadFormat.SHARDED_STATE:
|
| 1430 |
+
return ShardedStateLoader(load_config)
|
| 1431 |
+
|
| 1432 |
+
if load_config.load_format == LoadFormat.BITSANDBYTES:
|
| 1433 |
+
return BitsAndBytesModelLoader(load_config)
|
| 1434 |
+
|
| 1435 |
+
if load_config.load_format == LoadFormat.GGUF:
|
| 1436 |
+
return GGUFModelLoader(load_config)
|
| 1437 |
+
|
| 1438 |
+
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
| 1439 |
+
return RunaiModelStreamerLoader(load_config)
|
| 1440 |
+
|
| 1441 |
+
return DefaultModelLoader(load_config)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/neuron.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Utilities for selecting and loading neuron models."""
|
| 3 |
+
import copy
|
| 4 |
+
import importlib
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from transformers import PretrainedConfig
|
| 11 |
+
|
| 12 |
+
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
| 13 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 14 |
+
from vllm.model_executor.layers.quantization import get_quantization_config
|
| 15 |
+
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
| 16 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 17 |
+
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
| 18 |
+
SequenceOutput)
|
| 19 |
+
|
| 20 |
+
TORCH_DTYPE_TO_NEURON_AMP = {
|
| 21 |
+
"auto": "f32",
|
| 22 |
+
"half": "f16",
|
| 23 |
+
"float16": "f16",
|
| 24 |
+
"bfloat16": "bf16",
|
| 25 |
+
"float": "f32",
|
| 26 |
+
"float32": "f32",
|
| 27 |
+
torch.float16: "f16",
|
| 28 |
+
torch.bfloat16: "bf16",
|
| 29 |
+
torch.float32: "f32",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Models supported by Neuron.
|
| 33 |
+
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
| 34 |
+
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
| 35 |
+
"LlamaForSampling", "LlamaForCausalLM"),
|
| 36 |
+
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
| 37 |
+
"MistralForSampling", "MistralForCausalLM")
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class NeuronCausalLM(nn.Module):
|
| 42 |
+
|
| 43 |
+
def __init__(self,
|
| 44 |
+
config: PretrainedConfig,
|
| 45 |
+
on_device_sampling_disabled: bool = False) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.config = config
|
| 48 |
+
self.logits_processor = LogitsProcessor(config.vocab_size,
|
| 49 |
+
logits_as_input=True)
|
| 50 |
+
|
| 51 |
+
self.on_device_sampling_disabled = on_device_sampling_disabled
|
| 52 |
+
if self.on_device_sampling_disabled:
|
| 53 |
+
# Use default sampler
|
| 54 |
+
self.sampler = Sampler()
|
| 55 |
+
|
| 56 |
+
# Lazy initialized
|
| 57 |
+
self.model: nn.Module
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self,
|
| 61 |
+
input_ids: torch.Tensor,
|
| 62 |
+
positions: torch.Tensor,
|
| 63 |
+
input_block_ids: torch.Tensor,
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
logits = self.model(input_ids,
|
| 66 |
+
cache_ids=positions,
|
| 67 |
+
start_ids=input_block_ids)
|
| 68 |
+
return logits
|
| 69 |
+
|
| 70 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
| 71 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
| 72 |
+
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
| 73 |
+
return logits
|
| 74 |
+
|
| 75 |
+
def sample(
|
| 76 |
+
self,
|
| 77 |
+
logits: torch.Tensor,
|
| 78 |
+
sampling_metadata: SamplingMetadata,
|
| 79 |
+
) -> Optional[SamplerOutput]:
|
| 80 |
+
|
| 81 |
+
if self.on_device_sampling_disabled:
|
| 82 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 83 |
+
return next_tokens
|
| 84 |
+
|
| 85 |
+
# On-device sampling outputs the token ids directly.
|
| 86 |
+
sampled_token_ids = logits.flatten()
|
| 87 |
+
next_tokens = []
|
| 88 |
+
sample_idx = 0
|
| 89 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 90 |
+
samples = []
|
| 91 |
+
for seq_id in seq_group.seq_ids:
|
| 92 |
+
token_id = sampled_token_ids[sample_idx].item()
|
| 93 |
+
samples.append(
|
| 94 |
+
SequenceOutput(parent_seq_id=seq_id,
|
| 95 |
+
output_token=token_id,
|
| 96 |
+
logprobs={token_id: Logprob(token_id)}))
|
| 97 |
+
sample_idx += 1
|
| 98 |
+
next_tokens.append(
|
| 99 |
+
CompletionSequenceGroupOutput(samples=samples,
|
| 100 |
+
prompt_logprobs=None))
|
| 101 |
+
|
| 102 |
+
return SamplerOutput(outputs=next_tokens)
|
| 103 |
+
|
| 104 |
+
def load_weights(self, model_name_or_path: str, **kwargs):
|
| 105 |
+
arch = _get_model_architecture(self.config)
|
| 106 |
+
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
| 107 |
+
_NEURON_SUPPORTED_MODELS[arch])
|
| 108 |
+
neuronx_module = importlib.import_module(neuronx_module_path)
|
| 109 |
+
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
| 110 |
+
|
| 111 |
+
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
|
| 112 |
+
**kwargs)
|
| 113 |
+
self.model.to_neuron()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_model_architecture(config: PretrainedConfig) -> str:
|
| 117 |
+
architectures = getattr(config, "architectures", [])
|
| 118 |
+
for arch in architectures:
|
| 119 |
+
if arch in _NEURON_SUPPORTED_MODELS:
|
| 120 |
+
return arch
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Model architectures {architectures} are not supported on Neuron "
|
| 123 |
+
f"for now. Supported architectures: "
|
| 124 |
+
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
|
| 128 |
+
env_value = os.getenv(env)
|
| 129 |
+
if env_value is None:
|
| 130 |
+
return default_value
|
| 131 |
+
buckets_remove_empty = filter(
|
| 132 |
+
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
|
| 133 |
+
buckets_int = map(int, buckets_remove_empty)
|
| 134 |
+
buckets_list = list(buckets_int)
|
| 135 |
+
return buckets_list
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _get_default_neuron_config(model_config: ModelConfig,
|
| 139 |
+
parallel_config: ParallelConfig,
|
| 140 |
+
scheduler_config: SchedulerConfig):
|
| 141 |
+
from transformers_neuronx.config import ContinuousBatchingConfig
|
| 142 |
+
from transformers_neuronx.constants import LAYOUT_BSH
|
| 143 |
+
|
| 144 |
+
continuous_batching_config = ContinuousBatchingConfig(
|
| 145 |
+
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
| 146 |
+
quant_config = dict(
|
| 147 |
+
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
| 148 |
+
quantize_method="vector_dynamic")
|
| 149 |
+
neuron_quantization_config_builder = lambda quant: get_quantization_config(
|
| 150 |
+
quant).from_config(quant_config).get_quant_method(None, "")
|
| 151 |
+
# TODO: Add Paged attention config to the default neuron arguments.
|
| 152 |
+
default_neuron_args = dict(
|
| 153 |
+
collectives_layout=LAYOUT_BSH,
|
| 154 |
+
attention_layout=LAYOUT_BSH,
|
| 155 |
+
fuse_qkv=True,
|
| 156 |
+
quant=neuron_quantization_config_builder(model_config.quantization)
|
| 157 |
+
if model_config.quantization else None,
|
| 158 |
+
continuous_batching=continuous_batching_config,
|
| 159 |
+
weight_tiling=bool(model_config.quantization),
|
| 160 |
+
on_device_generation=_get_neuron_on_device_generation_config(
|
| 161 |
+
model_config))
|
| 162 |
+
return default_neuron_args
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
|
| 166 |
+
if not _is_neuron_on_device_sampling_disabled(model_config):
|
| 167 |
+
return copy.deepcopy(model_config.neuron_sampling_params)
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
|
| 172 |
+
return not getattr(model_config, "neuron_sampling_params", None)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _get_neuron_config_after_override(default_neuron_config,
|
| 176 |
+
overridden_neuron_config):
|
| 177 |
+
from transformers_neuronx.config import NeuronConfig
|
| 178 |
+
overridden_neuron_config = overridden_neuron_config or {}
|
| 179 |
+
default_neuron_config.update(overridden_neuron_config)
|
| 180 |
+
return NeuronConfig(**default_neuron_config)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_neuron_model(model_config: ModelConfig,
|
| 184 |
+
parallel_config: ParallelConfig,
|
| 185 |
+
scheduler_config: SchedulerConfig) -> nn.Module:
|
| 186 |
+
|
| 187 |
+
# Create a model instance.
|
| 188 |
+
model = NeuronCausalLM(
|
| 189 |
+
model_config.hf_config,
|
| 190 |
+
_is_neuron_on_device_sampling_disabled(model_config))
|
| 191 |
+
|
| 192 |
+
default_neuron_config_args = _get_default_neuron_config(
|
| 193 |
+
model_config, parallel_config, scheduler_config)
|
| 194 |
+
|
| 195 |
+
neuron_config = _get_neuron_config_after_override(
|
| 196 |
+
default_neuron_config_args, model_config.override_neuron_config)
|
| 197 |
+
|
| 198 |
+
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
| 199 |
+
[scheduler_config.max_model_len])
|
| 200 |
+
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
| 201 |
+
[scheduler_config.max_model_len])
|
| 202 |
+
|
| 203 |
+
# Load the weights from the cached or downloaded files.
|
| 204 |
+
model.load_weights(model_config.model,
|
| 205 |
+
tp_degree=parallel_config.tensor_parallel_size,
|
| 206 |
+
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
| 207 |
+
neuron_config=neuron_config,
|
| 208 |
+
context_length_estimate=context_length_estimates,
|
| 209 |
+
n_positions=n_positions,
|
| 210 |
+
batch_size=scheduler_config.max_num_seqs)
|
| 211 |
+
|
| 212 |
+
return model.eval()
|
.venv/lib/python3.11/site-packages/vllm/model_executor/model_loader/utils.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Utilities for selecting and loading models."""
|
| 3 |
+
import contextlib
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import transformers
|
| 9 |
+
from torch import nn
|
| 10 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 11 |
+
|
| 12 |
+
from vllm.config import ModelConfig, ModelImpl
|
| 13 |
+
from vllm.logger import init_logger
|
| 14 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 15 |
+
QuantizationConfig)
|
| 16 |
+
from vllm.model_executor.models import ModelRegistry
|
| 17 |
+
from vllm.model_executor.models.adapters import (as_classification_model,
|
| 18 |
+
as_embedding_model,
|
| 19 |
+
as_reward_model)
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@contextlib.contextmanager
|
| 25 |
+
def set_default_torch_dtype(dtype: torch.dtype):
|
| 26 |
+
"""Sets the default torch dtype to the given dtype."""
|
| 27 |
+
old_dtype = torch.get_default_dtype()
|
| 28 |
+
torch.set_default_dtype(dtype)
|
| 29 |
+
yield
|
| 30 |
+
torch.set_default_dtype(old_dtype)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_transformers_impl_compatible(
|
| 34 |
+
arch: str,
|
| 35 |
+
module: Optional[transformers.PreTrainedModel] = None) -> bool:
|
| 36 |
+
mod = module or getattr(transformers, arch, None)
|
| 37 |
+
if mod is None:
|
| 38 |
+
return False
|
| 39 |
+
if hasattr(mod, "supports_backend"):
|
| 40 |
+
return mod.is_backend_compatible()
|
| 41 |
+
else:
|
| 42 |
+
return mod._supports_flex_attn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def resolve_transformers_fallback(model_config: ModelConfig,
|
| 46 |
+
architectures: list[str]):
|
| 47 |
+
for i, arch in enumerate(architectures):
|
| 48 |
+
if arch == "TransformersModel":
|
| 49 |
+
continue
|
| 50 |
+
custom_module = None
|
| 51 |
+
auto_map = getattr(model_config.hf_config, "auto_map", None)
|
| 52 |
+
if auto_map is not None and "AutoModel" in auto_map:
|
| 53 |
+
custom_module = get_class_from_dynamic_module(
|
| 54 |
+
model_config.hf_config.auto_map["AutoModel"],
|
| 55 |
+
model_config.model)
|
| 56 |
+
# TODO(Isotr0py): Further clean up these raises.
|
| 57 |
+
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
|
| 58 |
+
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
| 59 |
+
if not is_transformers_impl_compatible(arch, custom_module):
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"The Transformers implementation of {arch} is not "
|
| 62 |
+
"compatible with vLLM.")
|
| 63 |
+
architectures[i] = "TransformersModel"
|
| 64 |
+
if model_config.model_impl == ModelImpl.AUTO:
|
| 65 |
+
if not is_transformers_impl_compatible(arch, custom_module):
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"{arch} has no vLLM implementation and the Transformers "
|
| 68 |
+
"implementation is not compatible with vLLM.")
|
| 69 |
+
logger.warning(
|
| 70 |
+
"%s has no vLLM implementation, falling back to Transformers "
|
| 71 |
+
"implementation. Some features may not be supported and "
|
| 72 |
+
"performance may not be optimal.", arch)
|
| 73 |
+
architectures[i] = "TransformersModel"
|
| 74 |
+
return architectures
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_model_architecture(
|
| 78 |
+
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
| 79 |
+
architectures = getattr(model_config.hf_config, "architectures", [])
|
| 80 |
+
|
| 81 |
+
# Special handling for quantized Mixtral.
|
| 82 |
+
# FIXME(woosuk): This is a temporary hack.
|
| 83 |
+
mixtral_supported = [
|
| 84 |
+
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
if (model_config.quantization is not None
|
| 88 |
+
and model_config.quantization not in mixtral_supported
|
| 89 |
+
and "MixtralForCausalLM" in architectures):
|
| 90 |
+
architectures = ["QuantMixtralForCausalLM"]
|
| 91 |
+
|
| 92 |
+
vllm_supported_archs = ModelRegistry.get_supported_archs()
|
| 93 |
+
is_vllm_supported = any(arch in vllm_supported_archs
|
| 94 |
+
for arch in architectures)
|
| 95 |
+
if (not is_vllm_supported
|
| 96 |
+
or model_config.model_impl == ModelImpl.TRANSFORMERS):
|
| 97 |
+
architectures = resolve_transformers_fallback(model_config,
|
| 98 |
+
architectures)
|
| 99 |
+
|
| 100 |
+
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
| 101 |
+
if model_config.task == "embed":
|
| 102 |
+
model_cls = as_embedding_model(model_cls)
|
| 103 |
+
elif model_config.task == "classify":
|
| 104 |
+
model_cls = as_classification_model(model_cls)
|
| 105 |
+
elif model_config.task == "reward":
|
| 106 |
+
model_cls = as_reward_model(model_cls)
|
| 107 |
+
|
| 108 |
+
return model_cls, arch
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
| 112 |
+
return get_model_architecture(model_config)[1]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class ParamMapping:
|
| 117 |
+
"""
|
| 118 |
+
A class to handle parameter mapping for model weight loading.
|
| 119 |
+
It creates a bidirectional mapping between packed parameters and their
|
| 120 |
+
constituent parts.
|
| 121 |
+
"""
|
| 122 |
+
packed_mapping: Dict[str, List[str]]
|
| 123 |
+
inverse_packed_mapping: Dict[str, Tuple[str,
|
| 124 |
+
int]] = field(default_factory=dict)
|
| 125 |
+
|
| 126 |
+
def __post_init__(self):
|
| 127 |
+
for packed_name, sub_params in self.packed_mapping.items():
|
| 128 |
+
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
| 129 |
+
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
| 130 |
+
continue
|
| 131 |
+
for index, param_name in enumerate(sub_params):
|
| 132 |
+
self.inverse_packed_mapping[param_name] = (
|
| 133 |
+
packed_name,
|
| 134 |
+
index,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def get_sub_modules(self,
|
| 138 |
+
module_name: str) -> Optional[Tuple[str, List[str]]]:
|
| 139 |
+
for key, value in self.packed_mapping.items():
|
| 140 |
+
if module_name.endswith(key):
|
| 141 |
+
return key, value
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def configure_quant_config(quant_config: QuantizationConfig,
|
| 146 |
+
model_class: Type[nn.Module]):
|
| 147 |
+
"""
|
| 148 |
+
Pass packed_modules_mapping by reference to quant_config so that
|
| 149 |
+
quant_config can properly match fused modules
|
| 150 |
+
|
| 151 |
+
Note that model attributes are passed by reference to quant_config,
|
| 152 |
+
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
| 153 |
+
"""
|
| 154 |
+
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
| 155 |
+
if packed_mapping is not None:
|
| 156 |
+
# pass packed_modules_mapping by reference to quant_config
|
| 157 |
+
quant_config.packed_modules_mapping = packed_mapping
|
| 158 |
+
else:
|
| 159 |
+
logger.warning(
|
| 160 |
+
"The model class %s has not defined `packed_modules_mapping`, "
|
| 161 |
+
"this may lead to incorrect mapping of quantized or ignored "
|
| 162 |
+
"modules", model_class.__name__)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/arctic.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Inference-only Snowflake Arctic model."""
|
| 3 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 9 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 10 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 11 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 12 |
+
get_tensor_model_parallel_world_size,
|
| 13 |
+
tensor_model_parallel_all_reduce)
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 16 |
+
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
| 17 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 18 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 19 |
+
QKVParallelLinear,
|
| 20 |
+
ReplicatedLinear,
|
| 21 |
+
RowParallelLinear)
|
| 22 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 23 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 24 |
+
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
| 25 |
+
DeepSpeedFPConfig, DeepSpeedFPParameter)
|
| 26 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 27 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 28 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 29 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 30 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 31 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 32 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 33 |
+
from vllm.sequence import IntermediateTensors
|
| 34 |
+
from vllm.transformers_utils.configs.arctic import ArcticConfig
|
| 35 |
+
|
| 36 |
+
from .interfaces import SupportsPP
|
| 37 |
+
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
| 38 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 39 |
+
maybe_prefix)
|
| 40 |
+
|
| 41 |
+
logger = init_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ArcticMLP(nn.Module):
|
| 45 |
+
|
| 46 |
+
def __init__(self,
|
| 47 |
+
config: ArcticConfig,
|
| 48 |
+
expert_id: int = -1,
|
| 49 |
+
is_residual_mlp: bool = False,
|
| 50 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 51 |
+
reduce_results: bool = True,
|
| 52 |
+
prefix: str = ""):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.hidden_size = config.hidden_size
|
| 55 |
+
self.expert_id = expert_id
|
| 56 |
+
|
| 57 |
+
self.ffn_dim = config.intermediate_size if not is_residual_mlp \
|
| 58 |
+
else self.hidden_size
|
| 59 |
+
|
| 60 |
+
self.w13 = MergedColumnParallelLinear(self.hidden_size,
|
| 61 |
+
[self.ffn_dim] * 2,
|
| 62 |
+
bias=False,
|
| 63 |
+
quant_config=quant_config)
|
| 64 |
+
self.w2 = RowParallelLinear(self.ffn_dim,
|
| 65 |
+
self.hidden_size,
|
| 66 |
+
bias=False,
|
| 67 |
+
reduce_results=reduce_results,
|
| 68 |
+
quant_config=quant_config)
|
| 69 |
+
if config.hidden_act != "silu":
|
| 70 |
+
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
| 71 |
+
"Only silu is supported for now.")
|
| 72 |
+
self.act_fn = SiluAndMul()
|
| 73 |
+
|
| 74 |
+
def forward(self, hidden_states):
|
| 75 |
+
gate_up, _ = self.w13(hidden_states)
|
| 76 |
+
hidden_states = self.act_fn(gate_up)
|
| 77 |
+
hidden_states, _ = self.w2(hidden_states)
|
| 78 |
+
return hidden_states
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ArcticMoE(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Model-parallel implementation of Arctic MoE Layer.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self,
|
| 87 |
+
config: ArcticConfig,
|
| 88 |
+
tp_size: Optional[int] = None,
|
| 89 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 90 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 91 |
+
reduce_results: bool = True,
|
| 92 |
+
prefix: str = ""):
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
layer_id = extract_layer_index(prefix)
|
| 96 |
+
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
| 97 |
+
self.hidden_size = config.hidden_size
|
| 98 |
+
self.num_experts = config.num_local_experts
|
| 99 |
+
self.layer_id = layer_id
|
| 100 |
+
self.top_k = config.num_experts_per_tok
|
| 101 |
+
self.intermediate_size = config.intermediate_size // self.tp_size
|
| 102 |
+
|
| 103 |
+
self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
|
| 104 |
+
self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
|
| 105 |
+
self.reduce_results = reduce_results
|
| 106 |
+
# Some other parameters
|
| 107 |
+
if params_dtype is None:
|
| 108 |
+
params_dtype = torch.get_default_dtype()
|
| 109 |
+
self.params_dtype = params_dtype
|
| 110 |
+
|
| 111 |
+
if not self.is_moe_layer:
|
| 112 |
+
self.mlp = ArcticMLP(config,
|
| 113 |
+
quant_config=quant_config,
|
| 114 |
+
reduce_results=reduce_results,
|
| 115 |
+
prefix=f"{prefix}.mlp")
|
| 116 |
+
else:
|
| 117 |
+
self.gate = ReplicatedLinear(self.hidden_size,
|
| 118 |
+
self.num_experts,
|
| 119 |
+
bias=False,
|
| 120 |
+
params_dtype=self.params_dtype,
|
| 121 |
+
quant_config=quant_config,
|
| 122 |
+
prefix=f"{prefix}.gate")
|
| 123 |
+
if self.is_quant:
|
| 124 |
+
self.ws = DeepSpeedFPParameter(
|
| 125 |
+
torch.Size((self.num_experts, 2 * self.intermediate_size,
|
| 126 |
+
self.hidden_size)),
|
| 127 |
+
params_dtype=params_dtype,
|
| 128 |
+
quant_config=quant_config,
|
| 129 |
+
)
|
| 130 |
+
self.w2s = DeepSpeedFPParameter(
|
| 131 |
+
torch.Size((self.num_experts, self.hidden_size,
|
| 132 |
+
self.intermediate_size)),
|
| 133 |
+
params_dtype=params_dtype,
|
| 134 |
+
quant_config=quant_config,
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
self.ws = nn.Parameter(
|
| 138 |
+
torch.empty(self.num_experts,
|
| 139 |
+
2 * self.intermediate_size,
|
| 140 |
+
self.hidden_size,
|
| 141 |
+
device="cuda",
|
| 142 |
+
dtype=self.params_dtype))
|
| 143 |
+
self.w2s = nn.Parameter(
|
| 144 |
+
torch.empty(self.num_experts,
|
| 145 |
+
self.hidden_size,
|
| 146 |
+
self.intermediate_size,
|
| 147 |
+
device="cuda",
|
| 148 |
+
dtype=self.params_dtype))
|
| 149 |
+
set_weight_attrs(self.ws, {
|
| 150 |
+
"weight_loader": self.weight_loader,
|
| 151 |
+
})
|
| 152 |
+
set_weight_attrs(self.w2s, {
|
| 153 |
+
"weight_loader": self.weight_loader,
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
| 157 |
+
weight_name: str, expert_id: int):
|
| 158 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 159 |
+
param_data = param.ds_dequantize() if self.is_quant else param.data
|
| 160 |
+
shard_size = self.intermediate_size
|
| 161 |
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
| 162 |
+
if weight_name.endswith("w1.weight"):
|
| 163 |
+
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
| 164 |
+
if weight_name.endswith("w3.weight"):
|
| 165 |
+
param_data[expert_id,
|
| 166 |
+
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
| 167 |
+
if weight_name.endswith("w2.weight"):
|
| 168 |
+
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
| 169 |
+
if self.is_quant:
|
| 170 |
+
param.ds_quantize_(param_data)
|
| 171 |
+
|
| 172 |
+
def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
num_tokens, hidden_size = hidden_states.shape
|
| 174 |
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
| 175 |
+
# router_logits: (num_tokens, n_experts)
|
| 176 |
+
router_logits, _ = self.gate(hidden_states)
|
| 177 |
+
do_normalize = self.top_k > 1
|
| 178 |
+
topk_weights, topk_ids = fused_topk(hidden_states,
|
| 179 |
+
router_logits,
|
| 180 |
+
self.top_k,
|
| 181 |
+
renormalize=do_normalize)
|
| 182 |
+
# topk_ids: (num_tokens, k)
|
| 183 |
+
if self.is_quant:
|
| 184 |
+
if 2 * num_tokens <= self.num_experts:
|
| 185 |
+
# If much fewer tokens than experts, use selective dequantize.
|
| 186 |
+
ws_dequantized = self.ws.ds_selective_dequantize(
|
| 187 |
+
topk_ids.flatten())
|
| 188 |
+
w2s_dequantized = self.w2s.ds_selective_dequantize(
|
| 189 |
+
topk_ids.flatten())
|
| 190 |
+
# We gathered the experts to the tokens so update the mapping.
|
| 191 |
+
topk_ids = torch.arange(
|
| 192 |
+
0,
|
| 193 |
+
topk_ids.numel(),
|
| 194 |
+
device=topk_ids.device,
|
| 195 |
+
).reshape(topk_ids.shape)
|
| 196 |
+
else:
|
| 197 |
+
ws_dequantized = self.ws.ds_dequantize()
|
| 198 |
+
w2s_dequantized = self.w2s.ds_dequantize()
|
| 199 |
+
|
| 200 |
+
final_hidden_states = fused_experts(
|
| 201 |
+
hidden_states,
|
| 202 |
+
ws_dequantized if self.is_quant else self.ws,
|
| 203 |
+
w2s_dequantized if self.is_quant else self.w2s,
|
| 204 |
+
topk_weights,
|
| 205 |
+
topk_ids,
|
| 206 |
+
inplace=True)
|
| 207 |
+
if self.reduce_results and self.tp_size > 1:
|
| 208 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 209 |
+
final_hidden_states)
|
| 210 |
+
return final_hidden_states.view(num_tokens, hidden_size)
|
| 211 |
+
|
| 212 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 213 |
+
if self.is_moe_layer:
|
| 214 |
+
final_hidden_states = self.local_moe_fused(hidden_states)
|
| 215 |
+
else:
|
| 216 |
+
final_hidden_states = self.mlp(hidden_states)
|
| 217 |
+
return final_hidden_states
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ArcticAttention(nn.Module):
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
config: ArcticConfig,
|
| 225 |
+
cache_config: Optional[CacheConfig] = None,
|
| 226 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 227 |
+
prefix: str = "",
|
| 228 |
+
):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.config = config
|
| 231 |
+
self.hidden_size = config.hidden_size
|
| 232 |
+
|
| 233 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 234 |
+
self.total_num_heads = config.num_attention_heads
|
| 235 |
+
assert self.total_num_heads % tp_size == 0
|
| 236 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 237 |
+
self.total_num_kv_heads = config.num_key_value_heads
|
| 238 |
+
if self.total_num_kv_heads >= tp_size:
|
| 239 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 240 |
+
else:
|
| 241 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 242 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 243 |
+
self.head_dim = self.hidden_size // self.total_num_heads
|
| 244 |
+
self.q_size = self.num_heads * self.head_dim
|
| 245 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 246 |
+
|
| 247 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 248 |
+
self.rope_theta = config.rope_theta
|
| 249 |
+
self.scaling = self.head_dim**-0.5
|
| 250 |
+
|
| 251 |
+
self.qkv_proj = QKVParallelLinear(self.hidden_size,
|
| 252 |
+
self.head_dim,
|
| 253 |
+
self.total_num_heads,
|
| 254 |
+
self.total_num_kv_heads,
|
| 255 |
+
bias=False,
|
| 256 |
+
quant_config=quant_config)
|
| 257 |
+
self.o_proj = RowParallelLinear(
|
| 258 |
+
self.total_num_heads * self.head_dim,
|
| 259 |
+
self.hidden_size,
|
| 260 |
+
bias=False,
|
| 261 |
+
reduce_results=True,
|
| 262 |
+
quant_config=quant_config,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.rotary_emb = get_rope(
|
| 266 |
+
self.head_dim,
|
| 267 |
+
rotary_dim=self.head_dim,
|
| 268 |
+
max_position=self.max_position_embeddings,
|
| 269 |
+
base=int(self.rope_theta),
|
| 270 |
+
is_neox_style=True,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
self.attn = Attention(self.num_heads,
|
| 274 |
+
self.head_dim,
|
| 275 |
+
self.scaling,
|
| 276 |
+
num_kv_heads=self.num_kv_heads,
|
| 277 |
+
cache_config=cache_config,
|
| 278 |
+
quant_config=quant_config,
|
| 279 |
+
prefix=f"{prefix}.attn")
|
| 280 |
+
|
| 281 |
+
def forward(
|
| 282 |
+
self,
|
| 283 |
+
positions: torch.Tensor,
|
| 284 |
+
hidden_states: torch.Tensor,
|
| 285 |
+
kv_cache: torch.Tensor,
|
| 286 |
+
attn_metadata: AttentionMetadata,
|
| 287 |
+
) -> torch.Tensor:
|
| 288 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 289 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 290 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 291 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 292 |
+
output, _ = self.o_proj(attn_output)
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class ArcticDecoderLayer(nn.Module):
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
config: ArcticConfig,
|
| 301 |
+
cache_config: Optional[CacheConfig] = None,
|
| 302 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 303 |
+
prefix: str = "",
|
| 304 |
+
) -> None:
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.hidden_size = config.hidden_size
|
| 307 |
+
layer_idx = extract_layer_index(prefix)
|
| 308 |
+
is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
|
| 309 |
+
self.use_residual = config.use_residual and is_moe_layer
|
| 310 |
+
self.self_attn = ArcticAttention(config,
|
| 311 |
+
cache_config,
|
| 312 |
+
quant_config=quant_config,
|
| 313 |
+
prefix=f"{prefix}.self_attn")
|
| 314 |
+
self.block_sparse_moe = ArcticMoE(
|
| 315 |
+
config,
|
| 316 |
+
quant_config=quant_config,
|
| 317 |
+
reduce_results=(not self.use_residual),
|
| 318 |
+
prefix=f"{prefix}.block_sparse_moe",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 322 |
+
eps=config.rms_norm_eps)
|
| 323 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 324 |
+
eps=config.rms_norm_eps)
|
| 325 |
+
|
| 326 |
+
if self.use_residual:
|
| 327 |
+
self.residual_layernorm = RMSNorm(config.hidden_size,
|
| 328 |
+
eps=config.rms_norm_eps)
|
| 329 |
+
self.residual_mlp = ArcticMLP(config,
|
| 330 |
+
is_residual_mlp=True,
|
| 331 |
+
reduce_results=False,
|
| 332 |
+
prefix=f"{prefix}.residual_mlp")
|
| 333 |
+
|
| 334 |
+
def forward(
|
| 335 |
+
self,
|
| 336 |
+
positions: torch.Tensor,
|
| 337 |
+
hidden_states: torch.Tensor,
|
| 338 |
+
kv_cache: torch.Tensor,
|
| 339 |
+
attn_metadata: AttentionMetadata,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
residual_input = hidden_states
|
| 342 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 343 |
+
hidden_states = self.self_attn(
|
| 344 |
+
positions=positions,
|
| 345 |
+
hidden_states=hidden_states,
|
| 346 |
+
kv_cache=kv_cache,
|
| 347 |
+
attn_metadata=attn_metadata,
|
| 348 |
+
)
|
| 349 |
+
hidden_states = residual_input + hidden_states
|
| 350 |
+
|
| 351 |
+
residual_attn = hidden_states
|
| 352 |
+
if self.use_residual:
|
| 353 |
+
hidden_states = self.residual_layernorm(hidden_states)
|
| 354 |
+
hidden_states = self.residual_mlp(hidden_states)
|
| 355 |
+
residual_mlp = hidden_states
|
| 356 |
+
hidden_states = self.post_attention_layernorm(residual_input)
|
| 357 |
+
hidden_states = self.block_sparse_moe(hidden_states)
|
| 358 |
+
hidden_states = residual_mlp + hidden_states
|
| 359 |
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
| 360 |
+
hidden_states = residual_attn + hidden_states
|
| 361 |
+
else:
|
| 362 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 363 |
+
hidden_states = self.block_sparse_moe(hidden_states)
|
| 364 |
+
hidden_states = residual_attn + hidden_states
|
| 365 |
+
return hidden_states
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@support_torch_compile
|
| 369 |
+
class ArcticModel(nn.Module):
|
| 370 |
+
|
| 371 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 372 |
+
super().__init__()
|
| 373 |
+
|
| 374 |
+
config = vllm_config.model_config.hf_config
|
| 375 |
+
cache_config = vllm_config.cache_config
|
| 376 |
+
quant_config = vllm_config.quant_config
|
| 377 |
+
|
| 378 |
+
self.padding_idx = config.pad_token_id
|
| 379 |
+
self.vocab_size = config.vocab_size
|
| 380 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 381 |
+
self.vocab_size,
|
| 382 |
+
config.hidden_size,
|
| 383 |
+
org_num_embeddings=self.vocab_size)
|
| 384 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 385 |
+
config.num_hidden_layers,
|
| 386 |
+
lambda prefix: ArcticDecoderLayer(
|
| 387 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 388 |
+
prefix=f"{prefix}.layers")
|
| 389 |
+
self._attn_implementation = config._attn_implementation
|
| 390 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 391 |
+
self.make_empty_intermediate_tensors = (
|
| 392 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 393 |
+
config.hidden_size))
|
| 394 |
+
|
| 395 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 396 |
+
return self.embed_tokens(input_ids)
|
| 397 |
+
|
| 398 |
+
def forward(
|
| 399 |
+
self,
|
| 400 |
+
input_ids: torch.Tensor,
|
| 401 |
+
positions: torch.Tensor,
|
| 402 |
+
kv_caches: List[torch.Tensor],
|
| 403 |
+
attn_metadata: AttentionMetadata,
|
| 404 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 405 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 406 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 407 |
+
if get_pp_group().is_first_rank:
|
| 408 |
+
if inputs_embeds is not None:
|
| 409 |
+
hidden_states = inputs_embeds
|
| 410 |
+
else:
|
| 411 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 412 |
+
else:
|
| 413 |
+
assert intermediate_tensors is not None
|
| 414 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 415 |
+
for i in range(self.start_layer, self.end_layer):
|
| 416 |
+
layer = self.layers[i]
|
| 417 |
+
hidden_states = layer(positions, hidden_states,
|
| 418 |
+
kv_caches[i - self.start_layer],
|
| 419 |
+
attn_metadata)
|
| 420 |
+
if not get_pp_group().is_last_rank:
|
| 421 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 422 |
+
hidden_states = self.norm(hidden_states)
|
| 423 |
+
return hidden_states
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class ArcticForCausalLM(nn.Module, SupportsPP):
|
| 427 |
+
|
| 428 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 429 |
+
super().__init__()
|
| 430 |
+
config = vllm_config.model_config.hf_config
|
| 431 |
+
quant_config = vllm_config.quant_config
|
| 432 |
+
self.config = config
|
| 433 |
+
self.model = ArcticModel(vllm_config=vllm_config,
|
| 434 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 435 |
+
self.vocab_size = config.vocab_size
|
| 436 |
+
self.lm_head = ParallelLMHead(
|
| 437 |
+
self.vocab_size,
|
| 438 |
+
config.hidden_size,
|
| 439 |
+
quant_config=quant_config,
|
| 440 |
+
)
|
| 441 |
+
if self.config.tie_word_embeddings:
|
| 442 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 443 |
+
self.num_experts = config.num_local_experts
|
| 444 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 445 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 446 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 447 |
+
config.vocab_size)
|
| 448 |
+
self.sampler = get_sampler()
|
| 449 |
+
self.make_empty_intermediate_tensors = (
|
| 450 |
+
self.model.make_empty_intermediate_tensors)
|
| 451 |
+
|
| 452 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 453 |
+
return self.model.get_input_embeddings(input_ids)
|
| 454 |
+
|
| 455 |
+
def forward(
|
| 456 |
+
self,
|
| 457 |
+
input_ids: torch.Tensor,
|
| 458 |
+
positions: torch.Tensor,
|
| 459 |
+
kv_caches: List[torch.Tensor],
|
| 460 |
+
attn_metadata: AttentionMetadata,
|
| 461 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 462 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 463 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 464 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 465 |
+
attn_metadata, intermediate_tensors,
|
| 466 |
+
inputs_embeds)
|
| 467 |
+
return hidden_states
|
| 468 |
+
|
| 469 |
+
def compute_logits(
|
| 470 |
+
self,
|
| 471 |
+
hidden_states: torch.Tensor,
|
| 472 |
+
sampling_metadata: SamplingMetadata,
|
| 473 |
+
) -> Optional[torch.Tensor]:
|
| 474 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 475 |
+
sampling_metadata)
|
| 476 |
+
return logits
|
| 477 |
+
|
| 478 |
+
def sample(
|
| 479 |
+
self,
|
| 480 |
+
logits: Optional[torch.Tensor],
|
| 481 |
+
sampling_metadata: SamplingMetadata,
|
| 482 |
+
) -> Optional[SamplerOutput]:
|
| 483 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 484 |
+
return next_tokens
|
| 485 |
+
|
| 486 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 487 |
+
torch.Tensor]]) -> Set[str]:
|
| 488 |
+
stacked_params_mapping = [
|
| 489 |
+
# (param_name, shard_name, shard_id)
|
| 490 |
+
("qkv_proj", "q_proj", "q"),
|
| 491 |
+
("qkv_proj", "k_proj", "k"),
|
| 492 |
+
("qkv_proj", "v_proj", "v"),
|
| 493 |
+
]
|
| 494 |
+
|
| 495 |
+
mlp_params_mapping: List[Tuple[str, str, int]] = []
|
| 496 |
+
expert_params_mapping: List[Tuple[str, str, int]] = []
|
| 497 |
+
num_layers = self.config.num_hidden_layers
|
| 498 |
+
|
| 499 |
+
for layer in range(num_layers):
|
| 500 |
+
mlp_params_mapping.append(
|
| 501 |
+
(f"layers.{layer}.residual_mlp.w13.weight",
|
| 502 |
+
f"layers.{layer}.residual_mlp.w1.weight", 0))
|
| 503 |
+
mlp_params_mapping.append(
|
| 504 |
+
(f"layers.{layer}.residual_mlp.w13.weight",
|
| 505 |
+
f"layers.{layer}.residual_mlp.w3.weight", 1))
|
| 506 |
+
if layer % 2 == 0:
|
| 507 |
+
# MLP layers
|
| 508 |
+
mlp_params_mapping.append(
|
| 509 |
+
(f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
| 510 |
+
f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0))
|
| 511 |
+
mlp_params_mapping.append(
|
| 512 |
+
(f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
| 513 |
+
f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1))
|
| 514 |
+
else:
|
| 515 |
+
# MoE layers
|
| 516 |
+
for expert_id in range(self.config.num_local_experts):
|
| 517 |
+
expert_params_mapping.append(
|
| 518 |
+
("ws", f"experts.{expert_id}.w1.weight", expert_id))
|
| 519 |
+
expert_params_mapping.append(
|
| 520 |
+
("w2s", f"experts.{expert_id}.w2.weight", expert_id))
|
| 521 |
+
expert_params_mapping.append(
|
| 522 |
+
("ws", f"experts.{expert_id}.w3.weight", expert_id))
|
| 523 |
+
|
| 524 |
+
params_dict = dict(self.named_parameters())
|
| 525 |
+
loaded_params: Set[str] = set()
|
| 526 |
+
|
| 527 |
+
logger.info(
|
| 528 |
+
"It will take ~10 minutes loading from the 16-bit weights. "
|
| 529 |
+
"Alternatively, use the prequantized 8-bit weights of arctic "
|
| 530 |
+
"and set load-format to `sharded_state` will accelerate loading.")
|
| 531 |
+
for name, loaded_weight in weights:
|
| 532 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 533 |
+
if weight_name not in name:
|
| 534 |
+
continue
|
| 535 |
+
name = name.replace(weight_name, param_name)
|
| 536 |
+
# Skip loading extra bias for GPTQ models.
|
| 537 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 538 |
+
continue
|
| 539 |
+
if is_pp_missing_parameter(name, self):
|
| 540 |
+
continue
|
| 541 |
+
param = params_dict[name]
|
| 542 |
+
weight_loader = param.weight_loader
|
| 543 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 544 |
+
break
|
| 545 |
+
else:
|
| 546 |
+
for param_name, weight_name, shard_id in mlp_params_mapping:
|
| 547 |
+
if weight_name not in name:
|
| 548 |
+
continue
|
| 549 |
+
name = name.replace(weight_name, param_name)
|
| 550 |
+
if is_pp_missing_parameter(name, self):
|
| 551 |
+
continue
|
| 552 |
+
param = params_dict[name]
|
| 553 |
+
weight_loader = param.weight_loader
|
| 554 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 555 |
+
break
|
| 556 |
+
else:
|
| 557 |
+
for param_name, weight_name, shard_id \
|
| 558 |
+
in expert_params_mapping:
|
| 559 |
+
if weight_name not in name:
|
| 560 |
+
continue
|
| 561 |
+
name = name.replace(weight_name, param_name)
|
| 562 |
+
if is_pp_missing_parameter(name, self):
|
| 563 |
+
continue
|
| 564 |
+
param = params_dict[name]
|
| 565 |
+
weight_loader = param.weight_loader
|
| 566 |
+
weight_loader(param,
|
| 567 |
+
loaded_weight,
|
| 568 |
+
weight_name,
|
| 569 |
+
expert_id=shard_id)
|
| 570 |
+
break
|
| 571 |
+
else:
|
| 572 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 573 |
+
continue
|
| 574 |
+
if is_pp_missing_parameter(name, self):
|
| 575 |
+
continue
|
| 576 |
+
param = params_dict[name]
|
| 577 |
+
|
| 578 |
+
weight_loader = getattr(param, "weight_loader",
|
| 579 |
+
default_weight_loader)
|
| 580 |
+
weight_loader(param, loaded_weight)
|
| 581 |
+
loaded_params.add(name)
|
| 582 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/bart.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Derived from BART implementation posted on HuggingFace; license below:
|
| 4 |
+
#
|
| 5 |
+
# coding=utf-8
|
| 6 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
|
| 7 |
+
# All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch BART model."""
|
| 21 |
+
import math
|
| 22 |
+
from typing import Iterable, List, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from transformers import BartConfig
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
| 30 |
+
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
| 31 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 32 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 33 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 34 |
+
QKVParallelLinear,
|
| 35 |
+
RowParallelLinear)
|
| 36 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 37 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 38 |
+
QuantizationConfig)
|
| 39 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 40 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 41 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 42 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 43 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 44 |
+
from vllm.sequence import IntermediateTensors
|
| 45 |
+
|
| 46 |
+
from .utils import maybe_prefix
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_bsz_seq_len(input_ids):
|
| 52 |
+
shp = input_ids.shape
|
| 53 |
+
ndim = len(shp)
|
| 54 |
+
if ndim == 1:
|
| 55 |
+
return 1, input_ids.numel()
|
| 56 |
+
else:
|
| 57 |
+
return shp[:2]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
|
| 61 |
+
"""
|
| 62 |
+
This module learns positional embeddings up to a fixed maximum size.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
| 66 |
+
# Bart is set up so that if padding_idx is
|
| 67 |
+
# specified then offset the embedding ids by 2
|
| 68 |
+
# and adjust num_embeddings appropriately.
|
| 69 |
+
# Other models don't have this hack
|
| 70 |
+
self.offset = 2
|
| 71 |
+
super().__init__(num_embeddings + self.offset, embedding_dim)
|
| 72 |
+
|
| 73 |
+
def forward(
|
| 74 |
+
self,
|
| 75 |
+
positions: torch.Tensor,
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
| 78 |
+
return super().forward(positions + self.offset)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class BartScaledWordEmbedding(VocabParallelEmbedding):
|
| 82 |
+
"""
|
| 83 |
+
This module overrides VocabParallelEmbedding's
|
| 84 |
+
forward by multiplying with embeddings scale.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self,
|
| 88 |
+
num_embeddings: int,
|
| 89 |
+
embedding_dim: int,
|
| 90 |
+
embed_scale: float = 1.0):
|
| 91 |
+
super().__init__(num_embeddings, embedding_dim)
|
| 92 |
+
self.embed_scale = embed_scale
|
| 93 |
+
|
| 94 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
return super().forward(input_ids) * self.embed_scale
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class BartParallelLMHead(ParallelLMHead):
|
| 99 |
+
"""
|
| 100 |
+
This module overrides ParallelLMHead's
|
| 101 |
+
forward by dividing by embeddings scale,
|
| 102 |
+
yielding effectively the inverse of
|
| 103 |
+
BartScaledWordEmbedding
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self,
|
| 107 |
+
num_embeddings: int,
|
| 108 |
+
embedding_dim: int,
|
| 109 |
+
embed_scale: float = 1.0):
|
| 110 |
+
super().__init__(num_embeddings, embedding_dim)
|
| 111 |
+
self.embed_scale = embed_scale
|
| 112 |
+
|
| 113 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
return super().forward(input_ids) / self.embed_scale
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class BartEncoderAttention(nn.Module):
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
embed_dim: int,
|
| 122 |
+
num_heads: int,
|
| 123 |
+
bias: bool = True,
|
| 124 |
+
config: Optional[BartConfig] = None,
|
| 125 |
+
cache_config: Optional[CacheConfig] = None,
|
| 126 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 127 |
+
prefix: str = "",
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.d_model = config.d_model
|
| 131 |
+
self.embed_dim = embed_dim
|
| 132 |
+
self.total_num_heads = num_heads
|
| 133 |
+
self.total_num_kv_heads = self.total_num_heads
|
| 134 |
+
self.head_dim = embed_dim // num_heads
|
| 135 |
+
self.config = config
|
| 136 |
+
|
| 137 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 138 |
+
raise ValueError(f"embed_dim must be divisible by num_heads "
|
| 139 |
+
f"(got `embed_dim`: {self.embed_dim}"
|
| 140 |
+
f" and `num_heads`: {num_heads}).")
|
| 141 |
+
self.scaling = self.head_dim**-0.5
|
| 142 |
+
|
| 143 |
+
self.qkv_proj = QKVParallelLinear(
|
| 144 |
+
self.d_model,
|
| 145 |
+
self.d_model // self.total_num_heads,
|
| 146 |
+
self.total_num_heads,
|
| 147 |
+
self.total_num_kv_heads,
|
| 148 |
+
bias=bias,
|
| 149 |
+
quant_config=quant_config,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.out_proj = RowParallelLinear(
|
| 153 |
+
embed_dim,
|
| 154 |
+
embed_dim,
|
| 155 |
+
bias=bias,
|
| 156 |
+
quant_config=quant_config,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
tp_world_size = get_tensor_model_parallel_world_size()
|
| 160 |
+
assert self.total_num_heads % tp_world_size == 0
|
| 161 |
+
self.num_heads = self.total_num_heads // tp_world_size
|
| 162 |
+
|
| 163 |
+
if self.total_num_kv_heads >= tp_world_size:
|
| 164 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 165 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 166 |
+
assert self.total_num_kv_heads % tp_world_size == 0
|
| 167 |
+
else:
|
| 168 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 169 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 170 |
+
assert tp_world_size % self.total_num_kv_heads == 0
|
| 171 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
| 172 |
+
self.q_size = self.num_heads * self.head_dim
|
| 173 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 174 |
+
|
| 175 |
+
self.attn = Attention(self.num_heads,
|
| 176 |
+
self.head_dim,
|
| 177 |
+
self.scaling,
|
| 178 |
+
num_kv_heads=self.num_kv_heads,
|
| 179 |
+
cache_config=cache_config,
|
| 180 |
+
quant_config=quant_config,
|
| 181 |
+
prefix=f"{prefix}.attn",
|
| 182 |
+
attn_type=AttentionType.ENCODER)
|
| 183 |
+
|
| 184 |
+
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
| 185 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 186 |
+
"""Input shape: Batch x Time x Channel"""
|
| 187 |
+
|
| 188 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 189 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 190 |
+
|
| 191 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 192 |
+
|
| 193 |
+
output, _ = self.out_proj(attn_output)
|
| 194 |
+
return output
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class BartDecoderSelfAttention(nn.Module):
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
embed_dim: int,
|
| 202 |
+
num_heads: int,
|
| 203 |
+
bias: bool = True,
|
| 204 |
+
config: Optional[BartConfig] = None,
|
| 205 |
+
cache_config: Optional[CacheConfig] = None,
|
| 206 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 207 |
+
prefix: str = "",
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.d_model = config.d_model
|
| 211 |
+
self.embed_dim = embed_dim
|
| 212 |
+
self.total_num_heads = num_heads
|
| 213 |
+
self.total_num_kv_heads = self.total_num_heads
|
| 214 |
+
self.head_dim = embed_dim // num_heads
|
| 215 |
+
self.config = config
|
| 216 |
+
|
| 217 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 218 |
+
raise ValueError(f"embed_dim must be divisible by num_heads "
|
| 219 |
+
f"(got `embed_dim`: {self.embed_dim}"
|
| 220 |
+
f" and `num_heads`: {num_heads}).")
|
| 221 |
+
self.scaling = self.head_dim**-0.5
|
| 222 |
+
|
| 223 |
+
self.qkv_proj = QKVParallelLinear(
|
| 224 |
+
self.d_model,
|
| 225 |
+
self.d_model // self.total_num_heads,
|
| 226 |
+
self.total_num_heads,
|
| 227 |
+
self.total_num_kv_heads,
|
| 228 |
+
bias=bias,
|
| 229 |
+
quant_config=quant_config,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.out_proj = RowParallelLinear(
|
| 233 |
+
embed_dim,
|
| 234 |
+
embed_dim,
|
| 235 |
+
bias=bias,
|
| 236 |
+
quant_config=quant_config,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
tp_world_size = get_tensor_model_parallel_world_size()
|
| 240 |
+
assert self.total_num_heads % tp_world_size == 0
|
| 241 |
+
self.num_heads = self.total_num_heads // tp_world_size
|
| 242 |
+
|
| 243 |
+
if self.total_num_kv_heads >= tp_world_size:
|
| 244 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 245 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 246 |
+
assert self.total_num_kv_heads % tp_world_size == 0
|
| 247 |
+
else:
|
| 248 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 249 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 250 |
+
assert tp_world_size % self.total_num_kv_heads == 0
|
| 251 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
| 252 |
+
self.q_size = self.num_heads * self.head_dim
|
| 253 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 254 |
+
|
| 255 |
+
self.attn = Attention(self.num_heads,
|
| 256 |
+
self.head_dim,
|
| 257 |
+
self.scaling,
|
| 258 |
+
num_kv_heads=self.num_kv_heads,
|
| 259 |
+
cache_config=cache_config,
|
| 260 |
+
quant_config=quant_config,
|
| 261 |
+
prefix=f"{prefix}.attn",
|
| 262 |
+
attn_type=AttentionType.DECODER)
|
| 263 |
+
|
| 264 |
+
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
| 265 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 266 |
+
"""Input shape: Batch x Time x Channel"""
|
| 267 |
+
|
| 268 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 269 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 270 |
+
|
| 271 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 272 |
+
|
| 273 |
+
output, _ = self.out_proj(attn_output)
|
| 274 |
+
return output
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class BartCrossAttention(nn.Module):
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
embed_dim: int,
|
| 282 |
+
num_heads: int,
|
| 283 |
+
bias: bool = True,
|
| 284 |
+
config: Optional[BartConfig] = None,
|
| 285 |
+
cache_config: Optional[CacheConfig] = None,
|
| 286 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 287 |
+
prefix: str = "",
|
| 288 |
+
):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.d_model = config.d_model
|
| 291 |
+
self.embed_dim = embed_dim
|
| 292 |
+
self.total_num_heads = num_heads
|
| 293 |
+
self.total_num_kv_heads = self.total_num_heads
|
| 294 |
+
self.head_dim = embed_dim // num_heads
|
| 295 |
+
self.config = config
|
| 296 |
+
|
| 297 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 298 |
+
raise ValueError(f"embed_dim must be divisible by num_heads "
|
| 299 |
+
f"(got `embed_dim`: {self.embed_dim}"
|
| 300 |
+
f" and `num_heads`: {num_heads}).")
|
| 301 |
+
self.scaling = self.head_dim**-0.5
|
| 302 |
+
|
| 303 |
+
self.qkv_proj = QKVParallelLinear(
|
| 304 |
+
self.d_model,
|
| 305 |
+
self.d_model // self.total_num_heads,
|
| 306 |
+
self.total_num_heads,
|
| 307 |
+
self.total_num_kv_heads,
|
| 308 |
+
bias=bias,
|
| 309 |
+
quant_config=quant_config,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
self.out_proj = RowParallelLinear(
|
| 313 |
+
embed_dim,
|
| 314 |
+
embed_dim,
|
| 315 |
+
bias=bias,
|
| 316 |
+
quant_config=quant_config,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
tp_world_size = get_tensor_model_parallel_world_size()
|
| 320 |
+
assert self.total_num_heads % tp_world_size == 0
|
| 321 |
+
self.num_heads = self.total_num_heads // tp_world_size
|
| 322 |
+
|
| 323 |
+
if self.total_num_kv_heads >= tp_world_size:
|
| 324 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 325 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 326 |
+
assert self.total_num_kv_heads % tp_world_size == 0
|
| 327 |
+
else:
|
| 328 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 329 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 330 |
+
assert tp_world_size % self.total_num_kv_heads == 0
|
| 331 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
| 332 |
+
self.q_size = self.num_heads * self.head_dim
|
| 333 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 334 |
+
|
| 335 |
+
self.attn = Attention(self.num_heads,
|
| 336 |
+
self.head_dim,
|
| 337 |
+
self.scaling,
|
| 338 |
+
num_kv_heads=self.num_kv_heads,
|
| 339 |
+
cache_config=cache_config,
|
| 340 |
+
quant_config=quant_config,
|
| 341 |
+
prefix=f"{prefix}.attn",
|
| 342 |
+
attn_type=AttentionType.ENCODER_DECODER)
|
| 343 |
+
|
| 344 |
+
def forward(
|
| 345 |
+
self,
|
| 346 |
+
decoder_hidden_states: torch.Tensor,
|
| 347 |
+
kv_cache: torch.Tensor,
|
| 348 |
+
attn_metadata: AttentionMetadata,
|
| 349 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 350 |
+
) -> torch.Tensor:
|
| 351 |
+
"""Input shape: Batch x Time x Channel"""
|
| 352 |
+
|
| 353 |
+
# (afeldman-nm 2024/07/22) TODO:
|
| 354 |
+
# Need a more efficient solution for q/k/v
|
| 355 |
+
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
|
| 356 |
+
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
|
| 357 |
+
dim=-1)
|
| 358 |
+
if encoder_hidden_states is None:
|
| 359 |
+
k = None
|
| 360 |
+
v = None
|
| 361 |
+
else:
|
| 362 |
+
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
|
| 363 |
+
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
| 364 |
+
dim=-1)
|
| 365 |
+
|
| 366 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 367 |
+
|
| 368 |
+
output, _ = self.out_proj(attn_output)
|
| 369 |
+
return output
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class BartEncoderLayer(nn.Module):
|
| 373 |
+
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
config: BartConfig,
|
| 377 |
+
cache_config: Optional[CacheConfig] = None,
|
| 378 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 379 |
+
prefix: str = "",
|
| 380 |
+
):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.embed_dim = config.d_model
|
| 383 |
+
|
| 384 |
+
self.self_attn = BartEncoderAttention(
|
| 385 |
+
embed_dim=self.embed_dim,
|
| 386 |
+
num_heads=config.encoder_attention_heads,
|
| 387 |
+
config=config,
|
| 388 |
+
cache_config=cache_config,
|
| 389 |
+
quant_config=quant_config,
|
| 390 |
+
prefix=f"{prefix}.self_attn",
|
| 391 |
+
)
|
| 392 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 393 |
+
self.activation_fn = get_act_fn(config.activation_function)
|
| 394 |
+
|
| 395 |
+
ffn_hidden_size = self.embed_dim
|
| 396 |
+
ffn_intermediate_size = config.encoder_ffn_dim
|
| 397 |
+
ffn_has_bias = True
|
| 398 |
+
self.fc1 = ColumnParallelLinear(
|
| 399 |
+
ffn_hidden_size,
|
| 400 |
+
ffn_intermediate_size,
|
| 401 |
+
bias=ffn_has_bias,
|
| 402 |
+
quant_config=quant_config,
|
| 403 |
+
)
|
| 404 |
+
self.act = get_act_fn("gelu")
|
| 405 |
+
self.fc2 = RowParallelLinear(
|
| 406 |
+
ffn_intermediate_size,
|
| 407 |
+
ffn_hidden_size,
|
| 408 |
+
bias=ffn_has_bias,
|
| 409 |
+
quant_config=quant_config,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 413 |
+
|
| 414 |
+
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
| 415 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 416 |
+
r"""
|
| 417 |
+
Args:
|
| 418 |
+
hidden_states
|
| 419 |
+
torch.Tensor of *encoder* input embeddings.
|
| 420 |
+
kv_cache:
|
| 421 |
+
Layer-wise list of KV cache tensors
|
| 422 |
+
attn_metadata:
|
| 423 |
+
vLLM Attention metadata structure
|
| 424 |
+
Returns:
|
| 425 |
+
Encoder layer output torch.Tensor
|
| 426 |
+
"""
|
| 427 |
+
residual = hidden_states
|
| 428 |
+
hidden_states = self.self_attn(hidden_states=hidden_states,
|
| 429 |
+
kv_cache=kv_cache,
|
| 430 |
+
attn_metadata=attn_metadata)
|
| 431 |
+
|
| 432 |
+
hidden_states = residual + hidden_states
|
| 433 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 434 |
+
|
| 435 |
+
residual = hidden_states
|
| 436 |
+
fc1_out, _ = self.fc1(hidden_states)
|
| 437 |
+
hidden_states = self.activation_fn(fc1_out)
|
| 438 |
+
|
| 439 |
+
hidden_states, _ = self.fc2(hidden_states)
|
| 440 |
+
|
| 441 |
+
hidden_states = residual + hidden_states
|
| 442 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 443 |
+
|
| 444 |
+
if hidden_states.dtype == torch.float16 and (
|
| 445 |
+
torch.isinf(hidden_states).any()
|
| 446 |
+
or torch.isnan(hidden_states).any()):
|
| 447 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 448 |
+
hidden_states = torch.clamp(hidden_states,
|
| 449 |
+
min=-clamp_value,
|
| 450 |
+
max=clamp_value)
|
| 451 |
+
|
| 452 |
+
return hidden_states
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class BartDecoderLayer(nn.Module):
|
| 456 |
+
|
| 457 |
+
def __init__(
|
| 458 |
+
self,
|
| 459 |
+
config: BartConfig,
|
| 460 |
+
cache_config: Optional[CacheConfig] = None,
|
| 461 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 462 |
+
prefix: str = "",
|
| 463 |
+
):
|
| 464 |
+
super().__init__()
|
| 465 |
+
self.embed_dim = config.d_model
|
| 466 |
+
|
| 467 |
+
self.self_attn = BartDecoderSelfAttention(
|
| 468 |
+
embed_dim=self.embed_dim,
|
| 469 |
+
num_heads=config.decoder_attention_heads,
|
| 470 |
+
config=config,
|
| 471 |
+
cache_config=cache_config,
|
| 472 |
+
quant_config=quant_config,
|
| 473 |
+
prefix=f"{prefix}.self_attn",
|
| 474 |
+
)
|
| 475 |
+
self.activation_fn = get_act_fn(config.activation_function)
|
| 476 |
+
|
| 477 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 478 |
+
'''
|
| 479 |
+
afeldman-nm: personally I would call this "cross-attention",
|
| 480 |
+
however I left the name as "encoder_attn" to maintain consistency
|
| 481 |
+
with the name of the pretrained weights.
|
| 482 |
+
'''
|
| 483 |
+
self.encoder_attn = BartCrossAttention(
|
| 484 |
+
self.embed_dim,
|
| 485 |
+
config.decoder_attention_heads,
|
| 486 |
+
config=config,
|
| 487 |
+
prefix=f"{prefix}.encoder_attn",
|
| 488 |
+
)
|
| 489 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 490 |
+
|
| 491 |
+
ffn_hidden_size = self.embed_dim
|
| 492 |
+
ffn_intermediate_size = config.encoder_ffn_dim
|
| 493 |
+
ffn_has_bias = True
|
| 494 |
+
self.fc1 = ColumnParallelLinear(
|
| 495 |
+
ffn_hidden_size,
|
| 496 |
+
ffn_intermediate_size,
|
| 497 |
+
bias=ffn_has_bias,
|
| 498 |
+
quant_config=quant_config,
|
| 499 |
+
)
|
| 500 |
+
self.fc2 = RowParallelLinear(
|
| 501 |
+
ffn_intermediate_size,
|
| 502 |
+
ffn_hidden_size,
|
| 503 |
+
bias=ffn_has_bias,
|
| 504 |
+
quant_config=quant_config,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 508 |
+
|
| 509 |
+
def forward(
|
| 510 |
+
self,
|
| 511 |
+
decoder_hidden_states: torch.Tensor,
|
| 512 |
+
kv_cache: torch.Tensor,
|
| 513 |
+
attn_metadata: AttentionMetadata,
|
| 514 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 515 |
+
) -> torch.Tensor:
|
| 516 |
+
r"""
|
| 517 |
+
Args:
|
| 518 |
+
decoder_hidden_states
|
| 519 |
+
torch.Tensor of *decoder* input embeddings.
|
| 520 |
+
kv_cache:
|
| 521 |
+
KV cache tensor
|
| 522 |
+
attn_metadata:
|
| 523 |
+
vLLM Attention metadata structure
|
| 524 |
+
encoder_hidden_states
|
| 525 |
+
torch.Tensor of *encoder* input embeddings.
|
| 526 |
+
Returns:
|
| 527 |
+
Decoder layer output torch.Tensor
|
| 528 |
+
"""
|
| 529 |
+
residual = decoder_hidden_states
|
| 530 |
+
|
| 531 |
+
# Self Attention
|
| 532 |
+
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
|
| 533 |
+
kv_cache=kv_cache,
|
| 534 |
+
attn_metadata=attn_metadata)
|
| 535 |
+
|
| 536 |
+
hidden_states = residual + hidden_states
|
| 537 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 538 |
+
|
| 539 |
+
# Cross-Attention Block
|
| 540 |
+
|
| 541 |
+
residual = hidden_states
|
| 542 |
+
|
| 543 |
+
hidden_states = self.encoder_attn(
|
| 544 |
+
decoder_hidden_states=hidden_states,
|
| 545 |
+
kv_cache=kv_cache,
|
| 546 |
+
attn_metadata=attn_metadata,
|
| 547 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
hidden_states = residual + hidden_states
|
| 551 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 552 |
+
|
| 553 |
+
# Fully Connected
|
| 554 |
+
residual = hidden_states
|
| 555 |
+
fc1_out, _ = self.fc1(hidden_states)
|
| 556 |
+
hidden_states = self.activation_fn(fc1_out)
|
| 557 |
+
|
| 558 |
+
hidden_states, _ = self.fc2(hidden_states)
|
| 559 |
+
|
| 560 |
+
hidden_states = residual + hidden_states
|
| 561 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 562 |
+
|
| 563 |
+
return hidden_states
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
class BartEncoder(nn.Module):
|
| 567 |
+
"""
|
| 568 |
+
Transformer encoder consisting of *config.encoder_layers*
|
| 569 |
+
self attention layers. Each layer is a [`BartEncoderLayer`].
|
| 570 |
+
Args:
|
| 571 |
+
config: BartConfig
|
| 572 |
+
embed_tokens (nn.Embedding): output embedding
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
def __init__(self,
|
| 576 |
+
config: BartConfig,
|
| 577 |
+
cache_config: Optional[CacheConfig] = None,
|
| 578 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 579 |
+
lora_config: Optional[LoRAConfig] = None,
|
| 580 |
+
embed_tokens: Optional[nn.Embedding] = None,
|
| 581 |
+
prefix: str = ""):
|
| 582 |
+
super().__init__()
|
| 583 |
+
|
| 584 |
+
self.cache_config = cache_config
|
| 585 |
+
self.quant_config = quant_config
|
| 586 |
+
self.lora_config = lora_config
|
| 587 |
+
embed_dim = config.d_model
|
| 588 |
+
self.max_source_positions = config.max_position_embeddings
|
| 589 |
+
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
| 590 |
+
|
| 591 |
+
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
| 592 |
+
embed_dim,
|
| 593 |
+
embed_scale=embed_scale)
|
| 594 |
+
|
| 595 |
+
if embed_tokens is not None:
|
| 596 |
+
self.embed_tokens.weight = embed_tokens.weight
|
| 597 |
+
|
| 598 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
| 599 |
+
config.max_position_embeddings,
|
| 600 |
+
embed_dim,
|
| 601 |
+
)
|
| 602 |
+
self.layers = nn.ModuleList([
|
| 603 |
+
BartEncoderLayer(config,
|
| 604 |
+
cache_config,
|
| 605 |
+
quant_config,
|
| 606 |
+
prefix=f"{prefix}.layers.{layer_idx}")
|
| 607 |
+
for layer_idx in range(config.encoder_layers)
|
| 608 |
+
])
|
| 609 |
+
|
| 610 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
| 611 |
+
|
| 612 |
+
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
| 613 |
+
kv_caches: List[torch.Tensor],
|
| 614 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 615 |
+
r"""
|
| 616 |
+
Args:
|
| 617 |
+
input_ids
|
| 618 |
+
Indices of *encoder* input sequence tokens in the vocabulary.
|
| 619 |
+
Padding will be ignored by default should you
|
| 620 |
+
provide it.
|
| 621 |
+
positions
|
| 622 |
+
Positions of *encoder* input sequence tokens.
|
| 623 |
+
kv_caches:
|
| 624 |
+
Layer-wise list of KV cache tensors
|
| 625 |
+
attn_metadata:
|
| 626 |
+
vLLM Attention metadata structure
|
| 627 |
+
Returns:
|
| 628 |
+
Decoder output torch.Tensor
|
| 629 |
+
"""
|
| 630 |
+
# retrieve input_ids and inputs_embeds
|
| 631 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 632 |
+
|
| 633 |
+
embed_pos = self.embed_positions(positions)
|
| 634 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
| 635 |
+
|
| 636 |
+
hidden_states = inputs_embeds + embed_pos
|
| 637 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
| 638 |
+
|
| 639 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 640 |
+
hidden_states = encoder_layer(
|
| 641 |
+
hidden_states=hidden_states,
|
| 642 |
+
kv_cache=kv_caches[idx],
|
| 643 |
+
attn_metadata=attn_metadata,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
return hidden_states
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class BartDecoder(nn.Module):
|
| 650 |
+
"""
|
| 651 |
+
Transformer decoder consisting of *config.decoder_layers* layers.
|
| 652 |
+
Each layer is a [`BartDecoderLayer`]
|
| 653 |
+
Args:
|
| 654 |
+
config: BartConfig
|
| 655 |
+
embed_tokens (nn.Embedding): output embedding
|
| 656 |
+
"""
|
| 657 |
+
|
| 658 |
+
def __init__(
|
| 659 |
+
self,
|
| 660 |
+
config: BartConfig,
|
| 661 |
+
cache_config: Optional[CacheConfig] = None,
|
| 662 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 663 |
+
lora_config: Optional[LoRAConfig] = None,
|
| 664 |
+
embed_tokens: Optional[nn.Embedding] = None,
|
| 665 |
+
prefix: str = "",
|
| 666 |
+
):
|
| 667 |
+
super().__init__()
|
| 668 |
+
self.cache_config = cache_config
|
| 669 |
+
self.quant_config = quant_config
|
| 670 |
+
self.lora_config = lora_config
|
| 671 |
+
self.max_target_positions = config.max_position_embeddings
|
| 672 |
+
embed_scale = math.sqrt(
|
| 673 |
+
config.d_model) if config.scale_embedding else 1.0
|
| 674 |
+
|
| 675 |
+
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
| 676 |
+
config.d_model,
|
| 677 |
+
embed_scale=embed_scale)
|
| 678 |
+
|
| 679 |
+
if embed_tokens is not None:
|
| 680 |
+
self.embed_tokens.weight = embed_tokens.weight
|
| 681 |
+
|
| 682 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
| 683 |
+
config.max_position_embeddings,
|
| 684 |
+
config.d_model,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.layers = nn.ModuleList(
|
| 688 |
+
[BartDecoderLayer(config,cache_config,quant_config,
|
| 689 |
+
prefix=f"{prefix}.layers.{layer_idx}") \
|
| 690 |
+
for layer_idx in range(config.decoder_layers)])
|
| 691 |
+
|
| 692 |
+
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
| 693 |
+
|
| 694 |
+
def forward(self, decoder_input_ids: torch.Tensor,
|
| 695 |
+
decoder_positions: torch.Tensor,
|
| 696 |
+
encoder_hidden_states: Optional[torch.Tensor],
|
| 697 |
+
kv_caches: List[torch.Tensor],
|
| 698 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 699 |
+
r"""
|
| 700 |
+
Args:
|
| 701 |
+
decoder_input_ids
|
| 702 |
+
Indices of *decoder* input sequence tokens in the vocabulary.
|
| 703 |
+
Padding will be ignored by default should you
|
| 704 |
+
provide it.
|
| 705 |
+
decoder_positions
|
| 706 |
+
Positions of *decoder* input sequence tokens.
|
| 707 |
+
encoder_hidden_states:
|
| 708 |
+
Tensor of encoder output embeddings
|
| 709 |
+
kv_caches:
|
| 710 |
+
Layer-wise list of KV cache tensors
|
| 711 |
+
attn_metadata:
|
| 712 |
+
vLLM Attention metadata structure
|
| 713 |
+
Returns:
|
| 714 |
+
Decoder output torch.Tensor
|
| 715 |
+
"""
|
| 716 |
+
|
| 717 |
+
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
| 718 |
+
|
| 719 |
+
# embed positions
|
| 720 |
+
embed_pos = self.embed_positions(decoder_positions)
|
| 721 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
| 722 |
+
|
| 723 |
+
hidden_states = inputs_embeds + embed_pos
|
| 724 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
| 725 |
+
|
| 726 |
+
# decoder layers
|
| 727 |
+
|
| 728 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 729 |
+
hidden_states = decoder_layer(
|
| 730 |
+
decoder_hidden_states=hidden_states,
|
| 731 |
+
kv_cache=kv_caches[idx],
|
| 732 |
+
attn_metadata=attn_metadata,
|
| 733 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
return hidden_states
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class BartModel(nn.Module):
|
| 740 |
+
_tied_weights_keys = [
|
| 741 |
+
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
| 742 |
+
]
|
| 743 |
+
|
| 744 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 745 |
+
super().__init__()
|
| 746 |
+
|
| 747 |
+
config = vllm_config.model_config.hf_config
|
| 748 |
+
cache_config = vllm_config.cache_config
|
| 749 |
+
quant_config = vllm_config.quant_config
|
| 750 |
+
lora_config = vllm_config.lora_config
|
| 751 |
+
|
| 752 |
+
self.config = config
|
| 753 |
+
|
| 754 |
+
self.padding_idx = config.pad_token_id
|
| 755 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 756 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 757 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 758 |
+
self.org_vocab_size = config.vocab_size
|
| 759 |
+
|
| 760 |
+
self.encoder = BartEncoder(config,
|
| 761 |
+
cache_config,
|
| 762 |
+
quant_config=quant_config,
|
| 763 |
+
prefix=f"{prefix}.encoder")
|
| 764 |
+
self.decoder = BartDecoder(config,
|
| 765 |
+
cache_config,
|
| 766 |
+
quant_config=quant_config,
|
| 767 |
+
prefix=f"{prefix}.decoder")
|
| 768 |
+
|
| 769 |
+
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
| 770 |
+
encoder_input_ids: torch.Tensor,
|
| 771 |
+
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
| 772 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 773 |
+
r"""
|
| 774 |
+
Args:
|
| 775 |
+
input_ids
|
| 776 |
+
Indices of *decoder* input sequence tokens in the vocabulary.
|
| 777 |
+
Padding will be ignored by default should you
|
| 778 |
+
provide it.
|
| 779 |
+
positions
|
| 780 |
+
Positions of *decoder* input sequence tokens.
|
| 781 |
+
encoder_input_ids
|
| 782 |
+
Indices of *encoder* input sequence tokens in the vocabulary.
|
| 783 |
+
encoder_positions:
|
| 784 |
+
Positions of *encoder* input sequence tokens.
|
| 785 |
+
kv_caches:
|
| 786 |
+
Layer-wise list of KV cache tensors
|
| 787 |
+
attn_metadata:
|
| 788 |
+
vLLM Attention metadata structure
|
| 789 |
+
Returns:
|
| 790 |
+
Model output torch.Tensor
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
encoder_hidden_states = None
|
| 794 |
+
|
| 795 |
+
if encoder_input_ids.numel() > 0:
|
| 796 |
+
# Run encoder attention if a non-zero number of encoder tokens
|
| 797 |
+
# are provided as input
|
| 798 |
+
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
| 799 |
+
positions=encoder_positions,
|
| 800 |
+
kv_caches=kv_caches,
|
| 801 |
+
attn_metadata=attn_metadata)
|
| 802 |
+
|
| 803 |
+
# decoder outputs consists of
|
| 804 |
+
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
| 805 |
+
decoder_outputs = self.decoder(
|
| 806 |
+
decoder_input_ids=input_ids,
|
| 807 |
+
decoder_positions=positions,
|
| 808 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 809 |
+
kv_caches=kv_caches,
|
| 810 |
+
attn_metadata=attn_metadata)
|
| 811 |
+
|
| 812 |
+
return decoder_outputs
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
class BartForConditionalGeneration(nn.Module):
|
| 816 |
+
base_model_prefix = "model"
|
| 817 |
+
|
| 818 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 819 |
+
|
| 820 |
+
super().__init__()
|
| 821 |
+
config = vllm_config.model_config.hf_config
|
| 822 |
+
lora_config = vllm_config.lora_config
|
| 823 |
+
# currently all existing BART models have `tie_word_embeddings` enabled
|
| 824 |
+
assert config.tie_word_embeddings
|
| 825 |
+
self.config = config
|
| 826 |
+
self.model = BartModel(vllm_config=vllm_config,
|
| 827 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 828 |
+
|
| 829 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 830 |
+
if lora_config:
|
| 831 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 832 |
+
|
| 833 |
+
embed_scale = math.sqrt(
|
| 834 |
+
config.d_model) if config.scale_embedding else 1.0
|
| 835 |
+
|
| 836 |
+
self.lm_head = BartParallelLMHead(config.vocab_size,
|
| 837 |
+
config.d_model,
|
| 838 |
+
embed_scale=embed_scale)
|
| 839 |
+
|
| 840 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 841 |
+
config.vocab_size)
|
| 842 |
+
self.sampler = get_sampler()
|
| 843 |
+
|
| 844 |
+
def forward(
|
| 845 |
+
self,
|
| 846 |
+
input_ids: torch.Tensor,
|
| 847 |
+
positions: torch.Tensor,
|
| 848 |
+
kv_caches: List[torch.Tensor],
|
| 849 |
+
attn_metadata: AttentionMetadata,
|
| 850 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 851 |
+
*,
|
| 852 |
+
encoder_input_ids: torch.Tensor,
|
| 853 |
+
encoder_positions: torch.Tensor,
|
| 854 |
+
**kwargs,
|
| 855 |
+
) -> torch.Tensor:
|
| 856 |
+
r"""
|
| 857 |
+
Args:
|
| 858 |
+
input_ids
|
| 859 |
+
torch.Tensor of *decoder* input token ids.
|
| 860 |
+
positions
|
| 861 |
+
torch.Tensor of *decoder* position indices.
|
| 862 |
+
encoder_input_ids
|
| 863 |
+
torch.Tensor of *encoder* input token ids.
|
| 864 |
+
encoder_positions
|
| 865 |
+
torch.Tensor of *encoder* position indices
|
| 866 |
+
kv_caches:
|
| 867 |
+
Layer-wise list of KV cache tensors
|
| 868 |
+
attn_metadata:
|
| 869 |
+
vLLM Attention metadata structure
|
| 870 |
+
Returns:
|
| 871 |
+
Output torch.Tensor
|
| 872 |
+
"""
|
| 873 |
+
return self.model(input_ids, positions, encoder_input_ids,
|
| 874 |
+
encoder_positions, kv_caches, attn_metadata)
|
| 875 |
+
|
| 876 |
+
def compute_logits(
|
| 877 |
+
self,
|
| 878 |
+
hidden_states: torch.Tensor,
|
| 879 |
+
sampling_metadata: SamplingMetadata,
|
| 880 |
+
) -> Optional[torch.Tensor]:
|
| 881 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 882 |
+
sampling_metadata)
|
| 883 |
+
return logits
|
| 884 |
+
|
| 885 |
+
def sample(
|
| 886 |
+
self,
|
| 887 |
+
logits: Optional[torch.Tensor],
|
| 888 |
+
sampling_metadata: SamplingMetadata,
|
| 889 |
+
) -> Optional[SamplerOutput]:
|
| 890 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 891 |
+
return next_tokens
|
| 892 |
+
|
| 893 |
+
stacked_params_mapping = {
|
| 894 |
+
"q_proj": {
|
| 895 |
+
"param_name": "qkv_proj",
|
| 896 |
+
"shard_id": "q",
|
| 897 |
+
},
|
| 898 |
+
"k_proj": {
|
| 899 |
+
"param_name": "qkv_proj",
|
| 900 |
+
"shard_id": "k",
|
| 901 |
+
},
|
| 902 |
+
"v_proj": {
|
| 903 |
+
"param_name": "qkv_proj",
|
| 904 |
+
"shard_id": "v",
|
| 905 |
+
},
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
params_mapping = {
|
| 909 |
+
"beta": "bias",
|
| 910 |
+
"gamma": "weight",
|
| 911 |
+
"LayerNorm": "layernorm",
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
def _rename_key(self, key: str):
|
| 915 |
+
prefix = f"{self.base_model_prefix}."
|
| 916 |
+
key = key[len(prefix):] if key.startswith(prefix) else key
|
| 917 |
+
|
| 918 |
+
for src, dst in self.params_mapping.items():
|
| 919 |
+
key = key.replace(src, dst)
|
| 920 |
+
|
| 921 |
+
return key
|
| 922 |
+
|
| 923 |
+
def _rename_stacked_param(
|
| 924 |
+
self,
|
| 925 |
+
name: str,
|
| 926 |
+
) -> Tuple[str, Optional[str]]:
|
| 927 |
+
for key, mapping in self.stacked_params_mapping.items():
|
| 928 |
+
if key in name:
|
| 929 |
+
name = name.replace(key, mapping["param_name"])
|
| 930 |
+
return name, mapping["shard_id"]
|
| 931 |
+
return name, None
|
| 932 |
+
|
| 933 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 934 |
+
|
| 935 |
+
model_params_dict = dict(self.model.named_parameters())
|
| 936 |
+
top_params_dict = dict(self.named_parameters())
|
| 937 |
+
|
| 938 |
+
weights_tuple_list = list(weights)
|
| 939 |
+
|
| 940 |
+
shared_embedding_weight = None
|
| 941 |
+
shared_embedding_shard_id = None
|
| 942 |
+
|
| 943 |
+
for name, loaded_weight in weights_tuple_list:
|
| 944 |
+
|
| 945 |
+
name = self._rename_key(name)
|
| 946 |
+
name, shard_id = self._rename_stacked_param(name)
|
| 947 |
+
|
| 948 |
+
if ('shared.weight' in name
|
| 949 |
+
or 'encoder.embed_tokens.weight' in name
|
| 950 |
+
or 'decoder.embed_tokens.weight' in name
|
| 951 |
+
or 'lm_head.weight' in name):
|
| 952 |
+
assert shared_embedding_weight is None, (
|
| 953 |
+
"Conflicting embedding weights.")
|
| 954 |
+
shared_embedding_weight = loaded_weight
|
| 955 |
+
shared_embedding_shard_id = shard_id
|
| 956 |
+
else:
|
| 957 |
+
# Skip the specific downstream task weight.
|
| 958 |
+
if name.startswith('cls.'):
|
| 959 |
+
continue
|
| 960 |
+
# use Pooler instead.
|
| 961 |
+
if name.startswith('pooler.'):
|
| 962 |
+
continue
|
| 963 |
+
# Skip loading extra bias for GPTQ models.
|
| 964 |
+
if name.endswith(".bias") and name not in model_params_dict:
|
| 965 |
+
continue
|
| 966 |
+
|
| 967 |
+
param = model_params_dict[name]
|
| 968 |
+
weight_loader = getattr(param, "weight_loader",
|
| 969 |
+
default_weight_loader)
|
| 970 |
+
if shard_id:
|
| 971 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 972 |
+
else:
|
| 973 |
+
weight_loader(param, loaded_weight)
|
| 974 |
+
|
| 975 |
+
# Assign shared weight values
|
| 976 |
+
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
|
| 977 |
+
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
|
| 978 |
+
default_weight_loader)
|
| 979 |
+
|
| 980 |
+
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
|
| 981 |
+
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
|
| 982 |
+
default_weight_loader)
|
| 983 |
+
|
| 984 |
+
lm_head_in_param = top_params_dict['lm_head.weight']
|
| 985 |
+
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
|
| 986 |
+
default_weight_loader)
|
| 987 |
+
|
| 988 |
+
assert shared_embedding_weight is not None
|
| 989 |
+
|
| 990 |
+
if shared_embedding_shard_id:
|
| 991 |
+
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
|
| 992 |
+
shared_embedding_shard_id)
|
| 993 |
+
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
|
| 994 |
+
shared_embedding_shard_id)
|
| 995 |
+
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
|
| 996 |
+
shared_embedding_shard_id)
|
| 997 |
+
else:
|
| 998 |
+
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
|
| 999 |
+
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
|
| 1000 |
+
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/bert.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Iterable, List, Optional, Set, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import BertConfig
|
| 8 |
+
|
| 9 |
+
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
| 10 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 11 |
+
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
| 12 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 13 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 14 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 15 |
+
QKVParallelLinear,
|
| 16 |
+
RowParallelLinear)
|
| 17 |
+
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
| 18 |
+
PoolingType)
|
| 19 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 20 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 21 |
+
VocabParallelEmbedding)
|
| 22 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 23 |
+
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
| 24 |
+
from vllm.sequence import IntermediateTensors, PoolerOutput
|
| 25 |
+
from vllm.transformers_utils.config import (
|
| 26 |
+
get_cross_encoder_activation_function)
|
| 27 |
+
|
| 28 |
+
from .interfaces import SupportsCrossEncoding
|
| 29 |
+
from .utils import WeightsMapper, maybe_prefix
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BertEmbedding(nn.Module):
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: BertConfig):
|
| 35 |
+
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.size = config.hidden_size
|
| 38 |
+
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
| 39 |
+
config.hidden_size)
|
| 40 |
+
self.position_embeddings = VocabParallelEmbedding(
|
| 41 |
+
config.max_position_embeddings, config.hidden_size)
|
| 42 |
+
self.token_type_embeddings = VocabParallelEmbedding(
|
| 43 |
+
config.type_vocab_size, config.hidden_size)
|
| 44 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
| 45 |
+
eps=config.layer_norm_eps)
|
| 46 |
+
self.position_ids = nn.Parameter(
|
| 47 |
+
torch.empty((1, config.max_position_embeddings)), )
|
| 48 |
+
|
| 49 |
+
self.position_embedding_type = config.position_embedding_type
|
| 50 |
+
if self.position_embedding_type != "absolute":
|
| 51 |
+
raise ValueError("Only 'absolute' position_embedding_type" +
|
| 52 |
+
" is supported")
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
input_ids: torch.Tensor,
|
| 57 |
+
seq_lens: torch.Tensor,
|
| 58 |
+
position_ids: torch.Tensor,
|
| 59 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
input_shape = input_ids.size()
|
| 62 |
+
|
| 63 |
+
# Input embeddings.
|
| 64 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 65 |
+
|
| 66 |
+
# Position embeddings.
|
| 67 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 68 |
+
|
| 69 |
+
if token_type_ids is None:
|
| 70 |
+
token_type_ids = torch.zeros(input_shape,
|
| 71 |
+
dtype=torch.long,
|
| 72 |
+
device=inputs_embeds.device)
|
| 73 |
+
|
| 74 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 75 |
+
|
| 76 |
+
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
| 77 |
+
embeddings = self.LayerNorm(embeddings)
|
| 78 |
+
return embeddings
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class BertPooler(nn.Module):
|
| 82 |
+
|
| 83 |
+
def __init__(self, config: BertConfig):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 86 |
+
self.activation = nn.Tanh()
|
| 87 |
+
|
| 88 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 90 |
+
# to the first token.
|
| 91 |
+
first_token_tensor = hidden_states[0, :]
|
| 92 |
+
pooled_output = self.dense(first_token_tensor)
|
| 93 |
+
pooled_output = self.activation(pooled_output)
|
| 94 |
+
return pooled_output
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@support_torch_compile
|
| 98 |
+
class BertEncoder(nn.Module):
|
| 99 |
+
|
| 100 |
+
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
| 101 |
+
super().__init__()
|
| 102 |
+
config = vllm_config.model_config.hf_config
|
| 103 |
+
cache_config = vllm_config.cache_config
|
| 104 |
+
quant_config = vllm_config.quant_config
|
| 105 |
+
self.layer = nn.ModuleList([
|
| 106 |
+
BertLayer(config=config,
|
| 107 |
+
cache_config=cache_config,
|
| 108 |
+
quant_config=quant_config,
|
| 109 |
+
prefix=f"{prefix}.layer.{layer_idx}")
|
| 110 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
hidden_states: torch.Tensor,
|
| 116 |
+
kv_caches: List[torch.Tensor],
|
| 117 |
+
attn_metadata: AttentionMetadata,
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
for i in range(len(self.layer)):
|
| 120 |
+
layer = self.layer[i]
|
| 121 |
+
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
| 122 |
+
return hidden_states
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class BertLayer(nn.Module):
|
| 126 |
+
|
| 127 |
+
def __init__(self,
|
| 128 |
+
config: BertConfig,
|
| 129 |
+
cache_config: Optional[CacheConfig] = None,
|
| 130 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 131 |
+
prefix: str = ""):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.attention = BertAttention(
|
| 135 |
+
hidden_size=config.hidden_size,
|
| 136 |
+
num_attention_heads=config.num_attention_heads,
|
| 137 |
+
layer_norm_eps=config.layer_norm_eps,
|
| 138 |
+
cache_config=cache_config,
|
| 139 |
+
quant_config=quant_config,
|
| 140 |
+
prefix=f"{prefix}.attention")
|
| 141 |
+
|
| 142 |
+
self.intermediate = BertIntermediate(
|
| 143 |
+
hidden_size=config.hidden_size,
|
| 144 |
+
intermediate_size=config.intermediate_size,
|
| 145 |
+
hidden_act=config.hidden_act,
|
| 146 |
+
quant_config=quant_config,
|
| 147 |
+
prefix=f"{prefix}.intermediate")
|
| 148 |
+
|
| 149 |
+
self.output = BertOutput(hidden_size=config.hidden_size,
|
| 150 |
+
intermediate_size=config.intermediate_size,
|
| 151 |
+
layer_norm_eps=config.layer_norm_eps,
|
| 152 |
+
quant_config=quant_config,
|
| 153 |
+
prefix=f"{prefix}.output")
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
hidden_states: torch.Tensor,
|
| 158 |
+
kv_cache: Optional[torch.Tensor],
|
| 159 |
+
attn_metadata: AttentionMetadata,
|
| 160 |
+
):
|
| 161 |
+
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
|
| 162 |
+
intermediate_output = self.intermediate(attn_output)
|
| 163 |
+
output = self.output(intermediate_output, attn_output)
|
| 164 |
+
return output
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class BertAttention(nn.Module):
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
hidden_size: int,
|
| 172 |
+
num_attention_heads: int,
|
| 173 |
+
layer_norm_eps: float,
|
| 174 |
+
cache_config: Optional[CacheConfig] = None,
|
| 175 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 176 |
+
prefix: str = "",
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.self = BertSelfAttention(hidden_size=hidden_size,
|
| 181 |
+
num_attention_heads=num_attention_heads,
|
| 182 |
+
cache_config=cache_config,
|
| 183 |
+
quant_config=quant_config,
|
| 184 |
+
prefix=f"{prefix}.output")
|
| 185 |
+
|
| 186 |
+
self.output = BertSelfOutput(hidden_size=hidden_size,
|
| 187 |
+
layer_norm_eps=layer_norm_eps,
|
| 188 |
+
quant_config=quant_config,
|
| 189 |
+
prefix=f"{prefix}.output")
|
| 190 |
+
|
| 191 |
+
def forward(
|
| 192 |
+
self,
|
| 193 |
+
hidden_states: torch.Tensor,
|
| 194 |
+
kv_cache: torch.Tensor,
|
| 195 |
+
attn_metadata: AttentionMetadata,
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
self_output = self.self(hidden_states, kv_cache, attn_metadata)
|
| 198 |
+
return self.output(self_output, hidden_states)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class BertSelfAttention(nn.Module):
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
hidden_size: int,
|
| 206 |
+
num_attention_heads: int,
|
| 207 |
+
cache_config: Optional[CacheConfig] = None,
|
| 208 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 209 |
+
prefix: str = "",
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.hidden_size = hidden_size
|
| 213 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 214 |
+
|
| 215 |
+
self.total_num_heads = num_attention_heads
|
| 216 |
+
assert self.total_num_heads % tp_size == 0
|
| 217 |
+
|
| 218 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 219 |
+
self.total_num_kv_heads = self.total_num_heads
|
| 220 |
+
self.head_dim = self.hidden_size // self.total_num_heads
|
| 221 |
+
assert self.head_dim * self.total_num_heads == self.hidden_size
|
| 222 |
+
|
| 223 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 224 |
+
|
| 225 |
+
self.q_size = self.num_heads * self.head_dim
|
| 226 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 227 |
+
self.scaling = self.head_dim**-0.5
|
| 228 |
+
self.qkv_proj = QKVParallelLinear(
|
| 229 |
+
hidden_size=self.hidden_size,
|
| 230 |
+
head_size=self.head_dim,
|
| 231 |
+
total_num_heads=self.total_num_heads,
|
| 232 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 233 |
+
bias=True,
|
| 234 |
+
quant_config=quant_config,
|
| 235 |
+
prefix=f"{prefix}.qkv_proj")
|
| 236 |
+
|
| 237 |
+
self.attn = Attention(num_heads=self.num_heads,
|
| 238 |
+
head_size=self.head_dim,
|
| 239 |
+
scale=self.scaling,
|
| 240 |
+
num_kv_heads=self.num_kv_heads,
|
| 241 |
+
cache_config=cache_config,
|
| 242 |
+
quant_config=quant_config,
|
| 243 |
+
prefix=f"{prefix}.attn",
|
| 244 |
+
attn_type=AttentionType.ENCODER_ONLY)
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
hidden_states: torch.Tensor,
|
| 249 |
+
kv_cache: torch.Tensor,
|
| 250 |
+
attn_metadata: AttentionMetadata,
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 253 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 254 |
+
output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 255 |
+
return output
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class BertSelfOutput(nn.Module):
|
| 259 |
+
|
| 260 |
+
def __init__(self,
|
| 261 |
+
hidden_size: int,
|
| 262 |
+
layer_norm_eps: float,
|
| 263 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 264 |
+
prefix: str = ""):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.dense = RowParallelLinear(input_size=hidden_size,
|
| 267 |
+
output_size=hidden_size,
|
| 268 |
+
bias=True,
|
| 269 |
+
quant_config=quant_config,
|
| 270 |
+
prefix=f"{prefix}.dense")
|
| 271 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
| 272 |
+
|
| 273 |
+
def forward(self, hidden_states: torch.Tensor,
|
| 274 |
+
input_tensor: torch.Tensor) -> torch.Tensor:
|
| 275 |
+
hidden_states, _ = self.dense(hidden_states)
|
| 276 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 277 |
+
return hidden_states
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class BertIntermediate(nn.Module):
|
| 281 |
+
|
| 282 |
+
def __init__(self,
|
| 283 |
+
hidden_size: int,
|
| 284 |
+
intermediate_size: int,
|
| 285 |
+
hidden_act: str,
|
| 286 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 287 |
+
prefix: str = ""):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.dense = ColumnParallelLinear(input_size=hidden_size,
|
| 290 |
+
output_size=intermediate_size,
|
| 291 |
+
bias=True,
|
| 292 |
+
quant_config=quant_config,
|
| 293 |
+
prefix=f"{prefix}.dense")
|
| 294 |
+
self.intermediate_act_fn = get_act_fn(hidden_act)
|
| 295 |
+
|
| 296 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 297 |
+
hidden_states, _ = self.dense(hidden_states)
|
| 298 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 299 |
+
return hidden_states
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class BertOutput(nn.Module):
|
| 303 |
+
|
| 304 |
+
def __init__(self,
|
| 305 |
+
hidden_size: int,
|
| 306 |
+
intermediate_size: int,
|
| 307 |
+
layer_norm_eps: float,
|
| 308 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 309 |
+
prefix: str = ""):
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
self.dense = RowParallelLinear(input_size=intermediate_size,
|
| 313 |
+
output_size=hidden_size,
|
| 314 |
+
bias=True,
|
| 315 |
+
quant_config=quant_config,
|
| 316 |
+
prefix=f"{prefix}.dense")
|
| 317 |
+
|
| 318 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
| 319 |
+
|
| 320 |
+
def forward(self, hidden_states: torch.Tensor,
|
| 321 |
+
input_tensor: torch.Tensor) -> torch.Tensor:
|
| 322 |
+
hidden_states, _ = self.dense(hidden_states)
|
| 323 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 324 |
+
return hidden_states
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class BertModel(nn.Module):
|
| 328 |
+
|
| 329 |
+
def __init__(self,
|
| 330 |
+
*,
|
| 331 |
+
vllm_config: VllmConfig,
|
| 332 |
+
prefix: str = "",
|
| 333 |
+
embedding_class: type = BertEmbedding,
|
| 334 |
+
add_pooling_layer: bool = False):
|
| 335 |
+
super().__init__()
|
| 336 |
+
config = vllm_config.model_config.hf_config
|
| 337 |
+
self.embeddings = embedding_class(config)
|
| 338 |
+
self.encoder = BertEncoder(vllm_config=vllm_config,
|
| 339 |
+
prefix=f"{prefix}.encoder")
|
| 340 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 341 |
+
|
| 342 |
+
def forward(
|
| 343 |
+
self,
|
| 344 |
+
input_ids: torch.Tensor,
|
| 345 |
+
position_ids: torch.Tensor,
|
| 346 |
+
kv_caches: List[torch.Tensor],
|
| 347 |
+
attn_metadata: AttentionMetadata,
|
| 348 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 349 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 350 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 351 |
+
) -> torch.Tensor:
|
| 352 |
+
if inputs_embeds is not None:
|
| 353 |
+
hidden_states = inputs_embeds
|
| 354 |
+
else:
|
| 355 |
+
assert hasattr(attn_metadata, "seq_lens_tensor")
|
| 356 |
+
hidden_states = self.embeddings(
|
| 357 |
+
input_ids=input_ids,
|
| 358 |
+
seq_lens=attn_metadata.seq_lens_tensor,
|
| 359 |
+
position_ids=position_ids,
|
| 360 |
+
token_type_ids=token_type_ids)
|
| 361 |
+
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
| 362 |
+
|
| 363 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 364 |
+
torch.Tensor]]) -> Set[str]:
|
| 365 |
+
stacked_params_mapping = [
|
| 366 |
+
# (param_name, shard_name, shard_id)
|
| 367 |
+
("qkv_proj", "query", "q"),
|
| 368 |
+
("qkv_proj", "key", "k"),
|
| 369 |
+
("qkv_proj", "value", "v"),
|
| 370 |
+
]
|
| 371 |
+
|
| 372 |
+
params_dict = dict(self.named_parameters())
|
| 373 |
+
loaded_params: Set[str] = set()
|
| 374 |
+
for name, loaded_weight in weights:
|
| 375 |
+
if self.pooler is None and "pooler" in name:
|
| 376 |
+
continue
|
| 377 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 378 |
+
if weight_name not in name:
|
| 379 |
+
continue
|
| 380 |
+
name = name.replace(weight_name, param_name)
|
| 381 |
+
# Skip loading extra bias for GPTQ models.
|
| 382 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 383 |
+
continue
|
| 384 |
+
param = params_dict[name]
|
| 385 |
+
weight_loader = param.weight_loader
|
| 386 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 387 |
+
break
|
| 388 |
+
else:
|
| 389 |
+
# Skip loading extra bias for GPTQ models.
|
| 390 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 391 |
+
continue
|
| 392 |
+
param = params_dict[name]
|
| 393 |
+
weight_loader = getattr(param, "weight_loader",
|
| 394 |
+
default_weight_loader)
|
| 395 |
+
weight_loader(param, loaded_weight)
|
| 396 |
+
loaded_params.add(name)
|
| 397 |
+
return loaded_params
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class BertEmbeddingModel(nn.Module):
|
| 401 |
+
"""A model that uses Bert to provide embedding functionalities.
|
| 402 |
+
|
| 403 |
+
This class encapsulates the BertModel and provides an interface for
|
| 404 |
+
embedding operations and customized pooling functions.
|
| 405 |
+
|
| 406 |
+
Attributes:
|
| 407 |
+
model: An instance of BertModel used for forward operations.
|
| 408 |
+
_pooler: An instance of Pooler used for pooling operations.
|
| 409 |
+
"""
|
| 410 |
+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
| 411 |
+
|
| 412 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 413 |
+
super().__init__()
|
| 414 |
+
pooler_config = vllm_config.model_config.pooler_config
|
| 415 |
+
self.model = self._build_model(vllm_config=vllm_config,
|
| 416 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 417 |
+
self._pooler = self._build_pooler(pooler_config)
|
| 418 |
+
|
| 419 |
+
def forward(
|
| 420 |
+
self,
|
| 421 |
+
input_ids: Optional[torch.Tensor],
|
| 422 |
+
positions: torch.Tensor,
|
| 423 |
+
kv_caches: List[torch.Tensor],
|
| 424 |
+
attn_metadata: AttentionMetadata,
|
| 425 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 426 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 427 |
+
) -> torch.Tensor:
|
| 428 |
+
return self.model(input_ids=input_ids,
|
| 429 |
+
position_ids=positions,
|
| 430 |
+
kv_caches=kv_caches,
|
| 431 |
+
inputs_embeds=inputs_embeds,
|
| 432 |
+
intermediate_tensors=intermediate_tensors,
|
| 433 |
+
attn_metadata=attn_metadata)
|
| 434 |
+
|
| 435 |
+
def pooler(
|
| 436 |
+
self,
|
| 437 |
+
hidden_states: torch.Tensor,
|
| 438 |
+
pooling_metadata: PoolingMetadata,
|
| 439 |
+
) -> Optional[PoolerOutput]:
|
| 440 |
+
return self._pooler(hidden_states, pooling_metadata)
|
| 441 |
+
|
| 442 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 443 |
+
weights = self.hf_to_vllm_mapper.apply(weights)
|
| 444 |
+
weights = ((name, data) for name, data in weights
|
| 445 |
+
if not name.startswith("lm_head."))
|
| 446 |
+
self.model.load_weights(weights)
|
| 447 |
+
|
| 448 |
+
def _build_model(self,
|
| 449 |
+
vllm_config: VllmConfig,
|
| 450 |
+
prefix: str = "") -> BertModel:
|
| 451 |
+
return BertModel(vllm_config=vllm_config,
|
| 452 |
+
prefix=prefix,
|
| 453 |
+
embedding_class=BertEmbedding)
|
| 454 |
+
|
| 455 |
+
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
| 456 |
+
return Pooler.from_config_with_defaults(pooler_config,
|
| 457 |
+
pooling_type=PoolingType.CLS,
|
| 458 |
+
normalize=True,
|
| 459 |
+
softmax=False)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
| 463 |
+
"""A model that uses Bert to provide embedding functionalities.
|
| 464 |
+
|
| 465 |
+
This class encapsulates the BertModel and provides an interface for
|
| 466 |
+
embedding operations and customized pooling functions.
|
| 467 |
+
|
| 468 |
+
Attributes:
|
| 469 |
+
model: An instance of BertModel used for forward operations.
|
| 470 |
+
_pooler: An instance of Pooler used for pooling operations.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 474 |
+
super().__init__()
|
| 475 |
+
config = vllm_config.model_config.hf_config
|
| 476 |
+
|
| 477 |
+
self.default_activation_function = \
|
| 478 |
+
get_cross_encoder_activation_function(config)
|
| 479 |
+
|
| 480 |
+
self.num_labels = config.num_labels
|
| 481 |
+
self.bert = BertModel(vllm_config=vllm_config,
|
| 482 |
+
prefix=maybe_prefix(prefix, "bert"),
|
| 483 |
+
embedding_class=BertEmbedding,
|
| 484 |
+
add_pooling_layer=True)
|
| 485 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 486 |
+
self._pooler = CrossEncodingPooler(config, self.classifier,
|
| 487 |
+
self.bert.pooler)
|
| 488 |
+
|
| 489 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 490 |
+
|
| 491 |
+
self_weights = []
|
| 492 |
+
|
| 493 |
+
def weight_filter():
|
| 494 |
+
for name, weight in weights:
|
| 495 |
+
if name.startswith("bert."):
|
| 496 |
+
yield (name[len("bert."):], weight)
|
| 497 |
+
else:
|
| 498 |
+
self_weights.append((name, weight))
|
| 499 |
+
|
| 500 |
+
self.bert.load_weights(weight_filter())
|
| 501 |
+
|
| 502 |
+
params_dict = dict(self.named_parameters())
|
| 503 |
+
|
| 504 |
+
for name, loaded_weight in self_weights:
|
| 505 |
+
if name.startswith("classifier"):
|
| 506 |
+
param = params_dict[name]
|
| 507 |
+
weight_loader = getattr(param, "weight_loader",
|
| 508 |
+
default_weight_loader)
|
| 509 |
+
weight_loader(param, loaded_weight)
|
| 510 |
+
|
| 511 |
+
def pooler(
|
| 512 |
+
self,
|
| 513 |
+
hidden_states: torch.Tensor,
|
| 514 |
+
pooling_metadata: PoolingMetadata,
|
| 515 |
+
) -> Optional[PoolerOutput]:
|
| 516 |
+
return self._pooler(hidden_states, pooling_metadata)
|
| 517 |
+
|
| 518 |
+
def forward(
|
| 519 |
+
self,
|
| 520 |
+
input_ids: Optional[torch.Tensor],
|
| 521 |
+
positions: torch.Tensor,
|
| 522 |
+
kv_caches: List[torch.Tensor],
|
| 523 |
+
attn_metadata: AttentionMetadata,
|
| 524 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 525 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 526 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 527 |
+
) -> torch.Tensor:
|
| 528 |
+
return self.bert(input_ids=input_ids,
|
| 529 |
+
position_ids=positions,
|
| 530 |
+
kv_caches=kv_caches,
|
| 531 |
+
inputs_embeds=inputs_embeds,
|
| 532 |
+
intermediate_tensors=intermediate_tensors,
|
| 533 |
+
attn_metadata=attn_metadata,
|
| 534 |
+
token_type_ids=token_type_ids)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/blip2.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from functools import cached_property
|
| 4 |
+
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
| 5 |
+
TypedDict, Union)
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
| 10 |
+
apply_chunking_to_forward)
|
| 11 |
+
|
| 12 |
+
from vllm.attention import AttentionMetadata
|
| 13 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 14 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 15 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 16 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 17 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 18 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 19 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 20 |
+
NestedTensors)
|
| 21 |
+
from vllm.multimodal.parse import MultiModalDataItems
|
| 22 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 23 |
+
BaseProcessingInfo, PromptReplacement,
|
| 24 |
+
PromptReplacementDetails)
|
| 25 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 26 |
+
from vllm.sequence import IntermediateTensors
|
| 27 |
+
|
| 28 |
+
from .blip import BlipVisionModel
|
| 29 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 30 |
+
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
| 31 |
+
maybe_prefix, merge_multimodal_embeddings)
|
| 32 |
+
|
| 33 |
+
# We use this internally as placeholders since there is no image token
|
| 34 |
+
# defined on the HuggingFace repo
|
| 35 |
+
_IMAGE_TOKEN_ID = 50265
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Blip2ImagePixelInputs(TypedDict):
|
| 39 |
+
type: Literal["pixel_values"]
|
| 40 |
+
data: torch.Tensor
|
| 41 |
+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Blip2ImageEmbeddingInputs(TypedDict):
|
| 45 |
+
type: Literal["image_embeds"]
|
| 46 |
+
data: torch.Tensor
|
| 47 |
+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
| 48 |
+
|
| 49 |
+
`hidden_size` must match the hidden size of language model backbone.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Blip2QFormerMultiHeadAttention(nn.Module):
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
config: Blip2QFormerConfig,
|
| 61 |
+
*,
|
| 62 |
+
quant_config: Optional[QuantizationConfig],
|
| 63 |
+
cache_config: Optional[CacheConfig],
|
| 64 |
+
is_cross_attention: bool = False,
|
| 65 |
+
) -> None:
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.config = config
|
| 69 |
+
|
| 70 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of "
|
| 73 |
+
f"the number of attention heads ({config.num_attention_heads})"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.num_attention_heads = config.num_attention_heads
|
| 77 |
+
self.attention_head_size = (config.hidden_size //
|
| 78 |
+
config.num_attention_heads)
|
| 79 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 80 |
+
self.scaling = self.attention_head_size**-0.5
|
| 81 |
+
|
| 82 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 83 |
+
if is_cross_attention:
|
| 84 |
+
kv_hidden_size = config.encoder_hidden_size
|
| 85 |
+
else:
|
| 86 |
+
kv_hidden_size = config.hidden_size
|
| 87 |
+
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
|
| 88 |
+
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
|
| 89 |
+
|
| 90 |
+
self.position_embedding_type = getattr(config,
|
| 91 |
+
"position_embedding_type",
|
| 92 |
+
"absolute")
|
| 93 |
+
if self.position_embedding_type != "absolute":
|
| 94 |
+
raise NotImplementedError("Unsupported position_embedding_type: "
|
| 95 |
+
f"{self.position_embedding_type}")
|
| 96 |
+
|
| 97 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 98 |
+
|
| 99 |
+
def transpose_for_scores(self, x):
|
| 100 |
+
x = x.view(*x.size()[:-1], self.num_attention_heads,
|
| 101 |
+
self.attention_head_size)
|
| 102 |
+
return x.permute(0, 2, 1, 3)
|
| 103 |
+
|
| 104 |
+
def forward(
|
| 105 |
+
self,
|
| 106 |
+
hidden_states: torch.Tensor,
|
| 107 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 108 |
+
):
|
| 109 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 110 |
+
|
| 111 |
+
if is_cross_attention:
|
| 112 |
+
key_layer = self.transpose_for_scores(
|
| 113 |
+
self.key(encoder_hidden_states))
|
| 114 |
+
value_layer = self.transpose_for_scores(
|
| 115 |
+
self.value(encoder_hidden_states))
|
| 116 |
+
else:
|
| 117 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 118 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 119 |
+
|
| 120 |
+
mixed_query_layer = self.query(hidden_states)
|
| 121 |
+
|
| 122 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 123 |
+
|
| 124 |
+
attention_scores = torch.matmul(query_layer,
|
| 125 |
+
key_layer.transpose(-1, -2))
|
| 126 |
+
attention_probs = torch.softmax(attention_scores * self.scaling,
|
| 127 |
+
dim=-1)
|
| 128 |
+
|
| 129 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 130 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 131 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 132 |
+
|
| 133 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 134 |
+
|
| 135 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 136 |
+
context_layer = context_layer.view(*context_layer.size()[:-2],
|
| 137 |
+
self.all_head_size)
|
| 138 |
+
|
| 139 |
+
return context_layer
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Blip2QFormerSelfOutput(nn.Module):
|
| 143 |
+
|
| 144 |
+
def __init__(self, config: Blip2QFormerConfig) -> None:
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 148 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
| 149 |
+
eps=config.layer_norm_eps)
|
| 150 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
hidden_states: torch.Tensor,
|
| 155 |
+
input_tensor: torch.Tensor,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
hidden_states = self.dense(hidden_states)
|
| 158 |
+
hidden_states = self.dropout(hidden_states)
|
| 159 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 160 |
+
return hidden_states
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Blip2QFormerAttention(nn.Module):
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
config: Blip2QFormerConfig,
|
| 168 |
+
*,
|
| 169 |
+
quant_config: Optional[QuantizationConfig],
|
| 170 |
+
cache_config: Optional[CacheConfig],
|
| 171 |
+
is_cross_attention: bool = False,
|
| 172 |
+
) -> None:
|
| 173 |
+
super().__init__()
|
| 174 |
+
|
| 175 |
+
self.attention = Blip2QFormerMultiHeadAttention(
|
| 176 |
+
config,
|
| 177 |
+
quant_config=quant_config,
|
| 178 |
+
cache_config=cache_config,
|
| 179 |
+
is_cross_attention=is_cross_attention,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
self.output = Blip2QFormerSelfOutput(config)
|
| 183 |
+
|
| 184 |
+
def forward(
|
| 185 |
+
self,
|
| 186 |
+
hidden_states: torch.Tensor,
|
| 187 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 188 |
+
) -> Tuple[torch.Tensor]:
|
| 189 |
+
self_output = self.attention(
|
| 190 |
+
hidden_states,
|
| 191 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 192 |
+
)
|
| 193 |
+
attention_output = self.output(self_output, hidden_states)
|
| 194 |
+
|
| 195 |
+
return attention_output
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Blip2QFormerIntermediate(nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(self, config: Blip2QFormerConfig) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 204 |
+
self.intermediate_act_fn = get_act_fn(config.hidden_act)
|
| 205 |
+
|
| 206 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
hidden_states = self.dense(hidden_states)
|
| 208 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 209 |
+
return hidden_states
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class Blip2QFormerOutput(nn.Module):
|
| 213 |
+
|
| 214 |
+
def __init__(self, config: Blip2QFormerConfig) -> None:
|
| 215 |
+
super().__init__()
|
| 216 |
+
|
| 217 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 218 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
| 219 |
+
eps=config.layer_norm_eps)
|
| 220 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 221 |
+
|
| 222 |
+
def forward(
|
| 223 |
+
self,
|
| 224 |
+
hidden_states: torch.Tensor,
|
| 225 |
+
input_tensor: torch.Tensor,
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
hidden_states = self.dense(hidden_states)
|
| 228 |
+
hidden_states = self.dropout(hidden_states)
|
| 229 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 230 |
+
return hidden_states
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class Blip2QFormerLayer(nn.Module):
|
| 234 |
+
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
config: Blip2QFormerConfig,
|
| 238 |
+
*,
|
| 239 |
+
quant_config: Optional[QuantizationConfig],
|
| 240 |
+
cache_config: Optional[CacheConfig],
|
| 241 |
+
layer_idx: int,
|
| 242 |
+
) -> None:
|
| 243 |
+
super().__init__()
|
| 244 |
+
|
| 245 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 246 |
+
self.seq_len_dim = 1
|
| 247 |
+
self.attention = Blip2QFormerAttention(config,
|
| 248 |
+
quant_config=quant_config,
|
| 249 |
+
cache_config=cache_config)
|
| 250 |
+
|
| 251 |
+
self.layer_idx = layer_idx
|
| 252 |
+
|
| 253 |
+
if layer_idx % config.cross_attention_frequency == 0:
|
| 254 |
+
self.crossattention = Blip2QFormerAttention(
|
| 255 |
+
config,
|
| 256 |
+
quant_config=quant_config,
|
| 257 |
+
cache_config=cache_config,
|
| 258 |
+
is_cross_attention=True)
|
| 259 |
+
self.has_cross_attention = True
|
| 260 |
+
else:
|
| 261 |
+
self.has_cross_attention = False
|
| 262 |
+
|
| 263 |
+
self.intermediate_query = Blip2QFormerIntermediate(config)
|
| 264 |
+
self.output_query = Blip2QFormerOutput(config)
|
| 265 |
+
|
| 266 |
+
def forward(
|
| 267 |
+
self,
|
| 268 |
+
hidden_states: torch.FloatTensor,
|
| 269 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 270 |
+
query_length: int,
|
| 271 |
+
):
|
| 272 |
+
attention_output = self.attention(hidden_states)
|
| 273 |
+
|
| 274 |
+
if query_length > 0:
|
| 275 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 276 |
+
|
| 277 |
+
if self.has_cross_attention:
|
| 278 |
+
query_attention_output = self.crossattention(
|
| 279 |
+
query_attention_output,
|
| 280 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
layer_output = apply_chunking_to_forward(
|
| 284 |
+
self.feed_forward_chunk_query,
|
| 285 |
+
self.chunk_size_feed_forward,
|
| 286 |
+
self.seq_len_dim,
|
| 287 |
+
query_attention_output,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if attention_output.shape[1] > query_length:
|
| 291 |
+
layer_output_text = apply_chunking_to_forward(
|
| 292 |
+
self.feed_forward_chunk,
|
| 293 |
+
self.chunk_size_feed_forward,
|
| 294 |
+
self.seq_len_dim,
|
| 295 |
+
attention_output[:, query_length:, :],
|
| 296 |
+
)
|
| 297 |
+
layer_output = torch.cat([layer_output, layer_output_text],
|
| 298 |
+
dim=1)
|
| 299 |
+
else:
|
| 300 |
+
layer_output = apply_chunking_to_forward(
|
| 301 |
+
self.feed_forward_chunk,
|
| 302 |
+
self.chunk_size_feed_forward,
|
| 303 |
+
self.seq_len_dim,
|
| 304 |
+
attention_output,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return layer_output
|
| 308 |
+
|
| 309 |
+
def feed_forward_chunk(self,
|
| 310 |
+
attention_output: torch.Tensor) -> torch.Tensor:
|
| 311 |
+
intermediate_output = self.intermediate(attention_output)
|
| 312 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 313 |
+
return layer_output
|
| 314 |
+
|
| 315 |
+
def feed_forward_chunk_query(
|
| 316 |
+
self, attention_output: torch.Tensor) -> torch.Tensor:
|
| 317 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 318 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 319 |
+
return layer_output
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class Blip2QFormerEncoder(nn.Module):
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
config: Blip2QFormerConfig,
|
| 327 |
+
*,
|
| 328 |
+
quant_config: Optional[QuantizationConfig],
|
| 329 |
+
cache_config: Optional[CacheConfig],
|
| 330 |
+
) -> None:
|
| 331 |
+
super().__init__()
|
| 332 |
+
|
| 333 |
+
self.config = config
|
| 334 |
+
|
| 335 |
+
self.layer = nn.ModuleList([
|
| 336 |
+
Blip2QFormerLayer(config,
|
| 337 |
+
quant_config=quant_config,
|
| 338 |
+
cache_config=cache_config,
|
| 339 |
+
layer_idx=layer_idx)
|
| 340 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 341 |
+
])
|
| 342 |
+
|
| 343 |
+
def forward(
|
| 344 |
+
self,
|
| 345 |
+
hidden_states: torch.FloatTensor,
|
| 346 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 347 |
+
query_length: int,
|
| 348 |
+
) -> torch.Tensor:
|
| 349 |
+
for i in range(self.config.num_hidden_layers):
|
| 350 |
+
layer_module = self.layer[i]
|
| 351 |
+
|
| 352 |
+
hidden_states = layer_module(
|
| 353 |
+
hidden_states,
|
| 354 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 355 |
+
query_length=query_length,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
return hidden_states
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
|
| 362 |
+
class Blip2QFormerModel(nn.Module):
|
| 363 |
+
|
| 364 |
+
def __init__(
|
| 365 |
+
self,
|
| 366 |
+
config: Blip2QFormerConfig,
|
| 367 |
+
*,
|
| 368 |
+
quant_config: Optional[QuantizationConfig],
|
| 369 |
+
cache_config: Optional[CacheConfig],
|
| 370 |
+
) -> None:
|
| 371 |
+
super().__init__()
|
| 372 |
+
|
| 373 |
+
self.config = config
|
| 374 |
+
|
| 375 |
+
self.layernorm = nn.LayerNorm(config.hidden_size,
|
| 376 |
+
eps=config.layer_norm_eps)
|
| 377 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 378 |
+
|
| 379 |
+
self.encoder = Blip2QFormerEncoder(config,
|
| 380 |
+
quant_config=quant_config,
|
| 381 |
+
cache_config=cache_config)
|
| 382 |
+
|
| 383 |
+
def forward(
|
| 384 |
+
self,
|
| 385 |
+
query_embeds: torch.FloatTensor,
|
| 386 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 387 |
+
) -> torch.Tensor:
|
| 388 |
+
query_length = query_embeds.shape[1]
|
| 389 |
+
|
| 390 |
+
embedding_output = self.layernorm(query_embeds)
|
| 391 |
+
embedding_output = self.dropout(embedding_output)
|
| 392 |
+
|
| 393 |
+
sequence_output = self.encoder(
|
| 394 |
+
embedding_output,
|
| 395 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 396 |
+
query_length=query_length,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return sequence_output
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class Blip2ProcessingInfo(BaseProcessingInfo):
|
| 403 |
+
|
| 404 |
+
def get_hf_config(self):
|
| 405 |
+
return self.ctx.get_hf_config(Blip2Config)
|
| 406 |
+
|
| 407 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 408 |
+
return {"image": 1}
|
| 409 |
+
|
| 410 |
+
def get_mm_max_tokens_per_item(
|
| 411 |
+
self,
|
| 412 |
+
seq_len: int,
|
| 413 |
+
mm_counts: Mapping[str, int],
|
| 414 |
+
) -> Mapping[str, int]:
|
| 415 |
+
return {"image": self.get_num_image_tokens()}
|
| 416 |
+
|
| 417 |
+
def get_num_image_tokens(self) -> int:
|
| 418 |
+
hf_config = self.get_hf_config()
|
| 419 |
+
return hf_config.num_query_tokens
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
|
| 423 |
+
|
| 424 |
+
def get_dummy_processor_inputs(
|
| 425 |
+
self,
|
| 426 |
+
seq_len: int,
|
| 427 |
+
mm_counts: Mapping[str, int],
|
| 428 |
+
) -> ProcessorInputs:
|
| 429 |
+
hf_config = self.info.get_hf_config()
|
| 430 |
+
vision_config = hf_config.vision_config
|
| 431 |
+
|
| 432 |
+
max_image_size = vision_config.image_size
|
| 433 |
+
num_images = mm_counts.get("image", 0)
|
| 434 |
+
|
| 435 |
+
mm_data = {
|
| 436 |
+
"image":
|
| 437 |
+
self._get_dummy_images(width=max_image_size,
|
| 438 |
+
height=max_image_size,
|
| 439 |
+
num_images=num_images)
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
return ProcessorInputs(
|
| 443 |
+
prompt_text="",
|
| 444 |
+
mm_data=mm_data,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
| 449 |
+
|
| 450 |
+
def _call_hf_processor(
|
| 451 |
+
self,
|
| 452 |
+
prompt: str,
|
| 453 |
+
mm_data: Mapping[str, object],
|
| 454 |
+
mm_kwargs: Mapping[str, object],
|
| 455 |
+
) -> BatchFeature:
|
| 456 |
+
if not mm_data:
|
| 457 |
+
# HF processor always adds placeholders even when there's no image
|
| 458 |
+
tokenizer = self.info.get_tokenizer()
|
| 459 |
+
prompt_ids = tokenizer.encode(prompt)
|
| 460 |
+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
| 461 |
+
|
| 462 |
+
return super()._call_hf_processor(
|
| 463 |
+
prompt=prompt,
|
| 464 |
+
mm_data=mm_data,
|
| 465 |
+
mm_kwargs=mm_kwargs,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
def _get_mm_fields_config(
|
| 469 |
+
self,
|
| 470 |
+
hf_inputs: BatchFeature,
|
| 471 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 472 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 473 |
+
return dict(
|
| 474 |
+
pixel_values=MultiModalFieldConfig.batched("image"),
|
| 475 |
+
image_embeds=MultiModalFieldConfig.batched("image"),
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
def _get_prompt_replacements(
|
| 479 |
+
self,
|
| 480 |
+
mm_items: MultiModalDataItems,
|
| 481 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 482 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 483 |
+
) -> list[PromptReplacement]:
|
| 484 |
+
tokenizer = self.info.get_tokenizer()
|
| 485 |
+
vocab = tokenizer.get_vocab()
|
| 486 |
+
|
| 487 |
+
bos_token_id = tokenizer.bos_token_id
|
| 488 |
+
assert isinstance(bos_token_id, int)
|
| 489 |
+
|
| 490 |
+
image_token_id = vocab["<image>"]
|
| 491 |
+
num_image_tokens = self.info.get_num_image_tokens()
|
| 492 |
+
image_tokens = [image_token_id] * num_image_tokens
|
| 493 |
+
|
| 494 |
+
return [
|
| 495 |
+
PromptReplacement(
|
| 496 |
+
modality="image",
|
| 497 |
+
target=[bos_token_id],
|
| 498 |
+
replacement=PromptReplacementDetails(
|
| 499 |
+
full=image_tokens + [bos_token_id],
|
| 500 |
+
features=image_tokens,
|
| 501 |
+
),
|
| 502 |
+
)
|
| 503 |
+
]
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
|
| 507 |
+
info=Blip2ProcessingInfo,
|
| 508 |
+
dummy_inputs=Blip2DummyInputsBuilder)
|
| 509 |
+
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
| 510 |
+
|
| 511 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 512 |
+
|
| 513 |
+
super().__init__()
|
| 514 |
+
config = vllm_config.model_config.hf_config
|
| 515 |
+
cache_config = vllm_config.cache_config
|
| 516 |
+
quant_config = vllm_config.quant_config
|
| 517 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 518 |
+
self.config = config
|
| 519 |
+
self.multimodal_config = multimodal_config
|
| 520 |
+
|
| 521 |
+
# TODO: Optionally initializes this for supporting embeddings.
|
| 522 |
+
self.vision_model = BlipVisionModel(config.vision_config, quant_config)
|
| 523 |
+
|
| 524 |
+
self.query_tokens = nn.Parameter(
|
| 525 |
+
torch.zeros(1, config.num_query_tokens,
|
| 526 |
+
config.qformer_config.hidden_size))
|
| 527 |
+
|
| 528 |
+
self.qformer = Blip2QFormerModel(config.qformer_config,
|
| 529 |
+
cache_config=cache_config,
|
| 530 |
+
quant_config=quant_config)
|
| 531 |
+
|
| 532 |
+
self.language_projection = nn.Linear(
|
| 533 |
+
config.qformer_config.hidden_size,
|
| 534 |
+
config.text_config.hidden_size,
|
| 535 |
+
bias=True,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
self.language_model = init_vllm_registered_model(
|
| 539 |
+
vllm_config=vllm_config,
|
| 540 |
+
hf_config=config.text_config,
|
| 541 |
+
prefix=maybe_prefix(prefix, "language_model"),
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
self.make_empty_intermediate_tensors = (
|
| 545 |
+
self.language_model.make_empty_intermediate_tensors)
|
| 546 |
+
|
| 547 |
+
@cached_property
|
| 548 |
+
def sampler(self):
|
| 549 |
+
if hasattr(self.language_model, "sampler"):
|
| 550 |
+
return self.language_model.sampler
|
| 551 |
+
|
| 552 |
+
return get_sampler()
|
| 553 |
+
|
| 554 |
+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
| 555 |
+
h = w = self.config.vision_config.image_size
|
| 556 |
+
expected_dims = (3, h, w)
|
| 557 |
+
actual_dims = tuple(data.shape[1:])
|
| 558 |
+
|
| 559 |
+
if actual_dims != expected_dims:
|
| 560 |
+
expected_expr = ("batch_size", *map(str, expected_dims))
|
| 561 |
+
raise ValueError(
|
| 562 |
+
f"The expected shape of pixel values is {expected_expr}. "
|
| 563 |
+
f"You supplied {tuple(data.shape)}.")
|
| 564 |
+
|
| 565 |
+
return data
|
| 566 |
+
|
| 567 |
+
def _parse_and_validate_image_input(
|
| 568 |
+
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
|
| 569 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 570 |
+
image_embeds = kwargs.pop("image_embeds", None)
|
| 571 |
+
|
| 572 |
+
if pixel_values is None and image_embeds is None:
|
| 573 |
+
return None
|
| 574 |
+
|
| 575 |
+
if pixel_values is not None:
|
| 576 |
+
if not isinstance(pixel_values, torch.Tensor):
|
| 577 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 578 |
+
f"Got type: {type(pixel_values)}")
|
| 579 |
+
|
| 580 |
+
# Remove the N dimension until multiple images are supported.
|
| 581 |
+
pixel_values = pixel_values.squeeze(1)
|
| 582 |
+
|
| 583 |
+
return Blip2ImagePixelInputs(
|
| 584 |
+
type="pixel_values",
|
| 585 |
+
data=self._validate_pixel_values(pixel_values),
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
if image_embeds is not None:
|
| 589 |
+
if not isinstance(image_embeds, torch.Tensor):
|
| 590 |
+
raise ValueError("Incorrect type of image embeddings. "
|
| 591 |
+
f"Got type: {type(image_embeds)}")
|
| 592 |
+
|
| 593 |
+
# Remove the N dimension until multiple images are supported.
|
| 594 |
+
image_embeds = image_embeds.squeeze(1)
|
| 595 |
+
|
| 596 |
+
return Blip2ImageEmbeddingInputs(
|
| 597 |
+
type="image_embeds",
|
| 598 |
+
data=image_embeds,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
raise AssertionError("This line should be unreachable.")
|
| 602 |
+
|
| 603 |
+
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
|
| 604 |
+
pixel_values: torch.Tensor) -> torch.Tensor:
|
| 605 |
+
|
| 606 |
+
# NOTE: we skip the step to select the vision feature layer since
|
| 607 |
+
# this is already done inside the vision tower
|
| 608 |
+
image_features = vision_model(pixel_values)
|
| 609 |
+
|
| 610 |
+
return image_features
|
| 611 |
+
|
| 612 |
+
def _process_image_pixels(self,
|
| 613 |
+
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
|
| 614 |
+
assert self.vision_model is not None
|
| 615 |
+
|
| 616 |
+
pixel_values = inputs["data"]
|
| 617 |
+
|
| 618 |
+
return self._image_pixels_to_features(self.vision_model, pixel_values)
|
| 619 |
+
|
| 620 |
+
def _process_image_input(self,
|
| 621 |
+
image_input: Blip2ImageInputs) -> torch.Tensor:
|
| 622 |
+
|
| 623 |
+
if image_input["type"] == "image_embeds":
|
| 624 |
+
return image_input["data"]
|
| 625 |
+
|
| 626 |
+
assert self.vision_model is not None
|
| 627 |
+
image_features = self._process_image_pixels(image_input)
|
| 628 |
+
|
| 629 |
+
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
|
| 630 |
+
-1)
|
| 631 |
+
query_output = self.qformer(
|
| 632 |
+
query_embeds=query_tokens,
|
| 633 |
+
encoder_hidden_states=image_features,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
return self.language_projection(query_output)
|
| 637 |
+
|
| 638 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 639 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 640 |
+
if image_input is None:
|
| 641 |
+
return None
|
| 642 |
+
vision_embeddings = self._process_image_input(image_input)
|
| 643 |
+
return vision_embeddings
|
| 644 |
+
|
| 645 |
+
def get_input_embeddings(
|
| 646 |
+
self,
|
| 647 |
+
input_ids: torch.Tensor,
|
| 648 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 649 |
+
) -> torch.Tensor:
|
| 650 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 651 |
+
if multimodal_embeddings is not None:
|
| 652 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 653 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 654 |
+
_IMAGE_TOKEN_ID)
|
| 655 |
+
return inputs_embeds
|
| 656 |
+
|
| 657 |
+
def forward(
|
| 658 |
+
self,
|
| 659 |
+
input_ids: torch.Tensor,
|
| 660 |
+
positions: torch.Tensor,
|
| 661 |
+
kv_caches: List[torch.Tensor],
|
| 662 |
+
attn_metadata: AttentionMetadata,
|
| 663 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 664 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 665 |
+
**kwargs: object,
|
| 666 |
+
) -> Union[SamplerOutput, IntermediateTensors]:
|
| 667 |
+
"""Run forward pass for BLIP-2.
|
| 668 |
+
|
| 669 |
+
One key thing to understand is the `input_ids` already accounts for the
|
| 670 |
+
positions of the to-be-inserted image embeddings.
|
| 671 |
+
|
| 672 |
+
Concretely, consider a text prompt:
|
| 673 |
+
`"Question: What's the content of the image? Answer:"`.
|
| 674 |
+
|
| 675 |
+
Tokenizer outputs:
|
| 676 |
+
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
|
| 677 |
+
|
| 678 |
+
To reserve space in KV cache, we have to insert placeholder tokens
|
| 679 |
+
before they are inputted to the model, so the input processor prepends
|
| 680 |
+
dummy tokens (denoted as `50265`), resulting in:
|
| 681 |
+
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
|
| 682 |
+
|
| 683 |
+
We insert 32 tokens since it corresponds to the number of query
|
| 684 |
+
embeddings outputted by the Q-Former and inputted to the language model.
|
| 685 |
+
|
| 686 |
+
This way, the `positions` and `attn_metadata` are consistent
|
| 687 |
+
with the `input_ids`.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
| 691 |
+
batch.
|
| 692 |
+
pixel_values: The pixels in each input image.
|
| 693 |
+
|
| 694 |
+
See also:
|
| 695 |
+
:class:`Blip2ImageInputs`
|
| 696 |
+
"""
|
| 697 |
+
|
| 698 |
+
if intermediate_tensors is not None:
|
| 699 |
+
inputs_embeds = None
|
| 700 |
+
|
| 701 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 702 |
+
# condition is for v0 compatibility.
|
| 703 |
+
elif inputs_embeds is None:
|
| 704 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 705 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 706 |
+
vision_embeddings)
|
| 707 |
+
input_ids = None
|
| 708 |
+
|
| 709 |
+
hidden_states = self.language_model.model(input_ids,
|
| 710 |
+
positions,
|
| 711 |
+
kv_caches,
|
| 712 |
+
attn_metadata,
|
| 713 |
+
intermediate_tensors,
|
| 714 |
+
inputs_embeds=inputs_embeds)
|
| 715 |
+
|
| 716 |
+
return hidden_states
|
| 717 |
+
|
| 718 |
+
def compute_logits(
|
| 719 |
+
self,
|
| 720 |
+
hidden_states: torch.Tensor,
|
| 721 |
+
sampling_metadata: SamplingMetadata,
|
| 722 |
+
) -> Optional[torch.Tensor]:
|
| 723 |
+
return self.language_model.compute_logits(hidden_states,
|
| 724 |
+
sampling_metadata)
|
| 725 |
+
|
| 726 |
+
def sample(
|
| 727 |
+
self,
|
| 728 |
+
logits: torch.Tensor,
|
| 729 |
+
sampling_metadata: SamplingMetadata,
|
| 730 |
+
) -> Optional[SamplerOutput]:
|
| 731 |
+
return self.language_model.sample(logits, sampling_metadata)
|
| 732 |
+
|
| 733 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 734 |
+
torch.Tensor]]) -> Set[str]:
|
| 735 |
+
loader = AutoWeightsLoader(self)
|
| 736 |
+
return loader.load_weights(weights)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/bloom.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
| 20 |
+
import math
|
| 21 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
from transformers import BloomConfig
|
| 26 |
+
|
| 27 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 28 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 29 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 30 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 31 |
+
get_tensor_model_parallel_world_size)
|
| 32 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 33 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 34 |
+
QKVParallelLinear,
|
| 35 |
+
RowParallelLinear)
|
| 36 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 37 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 38 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 39 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 40 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 41 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 42 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsPP
|
| 46 |
+
from .utils import (is_pp_missing_parameter,
|
| 47 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 48 |
+
maybe_prefix)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
| 52 |
+
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
| 53 |
+
base = torch.tensor(
|
| 54 |
+
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
| 55 |
+
dtype=torch.float32,
|
| 56 |
+
)
|
| 57 |
+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
| 58 |
+
slopes = torch.pow(base, powers)
|
| 59 |
+
|
| 60 |
+
if closest_power_of_2 != total_num_heads:
|
| 61 |
+
extra_base = torch.tensor(
|
| 62 |
+
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
| 63 |
+
dtype=torch.float32,
|
| 64 |
+
)
|
| 65 |
+
num_remaining_heads = min(closest_power_of_2,
|
| 66 |
+
total_num_heads - closest_power_of_2)
|
| 67 |
+
extra_powers = torch.arange(start=1,
|
| 68 |
+
end=1 + 2 * num_remaining_heads,
|
| 69 |
+
step=2,
|
| 70 |
+
dtype=torch.int32)
|
| 71 |
+
slopes = torch.cat(
|
| 72 |
+
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 73 |
+
return slopes
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class BloomAttention(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
config: BloomConfig,
|
| 81 |
+
cache_config: Optional[CacheConfig] = None,
|
| 82 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 83 |
+
prefix: str = "",
|
| 84 |
+
):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.hidden_size = config.hidden_size
|
| 87 |
+
self.total_num_heads = config.n_head
|
| 88 |
+
self.head_dim = self.hidden_size // self.total_num_heads
|
| 89 |
+
assert self.head_dim * self.total_num_heads == self.hidden_size
|
| 90 |
+
|
| 91 |
+
tp_world_size = get_tensor_model_parallel_world_size()
|
| 92 |
+
assert self.total_num_heads % tp_world_size == 0
|
| 93 |
+
self.num_heads = self.total_num_heads // tp_world_size
|
| 94 |
+
|
| 95 |
+
self.query_key_value = QKVParallelLinear(
|
| 96 |
+
self.hidden_size,
|
| 97 |
+
self.head_dim,
|
| 98 |
+
self.total_num_heads,
|
| 99 |
+
bias=True,
|
| 100 |
+
quant_config=quant_config,
|
| 101 |
+
)
|
| 102 |
+
self.dense = RowParallelLinear(
|
| 103 |
+
self.hidden_size,
|
| 104 |
+
self.hidden_size,
|
| 105 |
+
bias=True,
|
| 106 |
+
quant_config=quant_config,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Create the alibi slopes and slice them.
|
| 110 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 111 |
+
head_start = tp_rank * self.num_heads
|
| 112 |
+
head_end = (tp_rank + 1) * self.num_heads
|
| 113 |
+
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
| 114 |
+
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
| 115 |
+
|
| 116 |
+
scaling = self.head_dim**-0.5
|
| 117 |
+
self.attn = Attention(self.num_heads,
|
| 118 |
+
self.head_dim,
|
| 119 |
+
scaling,
|
| 120 |
+
alibi_slopes=alibi_slopes,
|
| 121 |
+
cache_config=cache_config,
|
| 122 |
+
quant_config=quant_config,
|
| 123 |
+
prefix=f"{prefix}.attn")
|
| 124 |
+
|
| 125 |
+
def forward(
|
| 126 |
+
self,
|
| 127 |
+
position_ids: torch.Tensor,
|
| 128 |
+
hidden_states: torch.Tensor,
|
| 129 |
+
kv_cache: torch.Tensor,
|
| 130 |
+
attn_metadata: AttentionMetadata,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
del position_ids # Unused.
|
| 133 |
+
qkv, _ = self.query_key_value(hidden_states)
|
| 134 |
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
| 135 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 136 |
+
output, _ = self.dense(attn_output)
|
| 137 |
+
return output
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class BloomMLP(nn.Module):
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
config: BloomConfig,
|
| 145 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
hidden_size = config.hidden_size
|
| 149 |
+
self.dense_h_to_4h = ColumnParallelLinear(
|
| 150 |
+
hidden_size,
|
| 151 |
+
4 * hidden_size,
|
| 152 |
+
quant_config=quant_config,
|
| 153 |
+
)
|
| 154 |
+
self.gelu_impl = get_act_fn("gelu")
|
| 155 |
+
self.dense_4h_to_h = RowParallelLinear(
|
| 156 |
+
4 * hidden_size,
|
| 157 |
+
hidden_size,
|
| 158 |
+
quant_config=quant_config,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
x, _ = self.dense_h_to_4h(x)
|
| 163 |
+
x = self.gelu_impl(x)
|
| 164 |
+
x, _ = self.dense_4h_to_h(x)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class BloomBlock(nn.Module):
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
config: BloomConfig,
|
| 173 |
+
cache_config: Optional[CacheConfig] = None,
|
| 174 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 175 |
+
prefix: str = "",
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
hidden_size = config.hidden_size
|
| 179 |
+
|
| 180 |
+
self.input_layernorm = nn.LayerNorm(hidden_size,
|
| 181 |
+
eps=config.layer_norm_epsilon)
|
| 182 |
+
self.self_attention = BloomAttention(config,
|
| 183 |
+
cache_config,
|
| 184 |
+
quant_config,
|
| 185 |
+
prefix=f"{prefix}.self_attention")
|
| 186 |
+
self.post_attention_layernorm = nn.LayerNorm(
|
| 187 |
+
hidden_size, eps=config.layer_norm_epsilon)
|
| 188 |
+
self.mlp = BloomMLP(config, quant_config)
|
| 189 |
+
self.apply_residual_connection_post_layernorm = (
|
| 190 |
+
config.apply_residual_connection_post_layernorm)
|
| 191 |
+
|
| 192 |
+
def forward(
|
| 193 |
+
self,
|
| 194 |
+
position_ids: torch.Tensor,
|
| 195 |
+
hidden_states: torch.Tensor,
|
| 196 |
+
kv_cache: torch.Tensor,
|
| 197 |
+
attn_metadata: AttentionMetadata,
|
| 198 |
+
) -> torch.Tensor:
|
| 199 |
+
# Layer norm at the beginning of the transformer layer.
|
| 200 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
| 201 |
+
|
| 202 |
+
# Layer norm post the self attention.
|
| 203 |
+
if self.apply_residual_connection_post_layernorm:
|
| 204 |
+
residual = layernorm_output
|
| 205 |
+
else:
|
| 206 |
+
residual = hidden_states
|
| 207 |
+
|
| 208 |
+
# Self attention.
|
| 209 |
+
attention_output = self.self_attention(
|
| 210 |
+
position_ids=position_ids,
|
| 211 |
+
hidden_states=layernorm_output,
|
| 212 |
+
kv_cache=kv_cache,
|
| 213 |
+
attn_metadata=attn_metadata,
|
| 214 |
+
)
|
| 215 |
+
attention_output = attention_output + residual
|
| 216 |
+
layernorm_output = self.post_attention_layernorm(attention_output)
|
| 217 |
+
|
| 218 |
+
# Get residual
|
| 219 |
+
if self.apply_residual_connection_post_layernorm:
|
| 220 |
+
residual = layernorm_output
|
| 221 |
+
else:
|
| 222 |
+
residual = attention_output
|
| 223 |
+
|
| 224 |
+
# MLP.
|
| 225 |
+
output = self.mlp(layernorm_output) + residual
|
| 226 |
+
return output
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@support_torch_compile
|
| 230 |
+
class BloomModel(nn.Module):
|
| 231 |
+
|
| 232 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
config = vllm_config.model_config.hf_config
|
| 236 |
+
cache_config = vllm_config.cache_config
|
| 237 |
+
quant_config = vllm_config.quant_config
|
| 238 |
+
|
| 239 |
+
self.embed_dim = config.hidden_size
|
| 240 |
+
|
| 241 |
+
# Embedding + LN Embedding
|
| 242 |
+
self.word_embeddings = VocabParallelEmbedding(
|
| 243 |
+
config.vocab_size,
|
| 244 |
+
self.embed_dim,
|
| 245 |
+
)
|
| 246 |
+
self.word_embeddings_layernorm = nn.LayerNorm(
|
| 247 |
+
self.embed_dim, eps=config.layer_norm_epsilon)
|
| 248 |
+
|
| 249 |
+
# Transformer blocks
|
| 250 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 251 |
+
config.num_hidden_layers,
|
| 252 |
+
lambda prefix: BloomBlock(
|
| 253 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 254 |
+
prefix=f"{prefix}.h")
|
| 255 |
+
|
| 256 |
+
# Final Layer Norm
|
| 257 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 258 |
+
self.make_empty_intermediate_tensors = (
|
| 259 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 260 |
+
config.hidden_size))
|
| 261 |
+
|
| 262 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 263 |
+
return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
input_ids: torch.Tensor,
|
| 268 |
+
position_ids: torch.Tensor,
|
| 269 |
+
kv_caches: List[torch.Tensor],
|
| 270 |
+
attn_metadata: AttentionMetadata,
|
| 271 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 272 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 273 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 274 |
+
if get_pp_group().is_first_rank:
|
| 275 |
+
if inputs_embeds is not None:
|
| 276 |
+
hidden_states = inputs_embeds
|
| 277 |
+
else:
|
| 278 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 279 |
+
else:
|
| 280 |
+
assert intermediate_tensors is not None
|
| 281 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 282 |
+
for i in range(self.start_layer, self.end_layer):
|
| 283 |
+
layer = self.h[i]
|
| 284 |
+
hidden_states = layer(
|
| 285 |
+
position_ids,
|
| 286 |
+
hidden_states,
|
| 287 |
+
kv_caches[i - self.start_layer],
|
| 288 |
+
attn_metadata,
|
| 289 |
+
)
|
| 290 |
+
if not get_pp_group().is_last_rank:
|
| 291 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 292 |
+
hidden_states = self.ln_f(hidden_states)
|
| 293 |
+
return hidden_states
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class BloomForCausalLM(nn.Module, SupportsPP):
|
| 297 |
+
|
| 298 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 299 |
+
super().__init__()
|
| 300 |
+
config = vllm_config.model_config.hf_config
|
| 301 |
+
quant_config = vllm_config.quant_config
|
| 302 |
+
self.config = config
|
| 303 |
+
self.quant_config = quant_config
|
| 304 |
+
self.transformer = BloomModel(vllm_config=vllm_config,
|
| 305 |
+
prefix=maybe_prefix(
|
| 306 |
+
prefix, "transformer"))
|
| 307 |
+
if self.config.tie_word_embeddings:
|
| 308 |
+
self.lm_head = self.transformer.word_embeddings
|
| 309 |
+
else:
|
| 310 |
+
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
| 311 |
+
self.config.hidden_size)
|
| 312 |
+
|
| 313 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 314 |
+
self.sampler = get_sampler()
|
| 315 |
+
self.make_empty_intermediate_tensors = (
|
| 316 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 317 |
+
|
| 318 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 319 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 320 |
+
|
| 321 |
+
def forward(
|
| 322 |
+
self,
|
| 323 |
+
input_ids: torch.Tensor,
|
| 324 |
+
positions: torch.Tensor,
|
| 325 |
+
kv_caches: List[torch.Tensor],
|
| 326 |
+
attn_metadata: AttentionMetadata,
|
| 327 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 328 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 329 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 330 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 331 |
+
attn_metadata, intermediate_tensors,
|
| 332 |
+
inputs_embeds)
|
| 333 |
+
return hidden_states
|
| 334 |
+
|
| 335 |
+
def compute_logits(
|
| 336 |
+
self,
|
| 337 |
+
hidden_states: torch.Tensor,
|
| 338 |
+
sampling_metadata: SamplingMetadata,
|
| 339 |
+
) -> Optional[torch.Tensor]:
|
| 340 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 341 |
+
sampling_metadata)
|
| 342 |
+
return logits
|
| 343 |
+
|
| 344 |
+
def sample(
|
| 345 |
+
self,
|
| 346 |
+
logits: torch.Tensor,
|
| 347 |
+
sampling_metadata: SamplingMetadata,
|
| 348 |
+
) -> Optional[SamplerOutput]:
|
| 349 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 350 |
+
return next_tokens
|
| 351 |
+
|
| 352 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 353 |
+
torch.Tensor]]) -> Set[str]:
|
| 354 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 355 |
+
loaded_params: Set[str] = set()
|
| 356 |
+
for name, loaded_weight in weights:
|
| 357 |
+
if name == "lm_head.weight":
|
| 358 |
+
continue
|
| 359 |
+
if not name.startswith("transformer."):
|
| 360 |
+
name = "transformer." + name
|
| 361 |
+
if is_pp_missing_parameter(name, self):
|
| 362 |
+
continue
|
| 363 |
+
param = params_dict[name]
|
| 364 |
+
|
| 365 |
+
if "query_key_value" in name:
|
| 366 |
+
# NOTE: BLOOM's fused QKV's output_dim has the shape of
|
| 367 |
+
# (num_heads * 3 * head_size), while the
|
| 368 |
+
# required shape is (3 * num_heads * head_size).
|
| 369 |
+
# Thus, we need weight conversion.
|
| 370 |
+
output_dim = getattr(param, "output_dim", None)
|
| 371 |
+
num_heads = self.config.num_attention_heads
|
| 372 |
+
if output_dim is not None:
|
| 373 |
+
loaded_weight_shape = loaded_weight.shape
|
| 374 |
+
loaded_weight = loaded_weight.view(
|
| 375 |
+
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
| 376 |
+
loaded_weight_shape[output_dim + 1:])
|
| 377 |
+
loaded_weight = loaded_weight.transpose(
|
| 378 |
+
output_dim, output_dim + 1)
|
| 379 |
+
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
| 380 |
+
|
| 381 |
+
weight_loader = getattr(param, "weight_loader",
|
| 382 |
+
default_weight_loader)
|
| 383 |
+
weight_loader(param, loaded_weight)
|
| 384 |
+
loaded_params.add(name)
|
| 385 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/chameleon.py
ADDED
|
@@ -0,0 +1,1161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from functools import cached_property
|
| 4 |
+
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
| 5 |
+
Tuple, TypedDict, Union)
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
| 11 |
+
ChameleonVQVAEConfig)
|
| 12 |
+
|
| 13 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 14 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 15 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 16 |
+
from vllm.logger import init_logger
|
| 17 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 18 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 19 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 20 |
+
QKVParallelLinear,
|
| 21 |
+
RowParallelLinear)
|
| 22 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 23 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 24 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 25 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 26 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 27 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 28 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 29 |
+
default_weight_loader, row_parallel_weight_loader)
|
| 30 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 31 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 32 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 33 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 34 |
+
NestedTensors)
|
| 35 |
+
from vllm.multimodal.parse import MultiModalDataItems
|
| 36 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 37 |
+
BaseProcessingInfo, PromptReplacement,
|
| 38 |
+
PromptReplacementDetails)
|
| 39 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 40 |
+
from vllm.sequence import IntermediateTensors
|
| 41 |
+
|
| 42 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 43 |
+
from .utils import (is_pp_missing_parameter,
|
| 44 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 45 |
+
maybe_prefix, merge_multimodal_embeddings)
|
| 46 |
+
|
| 47 |
+
logger = init_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ChameleonImagePixelInputs(TypedDict):
|
| 51 |
+
type: Literal["pixel_values"]
|
| 52 |
+
data: torch.Tensor
|
| 53 |
+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ChameleonProcessingInfo(BaseProcessingInfo):
|
| 57 |
+
|
| 58 |
+
def get_hf_config(self):
|
| 59 |
+
return self.ctx.get_hf_config(ChameleonConfig)
|
| 60 |
+
|
| 61 |
+
def get_hf_processor(self):
|
| 62 |
+
return self.ctx.get_hf_processor(ChameleonProcessor)
|
| 63 |
+
|
| 64 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 65 |
+
return {"image": 1}
|
| 66 |
+
|
| 67 |
+
def get_mm_max_tokens_per_item(
|
| 68 |
+
self,
|
| 69 |
+
seq_len: int,
|
| 70 |
+
mm_counts: Mapping[str, int],
|
| 71 |
+
) -> Mapping[str, int]:
|
| 72 |
+
return {"image": self.get_num_image_tokens()}
|
| 73 |
+
|
| 74 |
+
def get_num_image_tokens(self) -> int:
|
| 75 |
+
processor = self.get_hf_processor()
|
| 76 |
+
return processor.image_seq_length
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ChameleonDummyInputsBuilder(
|
| 80 |
+
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
|
| 81 |
+
|
| 82 |
+
def get_dummy_processor_inputs(
|
| 83 |
+
self,
|
| 84 |
+
seq_len: int,
|
| 85 |
+
mm_counts: Mapping[str, int],
|
| 86 |
+
) -> ProcessorInputs:
|
| 87 |
+
config = self.info.get_hf_config()
|
| 88 |
+
|
| 89 |
+
width = height = config.vq_config.resolution
|
| 90 |
+
num_images = mm_counts.get("image", 0)
|
| 91 |
+
|
| 92 |
+
mm_data = {
|
| 93 |
+
"image":
|
| 94 |
+
self._get_dummy_images(width=width,
|
| 95 |
+
height=height,
|
| 96 |
+
num_images=num_images)
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return ProcessorInputs(
|
| 100 |
+
prompt_text="<image>" * num_images,
|
| 101 |
+
mm_data=mm_data,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class ChameleonMultiModalProcessor(
|
| 106 |
+
BaseMultiModalProcessor[ChameleonProcessingInfo]):
|
| 107 |
+
|
| 108 |
+
def _call_hf_processor(
|
| 109 |
+
self,
|
| 110 |
+
prompt: str,
|
| 111 |
+
mm_data: Mapping[str, object],
|
| 112 |
+
mm_kwargs: Mapping[str, object],
|
| 113 |
+
) -> BatchFeature:
|
| 114 |
+
if not mm_data:
|
| 115 |
+
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
| 116 |
+
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
| 117 |
+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
| 118 |
+
|
| 119 |
+
return super()._call_hf_processor(
|
| 120 |
+
prompt=prompt,
|
| 121 |
+
mm_data=mm_data,
|
| 122 |
+
mm_kwargs=mm_kwargs,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def _apply_hf_processor_tokens_only(
|
| 126 |
+
self,
|
| 127 |
+
prompt_tokens: list[int],
|
| 128 |
+
) -> list[int]:
|
| 129 |
+
# HF processor adds sep token for chat mode
|
| 130 |
+
tokenizer = self.info.get_tokenizer()
|
| 131 |
+
vocab = tokenizer.get_vocab()
|
| 132 |
+
|
| 133 |
+
sep_token_id = vocab[tokenizer.sep_token] # type: ignore
|
| 134 |
+
|
| 135 |
+
return prompt_tokens + [sep_token_id]
|
| 136 |
+
|
| 137 |
+
def _get_mm_fields_config(
|
| 138 |
+
self,
|
| 139 |
+
hf_inputs: BatchFeature,
|
| 140 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 141 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 142 |
+
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
| 143 |
+
|
| 144 |
+
def _get_prompt_replacements(
|
| 145 |
+
self,
|
| 146 |
+
mm_items: MultiModalDataItems,
|
| 147 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 148 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 149 |
+
) -> list[PromptReplacement]:
|
| 150 |
+
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 151 |
+
tokenizer = self.info.get_tokenizer()
|
| 152 |
+
vocab = tokenizer.get_vocab()
|
| 153 |
+
|
| 154 |
+
image_start_id = vocab[processor.image_start_token]
|
| 155 |
+
image_token_id = vocab[processor.image_token]
|
| 156 |
+
image_end_id = vocab[processor.image_end_token]
|
| 157 |
+
|
| 158 |
+
num_image_tokens = self.info.get_num_image_tokens()
|
| 159 |
+
image_tokens = [image_token_id] * num_image_tokens
|
| 160 |
+
|
| 161 |
+
return [
|
| 162 |
+
PromptReplacement(
|
| 163 |
+
modality="image",
|
| 164 |
+
target=[image_token_id],
|
| 165 |
+
replacement=PromptReplacementDetails(
|
| 166 |
+
full=([image_start_id] + image_tokens + [image_end_id]),
|
| 167 |
+
features=image_tokens,
|
| 168 |
+
),
|
| 169 |
+
)
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ChameleonLayerNorm(nn.LayerNorm):
|
| 174 |
+
|
| 175 |
+
def __init__(self, hidden_size, *args, **kwargs):
|
| 176 |
+
super().__init__(hidden_size, *args, **kwargs)
|
| 177 |
+
self.normalized_shape = (hidden_size[-1], )
|
| 178 |
+
|
| 179 |
+
set_weight_attrs(self.weight,
|
| 180 |
+
{"weight_loader": row_parallel_weight_loader})
|
| 181 |
+
set_weight_attrs(self.bias,
|
| 182 |
+
{"weight_loader": row_parallel_weight_loader})
|
| 183 |
+
|
| 184 |
+
def forward(self, hidden_states):
|
| 185 |
+
hidden_states = F.layer_norm(hidden_states,
|
| 186 |
+
self.normalized_shape,
|
| 187 |
+
None,
|
| 188 |
+
None,
|
| 189 |
+
eps=1e-5)
|
| 190 |
+
hidden_states = hidden_states * self.weight + self.bias
|
| 191 |
+
return hidden_states
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
|
| 195 |
+
class ChameleonMLP(nn.Module):
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
hidden_size: int,
|
| 200 |
+
intermediate_size: int,
|
| 201 |
+
hidden_act: str,
|
| 202 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 203 |
+
bias: bool = False,
|
| 204 |
+
) -> None:
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 207 |
+
input_size=hidden_size,
|
| 208 |
+
output_sizes=[intermediate_size] * 2,
|
| 209 |
+
bias=bias,
|
| 210 |
+
quant_config=quant_config)
|
| 211 |
+
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
| 212 |
+
output_size=hidden_size,
|
| 213 |
+
bias=bias,
|
| 214 |
+
quant_config=quant_config)
|
| 215 |
+
if hidden_act != "silu":
|
| 216 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 217 |
+
"Only silu is supported for now.")
|
| 218 |
+
self.act_fn = SiluAndMul()
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 222 |
+
x = self.act_fn(gate_up)
|
| 223 |
+
x, _ = self.down_proj(x)
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
|
| 228 |
+
class ChameleonAttention(nn.Module):
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
hidden_size: int,
|
| 233 |
+
num_heads: int,
|
| 234 |
+
num_kv_heads: int,
|
| 235 |
+
rope_theta: float = 10000,
|
| 236 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 237 |
+
max_position_embeddings: int = 4096,
|
| 238 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 239 |
+
bias: bool = False,
|
| 240 |
+
cache_config: Optional[CacheConfig] = None,
|
| 241 |
+
prefix: str = "",
|
| 242 |
+
) -> None:
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.hidden_size = hidden_size
|
| 245 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 246 |
+
self.total_num_heads = num_heads
|
| 247 |
+
assert self.total_num_heads % tp_size == 0
|
| 248 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 249 |
+
self.total_num_kv_heads = num_kv_heads
|
| 250 |
+
if self.total_num_kv_heads >= tp_size:
|
| 251 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 252 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 253 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 254 |
+
else:
|
| 255 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 256 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 257 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 258 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 259 |
+
self.head_dim = hidden_size // self.total_num_heads
|
| 260 |
+
self.q_size = self.num_heads * self.head_dim
|
| 261 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 262 |
+
self.scaling = self.head_dim**-0.5
|
| 263 |
+
self.rope_theta = rope_theta
|
| 264 |
+
self.max_position_embeddings = max_position_embeddings
|
| 265 |
+
|
| 266 |
+
self.qkv_proj = QKVParallelLinear(
|
| 267 |
+
hidden_size=hidden_size,
|
| 268 |
+
head_size=self.head_dim,
|
| 269 |
+
total_num_heads=self.total_num_heads,
|
| 270 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 271 |
+
bias=bias,
|
| 272 |
+
quant_config=quant_config,
|
| 273 |
+
)
|
| 274 |
+
self.o_proj = RowParallelLinear(
|
| 275 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 276 |
+
output_size=hidden_size,
|
| 277 |
+
bias=bias,
|
| 278 |
+
quant_config=quant_config,
|
| 279 |
+
)
|
| 280 |
+
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
|
| 281 |
+
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
|
| 282 |
+
self.rotary_emb = get_rope(
|
| 283 |
+
self.head_dim,
|
| 284 |
+
rotary_dim=self.head_dim,
|
| 285 |
+
max_position=max_position_embeddings,
|
| 286 |
+
base=rope_theta,
|
| 287 |
+
rope_scaling=rope_scaling,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.attn = Attention(self.num_heads,
|
| 291 |
+
self.head_dim,
|
| 292 |
+
self.scaling,
|
| 293 |
+
num_kv_heads=self.num_kv_heads,
|
| 294 |
+
cache_config=cache_config,
|
| 295 |
+
quant_config=quant_config,
|
| 296 |
+
prefix=f"{prefix}.attn")
|
| 297 |
+
|
| 298 |
+
def _apply_qk_norm(self, q: torch.Tensor,
|
| 299 |
+
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
# reshape for layernorm
|
| 301 |
+
q = q.reshape(-1, self.num_heads, self.head_dim)
|
| 302 |
+
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
|
| 303 |
+
q = self.q_norm(q)
|
| 304 |
+
k = self.k_norm(k)
|
| 305 |
+
q = q.view(*q.shape[:-2], -1)
|
| 306 |
+
k = k.view(*k.shape[:-2], -1)
|
| 307 |
+
return q, k
|
| 308 |
+
|
| 309 |
+
def forward(
|
| 310 |
+
self,
|
| 311 |
+
positions: torch.Tensor,
|
| 312 |
+
hidden_states: torch.Tensor,
|
| 313 |
+
kv_cache: torch.Tensor,
|
| 314 |
+
attn_metadata: AttentionMetadata,
|
| 315 |
+
) -> torch.Tensor:
|
| 316 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 317 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 318 |
+
q, k = self._apply_qk_norm(q, k)
|
| 319 |
+
|
| 320 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 321 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 322 |
+
output, _ = self.o_proj(attn_output)
|
| 323 |
+
return output
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class ChameleonDecoderLayer(nn.Module):
|
| 327 |
+
|
| 328 |
+
def __init__(
|
| 329 |
+
self,
|
| 330 |
+
config: ChameleonConfig,
|
| 331 |
+
cache_config: Optional[CacheConfig] = None,
|
| 332 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 333 |
+
prefix: str = "",
|
| 334 |
+
) -> None:
|
| 335 |
+
super().__init__()
|
| 336 |
+
self.hidden_size = config.hidden_size
|
| 337 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 338 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 339 |
+
if rope_scaling is not None and getattr(
|
| 340 |
+
config, "original_max_position_embeddings", None):
|
| 341 |
+
rope_scaling["original_max_position_embeddings"] = (
|
| 342 |
+
config.original_max_position_embeddings)
|
| 343 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 344 |
+
4096)
|
| 345 |
+
|
| 346 |
+
self.self_attn = ChameleonAttention(
|
| 347 |
+
hidden_size=self.hidden_size,
|
| 348 |
+
num_heads=config.num_attention_heads,
|
| 349 |
+
num_kv_heads=getattr(config, "num_key_value_heads",
|
| 350 |
+
config.num_attention_heads),
|
| 351 |
+
rope_theta=rope_theta,
|
| 352 |
+
rope_scaling=rope_scaling,
|
| 353 |
+
max_position_embeddings=max_position_embeddings,
|
| 354 |
+
quant_config=quant_config,
|
| 355 |
+
bias=False,
|
| 356 |
+
cache_config=cache_config,
|
| 357 |
+
prefix=f"{prefix}.self_attn",
|
| 358 |
+
)
|
| 359 |
+
self.mlp = ChameleonMLP(
|
| 360 |
+
hidden_size=self.hidden_size,
|
| 361 |
+
intermediate_size=config.intermediate_size,
|
| 362 |
+
hidden_act=config.hidden_act,
|
| 363 |
+
quant_config=quant_config,
|
| 364 |
+
bias=getattr(config, "mlp_bias", False),
|
| 365 |
+
)
|
| 366 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 367 |
+
eps=config.rms_norm_eps)
|
| 368 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 369 |
+
eps=config.rms_norm_eps)
|
| 370 |
+
|
| 371 |
+
def forward(
|
| 372 |
+
self,
|
| 373 |
+
positions: torch.Tensor,
|
| 374 |
+
hidden_states: torch.Tensor,
|
| 375 |
+
kv_cache: torch.Tensor,
|
| 376 |
+
attn_metadata: AttentionMetadata,
|
| 377 |
+
residual: Optional[torch.Tensor],
|
| 378 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 379 |
+
|
| 380 |
+
if residual is None:
|
| 381 |
+
residual = hidden_states
|
| 382 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 383 |
+
else:
|
| 384 |
+
hidden_states, residual = self.input_layernorm(
|
| 385 |
+
hidden_states, residual)
|
| 386 |
+
hidden_states = self.self_attn(
|
| 387 |
+
positions=positions,
|
| 388 |
+
hidden_states=hidden_states,
|
| 389 |
+
kv_cache=kv_cache,
|
| 390 |
+
attn_metadata=attn_metadata,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Fully Connected
|
| 394 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 395 |
+
hidden_states, residual)
|
| 396 |
+
hidden_states = self.mlp(hidden_states)
|
| 397 |
+
|
| 398 |
+
return hidden_states, residual
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class ChameleonSwinDecoderLayer(nn.Module):
|
| 402 |
+
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
config: ChameleonConfig,
|
| 406 |
+
cache_config: Optional[CacheConfig] = None,
|
| 407 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 408 |
+
prefix: str = "",
|
| 409 |
+
) -> None:
|
| 410 |
+
super().__init__()
|
| 411 |
+
self.hidden_size = config.hidden_size
|
| 412 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 413 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 414 |
+
if rope_scaling is not None and getattr(
|
| 415 |
+
config, "original_max_position_embeddings", None):
|
| 416 |
+
rope_scaling["original_max_position_embeddings"] = (
|
| 417 |
+
config.original_max_position_embeddings)
|
| 418 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 419 |
+
4096)
|
| 420 |
+
|
| 421 |
+
self.self_attn = ChameleonAttention(
|
| 422 |
+
hidden_size=self.hidden_size,
|
| 423 |
+
num_heads=config.num_attention_heads,
|
| 424 |
+
num_kv_heads=getattr(config, "num_key_value_heads",
|
| 425 |
+
config.num_attention_heads),
|
| 426 |
+
rope_theta=rope_theta,
|
| 427 |
+
rope_scaling=rope_scaling,
|
| 428 |
+
max_position_embeddings=max_position_embeddings,
|
| 429 |
+
quant_config=quant_config,
|
| 430 |
+
bias=False,
|
| 431 |
+
cache_config=cache_config,
|
| 432 |
+
prefix=f"{prefix}.self_attn",
|
| 433 |
+
)
|
| 434 |
+
self.mlp = ChameleonMLP(
|
| 435 |
+
hidden_size=self.hidden_size,
|
| 436 |
+
intermediate_size=config.intermediate_size,
|
| 437 |
+
hidden_act=config.hidden_act,
|
| 438 |
+
quant_config=quant_config,
|
| 439 |
+
bias=getattr(config, "mlp_bias", False),
|
| 440 |
+
)
|
| 441 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 442 |
+
eps=config.rms_norm_eps)
|
| 443 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 444 |
+
eps=config.rms_norm_eps)
|
| 445 |
+
|
| 446 |
+
def forward(
|
| 447 |
+
self,
|
| 448 |
+
positions: torch.Tensor,
|
| 449 |
+
hidden_states: torch.Tensor,
|
| 450 |
+
kv_cache: torch.Tensor,
|
| 451 |
+
attn_metadata: AttentionMetadata,
|
| 452 |
+
residual: Optional[torch.Tensor],
|
| 453 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 454 |
+
|
| 455 |
+
residual = hidden_states
|
| 456 |
+
hidden_states = self.self_attn(
|
| 457 |
+
positions=positions,
|
| 458 |
+
hidden_states=hidden_states,
|
| 459 |
+
kv_cache=kv_cache,
|
| 460 |
+
attn_metadata=attn_metadata,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 464 |
+
hidden_states = hidden_states + residual
|
| 465 |
+
|
| 466 |
+
# Fully Connected
|
| 467 |
+
residual = hidden_states
|
| 468 |
+
hidden_states = self.mlp(hidden_states)
|
| 469 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 470 |
+
hidden_states = residual + hidden_states
|
| 471 |
+
|
| 472 |
+
return hidden_states, residual
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
|
| 476 |
+
class ChameleonVQVAEVectorQuantizer(nn.Module):
|
| 477 |
+
|
| 478 |
+
def __init__(self, config: ChameleonVQVAEConfig):
|
| 479 |
+
super().__init__()
|
| 480 |
+
self.num_embeddings = config.num_embeddings
|
| 481 |
+
self.embedding_dim = config.embed_dim
|
| 482 |
+
self.beta = getattr(config, "beta", 0.25)
|
| 483 |
+
|
| 484 |
+
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
|
| 485 |
+
self.re_embed = self.num_embeddings
|
| 486 |
+
|
| 487 |
+
def forward(self, hidden_state: torch.Tensor):
|
| 488 |
+
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
| 489 |
+
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
|
| 490 |
+
|
| 491 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 492 |
+
distances = (
|
| 493 |
+
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
|
| 494 |
+
torch.sum(self.embedding.weight**2, dim=1) -
|
| 495 |
+
2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
|
| 496 |
+
self.embedding.weight.transpose(0, 1)))
|
| 497 |
+
|
| 498 |
+
min_encoding_indices = torch.argmin(distances, dim=1)
|
| 499 |
+
hidden_state_quant = self.embedding(min_encoding_indices).view(
|
| 500 |
+
hidden_state.shape)
|
| 501 |
+
|
| 502 |
+
# compute loss for embedding
|
| 503 |
+
loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
|
| 504 |
+
2) + self.beta * torch.mean(
|
| 505 |
+
(hidden_state_quant - hidden_state.detach())**2)
|
| 506 |
+
|
| 507 |
+
# preserve gradients
|
| 508 |
+
hidden_state_quant = hidden_state + (hidden_state_quant -
|
| 509 |
+
hidden_state).detach()
|
| 510 |
+
|
| 511 |
+
# reshape back to match original input shape
|
| 512 |
+
hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
|
| 513 |
+
2).contiguous()
|
| 514 |
+
|
| 515 |
+
return hidden_state_quant, loss, min_encoding_indices
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
|
| 519 |
+
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
|
| 520 |
+
|
| 521 |
+
def __init__(self, in_channels: int):
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.conv = nn.Conv2d(in_channels,
|
| 524 |
+
in_channels,
|
| 525 |
+
kernel_size=3,
|
| 526 |
+
stride=2,
|
| 527 |
+
padding=0)
|
| 528 |
+
|
| 529 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 530 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 531 |
+
hidden_states = F.pad(hidden_states,
|
| 532 |
+
pad=(0, 1, 0, 1),
|
| 533 |
+
mode="constant",
|
| 534 |
+
value=0)
|
| 535 |
+
hidden_states = self.conv(hidden_states)
|
| 536 |
+
return hidden_states
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
|
| 540 |
+
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
| 541 |
+
|
| 542 |
+
def __init__(
|
| 543 |
+
self,
|
| 544 |
+
config: ChameleonVQVAEConfig,
|
| 545 |
+
in_channels: int,
|
| 546 |
+
out_channels=None,
|
| 547 |
+
conv_shortcut=False,
|
| 548 |
+
):
|
| 549 |
+
super().__init__()
|
| 550 |
+
self.in_channels = in_channels
|
| 551 |
+
self.out_channels = in_channels if out_channels is None \
|
| 552 |
+
else out_channels
|
| 553 |
+
self.use_conv_shortcut = conv_shortcut
|
| 554 |
+
|
| 555 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=32,
|
| 556 |
+
num_channels=in_channels,
|
| 557 |
+
eps=1e-6,
|
| 558 |
+
affine=True)
|
| 559 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 560 |
+
out_channels,
|
| 561 |
+
kernel_size=3,
|
| 562 |
+
stride=1,
|
| 563 |
+
padding=1)
|
| 564 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=32,
|
| 565 |
+
num_channels=out_channels,
|
| 566 |
+
eps=1e-6,
|
| 567 |
+
affine=True)
|
| 568 |
+
self.dropout = torch.nn.Dropout(config.dropout)
|
| 569 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 570 |
+
out_channels,
|
| 571 |
+
kernel_size=3,
|
| 572 |
+
stride=1,
|
| 573 |
+
padding=1)
|
| 574 |
+
if self.in_channels != self.out_channels:
|
| 575 |
+
if self.use_conv_shortcut:
|
| 576 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 577 |
+
out_channels,
|
| 578 |
+
kernel_size=3,
|
| 579 |
+
stride=1,
|
| 580 |
+
padding=1)
|
| 581 |
+
else:
|
| 582 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 583 |
+
out_channels,
|
| 584 |
+
kernel_size=1,
|
| 585 |
+
stride=1,
|
| 586 |
+
padding=0)
|
| 587 |
+
|
| 588 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 589 |
+
residual = hidden_states
|
| 590 |
+
hidden_states = self.norm1(hidden_states)
|
| 591 |
+
hidden_states *= torch.sigmoid(hidden_states)
|
| 592 |
+
hidden_states = self.conv1(hidden_states)
|
| 593 |
+
|
| 594 |
+
hidden_states = self.norm2(hidden_states)
|
| 595 |
+
hidden_states *= torch.sigmoid(hidden_states)
|
| 596 |
+
hidden_states = self.dropout(hidden_states)
|
| 597 |
+
hidden_states = self.conv2(hidden_states)
|
| 598 |
+
|
| 599 |
+
if self.in_channels != self.out_channels:
|
| 600 |
+
if self.use_conv_shortcut:
|
| 601 |
+
residual = self.conv_shortcut(residual)
|
| 602 |
+
else:
|
| 603 |
+
residual = self.nin_shortcut(residual)
|
| 604 |
+
|
| 605 |
+
return residual + hidden_states
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
|
| 609 |
+
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
| 610 |
+
|
| 611 |
+
def __init__(self, in_channels: int):
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.in_channels = in_channels
|
| 614 |
+
|
| 615 |
+
self.norm = torch.nn.GroupNorm(num_groups=32,
|
| 616 |
+
num_channels=in_channels,
|
| 617 |
+
eps=1e-6,
|
| 618 |
+
affine=True)
|
| 619 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 620 |
+
in_channels,
|
| 621 |
+
kernel_size=1,
|
| 622 |
+
stride=1,
|
| 623 |
+
padding=0)
|
| 624 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 625 |
+
in_channels,
|
| 626 |
+
kernel_size=1,
|
| 627 |
+
stride=1,
|
| 628 |
+
padding=0)
|
| 629 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 630 |
+
in_channels,
|
| 631 |
+
kernel_size=1,
|
| 632 |
+
stride=1,
|
| 633 |
+
padding=0)
|
| 634 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 635 |
+
in_channels,
|
| 636 |
+
kernel_size=1,
|
| 637 |
+
stride=1,
|
| 638 |
+
padding=0)
|
| 639 |
+
|
| 640 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 641 |
+
residual = hidden_states
|
| 642 |
+
hidden_states = self.norm(hidden_states)
|
| 643 |
+
query_states = self.q(hidden_states)
|
| 644 |
+
key_states = self.k(hidden_states)
|
| 645 |
+
value_states = self.v(hidden_states)
|
| 646 |
+
|
| 647 |
+
# compute attention
|
| 648 |
+
batch_size, channels, height, width = query_states.shape
|
| 649 |
+
query_states = query_states.reshape(batch_size, channels,
|
| 650 |
+
height * width).permute(0, 2, 1)
|
| 651 |
+
key_states = key_states.reshape(batch_size, channels, height * width)
|
| 652 |
+
attn_weights = torch.bmm(query_states, key_states)
|
| 653 |
+
attn_weights = attn_weights * (int(channels)**(-0.5))
|
| 654 |
+
attn_weights = F.softmax(attn_weights, dim=2)
|
| 655 |
+
|
| 656 |
+
# attend to values
|
| 657 |
+
value_states = value_states.reshape(batch_size, channels,
|
| 658 |
+
height * width)
|
| 659 |
+
attn_weights = attn_weights.permute(0, 2, 1)
|
| 660 |
+
attn_output = torch.bmm(value_states,
|
| 661 |
+
attn_weights).reshape(batch_size, channels,
|
| 662 |
+
height, width)
|
| 663 |
+
|
| 664 |
+
attn_output = self.proj_out(attn_output)
|
| 665 |
+
return residual + attn_output
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
|
| 669 |
+
class ChameleonVQVAEEncoder(nn.Module):
|
| 670 |
+
|
| 671 |
+
def __init__(self, config: ChameleonVQVAEConfig):
|
| 672 |
+
super().__init__()
|
| 673 |
+
|
| 674 |
+
self.num_resolutions = len(config.channel_multiplier)
|
| 675 |
+
self.num_res_blocks = config.num_res_blocks
|
| 676 |
+
base_channels = config.base_channels
|
| 677 |
+
resolution = config.resolution
|
| 678 |
+
in_channels = config.in_channels
|
| 679 |
+
double_latent = config.double_latent
|
| 680 |
+
latent_channels = config.latent_channels
|
| 681 |
+
channel_multiplier = config.channel_multiplier
|
| 682 |
+
|
| 683 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 684 |
+
base_channels,
|
| 685 |
+
kernel_size=3,
|
| 686 |
+
stride=1,
|
| 687 |
+
padding=1)
|
| 688 |
+
|
| 689 |
+
curr_res = resolution
|
| 690 |
+
in_channel_multiplier = (1, ) + tuple(channel_multiplier)
|
| 691 |
+
self.in_channel_multiplier = in_channel_multiplier
|
| 692 |
+
self.down = nn.ModuleList()
|
| 693 |
+
for i_level in range(self.num_resolutions):
|
| 694 |
+
block = nn.ModuleList()
|
| 695 |
+
attn = nn.ModuleList()
|
| 696 |
+
block_in = base_channels * in_channel_multiplier[i_level]
|
| 697 |
+
block_out = base_channels * channel_multiplier[i_level]
|
| 698 |
+
for i_block in range(self.num_res_blocks):
|
| 699 |
+
block.append(
|
| 700 |
+
ChameleonVQVAEEncoderResnetBlock(
|
| 701 |
+
config=config,
|
| 702 |
+
in_channels=block_in,
|
| 703 |
+
out_channels=block_out,
|
| 704 |
+
))
|
| 705 |
+
block_in = block_out
|
| 706 |
+
if (config.attn_resolutions is not None
|
| 707 |
+
and curr_res in config.attn_resolutions
|
| 708 |
+
and config.attn_type == "vanilla"):
|
| 709 |
+
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
|
| 710 |
+
|
| 711 |
+
down = nn.Module()
|
| 712 |
+
down.block = block
|
| 713 |
+
down.attn = attn
|
| 714 |
+
if i_level != self.num_resolutions - 1:
|
| 715 |
+
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
|
| 716 |
+
curr_res = curr_res // 2
|
| 717 |
+
self.down.append(down)
|
| 718 |
+
|
| 719 |
+
self.mid = nn.Module()
|
| 720 |
+
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
|
| 721 |
+
config=config,
|
| 722 |
+
in_channels=block_in,
|
| 723 |
+
out_channels=block_in,
|
| 724 |
+
)
|
| 725 |
+
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
|
| 726 |
+
block_in) if config.attn_type == "vanilla" else nn.Identity()
|
| 727 |
+
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
|
| 728 |
+
config=config,
|
| 729 |
+
in_channels=block_in,
|
| 730 |
+
out_channels=block_in,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
self.norm_out = torch.nn.GroupNorm(num_groups=32,
|
| 734 |
+
num_channels=block_in,
|
| 735 |
+
eps=1e-6,
|
| 736 |
+
affine=True)
|
| 737 |
+
self.conv_out = torch.nn.Conv2d(
|
| 738 |
+
block_in,
|
| 739 |
+
2 * latent_channels if double_latent else latent_channels,
|
| 740 |
+
kernel_size=3,
|
| 741 |
+
stride=1,
|
| 742 |
+
padding=1,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
def forward(self, pixel_values: torch.Tensor):
|
| 746 |
+
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
|
| 747 |
+
|
| 748 |
+
# downsampling
|
| 749 |
+
hidden_states = [self.conv_in(pixel_values)]
|
| 750 |
+
for i_level in range(self.num_resolutions):
|
| 751 |
+
for i_block in range(self.num_res_blocks):
|
| 752 |
+
hidden_state = self.down[i_level].block[i_block](
|
| 753 |
+
hidden_states[-1])
|
| 754 |
+
if len(self.down[i_level].attn) > 0:
|
| 755 |
+
hidden_state = self.down[i_level].attn[i_block](
|
| 756 |
+
hidden_state)
|
| 757 |
+
hidden_states.append(hidden_state)
|
| 758 |
+
if i_level != self.num_resolutions - 1:
|
| 759 |
+
hidden_states.append(self.down[i_level].downsample(
|
| 760 |
+
hidden_states[-1]))
|
| 761 |
+
|
| 762 |
+
# middle
|
| 763 |
+
last_hidden_state = hidden_states[-1]
|
| 764 |
+
last_hidden_state = self.mid.block_1(last_hidden_state)
|
| 765 |
+
last_hidden_state = self.mid.attn_1(last_hidden_state)
|
| 766 |
+
last_hidden_state = self.mid.block_2(last_hidden_state)
|
| 767 |
+
|
| 768 |
+
# end
|
| 769 |
+
last_hidden_state = self.norm_out(last_hidden_state)
|
| 770 |
+
last_hidden_state *= torch.sigmoid(last_hidden_state)
|
| 771 |
+
last_hidden_state = self.conv_out(last_hidden_state)
|
| 772 |
+
return last_hidden_state
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
|
| 776 |
+
class ChameleonVQVAE(nn.Module):
|
| 777 |
+
|
| 778 |
+
def __init__(self, config: ChameleonVQVAEConfig):
|
| 779 |
+
super().__init__()
|
| 780 |
+
self.encoder = ChameleonVQVAEEncoder(config)
|
| 781 |
+
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
| 782 |
+
self.quant_conv = torch.nn.Conv2d(config.latent_channels,
|
| 783 |
+
config.embed_dim, 1)
|
| 784 |
+
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
|
| 785 |
+
config.latent_channels, 1)
|
| 786 |
+
self.eval() # Chameleon's VQ model is frozen
|
| 787 |
+
|
| 788 |
+
def encode(
|
| 789 |
+
self, pixel_values: torch.Tensor
|
| 790 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 791 |
+
hidden_states = self.encoder(pixel_values)
|
| 792 |
+
hidden_states = self.quant_conv(hidden_states)
|
| 793 |
+
quant, emb_loss, indices = self.quantize(hidden_states)
|
| 794 |
+
return quant, emb_loss, indices
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
|
| 798 |
+
class ChameleonImageVocabularyMapping:
|
| 799 |
+
"""
|
| 800 |
+
A class for mapping discrete image tokens from VQGAN to BPE tokens.
|
| 801 |
+
"""
|
| 802 |
+
|
| 803 |
+
def __init__(self, vocab_map: Dict[str, int]):
|
| 804 |
+
self.vocab_map = vocab_map
|
| 805 |
+
self.image_token_id = vocab_map.get("<image>")
|
| 806 |
+
|
| 807 |
+
@cached_property
|
| 808 |
+
def val2name(self):
|
| 809 |
+
return {v: k for k, v in self.vocab_map.items()}
|
| 810 |
+
|
| 811 |
+
@cached_property
|
| 812 |
+
def image_tokens(self):
|
| 813 |
+
return sorted([
|
| 814 |
+
val for name, val in self.vocab_map.items()
|
| 815 |
+
if name.startswith("IMGIMG")
|
| 816 |
+
])
|
| 817 |
+
|
| 818 |
+
@cached_property
|
| 819 |
+
def bpe2img(self):
|
| 820 |
+
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
|
| 821 |
+
|
| 822 |
+
def remap(old_name: str) -> str:
|
| 823 |
+
return "".join(
|
| 824 |
+
img_tkn_chr_mapping.get(c, c)
|
| 825 |
+
for c in old_name[len("IMGIMG"):-1])
|
| 826 |
+
|
| 827 |
+
return {
|
| 828 |
+
tok: int(remap(self.val2name[tok]))
|
| 829 |
+
for tok in self.image_tokens
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
@cached_property
|
| 833 |
+
def img2bpe(self):
|
| 834 |
+
return {v: k for k, v in self.bpe2img.items()}
|
| 835 |
+
|
| 836 |
+
@cached_property
|
| 837 |
+
def bpe2img_search_tensors(self):
|
| 838 |
+
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
|
| 839 |
+
sorted(self.bpe2img.values()))
|
| 840 |
+
|
| 841 |
+
@cached_property
|
| 842 |
+
def img2bpe_mapping_tensor(self):
|
| 843 |
+
mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
|
| 844 |
+
for k, v in self.img2bpe.items():
|
| 845 |
+
mapping[k] = v
|
| 846 |
+
return mapping
|
| 847 |
+
|
| 848 |
+
def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
|
| 849 |
+
device = img_batch.device
|
| 850 |
+
img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
|
| 851 |
+
return img_tokens.to(device)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
class ChameleonModel(nn.Module):
|
| 855 |
+
|
| 856 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 857 |
+
super().__init__()
|
| 858 |
+
|
| 859 |
+
config = vllm_config.model_config.hf_config
|
| 860 |
+
cache_config = vllm_config.cache_config
|
| 861 |
+
quant_config = vllm_config.quant_config
|
| 862 |
+
|
| 863 |
+
self.config = config
|
| 864 |
+
self.padding_idx = config.pad_token_id
|
| 865 |
+
self.vocab_size = config.vocab_size
|
| 866 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 867 |
+
self.vocab_size,
|
| 868 |
+
config.hidden_size,
|
| 869 |
+
)
|
| 870 |
+
self.vocabulary_mapping = ChameleonImageVocabularyMapping(
|
| 871 |
+
config.vocabulary_map)
|
| 872 |
+
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
|
| 873 |
+
else ChameleonSwinDecoderLayer
|
| 874 |
+
|
| 875 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 876 |
+
config.num_hidden_layers,
|
| 877 |
+
lambda prefix: decoder_layer(config=config,
|
| 878 |
+
cache_config=cache_config,
|
| 879 |
+
quant_config=quant_config,
|
| 880 |
+
prefix=prefix),
|
| 881 |
+
prefix=f"{prefix}.layers",
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 885 |
+
self.vqmodel = ChameleonVQVAE(config.vq_config)
|
| 886 |
+
self.make_empty_intermediate_tensors = (
|
| 887 |
+
make_empty_intermediate_tensors_factory(
|
| 888 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 889 |
+
|
| 890 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 891 |
+
return self.embed_tokens(input_ids)
|
| 892 |
+
|
| 893 |
+
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 894 |
+
"""
|
| 895 |
+
Tokenizes images into discrete tokens with VQGAN module. Converts
|
| 896 |
+
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
| 897 |
+
special tokens.
|
| 898 |
+
"""
|
| 899 |
+
batch_size = pixel_values.shape[0]
|
| 900 |
+
_, _, image_toks = self.vqmodel.encode(pixel_values)
|
| 901 |
+
bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
|
| 902 |
+
bpe_toks = bpe_toks.view(batch_size, -1)
|
| 903 |
+
return bpe_toks
|
| 904 |
+
|
| 905 |
+
def forward(
|
| 906 |
+
self,
|
| 907 |
+
input_ids: Optional[torch.Tensor],
|
| 908 |
+
positions: torch.Tensor,
|
| 909 |
+
kv_caches: List[torch.Tensor],
|
| 910 |
+
attn_metadata: AttentionMetadata,
|
| 911 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 912 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 913 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 914 |
+
if get_pp_group().is_first_rank:
|
| 915 |
+
if inputs_embeds is not None:
|
| 916 |
+
hidden_states = inputs_embeds
|
| 917 |
+
else:
|
| 918 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 919 |
+
residual = None
|
| 920 |
+
else:
|
| 921 |
+
assert intermediate_tensors is not None
|
| 922 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 923 |
+
residual = intermediate_tensors["residual"]
|
| 924 |
+
for i in range(self.start_layer, self.end_layer):
|
| 925 |
+
layer = self.layers[i]
|
| 926 |
+
hidden_states, residual = layer(
|
| 927 |
+
positions,
|
| 928 |
+
hidden_states,
|
| 929 |
+
kv_caches[i - self.start_layer],
|
| 930 |
+
attn_metadata,
|
| 931 |
+
residual,
|
| 932 |
+
)
|
| 933 |
+
if not get_pp_group().is_last_rank:
|
| 934 |
+
return IntermediateTensors({
|
| 935 |
+
"hidden_states": hidden_states,
|
| 936 |
+
"residual": residual
|
| 937 |
+
})
|
| 938 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 939 |
+
return hidden_states
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
@MULTIMODAL_REGISTRY.register_processor(
|
| 943 |
+
ChameleonMultiModalProcessor,
|
| 944 |
+
info=ChameleonProcessingInfo,
|
| 945 |
+
dummy_inputs=ChameleonDummyInputsBuilder)
|
| 946 |
+
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
| 947 |
+
SupportsPP):
|
| 948 |
+
|
| 949 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 950 |
+
super().__init__()
|
| 951 |
+
config = vllm_config.model_config.hf_config
|
| 952 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 953 |
+
self.config = config
|
| 954 |
+
self.multimodal_config = multimodal_config
|
| 955 |
+
self.model = ChameleonModel(vllm_config=vllm_config,
|
| 956 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 957 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 958 |
+
self.lm_head = ParallelLMHead(
|
| 959 |
+
self.unpadded_vocab_size,
|
| 960 |
+
config.hidden_size,
|
| 961 |
+
)
|
| 962 |
+
if config.tie_word_embeddings:
|
| 963 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 964 |
+
|
| 965 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 966 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 967 |
+
config.vocab_size, logit_scale)
|
| 968 |
+
self.sampler = get_sampler()
|
| 969 |
+
self.make_empty_intermediate_tensors = (
|
| 970 |
+
self.model.make_empty_intermediate_tensors)
|
| 971 |
+
|
| 972 |
+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
| 973 |
+
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
| 974 |
+
expected_dims = (3, vq_config.resolution, vq_config.resolution)
|
| 975 |
+
actual_dims = tuple(data.shape[1:])
|
| 976 |
+
|
| 977 |
+
if actual_dims != expected_dims:
|
| 978 |
+
expected_expr = ("batch_size", *map(str, expected_dims))
|
| 979 |
+
raise ValueError(
|
| 980 |
+
f"The expected shape of pixel values is {expected_expr}. "
|
| 981 |
+
f"You supplied {tuple(data.shape)}.")
|
| 982 |
+
|
| 983 |
+
return data
|
| 984 |
+
|
| 985 |
+
def _parse_and_validate_image_input(
|
| 986 |
+
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
|
| 987 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 988 |
+
|
| 989 |
+
if pixel_values is None:
|
| 990 |
+
return None
|
| 991 |
+
|
| 992 |
+
if not isinstance(pixel_values, torch.Tensor):
|
| 993 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 994 |
+
f"Got type: {type(pixel_values)}")
|
| 995 |
+
|
| 996 |
+
# Remove the N dimension until multiple images are supported.
|
| 997 |
+
pixel_values = pixel_values.squeeze(1)
|
| 998 |
+
|
| 999 |
+
return ChameleonImagePixelInputs(
|
| 1000 |
+
type="pixel_values",
|
| 1001 |
+
data=self._validate_pixel_values(pixel_values),
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 1005 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 1006 |
+
if image_input is None:
|
| 1007 |
+
return None
|
| 1008 |
+
assert self.model.vqmodel is not None
|
| 1009 |
+
image_tokens = self.model.get_image_tokens(image_input["data"].to(
|
| 1010 |
+
self.config.torch_dtype))
|
| 1011 |
+
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
| 1012 |
+
return vision_embeddings
|
| 1013 |
+
|
| 1014 |
+
def get_input_embeddings(
|
| 1015 |
+
self,
|
| 1016 |
+
input_ids: torch.Tensor,
|
| 1017 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 1018 |
+
) -> torch.Tensor:
|
| 1019 |
+
|
| 1020 |
+
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
| 1021 |
+
if multimodal_embeddings is not None:
|
| 1022 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 1023 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 1024 |
+
self.model.vocabulary_mapping.image_token_id)
|
| 1025 |
+
return inputs_embeds
|
| 1026 |
+
|
| 1027 |
+
def forward(
|
| 1028 |
+
self,
|
| 1029 |
+
input_ids: torch.Tensor,
|
| 1030 |
+
positions: torch.Tensor,
|
| 1031 |
+
kv_caches: List[torch.Tensor],
|
| 1032 |
+
attn_metadata: AttentionMetadata,
|
| 1033 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 1034 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1035 |
+
**kwargs,
|
| 1036 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 1037 |
+
|
| 1038 |
+
if intermediate_tensors is not None:
|
| 1039 |
+
inputs_embeds = None
|
| 1040 |
+
|
| 1041 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 1042 |
+
# condition is for v0 compatibility.
|
| 1043 |
+
elif inputs_embeds is None:
|
| 1044 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 1045 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 1046 |
+
vision_embeddings)
|
| 1047 |
+
input_ids = None
|
| 1048 |
+
|
| 1049 |
+
hidden_states = self.model(input_ids,
|
| 1050 |
+
positions,
|
| 1051 |
+
kv_caches,
|
| 1052 |
+
attn_metadata,
|
| 1053 |
+
intermediate_tensors,
|
| 1054 |
+
inputs_embeds=inputs_embeds)
|
| 1055 |
+
return hidden_states
|
| 1056 |
+
|
| 1057 |
+
def compute_logits(
|
| 1058 |
+
self,
|
| 1059 |
+
hidden_states: torch.Tensor,
|
| 1060 |
+
sampling_metadata: SamplingMetadata,
|
| 1061 |
+
) -> Optional[torch.Tensor]:
|
| 1062 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 1063 |
+
sampling_metadata)
|
| 1064 |
+
|
| 1065 |
+
# Disallow image tokens which does not include special
|
| 1066 |
+
# begin-image and end-image tokens
|
| 1067 |
+
if logits is not None:
|
| 1068 |
+
image_tokens = self.model.vocabulary_mapping.image_tokens
|
| 1069 |
+
logits[:, image_tokens] = torch.finfo(logits.dtype).min
|
| 1070 |
+
|
| 1071 |
+
return logits
|
| 1072 |
+
|
| 1073 |
+
def sample(
|
| 1074 |
+
self,
|
| 1075 |
+
logits: torch.Tensor,
|
| 1076 |
+
sampling_metadata: SamplingMetadata,
|
| 1077 |
+
) -> Optional[SamplerOutput]:
|
| 1078 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 1079 |
+
return next_tokens
|
| 1080 |
+
|
| 1081 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 1082 |
+
torch.Tensor]]) -> Set[str]:
|
| 1083 |
+
stacked_params_mapping = [
|
| 1084 |
+
# (param_name, shard_name, shard_id)
|
| 1085 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 1086 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 1087 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 1088 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 1089 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 1090 |
+
]
|
| 1091 |
+
params_dict = dict(self.named_parameters())
|
| 1092 |
+
loaded_params: Set[str] = set()
|
| 1093 |
+
for name, loaded_weight in weights:
|
| 1094 |
+
if "rotary_emb.inv_freq" in name:
|
| 1095 |
+
continue
|
| 1096 |
+
|
| 1097 |
+
if ("rotary_emb.cos_cached" in name
|
| 1098 |
+
or "rotary_emb.sin_cached" in name):
|
| 1099 |
+
# Models trained using ColossalAI may include these tensors in
|
| 1100 |
+
# the checkpoint. Skip them.
|
| 1101 |
+
continue
|
| 1102 |
+
|
| 1103 |
+
# With tie_word_embeddings, we can skip lm_head.weight
|
| 1104 |
+
# The weight might appear unnecessarily in the files if the model is
|
| 1105 |
+
# processed with quantization, LoRA, fine-tuning, etc.
|
| 1106 |
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
| 1107 |
+
continue
|
| 1108 |
+
|
| 1109 |
+
use_default_weight_loading = False
|
| 1110 |
+
if "vqmodel" in name:
|
| 1111 |
+
if self.model.vqmodel is not None:
|
| 1112 |
+
# We only do sharding for language model and
|
| 1113 |
+
# not vqvae for now.
|
| 1114 |
+
use_default_weight_loading = True
|
| 1115 |
+
else:
|
| 1116 |
+
for (param_name, weight_name,
|
| 1117 |
+
shard_id) in stacked_params_mapping:
|
| 1118 |
+
if weight_name not in name:
|
| 1119 |
+
continue
|
| 1120 |
+
name = name.replace(weight_name, param_name)
|
| 1121 |
+
# Skip loading extra bias for GPTQ models.
|
| 1122 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 1123 |
+
continue
|
| 1124 |
+
if is_pp_missing_parameter(name, self):
|
| 1125 |
+
continue
|
| 1126 |
+
param = params_dict[name]
|
| 1127 |
+
weight_loader = param.weight_loader
|
| 1128 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 1129 |
+
break
|
| 1130 |
+
else:
|
| 1131 |
+
# Skip loading extra bias for GPTQ models.
|
| 1132 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 1133 |
+
continue
|
| 1134 |
+
# Remapping the name of FP8 kv-scale.
|
| 1135 |
+
if name.endswith("kv_scale"):
|
| 1136 |
+
remapped_kv_scale_name = name.replace(
|
| 1137 |
+
".kv_scale", ".attn.kv_scale")
|
| 1138 |
+
if remapped_kv_scale_name not in params_dict:
|
| 1139 |
+
logger.warning_once(
|
| 1140 |
+
"Found kv scale in the checkpoint (e.g. "
|
| 1141 |
+
f"{name}), but not found the expected name in "
|
| 1142 |
+
f"the model (e.g. {remapped_kv_scale_name}). "
|
| 1143 |
+
"kv-scale is not loaded.")
|
| 1144 |
+
continue
|
| 1145 |
+
else:
|
| 1146 |
+
name = remapped_kv_scale_name
|
| 1147 |
+
if is_pp_missing_parameter(name, self):
|
| 1148 |
+
continue
|
| 1149 |
+
param = params_dict[name]
|
| 1150 |
+
weight_loader = getattr(param, "weight_loader",
|
| 1151 |
+
default_weight_loader)
|
| 1152 |
+
weight_loader(param, loaded_weight)
|
| 1153 |
+
if use_default_weight_loading and name in params_dict:
|
| 1154 |
+
if is_pp_missing_parameter(name, self):
|
| 1155 |
+
continue
|
| 1156 |
+
param = params_dict[name]
|
| 1157 |
+
weight_loader = getattr(param, "weight_loader",
|
| 1158 |
+
default_weight_loader)
|
| 1159 |
+
weight_loader(param, loaded_weight)
|
| 1160 |
+
loaded_params.add(name)
|
| 1161 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/chatglm.py
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/THUDM/CogAgent
|
| 5 |
+
"""Inference-only CogAgent model compatible with THUDM weights."""
|
| 6 |
+
from argparse import Namespace
|
| 7 |
+
from array import array
|
| 8 |
+
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
|
| 9 |
+
TypedDict)
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.nn import LayerNorm
|
| 15 |
+
|
| 16 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 17 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 18 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 19 |
+
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
| 20 |
+
InputContext, token_inputs)
|
| 21 |
+
from vllm.logger import init_logger
|
| 22 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 23 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 24 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 25 |
+
QKVParallelLinear,
|
| 26 |
+
RowParallelLinear)
|
| 27 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 28 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 29 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 30 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 31 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 32 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 33 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 34 |
+
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
| 35 |
+
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
| 36 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 37 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 38 |
+
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
|
| 39 |
+
NestedTensors)
|
| 40 |
+
from vllm.multimodal.utils import cached_get_tokenizer
|
| 41 |
+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
| 42 |
+
SequenceData)
|
| 43 |
+
from vllm.transformers_utils.configs import ChatGLMConfig
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
| 46 |
+
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
| 47 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 48 |
+
maybe_prefix)
|
| 49 |
+
|
| 50 |
+
logger = init_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def calculate_image_placeholder(vision_config):
|
| 54 |
+
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def mm_input_mapper_for_glmv(
|
| 58 |
+
ctx: InputContext,
|
| 59 |
+
data: ModalityData[object],
|
| 60 |
+
) -> Dict:
|
| 61 |
+
model_config = ctx.model_config
|
| 62 |
+
tokenizer = cached_get_tokenizer(
|
| 63 |
+
model_config.tokenizer,
|
| 64 |
+
trust_remote_code=model_config.trust_remote_code)
|
| 65 |
+
if tokenizer is None:
|
| 66 |
+
raise RuntimeError("No HuggingFace processor is available "
|
| 67 |
+
"to process the image object")
|
| 68 |
+
try:
|
| 69 |
+
raw_batch_data = tokenizer.apply_chat_template(
|
| 70 |
+
conversation=[{
|
| 71 |
+
"role": "user",
|
| 72 |
+
"image": data
|
| 73 |
+
}],
|
| 74 |
+
add_generation_prompt=True,
|
| 75 |
+
tokenize=True,
|
| 76 |
+
return_tensors="pt",
|
| 77 |
+
return_dict=True).data
|
| 78 |
+
except Exception:
|
| 79 |
+
logger.error("Failed to process image (%s)", data)
|
| 80 |
+
raise
|
| 81 |
+
pixel_values = raw_batch_data['images']
|
| 82 |
+
|
| 83 |
+
return MultiModalKwargs({'pixel_values': pixel_values})
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def merge_glm_vision_embeddings(
|
| 87 |
+
input_ids: torch.Tensor,
|
| 88 |
+
inputs_embeds: torch.Tensor,
|
| 89 |
+
vision_embeddings: torch.Tensor,
|
| 90 |
+
boi_token_id: int,
|
| 91 |
+
eoi_token_id: int,
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
|
| 94 |
+
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
|
| 95 |
+
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
|
| 96 |
+
|
| 97 |
+
mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 98 |
+
|
| 99 |
+
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
|
| 100 |
+
assert boi_pos < eoi_pos
|
| 101 |
+
mask[boi_pos:eoi_pos + 1] = True
|
| 102 |
+
inputs_embeds[mask] = vision_embeddings.view(-1,
|
| 103 |
+
vision_embeddings.shape[-1])
|
| 104 |
+
return inputs_embeds
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class GLMImagePixelInputs(TypedDict):
|
| 108 |
+
pixel_values: torch.Tensor
|
| 109 |
+
"""Shape: `(batch_size, num_channels, height, width)`"""
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_max_glmv_image_tokens(ctx: InputContext):
|
| 113 |
+
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
| 114 |
+
|
| 115 |
+
vision_config = getattr(hf_config, 'vision_config', None)
|
| 116 |
+
if vision_config is None:
|
| 117 |
+
return 1
|
| 118 |
+
elif isinstance(vision_config, dict):
|
| 119 |
+
return calculate_image_placeholder(vision_config)
|
| 120 |
+
|
| 121 |
+
msg = f"Unsupported vision config: {type(vision_config)}"
|
| 122 |
+
raise NotImplementedError(msg)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
|
| 126 |
+
mm_counts: Mapping[str, int]) -> DummyData:
|
| 127 |
+
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
| 128 |
+
vision_config = getattr(hf_config, 'vision_config', None)
|
| 129 |
+
|
| 130 |
+
if vision_config is None:
|
| 131 |
+
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
|
| 132 |
+
seq_data = SequenceData(token_ids)
|
| 133 |
+
return DummyData(seq_data, None)
|
| 134 |
+
elif isinstance(vision_config, dict):
|
| 135 |
+
image_size = vision_config["image_size"]
|
| 136 |
+
image_placeholder_length = calculate_image_placeholder(vision_config)
|
| 137 |
+
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
|
| 138 |
+
[0] * image_placeholder_length +
|
| 139 |
+
[hf_config.eoi_token_id])
|
| 140 |
+
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
| 141 |
+
[0] * (seq_len - image_placeholder_length - 2))
|
| 142 |
+
seq_data = SequenceData(token_ids)
|
| 143 |
+
|
| 144 |
+
mm_data = {
|
| 145 |
+
"image": Image.new("RGB", (image_size, image_size), color=0)
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return DummyData(seq_data, mm_data)
|
| 149 |
+
|
| 150 |
+
msg = f"Unsupported vision config: {type(vision_config)}"
|
| 151 |
+
raise NotImplementedError(msg)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
| 155 |
+
return [index for index, value in enumerate(input_ids) if value == target]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
| 159 |
+
multi_modal_data = inputs.get("multi_modal_data")
|
| 160 |
+
if multi_modal_data is None or "image" not in multi_modal_data:
|
| 161 |
+
return inputs
|
| 162 |
+
|
| 163 |
+
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
| 164 |
+
vision_config = getattr(hf_config, 'vision_config', None)
|
| 165 |
+
|
| 166 |
+
if vision_config is None:
|
| 167 |
+
return inputs
|
| 168 |
+
elif isinstance(vision_config, dict):
|
| 169 |
+
image_placeholder_length = calculate_image_placeholder(vision_config)
|
| 170 |
+
else:
|
| 171 |
+
msg = f"Unsupported vision config: {type(vision_config)}"
|
| 172 |
+
raise NotImplementedError(msg)
|
| 173 |
+
|
| 174 |
+
input_ids = inputs["prompt_token_ids"]
|
| 175 |
+
|
| 176 |
+
tokenizer = cached_get_tokenizer(
|
| 177 |
+
ctx.model_config.model,
|
| 178 |
+
trust_remote_code=ctx.model_config.trust_remote_code)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
raw_batch_data = tokenizer.apply_chat_template(
|
| 182 |
+
conversation=[{
|
| 183 |
+
"role": "user",
|
| 184 |
+
"image": multi_modal_data["image"],
|
| 185 |
+
"content": inputs['prompt'],
|
| 186 |
+
}],
|
| 187 |
+
add_generation_prompt=True,
|
| 188 |
+
tokenize=True,
|
| 189 |
+
return_tensors="pt",
|
| 190 |
+
return_dict=True,
|
| 191 |
+
).data
|
| 192 |
+
except Exception:
|
| 193 |
+
logger.error("Failed to process content (%s)", inputs['prompt'])
|
| 194 |
+
raise
|
| 195 |
+
input_ids = raw_batch_data['input_ids'][0].tolist()
|
| 196 |
+
|
| 197 |
+
boi_token_id = hf_config.boi_token_id
|
| 198 |
+
eoi_token_id = hf_config.eoi_token_id
|
| 199 |
+
boi_positions = find_all_positions(input_ids, boi_token_id)
|
| 200 |
+
eoi_positions = find_all_positions(input_ids, eoi_token_id)
|
| 201 |
+
|
| 202 |
+
assert len(boi_positions) == len(eoi_positions)
|
| 203 |
+
|
| 204 |
+
new_input_ids = []
|
| 205 |
+
final_processed_position = 0
|
| 206 |
+
|
| 207 |
+
for boi_position, eoi_position in zip(boi_positions, eoi_positions):
|
| 208 |
+
assert boi_position < eoi_position
|
| 209 |
+
new_input_ids.extend(input_ids[final_processed_position:boi_position +
|
| 210 |
+
1])
|
| 211 |
+
new_input_ids.extend([input_ids[boi_position + 1]] *
|
| 212 |
+
image_placeholder_length)
|
| 213 |
+
final_processed_position = eoi_position
|
| 214 |
+
|
| 215 |
+
new_input_ids.extend(input_ids[final_processed_position:])
|
| 216 |
+
|
| 217 |
+
prompt = inputs.get("prompt")
|
| 218 |
+
if prompt is None:
|
| 219 |
+
prompt = tokenizer.decode(new_input_ids)
|
| 220 |
+
|
| 221 |
+
return token_inputs(
|
| 222 |
+
prompt_token_ids=new_input_ids,
|
| 223 |
+
prompt=prompt,
|
| 224 |
+
multi_modal_data=multi_modal_data,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class GLMAttention(nn.Module):
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
config: ChatGLMConfig,
|
| 233 |
+
cache_config: Optional[CacheConfig] = None,
|
| 234 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 235 |
+
prefix: str = "",
|
| 236 |
+
):
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.hidden_size = config.hidden_size
|
| 239 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 240 |
+
self.total_num_heads = config.num_attention_heads
|
| 241 |
+
assert self.total_num_heads % tp_size == 0
|
| 242 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 243 |
+
self.multi_query_attention = config.multi_query_attention
|
| 244 |
+
self.total_num_kv_heads = (config.multi_query_group_num
|
| 245 |
+
if config.multi_query_attention else
|
| 246 |
+
config.num_attention_heads)
|
| 247 |
+
if self.total_num_kv_heads >= tp_size:
|
| 248 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 249 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 250 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 251 |
+
else:
|
| 252 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 253 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 254 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 255 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 256 |
+
self.head_dim = config.hidden_size // self.total_num_heads
|
| 257 |
+
self.q_size = self.num_heads * self.head_dim
|
| 258 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 259 |
+
self.scaling = self.head_dim**-0.5
|
| 260 |
+
|
| 261 |
+
self.query_key_value = QKVParallelLinear(
|
| 262 |
+
self.hidden_size,
|
| 263 |
+
self.head_dim,
|
| 264 |
+
self.total_num_heads,
|
| 265 |
+
self.total_num_kv_heads,
|
| 266 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
| 267 |
+
quant_config=quant_config,
|
| 268 |
+
prefix=f"{prefix}.query_key_value",
|
| 269 |
+
)
|
| 270 |
+
self.dense = RowParallelLinear(
|
| 271 |
+
self.total_num_heads * self.head_dim,
|
| 272 |
+
config.hidden_size,
|
| 273 |
+
bias=config.add_bias_linear,
|
| 274 |
+
quant_config=quant_config,
|
| 275 |
+
prefix=f"{prefix}.dense",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
| 279 |
+
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
| 280 |
+
max_positions = getattr(config, "seq_length", 8192)
|
| 281 |
+
# NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
|
| 282 |
+
# which is equivalent to is_neox_style=True
|
| 283 |
+
is_neox_style = not config.original_rope
|
| 284 |
+
self.rotary_emb = get_rope(
|
| 285 |
+
self.head_dim,
|
| 286 |
+
rotary_dim=self.head_dim // 2,
|
| 287 |
+
max_position=max_positions,
|
| 288 |
+
base=10000 * rope_ratio,
|
| 289 |
+
is_neox_style=is_neox_style,
|
| 290 |
+
)
|
| 291 |
+
self.attn = Attention(self.num_heads,
|
| 292 |
+
self.head_dim,
|
| 293 |
+
self.scaling,
|
| 294 |
+
num_kv_heads=self.num_kv_heads,
|
| 295 |
+
cache_config=cache_config,
|
| 296 |
+
quant_config=quant_config,
|
| 297 |
+
prefix=f"{prefix}.attn")
|
| 298 |
+
|
| 299 |
+
def forward(
|
| 300 |
+
self,
|
| 301 |
+
hidden_states: torch.Tensor,
|
| 302 |
+
position_ids: torch.Tensor,
|
| 303 |
+
kv_cache: torch.Tensor,
|
| 304 |
+
attn_metadata: AttentionMetadata,
|
| 305 |
+
) -> torch.Tensor:
|
| 306 |
+
qkv, _ = self.query_key_value(hidden_states)
|
| 307 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 308 |
+
q, k = self.rotary_emb(position_ids, q, k)
|
| 309 |
+
context_layer = self.attn(
|
| 310 |
+
q,
|
| 311 |
+
k,
|
| 312 |
+
v,
|
| 313 |
+
kv_cache,
|
| 314 |
+
attn_metadata,
|
| 315 |
+
)
|
| 316 |
+
attn_output, _ = self.dense(context_layer)
|
| 317 |
+
return attn_output
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class GLMMLP(nn.Module):
|
| 321 |
+
"""MLP.
|
| 322 |
+
|
| 323 |
+
MLP will take the input with h hidden state, project it to 4*h
|
| 324 |
+
hidden dimension, perform nonlinear transformation, and project the
|
| 325 |
+
state back into h hidden dimension.
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
def __init__(
|
| 329 |
+
self,
|
| 330 |
+
config: ChatGLMConfig,
|
| 331 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 332 |
+
prefix: str = "",
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
|
| 336 |
+
self.add_bias = config.add_bias_linear
|
| 337 |
+
|
| 338 |
+
# Project to 4h.
|
| 339 |
+
self.dense_h_to_4h = MergedColumnParallelLinear(
|
| 340 |
+
config.hidden_size,
|
| 341 |
+
[config.ffn_hidden_size] * 2,
|
| 342 |
+
bias=config.add_bias_linear,
|
| 343 |
+
quant_config=quant_config,
|
| 344 |
+
prefix=f"{prefix}.dense_h_to_4h",
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
self.activation_func = SiluAndMul()
|
| 348 |
+
|
| 349 |
+
# Project back to h.
|
| 350 |
+
self.dense_4h_to_h = RowParallelLinear(
|
| 351 |
+
config.ffn_hidden_size,
|
| 352 |
+
config.hidden_size,
|
| 353 |
+
bias=config.add_bias_linear,
|
| 354 |
+
quant_config=quant_config,
|
| 355 |
+
prefix=f"{prefix}.dense_4h_to_h",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def forward(self, hidden_states):
|
| 359 |
+
# [s, b, 4hp]
|
| 360 |
+
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
| 361 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
| 362 |
+
# [s, b, h]
|
| 363 |
+
output, _ = self.dense_4h_to_h(intermediate_parallel)
|
| 364 |
+
return output
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class GLMBlock(nn.Module):
|
| 368 |
+
"""A single transformer layer.
|
| 369 |
+
|
| 370 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
| 371 |
+
output of the same size.
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
config: ChatGLMConfig,
|
| 377 |
+
cache_config: Optional[CacheConfig] = None,
|
| 378 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 379 |
+
prefix: str = "",
|
| 380 |
+
):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.apply_residual_connection_post_layernorm = (
|
| 383 |
+
config.apply_residual_connection_post_layernorm)
|
| 384 |
+
|
| 385 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
| 386 |
+
|
| 387 |
+
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
| 388 |
+
# Layernorm on the input data.
|
| 389 |
+
self.input_layernorm = layer_norm_func(config.hidden_size,
|
| 390 |
+
eps=config.layernorm_epsilon)
|
| 391 |
+
|
| 392 |
+
# Self attention.
|
| 393 |
+
self.self_attention = GLMAttention(config,
|
| 394 |
+
cache_config,
|
| 395 |
+
quant_config,
|
| 396 |
+
prefix=f"{prefix}.self_attention")
|
| 397 |
+
self.hidden_dropout = config.hidden_dropout
|
| 398 |
+
|
| 399 |
+
# Layernorm on the attention output
|
| 400 |
+
self.post_attention_layernorm = layer_norm_func(
|
| 401 |
+
config.hidden_size, eps=config.layernorm_epsilon)
|
| 402 |
+
|
| 403 |
+
# MLP
|
| 404 |
+
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self,
|
| 408 |
+
hidden_states: torch.Tensor,
|
| 409 |
+
position_ids: torch.Tensor,
|
| 410 |
+
kv_cache: torch.Tensor,
|
| 411 |
+
attn_metadata: AttentionMetadata,
|
| 412 |
+
) -> torch.Tensor:
|
| 413 |
+
# hidden_states: [num_tokens, h]
|
| 414 |
+
# Layer norm at the beginning of the transformer layer.
|
| 415 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
| 416 |
+
# Self attention.
|
| 417 |
+
attention_output = self.self_attention(
|
| 418 |
+
hidden_states=layernorm_output,
|
| 419 |
+
position_ids=position_ids,
|
| 420 |
+
kv_cache=kv_cache,
|
| 421 |
+
attn_metadata=attn_metadata,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Residual connection.
|
| 425 |
+
if self.apply_residual_connection_post_layernorm:
|
| 426 |
+
residual = layernorm_output
|
| 427 |
+
else:
|
| 428 |
+
residual = hidden_states
|
| 429 |
+
|
| 430 |
+
layernorm_input = residual + attention_output
|
| 431 |
+
|
| 432 |
+
# Layer norm post the self attention.
|
| 433 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
| 434 |
+
|
| 435 |
+
# Second residual connection.
|
| 436 |
+
if self.apply_residual_connection_post_layernorm:
|
| 437 |
+
residual = layernorm_output
|
| 438 |
+
else:
|
| 439 |
+
residual = layernorm_input
|
| 440 |
+
|
| 441 |
+
output = self.mlp(layernorm_output) + residual
|
| 442 |
+
|
| 443 |
+
return output
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class GLMTransformer(nn.Module):
|
| 447 |
+
"""Transformer class."""
|
| 448 |
+
|
| 449 |
+
def __init__(
|
| 450 |
+
self,
|
| 451 |
+
config: ChatGLMConfig,
|
| 452 |
+
cache_config: Optional[CacheConfig] = None,
|
| 453 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 454 |
+
prefix: str = "",
|
| 455 |
+
):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.post_layer_norm = config.post_layer_norm
|
| 458 |
+
|
| 459 |
+
# Number of layers.
|
| 460 |
+
self.num_layers = config.num_layers
|
| 461 |
+
|
| 462 |
+
# Transformer layers.
|
| 463 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 464 |
+
self.num_layers,
|
| 465 |
+
lambda prefix: GLMBlock(
|
| 466 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 467 |
+
prefix=f"{prefix}.layers",
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if self.post_layer_norm:
|
| 471 |
+
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
| 472 |
+
# Final layer norm before output.
|
| 473 |
+
self.final_layernorm = layer_norm_func(
|
| 474 |
+
config.hidden_size, eps=config.layernorm_epsilon)
|
| 475 |
+
|
| 476 |
+
self.make_empty_intermediate_tensors = (
|
| 477 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 478 |
+
config.hidden_size))
|
| 479 |
+
|
| 480 |
+
def forward(
|
| 481 |
+
self,
|
| 482 |
+
hidden_states: torch.Tensor,
|
| 483 |
+
position_ids: torch.Tensor,
|
| 484 |
+
kv_caches: List[torch.Tensor],
|
| 485 |
+
attn_metadata: AttentionMetadata,
|
| 486 |
+
) -> torch.Tensor:
|
| 487 |
+
for i in range(self.start_layer, self.end_layer):
|
| 488 |
+
layer = self.layers[i]
|
| 489 |
+
hidden_states = layer(
|
| 490 |
+
hidden_states=hidden_states,
|
| 491 |
+
position_ids=position_ids,
|
| 492 |
+
kv_cache=kv_caches[i - self.start_layer],
|
| 493 |
+
attn_metadata=attn_metadata,
|
| 494 |
+
)
|
| 495 |
+
# Final layer norm.
|
| 496 |
+
if get_pp_group().is_last_rank and self.post_layer_norm:
|
| 497 |
+
hidden_states = self.final_layernorm(hidden_states)
|
| 498 |
+
|
| 499 |
+
return hidden_states
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class ChatGLMModel(nn.Module):
|
| 503 |
+
|
| 504 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 505 |
+
super().__init__()
|
| 506 |
+
|
| 507 |
+
config = vllm_config.model_config.hf_config
|
| 508 |
+
cache_config = vllm_config.cache_config
|
| 509 |
+
quant_config = vllm_config.quant_config
|
| 510 |
+
|
| 511 |
+
self.config = config
|
| 512 |
+
|
| 513 |
+
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
| 514 |
+
config.hidden_size,
|
| 515 |
+
quant_config=quant_config,
|
| 516 |
+
prefix=f"{prefix}.embedding")
|
| 517 |
+
|
| 518 |
+
self.num_layers = config.num_layers
|
| 519 |
+
self.multi_query_group_num = config.multi_query_group_num
|
| 520 |
+
self.kv_channels = config.kv_channels
|
| 521 |
+
self.encoder = GLMTransformer(config,
|
| 522 |
+
cache_config,
|
| 523 |
+
quant_config,
|
| 524 |
+
prefix=f"{prefix}.encoder")
|
| 525 |
+
|
| 526 |
+
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
| 527 |
+
config.hidden_size,
|
| 528 |
+
quant_config=quant_config,
|
| 529 |
+
prefix=f"{prefix}.output_layer")
|
| 530 |
+
|
| 531 |
+
vision_config_flag = getattr(config, 'vision_config', None)
|
| 532 |
+
if vision_config_flag is not None:
|
| 533 |
+
self.vision_config = Namespace(**config.vision_config)
|
| 534 |
+
self.vision = EVA2CLIPModel(self.config,
|
| 535 |
+
quant_config,
|
| 536 |
+
prefix=f"{prefix}.vision")
|
| 537 |
+
else:
|
| 538 |
+
self.vision = None
|
| 539 |
+
|
| 540 |
+
self.make_empty_intermediate_tensors = (
|
| 541 |
+
self.encoder.make_empty_intermediate_tensors)
|
| 542 |
+
|
| 543 |
+
def _parse_and_validate_image_input(
|
| 544 |
+
self, **kwargs: object) -> GLMImagePixelInputs:
|
| 545 |
+
|
| 546 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 547 |
+
if pixel_values is not None and self.vision is not None:
|
| 548 |
+
if isinstance(pixel_values, torch.Tensor):
|
| 549 |
+
if pixel_values.ndim > 2:
|
| 550 |
+
pixel_values = torch.concat(list(pixel_values))
|
| 551 |
+
elif isinstance(pixel_values, list):
|
| 552 |
+
return torch.concat(pixel_values)
|
| 553 |
+
else:
|
| 554 |
+
raise TypeError("""pixel_values must be a torch.Tensor
|
| 555 |
+
or a list of torch.Tensor
|
| 556 |
+
""")
|
| 557 |
+
return GLMImagePixelInputs(pixel_values=pixel_values)
|
| 558 |
+
|
| 559 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 560 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 561 |
+
if image_input["pixel_values"] is None:
|
| 562 |
+
return None
|
| 563 |
+
pixel_values = image_input["pixel_values"].to(
|
| 564 |
+
dtype=self.config.torch_dtype)
|
| 565 |
+
vision_embeddings = self.vision(pixel_values)
|
| 566 |
+
return vision_embeddings
|
| 567 |
+
|
| 568 |
+
def get_input_embeddings(
|
| 569 |
+
self,
|
| 570 |
+
input_ids: torch.Tensor,
|
| 571 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 572 |
+
) -> torch.Tensor:
|
| 573 |
+
inputs_embeds = self.embedding(input_ids)
|
| 574 |
+
if multimodal_embeddings is not None:
|
| 575 |
+
inputs_embeds = merge_glm_vision_embeddings(
|
| 576 |
+
input_ids=input_ids,
|
| 577 |
+
inputs_embeds=inputs_embeds,
|
| 578 |
+
vision_embeddings=multimodal_embeddings,
|
| 579 |
+
boi_token_id=self.config.boi_token_id,
|
| 580 |
+
eoi_token_id=self.config.eoi_token_id)
|
| 581 |
+
return inputs_embeds
|
| 582 |
+
|
| 583 |
+
def forward(
|
| 584 |
+
self,
|
| 585 |
+
input_ids: torch.Tensor,
|
| 586 |
+
positions: torch.Tensor,
|
| 587 |
+
kv_caches: List[torch.Tensor],
|
| 588 |
+
attn_metadata: AttentionMetadata,
|
| 589 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 590 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 591 |
+
**kwargs: object,
|
| 592 |
+
) -> torch.Tensor:
|
| 593 |
+
|
| 594 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 595 |
+
# condition is for v0 compatibility.
|
| 596 |
+
if intermediate_tensors is None and inputs_embeds is None:
|
| 597 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 598 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 599 |
+
vision_embeddings)
|
| 600 |
+
input_ids = None
|
| 601 |
+
else:
|
| 602 |
+
inputs_embeds = intermediate_tensors["hidden_states"]
|
| 603 |
+
|
| 604 |
+
# Run encoder.
|
| 605 |
+
hidden_states = self.encoder(
|
| 606 |
+
hidden_states=inputs_embeds,
|
| 607 |
+
position_ids=positions,
|
| 608 |
+
kv_caches=kv_caches,
|
| 609 |
+
attn_metadata=attn_metadata,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
if not get_pp_group().is_last_rank:
|
| 613 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 614 |
+
return hidden_states
|
| 615 |
+
|
| 616 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 617 |
+
torch.Tensor]]) -> Set[str]:
|
| 618 |
+
stacked_params_mapping = [
|
| 619 |
+
# (param_name, shard_name, shard_id)
|
| 620 |
+
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
|
| 621 |
+
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
|
| 622 |
+
]
|
| 623 |
+
params_dict = dict(self.named_parameters())
|
| 624 |
+
loaded_params: Set[str] = set()
|
| 625 |
+
|
| 626 |
+
for name, loaded_weight in weights:
|
| 627 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 628 |
+
if weight_name not in name:
|
| 629 |
+
continue
|
| 630 |
+
name = name.replace(weight_name, param_name)
|
| 631 |
+
# Skip loading extra bias for GPTQ models.
|
| 632 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 633 |
+
continue
|
| 634 |
+
if is_pp_missing_parameter(name, self):
|
| 635 |
+
continue
|
| 636 |
+
param = params_dict[name]
|
| 637 |
+
weight_loader = param.weight_loader
|
| 638 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 639 |
+
break
|
| 640 |
+
else:
|
| 641 |
+
if "rotary_pos_emb.inv_freq" in name:
|
| 642 |
+
continue
|
| 643 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 644 |
+
continue
|
| 645 |
+
if is_pp_missing_parameter(name, self):
|
| 646 |
+
continue
|
| 647 |
+
param = params_dict[name]
|
| 648 |
+
weight_loader = getattr(param, "weight_loader",
|
| 649 |
+
default_weight_loader)
|
| 650 |
+
weight_loader(param, loaded_weight)
|
| 651 |
+
loaded_params.add(name)
|
| 652 |
+
return loaded_params
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
| 656 |
+
|
| 657 |
+
hf_to_vllm_mapper = WeightsMapper(
|
| 658 |
+
orig_to_new_substr={".word_embeddings": ""}, )
|
| 659 |
+
|
| 660 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 661 |
+
super().__init__()
|
| 662 |
+
config = vllm_config.model_config.hf_config
|
| 663 |
+
quant_config = vllm_config.quant_config
|
| 664 |
+
lora_config = vllm_config.lora_config
|
| 665 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 666 |
+
self.config = config
|
| 667 |
+
self.lora_config = lora_config
|
| 668 |
+
self.multimodal_config = multimodal_config
|
| 669 |
+
|
| 670 |
+
self.quant_config = quant_config
|
| 671 |
+
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
| 672 |
+
8192)
|
| 673 |
+
self.transformer = ChatGLMModel(vllm_config=vllm_config,
|
| 674 |
+
prefix=maybe_prefix(
|
| 675 |
+
prefix, "transformer"))
|
| 676 |
+
if self.config.tie_word_embeddings:
|
| 677 |
+
self.transformer.output_layer.weight = (
|
| 678 |
+
self.transformer.embedding.weight)
|
| 679 |
+
self.lm_head = self.transformer.output_layer
|
| 680 |
+
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
| 681 |
+
self.sampler = get_sampler()
|
| 682 |
+
|
| 683 |
+
def forward(self,
|
| 684 |
+
input_ids: torch.Tensor,
|
| 685 |
+
positions: torch.Tensor,
|
| 686 |
+
kv_caches: List[torch.Tensor],
|
| 687 |
+
attn_metadata: AttentionMetadata,
|
| 688 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 689 |
+
**kwargs) -> torch.Tensor:
|
| 690 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 691 |
+
attn_metadata, intermediate_tensors,
|
| 692 |
+
**kwargs)
|
| 693 |
+
return hidden_states
|
| 694 |
+
|
| 695 |
+
def compute_logits(
|
| 696 |
+
self,
|
| 697 |
+
hidden_states: torch.Tensor,
|
| 698 |
+
sampling_metadata: SamplingMetadata,
|
| 699 |
+
) -> Optional[torch.Tensor]:
|
| 700 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 701 |
+
sampling_metadata)
|
| 702 |
+
return logits
|
| 703 |
+
|
| 704 |
+
def sample(
|
| 705 |
+
self,
|
| 706 |
+
logits: torch.Tensor,
|
| 707 |
+
sampling_metadata: SamplingMetadata,
|
| 708 |
+
) -> Optional[SamplerOutput]:
|
| 709 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 710 |
+
return next_tokens
|
| 711 |
+
|
| 712 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 713 |
+
loader = AutoWeightsLoader(self)
|
| 714 |
+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class ChatGLM(ChatGLMBaseModel):
|
| 718 |
+
packed_modules_mapping = {
|
| 719 |
+
"query_key_value": ["query_key_value"],
|
| 720 |
+
"dense_h_to_4h": ["dense_h_to_4h"]
|
| 721 |
+
}
|
| 722 |
+
# LoRA specific attributes
|
| 723 |
+
supported_lora_modules = [
|
| 724 |
+
"query_key_value",
|
| 725 |
+
"dense",
|
| 726 |
+
"dense_h_to_4h",
|
| 727 |
+
"dense_4h_to_h",
|
| 728 |
+
]
|
| 729 |
+
|
| 730 |
+
embedding_modules = {}
|
| 731 |
+
embedding_padding_modules = []
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
| 735 |
+
|
| 736 |
+
packed_modules_mapping = {
|
| 737 |
+
"query_key_value": ["query_key_value"],
|
| 738 |
+
"dense_h_to_4h": ["dense_h_to_4h"],
|
| 739 |
+
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
| 740 |
+
}
|
| 741 |
+
# LoRA specific attributes
|
| 742 |
+
supported_lora_modules = [
|
| 743 |
+
"query_key_value",
|
| 744 |
+
"dense",
|
| 745 |
+
"dense_h_to_4h",
|
| 746 |
+
"dense_4h_to_h",
|
| 747 |
+
# vision
|
| 748 |
+
"fc1",
|
| 749 |
+
"fc2",
|
| 750 |
+
"merged_proj",
|
| 751 |
+
"linear_proj"
|
| 752 |
+
]
|
| 753 |
+
|
| 754 |
+
embedding_modules = {}
|
| 755 |
+
embedding_padding_modules = []
|
| 756 |
+
|
| 757 |
+
def get_mm_mapping(self) -> MultiModelKeys:
|
| 758 |
+
"""
|
| 759 |
+
Get the module prefix in multimodal models
|
| 760 |
+
"""
|
| 761 |
+
return MultiModelKeys.from_string_field(
|
| 762 |
+
language_model="transformer.encoder",
|
| 763 |
+
connector="transformer.vision.linear_proj",
|
| 764 |
+
tower_model="transformer.vision.transformer")
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
| 768 |
+
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
| 769 |
+
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
| 770 |
+
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
| 771 |
+
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
| 772 |
+
SupportsMultiModal):
|
| 773 |
+
# Ensure that the LoRA support check passes when the class is not
|
| 774 |
+
# initialized, but set all these attributes to empty.
|
| 775 |
+
# These will be updated when an instance class is selected
|
| 776 |
+
packed_modules_mapping = {}
|
| 777 |
+
supported_lora_modules = []
|
| 778 |
+
embedding_modules = {}
|
| 779 |
+
embedding_padding_modules = []
|
| 780 |
+
|
| 781 |
+
def __new__(
|
| 782 |
+
cls,
|
| 783 |
+
vllm_config: VllmConfig,
|
| 784 |
+
prefix: str = "",
|
| 785 |
+
) -> None:
|
| 786 |
+
config = vllm_config.model_config.hf_config
|
| 787 |
+
|
| 788 |
+
# Initialize VL
|
| 789 |
+
if hasattr(config, "vision_config"): # noqa: SIM108
|
| 790 |
+
instance_cls = ChatGLMV
|
| 791 |
+
# Initialize LLM
|
| 792 |
+
else:
|
| 793 |
+
instance_cls = ChatGLM
|
| 794 |
+
|
| 795 |
+
# quant_config references base class members,
|
| 796 |
+
# so update values before init is called
|
| 797 |
+
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
| 798 |
+
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
| 799 |
+
cls.embedding_modules.update(instance_cls.embedding_modules)
|
| 800 |
+
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
| 801 |
+
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 9 |
+
# and OPT implementations in this library. It has been modified from its
|
| 10 |
+
# original forms to accommodate minor architectural differences compared
|
| 11 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
"""Inference-only Deepseek model."""
|
| 25 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import nn
|
| 29 |
+
from transformers import PretrainedConfig
|
| 30 |
+
|
| 31 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 32 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 33 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 34 |
+
get_tensor_model_parallel_world_size,
|
| 35 |
+
tensor_model_parallel_all_reduce)
|
| 36 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 37 |
+
from vllm.model_executor.layers.fused_moe import fused_moe
|
| 38 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 39 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 40 |
+
QKVParallelLinear,
|
| 41 |
+
ReplicatedLinear,
|
| 42 |
+
RowParallelLinear)
|
| 43 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 44 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 45 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 46 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 47 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 48 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 49 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 50 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 51 |
+
from vllm.sequence import IntermediateTensors
|
| 52 |
+
|
| 53 |
+
from .interfaces import SupportsPP
|
| 54 |
+
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
| 55 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 56 |
+
maybe_prefix)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class DeepseekMLP(nn.Module):
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
hidden_size: int,
|
| 64 |
+
intermediate_size: int,
|
| 65 |
+
hidden_act: str,
|
| 66 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 67 |
+
reduce_results: bool = True,
|
| 68 |
+
prefix: str = "",
|
| 69 |
+
) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 72 |
+
hidden_size, [intermediate_size] * 2,
|
| 73 |
+
bias=False,
|
| 74 |
+
quant_config=quant_config)
|
| 75 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 76 |
+
hidden_size,
|
| 77 |
+
bias=False,
|
| 78 |
+
quant_config=quant_config,
|
| 79 |
+
reduce_results=reduce_results)
|
| 80 |
+
if hidden_act != "silu":
|
| 81 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 82 |
+
"Only silu is supported for now.")
|
| 83 |
+
self.act_fn = SiluAndMul()
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 87 |
+
x = self.act_fn(gate_up)
|
| 88 |
+
x, _ = self.down_proj(x)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class DeepseekMoE(nn.Module):
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
config: PretrainedConfig,
|
| 97 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 98 |
+
prefix: str = "",
|
| 99 |
+
):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.config = config
|
| 102 |
+
self.rank = get_tensor_model_parallel_rank()
|
| 103 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 104 |
+
self.n_routed_experts = config.n_routed_experts
|
| 105 |
+
self.top_k = config.num_experts_per_tok
|
| 106 |
+
if self.tp_size > self.n_routed_experts:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"Tensor parallel size {self.tp_size} is greater than "
|
| 109 |
+
f"the number of experts {self.n_routed_experts}.")
|
| 110 |
+
|
| 111 |
+
self.experts = nn.ModuleList([
|
| 112 |
+
DeepseekMLP(hidden_size=config.hidden_size,
|
| 113 |
+
intermediate_size=config.moe_intermediate_size,
|
| 114 |
+
hidden_act=config.hidden_act,
|
| 115 |
+
quant_config=quant_config,
|
| 116 |
+
reduce_results=False)
|
| 117 |
+
for idx in range(self.n_routed_experts)
|
| 118 |
+
])
|
| 119 |
+
self.pack_params()
|
| 120 |
+
|
| 121 |
+
self.gate = ReplicatedLinear(config.hidden_size,
|
| 122 |
+
self.n_routed_experts,
|
| 123 |
+
bias=False,
|
| 124 |
+
quant_config=None)
|
| 125 |
+
|
| 126 |
+
if config.n_shared_experts is not None:
|
| 127 |
+
intermediate_size = (config.moe_intermediate_size *
|
| 128 |
+
config.n_shared_experts)
|
| 129 |
+
self.shared_experts = DeepseekMLP(
|
| 130 |
+
hidden_size=config.hidden_size,
|
| 131 |
+
intermediate_size=intermediate_size,
|
| 132 |
+
hidden_act=config.hidden_act,
|
| 133 |
+
quant_config=quant_config,
|
| 134 |
+
reduce_results=False,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def pack_params(self):
|
| 138 |
+
w1 = []
|
| 139 |
+
w2 = []
|
| 140 |
+
for expert in self.experts:
|
| 141 |
+
w1.append(expert.gate_up_proj.weight)
|
| 142 |
+
w2.append(expert.down_proj.weight)
|
| 143 |
+
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
| 144 |
+
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
| 145 |
+
for data, param in zip(w1s, w1):
|
| 146 |
+
param.data = data
|
| 147 |
+
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
| 148 |
+
|
| 149 |
+
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
| 150 |
+
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
| 151 |
+
for data, param in zip(w2s, w2):
|
| 152 |
+
param.data = data
|
| 153 |
+
|
| 154 |
+
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
| 155 |
+
|
| 156 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
num_tokens, hidden_dim = hidden_states.shape
|
| 158 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 159 |
+
if self.config.n_shared_experts is not None:
|
| 160 |
+
shared_output = self.shared_experts(hidden_states)
|
| 161 |
+
# router_logits: (num_tokens, n_experts)
|
| 162 |
+
router_logits, _ = self.gate(hidden_states)
|
| 163 |
+
final_hidden_states = fused_moe(hidden_states,
|
| 164 |
+
self.w1,
|
| 165 |
+
self.w2,
|
| 166 |
+
router_logits,
|
| 167 |
+
self.top_k,
|
| 168 |
+
renormalize=self.config.norm_topk_prob,
|
| 169 |
+
inplace=True)
|
| 170 |
+
|
| 171 |
+
if self.config.n_shared_experts is not None:
|
| 172 |
+
final_hidden_states = final_hidden_states + shared_output
|
| 173 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 174 |
+
final_hidden_states)
|
| 175 |
+
|
| 176 |
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class DeepseekAttention(nn.Module):
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
hidden_size: int,
|
| 184 |
+
num_heads: int,
|
| 185 |
+
num_kv_heads: int,
|
| 186 |
+
rope_theta: float = 10000,
|
| 187 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 188 |
+
max_position_embeddings: int = 8192,
|
| 189 |
+
cache_config: Optional[CacheConfig] = None,
|
| 190 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 191 |
+
prefix: str = "",
|
| 192 |
+
) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.hidden_size = hidden_size
|
| 195 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 196 |
+
self.total_num_heads = num_heads
|
| 197 |
+
assert self.total_num_heads % tp_size == 0
|
| 198 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 199 |
+
self.total_num_kv_heads = num_kv_heads
|
| 200 |
+
if self.total_num_kv_heads >= tp_size:
|
| 201 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 202 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 203 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 204 |
+
else:
|
| 205 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 206 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 207 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 208 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 209 |
+
self.head_dim = hidden_size // self.total_num_heads
|
| 210 |
+
self.q_size = self.num_heads * self.head_dim
|
| 211 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 212 |
+
self.scaling = self.head_dim**-0.5
|
| 213 |
+
self.rope_theta = rope_theta
|
| 214 |
+
self.max_position_embeddings = max_position_embeddings
|
| 215 |
+
|
| 216 |
+
self.qkv_proj = QKVParallelLinear(
|
| 217 |
+
hidden_size,
|
| 218 |
+
self.head_dim,
|
| 219 |
+
self.total_num_heads,
|
| 220 |
+
self.total_num_kv_heads,
|
| 221 |
+
bias=False,
|
| 222 |
+
quant_config=quant_config,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.o_proj = RowParallelLinear(
|
| 226 |
+
self.total_num_heads * self.head_dim,
|
| 227 |
+
hidden_size,
|
| 228 |
+
bias=False,
|
| 229 |
+
quant_config=quant_config,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.rotary_emb = get_rope(
|
| 233 |
+
self.head_dim,
|
| 234 |
+
rotary_dim=self.head_dim,
|
| 235 |
+
max_position=max_position_embeddings,
|
| 236 |
+
base=rope_theta,
|
| 237 |
+
rope_scaling=rope_scaling,
|
| 238 |
+
)
|
| 239 |
+
self.attn = Attention(self.num_heads,
|
| 240 |
+
self.head_dim,
|
| 241 |
+
self.scaling,
|
| 242 |
+
num_kv_heads=self.num_kv_heads,
|
| 243 |
+
cache_config=cache_config,
|
| 244 |
+
quant_config=quant_config,
|
| 245 |
+
prefix=f"{prefix}.attn")
|
| 246 |
+
|
| 247 |
+
def forward(
|
| 248 |
+
self,
|
| 249 |
+
positions: torch.Tensor,
|
| 250 |
+
hidden_states: torch.Tensor,
|
| 251 |
+
kv_cache: torch.Tensor,
|
| 252 |
+
attn_metadata: AttentionMetadata,
|
| 253 |
+
) -> torch.Tensor:
|
| 254 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 255 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 256 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 257 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 258 |
+
output, _ = self.o_proj(attn_output)
|
| 259 |
+
return output
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class DeepseekDecoderLayer(nn.Module):
|
| 263 |
+
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
config: PretrainedConfig,
|
| 267 |
+
cache_config: Optional[CacheConfig] = None,
|
| 268 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 269 |
+
prefix: str = "",
|
| 270 |
+
) -> None:
|
| 271 |
+
super().__init__()
|
| 272 |
+
layer_idx = extract_layer_index(prefix)
|
| 273 |
+
self.hidden_size = config.hidden_size
|
| 274 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 275 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 276 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 277 |
+
8192)
|
| 278 |
+
self.self_attn = DeepseekAttention(
|
| 279 |
+
hidden_size=self.hidden_size,
|
| 280 |
+
num_heads=config.num_attention_heads,
|
| 281 |
+
num_kv_heads=config.num_key_value_heads,
|
| 282 |
+
rope_theta=rope_theta,
|
| 283 |
+
rope_scaling=rope_scaling,
|
| 284 |
+
max_position_embeddings=max_position_embeddings,
|
| 285 |
+
cache_config=cache_config,
|
| 286 |
+
quant_config=quant_config,
|
| 287 |
+
prefix=f"{prefix}.self_attn",
|
| 288 |
+
)
|
| 289 |
+
if (config.n_routed_experts is not None
|
| 290 |
+
and layer_idx >= config.first_k_dense_replace
|
| 291 |
+
and layer_idx % config.moe_layer_freq == 0):
|
| 292 |
+
self.mlp = DeepseekMoE(config=config,
|
| 293 |
+
quant_config=quant_config,
|
| 294 |
+
prefix=f"{prefix}.mlp")
|
| 295 |
+
else:
|
| 296 |
+
self.mlp = DeepseekMLP(
|
| 297 |
+
hidden_size=config.hidden_size,
|
| 298 |
+
intermediate_size=config.intermediate_size,
|
| 299 |
+
hidden_act=config.hidden_act,
|
| 300 |
+
quant_config=quant_config,
|
| 301 |
+
prefix=f"{prefix}.mlp",
|
| 302 |
+
)
|
| 303 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 304 |
+
eps=config.rms_norm_eps)
|
| 305 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 306 |
+
eps=config.rms_norm_eps)
|
| 307 |
+
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
positions: torch.Tensor,
|
| 311 |
+
hidden_states: torch.Tensor,
|
| 312 |
+
kv_cache: torch.Tensor,
|
| 313 |
+
attn_metadata: AttentionMetadata,
|
| 314 |
+
residual: Optional[torch.Tensor],
|
| 315 |
+
) -> torch.Tensor:
|
| 316 |
+
# Self Attention
|
| 317 |
+
if residual is None:
|
| 318 |
+
residual = hidden_states
|
| 319 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 320 |
+
else:
|
| 321 |
+
hidden_states, residual = self.input_layernorm(
|
| 322 |
+
hidden_states, residual)
|
| 323 |
+
hidden_states = self.self_attn(
|
| 324 |
+
positions=positions,
|
| 325 |
+
hidden_states=hidden_states,
|
| 326 |
+
kv_cache=kv_cache,
|
| 327 |
+
attn_metadata=attn_metadata,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Fully Connected
|
| 331 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 332 |
+
hidden_states, residual)
|
| 333 |
+
hidden_states = self.mlp(hidden_states)
|
| 334 |
+
return hidden_states, residual
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class DeepseekModel(nn.Module):
|
| 338 |
+
|
| 339 |
+
fall_back_to_pt_during_load = False
|
| 340 |
+
|
| 341 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
config = vllm_config.model_config.hf_config
|
| 345 |
+
cache_config = vllm_config.cache_config
|
| 346 |
+
quant_config = vllm_config.quant_config
|
| 347 |
+
|
| 348 |
+
self.padding_idx = config.pad_token_id
|
| 349 |
+
self.vocab_size = config.vocab_size
|
| 350 |
+
|
| 351 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 352 |
+
config.vocab_size,
|
| 353 |
+
config.hidden_size,
|
| 354 |
+
)
|
| 355 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 356 |
+
config.num_hidden_layers,
|
| 357 |
+
lambda prefix: DeepseekDecoderLayer(
|
| 358 |
+
config, cache_config, quant_config=quant_config, prefix=prefix
|
| 359 |
+
),
|
| 360 |
+
prefix=f"{prefix}.layers")
|
| 361 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 362 |
+
self.make_empty_intermediate_tensors = (
|
| 363 |
+
make_empty_intermediate_tensors_factory(
|
| 364 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 365 |
+
|
| 366 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 367 |
+
return self.embed_tokens(input_ids)
|
| 368 |
+
|
| 369 |
+
def forward(
|
| 370 |
+
self,
|
| 371 |
+
input_ids: torch.Tensor,
|
| 372 |
+
positions: torch.Tensor,
|
| 373 |
+
kv_caches: List[torch.Tensor],
|
| 374 |
+
attn_metadata: AttentionMetadata,
|
| 375 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 376 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 377 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 378 |
+
if get_pp_group().is_first_rank:
|
| 379 |
+
if inputs_embeds is not None:
|
| 380 |
+
hidden_states = inputs_embeds
|
| 381 |
+
else:
|
| 382 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 383 |
+
residual = None
|
| 384 |
+
else:
|
| 385 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 386 |
+
residual = intermediate_tensors["residual"]
|
| 387 |
+
for i in range(self.start_layer, self.end_layer):
|
| 388 |
+
layer = self.layers[i]
|
| 389 |
+
hidden_states, residual = layer(positions, hidden_states,
|
| 390 |
+
kv_caches[i - self.start_layer],
|
| 391 |
+
attn_metadata, residual)
|
| 392 |
+
if not get_pp_group().is_last_rank:
|
| 393 |
+
return IntermediateTensors({
|
| 394 |
+
"hidden_states": hidden_states,
|
| 395 |
+
"residual": residual
|
| 396 |
+
})
|
| 397 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 398 |
+
return hidden_states
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class DeepseekForCausalLM(nn.Module, SupportsPP):
|
| 402 |
+
|
| 403 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 404 |
+
super().__init__()
|
| 405 |
+
config = vllm_config.model_config.hf_config
|
| 406 |
+
quant_config = vllm_config.quant_config
|
| 407 |
+
self.config = config
|
| 408 |
+
self.quant_config = quant_config
|
| 409 |
+
self.model = DeepseekModel(vllm_config=vllm_config,
|
| 410 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 411 |
+
self.lm_head = ParallelLMHead(config.vocab_size,
|
| 412 |
+
config.hidden_size,
|
| 413 |
+
quant_config=quant_config)
|
| 414 |
+
if self.config.tie_word_embeddings:
|
| 415 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 416 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 417 |
+
self.sampler = get_sampler()
|
| 418 |
+
self.make_empty_intermediate_tensors = (
|
| 419 |
+
self.model.make_empty_intermediate_tensors)
|
| 420 |
+
|
| 421 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 422 |
+
return self.model.get_input_embeddings(input_ids)
|
| 423 |
+
|
| 424 |
+
def forward(
|
| 425 |
+
self,
|
| 426 |
+
input_ids: torch.Tensor,
|
| 427 |
+
positions: torch.Tensor,
|
| 428 |
+
kv_caches: List[torch.Tensor],
|
| 429 |
+
attn_metadata: AttentionMetadata,
|
| 430 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 431 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 432 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 433 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 434 |
+
attn_metadata, intermediate_tensors,
|
| 435 |
+
inputs_embeds)
|
| 436 |
+
return hidden_states
|
| 437 |
+
|
| 438 |
+
def compute_logits(
|
| 439 |
+
self,
|
| 440 |
+
hidden_states: torch.Tensor,
|
| 441 |
+
sampling_metadata: SamplingMetadata,
|
| 442 |
+
) -> Optional[torch.Tensor]:
|
| 443 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 444 |
+
sampling_metadata)
|
| 445 |
+
return logits
|
| 446 |
+
|
| 447 |
+
def sample(
|
| 448 |
+
self,
|
| 449 |
+
logits: Optional[torch.Tensor],
|
| 450 |
+
sampling_metadata: SamplingMetadata,
|
| 451 |
+
) -> Optional[SamplerOutput]:
|
| 452 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 453 |
+
return next_tokens
|
| 454 |
+
|
| 455 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 456 |
+
torch.Tensor]]) -> Set[str]:
|
| 457 |
+
stacked_params_mapping = [
|
| 458 |
+
# (param_name, shard_name, shard_id)
|
| 459 |
+
("qkv_proj", "q_proj", "q"),
|
| 460 |
+
("qkv_proj", "k_proj", "k"),
|
| 461 |
+
("qkv_proj", "v_proj", "v"),
|
| 462 |
+
("gate_up_proj", "gate_proj", 0),
|
| 463 |
+
("gate_up_proj", "up_proj", 1),
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
params_dict = dict(self.named_parameters())
|
| 467 |
+
loaded_params: Set[str] = set()
|
| 468 |
+
for name, loaded_weight in weights:
|
| 469 |
+
if "rotary_emb.inv_freq" in name:
|
| 470 |
+
continue
|
| 471 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 472 |
+
if weight_name not in name:
|
| 473 |
+
continue
|
| 474 |
+
name = name.replace(weight_name, param_name)
|
| 475 |
+
# Skip loading extra bias for GPTQ models.
|
| 476 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 477 |
+
continue
|
| 478 |
+
# Skip experts that are not assigned to this worker.
|
| 479 |
+
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
| 480 |
+
and name not in params_dict):
|
| 481 |
+
continue
|
| 482 |
+
if is_pp_missing_parameter(name, self):
|
| 483 |
+
continue
|
| 484 |
+
param = params_dict[name]
|
| 485 |
+
weight_loader = param.weight_loader
|
| 486 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 487 |
+
break
|
| 488 |
+
else:
|
| 489 |
+
# Skip loading extra bias for GPTQ models.
|
| 490 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 491 |
+
continue
|
| 492 |
+
# Skip experts that are not assigned to this worker.
|
| 493 |
+
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
| 494 |
+
and name not in params_dict):
|
| 495 |
+
continue
|
| 496 |
+
if is_pp_missing_parameter(name, self):
|
| 497 |
+
continue
|
| 498 |
+
param = params_dict[name]
|
| 499 |
+
weight_loader = getattr(param, "weight_loader",
|
| 500 |
+
default_weight_loader)
|
| 501 |
+
weight_loader(param, loaded_weight)
|
| 502 |
+
loaded_params.add(name)
|
| 503 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/eagle.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Iterable, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from vllm.attention.backends.abstract import AttentionMetadata
|
| 9 |
+
from vllm.config import VllmConfig
|
| 10 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 11 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 12 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 13 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
| 14 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 15 |
+
from vllm.model_executor.models import ModelRegistry
|
| 16 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 17 |
+
from vllm.sequence import IntermediateTensors
|
| 18 |
+
|
| 19 |
+
from .utils import maybe_prefix
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DummyInputLayerNorm(nn.Module):
|
| 23 |
+
|
| 24 |
+
def __init__(self, weight=None, bias=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.weight = nn.Parameter(weight) if weight is not None else None
|
| 27 |
+
self.bias = nn.Parameter(bias) if bias is not None else None
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DummyOutputNorm(nn.Module):
|
| 34 |
+
|
| 35 |
+
def forward(self, x, residual):
|
| 36 |
+
if residual is None:
|
| 37 |
+
return x
|
| 38 |
+
else:
|
| 39 |
+
return x, residual
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class EAGLE(nn.Module):
|
| 43 |
+
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
| 44 |
+
Reference implementation: https://github.com/SafeAILab/EAGLE
|
| 45 |
+
|
| 46 |
+
Differences from reference implementation:
|
| 47 |
+
1. In reference, LlamaDecoderLayer implementation doesn't have
|
| 48 |
+
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
|
| 49 |
+
Following this approach, our implementation also disables
|
| 50 |
+
the input_layernorm for the first decoder layer.
|
| 51 |
+
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
| 52 |
+
decoder layer is fixed to be LlamaDecoderLayer.
|
| 53 |
+
3. We have an optional token_map which reduces draft vocab to most
|
| 54 |
+
frequently used tokens to give some additional speed-up by reducing
|
| 55 |
+
sampling overhead. This is disabled unless the checkpoint file has
|
| 56 |
+
explicit token_map tensor and config has an optional attribute
|
| 57 |
+
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
| 58 |
+
the top-k most frequent tokens in target dataset and add that as a tensor
|
| 59 |
+
in the draft checkpoint (using key token_map). Also, the draft config
|
| 60 |
+
needs to have truncated_vocab_size (=k) as an attribute."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 63 |
+
super().__init__()
|
| 64 |
+
config = vllm_config.model_config.hf_config
|
| 65 |
+
self.config = config
|
| 66 |
+
|
| 67 |
+
architectures = getattr(self.config.model, "architectures", [])
|
| 68 |
+
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
|
| 69 |
+
|
| 70 |
+
self.model = model_cls(vllm_config=vllm_config,
|
| 71 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 72 |
+
|
| 73 |
+
self.fc = nn.Linear(config.model.hidden_size * 2,
|
| 74 |
+
config.model.hidden_size,
|
| 75 |
+
bias=getattr(self.config, "eagle_fc_bias", False))
|
| 76 |
+
|
| 77 |
+
# Modify layer normalization and residual connections as suggested
|
| 78 |
+
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
|
| 79 |
+
# While weights and biases are generally not needed,
|
| 80 |
+
# they are retained here to support certain unit tests
|
| 81 |
+
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
|
| 82 |
+
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
|
| 83 |
+
weight=self.model.model.layers[0].input_layernorm.weight)
|
| 84 |
+
self.model.model.norm = DummyOutputNorm()
|
| 85 |
+
|
| 86 |
+
self.orig_vocab_size = config.vocab_size
|
| 87 |
+
self.truncated_vocab_size = config.truncated_vocab_size
|
| 88 |
+
self.unpadded_vocab_size = self.truncated_vocab_size
|
| 89 |
+
|
| 90 |
+
self.lm_head = ParallelLMHead(
|
| 91 |
+
self.unpadded_vocab_size,
|
| 92 |
+
config.hidden_size,
|
| 93 |
+
org_num_embeddings=self.truncated_vocab_size,
|
| 94 |
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 98 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 99 |
+
self.truncated_vocab_size,
|
| 100 |
+
logit_scale)
|
| 101 |
+
|
| 102 |
+
# Token map is a idx to token mapping to reduce the vocab size for
|
| 103 |
+
# the draft model. Using smaller vocab size for draft, containing
|
| 104 |
+
# only most frequent tokens reduces the speculation overhead. This
|
| 105 |
+
# doesn't affect the acceptance rate much and thus gives more speed
|
| 106 |
+
# -up. By default, this is disabled and is only used if the EAGLE
|
| 107 |
+
# checkpoint file has token_map tensor.
|
| 108 |
+
self.token_map = None
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def sampler(self):
|
| 112 |
+
return self.model.sampler
|
| 113 |
+
|
| 114 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
return self.model.model.get_input_embeddings(input_ids)
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
input_ids: torch.Tensor,
|
| 120 |
+
positions: torch.Tensor,
|
| 121 |
+
kv_caches: List[torch.Tensor],
|
| 122 |
+
attn_metadata: AttentionMetadata,
|
| 123 |
+
previous_hidden_states: torch.Tensor,
|
| 124 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 125 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
|
| 128 |
+
if inputs_embeds is None:
|
| 129 |
+
inputs_embeds = self.get_input_embeddings(input_ids)
|
| 130 |
+
|
| 131 |
+
inputs_embeds = self.fc(
|
| 132 |
+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
| 133 |
+
|
| 134 |
+
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
| 135 |
+
|
| 136 |
+
hidden_states = self.model.model(
|
| 137 |
+
input_ids=None,
|
| 138 |
+
inputs_embeds=inputs_embeds,
|
| 139 |
+
positions=positions,
|
| 140 |
+
kv_caches=kv_caches,
|
| 141 |
+
attn_metadata=attn_metadata,
|
| 142 |
+
intermediate_tensors=intermediate_tensors,
|
| 143 |
+
)
|
| 144 |
+
return hidden_states
|
| 145 |
+
|
| 146 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
| 147 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
| 148 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 149 |
+
sampling_metadata)
|
| 150 |
+
|
| 151 |
+
if self.token_map is not None:
|
| 152 |
+
_logits = logits
|
| 153 |
+
logits = -torch.inf * torch.ones(
|
| 154 |
+
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
| 155 |
+
device=_logits.device,
|
| 156 |
+
dtype=_logits.dtype)
|
| 157 |
+
|
| 158 |
+
logits[..., self.token_map] = _logits
|
| 159 |
+
|
| 160 |
+
return logits
|
| 161 |
+
|
| 162 |
+
def sample(
|
| 163 |
+
self,
|
| 164 |
+
logits: torch.Tensor,
|
| 165 |
+
sampling_metadata: SamplingMetadata,
|
| 166 |
+
) -> Optional[SamplerOutput]:
|
| 167 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 168 |
+
return next_tokens
|
| 169 |
+
|
| 170 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 171 |
+
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
|
| 172 |
+
# due to missing lm_head weights and its config being that of a
|
| 173 |
+
# Llama model. Here's a compatible version with the same weights:
|
| 174 |
+
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
|
| 175 |
+
# Also, here's an example script for converting trained EAGLE
|
| 176 |
+
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
|
| 177 |
+
model_weights = {}
|
| 178 |
+
for name, loaded_weight in weights:
|
| 179 |
+
if name == "token_map":
|
| 180 |
+
if self.config.truncated_vocab_size < self.config.vocab_size:
|
| 181 |
+
self.token_map = nn.Parameter(loaded_weight,
|
| 182 |
+
requires_grad=False)
|
| 183 |
+
elif name.startswith("fc.weight"):
|
| 184 |
+
weight_loader = getattr(self.fc.weight, "weight_loader",
|
| 185 |
+
default_weight_loader)
|
| 186 |
+
weight_loader(self.fc.weight, loaded_weight)
|
| 187 |
+
elif name.startswith("fc.bias"):
|
| 188 |
+
if self.fc.bias is not None:
|
| 189 |
+
weight_loader = getattr(self.fc.bias, "weight_loader",
|
| 190 |
+
default_weight_loader)
|
| 191 |
+
weight_loader(self.fc.bias, loaded_weight)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError("Found bias in the loaded weights "
|
| 194 |
+
"but the model config doesn't have bias")
|
| 195 |
+
elif name.startswith("model.lm_head.") or name.startswith(
|
| 196 |
+
"model.model."):
|
| 197 |
+
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
| 198 |
+
elif name.startswith("lm_head.") or name.startswith("model."):
|
| 199 |
+
model_weights[name] = loaded_weight
|
| 200 |
+
else:
|
| 201 |
+
model_weights[f"model.{name}"] = loaded_weight
|
| 202 |
+
|
| 203 |
+
lm_head_weight = model_weights.pop("lm_head.weight")
|
| 204 |
+
|
| 205 |
+
if self.token_map is not None and\
|
| 206 |
+
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
| 207 |
+
|
| 208 |
+
lm_head_weight = lm_head_weight[self.token_map]
|
| 209 |
+
|
| 210 |
+
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
| 211 |
+
default_weight_loader)
|
| 212 |
+
weight_loader(self.lm_head.weight, lm_head_weight)
|
| 213 |
+
|
| 214 |
+
self.model.load_weights(model_weights.items())
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/falcon.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
|
| 7 |
+
# reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch Falcon model."""
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn
|
| 27 |
+
from torch.nn import LayerNorm
|
| 28 |
+
from transformers import FalconConfig as HF_FalconConfig
|
| 29 |
+
|
| 30 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 31 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 32 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 33 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 34 |
+
get_tensor_model_parallel_world_size,
|
| 35 |
+
tensor_model_parallel_all_reduce)
|
| 36 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 37 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 38 |
+
QKVParallelLinear,
|
| 39 |
+
RowParallelLinear)
|
| 40 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 41 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 42 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 43 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 44 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 45 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 46 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 47 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 48 |
+
from vllm.sequence import IntermediateTensors
|
| 49 |
+
from vllm.transformers_utils.configs import RWConfig
|
| 50 |
+
|
| 51 |
+
from .interfaces import SupportsPP
|
| 52 |
+
from .utils import (is_pp_missing_parameter,
|
| 53 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 54 |
+
maybe_prefix)
|
| 55 |
+
|
| 56 |
+
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
| 60 |
+
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
| 61 |
+
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
| 62 |
+
dtype=torch.float32)
|
| 63 |
+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
| 64 |
+
slopes = torch.pow(base, powers)
|
| 65 |
+
|
| 66 |
+
if closest_power_of_2 != total_num_heads:
|
| 67 |
+
extra_base = torch.tensor(
|
| 68 |
+
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
| 69 |
+
dtype=torch.float32)
|
| 70 |
+
num_remaining_heads = min(closest_power_of_2,
|
| 71 |
+
total_num_heads - closest_power_of_2)
|
| 72 |
+
extra_powers = torch.arange(1,
|
| 73 |
+
1 + 2 * num_remaining_heads,
|
| 74 |
+
2,
|
| 75 |
+
dtype=torch.int32)
|
| 76 |
+
slopes = torch.cat(
|
| 77 |
+
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 78 |
+
|
| 79 |
+
return slopes
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FalconAttention(nn.Module):
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
config: FalconConfig,
|
| 87 |
+
cache_config: Optional[CacheConfig] = None,
|
| 88 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 89 |
+
prefix: str = "",
|
| 90 |
+
):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
self.hidden_size = config.hidden_size
|
| 94 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 95 |
+
|
| 96 |
+
self.total_num_heads = config.num_attention_heads
|
| 97 |
+
assert self.total_num_heads % tp_size == 0
|
| 98 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 99 |
+
self.head_dim = self.hidden_size // self.total_num_heads
|
| 100 |
+
assert self.head_dim * self.total_num_heads == self.hidden_size
|
| 101 |
+
|
| 102 |
+
self.new_decoder_architecture = config.new_decoder_architecture
|
| 103 |
+
self.multi_query = config.multi_query
|
| 104 |
+
|
| 105 |
+
if self.new_decoder_architecture:
|
| 106 |
+
self.total_num_kv_heads = config.num_kv_heads
|
| 107 |
+
elif self.multi_query:
|
| 108 |
+
self.total_num_kv_heads = 1
|
| 109 |
+
else:
|
| 110 |
+
self.total_num_kv_heads = self.total_num_heads
|
| 111 |
+
if self.total_num_kv_heads >= tp_size:
|
| 112 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 113 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 114 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 115 |
+
else:
|
| 116 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 117 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 118 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 119 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 120 |
+
|
| 121 |
+
self.query_key_value = QKVParallelLinear(
|
| 122 |
+
self.hidden_size,
|
| 123 |
+
self.head_dim,
|
| 124 |
+
self.total_num_heads,
|
| 125 |
+
self.total_num_kv_heads,
|
| 126 |
+
bias=config.bias,
|
| 127 |
+
skip_bias_add=True,
|
| 128 |
+
quant_config=quant_config,
|
| 129 |
+
)
|
| 130 |
+
self.q_size = self.num_heads * self.head_dim
|
| 131 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 132 |
+
|
| 133 |
+
# Layer-wise attention scaling
|
| 134 |
+
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
| 135 |
+
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
| 136 |
+
or config.parallel_attn)
|
| 137 |
+
self.dense = RowParallelLinear(
|
| 138 |
+
self.hidden_size,
|
| 139 |
+
self.hidden_size,
|
| 140 |
+
bias=config.bias,
|
| 141 |
+
skip_bias_add=True,
|
| 142 |
+
quant_config=quant_config,
|
| 143 |
+
reduce_results=self.reduce_row_parallel_results)
|
| 144 |
+
|
| 145 |
+
self.use_rotary = config.rotary
|
| 146 |
+
self.use_alibi = config.alibi
|
| 147 |
+
assert not (self.use_rotary and self.use_alibi), (
|
| 148 |
+
"Rotary and alibi are mutually exclusive.")
|
| 149 |
+
|
| 150 |
+
if self.use_rotary:
|
| 151 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 152 |
+
max_position_embeddings = getattr(config,
|
| 153 |
+
"max_position_embeddings", 8192)
|
| 154 |
+
self.rotary_emb = get_rope(
|
| 155 |
+
self.head_dim,
|
| 156 |
+
rotary_dim=self.head_dim,
|
| 157 |
+
max_position=max_position_embeddings,
|
| 158 |
+
base=rope_theta,
|
| 159 |
+
)
|
| 160 |
+
self.attn = Attention(self.num_heads,
|
| 161 |
+
self.head_dim,
|
| 162 |
+
self.inv_norm_factor,
|
| 163 |
+
num_kv_heads=self.num_kv_heads,
|
| 164 |
+
quant_config=quant_config,
|
| 165 |
+
prefix=f"{prefix}.attn")
|
| 166 |
+
elif self.use_alibi:
|
| 167 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 168 |
+
head_start = tp_rank * self.num_heads
|
| 169 |
+
head_end = (tp_rank + 1) * self.num_heads
|
| 170 |
+
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
|
| 171 |
+
self.inv_norm_factor)
|
| 172 |
+
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
| 173 |
+
self.attn = Attention(self.num_heads,
|
| 174 |
+
self.head_dim,
|
| 175 |
+
self.inv_norm_factor,
|
| 176 |
+
num_kv_heads=self.num_kv_heads,
|
| 177 |
+
alibi_slopes=alibi_slopes,
|
| 178 |
+
quant_config=quant_config,
|
| 179 |
+
prefix=f"{prefix}.attn")
|
| 180 |
+
else:
|
| 181 |
+
self.attn = Attention(self.num_heads,
|
| 182 |
+
self.head_dim,
|
| 183 |
+
scale=self.inv_norm_factor,
|
| 184 |
+
num_kv_heads=self.num_kv_heads,
|
| 185 |
+
cache_config=cache_config,
|
| 186 |
+
quant_config=quant_config,
|
| 187 |
+
prefix=f"{prefix}.attn")
|
| 188 |
+
|
| 189 |
+
def forward(
|
| 190 |
+
self,
|
| 191 |
+
positions: torch.Tensor,
|
| 192 |
+
hidden_states: torch.Tensor,
|
| 193 |
+
kv_cache: torch.Tensor,
|
| 194 |
+
attn_metadata: AttentionMetadata,
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
qkv, bias = self.query_key_value(hidden_states)
|
| 197 |
+
if bias is not None:
|
| 198 |
+
qkv += bias
|
| 199 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 200 |
+
if self.use_rotary:
|
| 201 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 202 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 203 |
+
attn_output, bias = self.dense(attn_output)
|
| 204 |
+
return attn_output, bias
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class FalconMLP(nn.Module):
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
config: FalconConfig,
|
| 212 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 213 |
+
):
|
| 214 |
+
super().__init__()
|
| 215 |
+
hidden_size = config.hidden_size
|
| 216 |
+
|
| 217 |
+
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
| 218 |
+
4 * hidden_size,
|
| 219 |
+
bias=config.bias,
|
| 220 |
+
skip_bias_add=True,
|
| 221 |
+
quant_config=quant_config)
|
| 222 |
+
self.act = get_act_fn("gelu")
|
| 223 |
+
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
| 224 |
+
or config.parallel_attn)
|
| 225 |
+
self.dense_4h_to_h = RowParallelLinear(
|
| 226 |
+
4 * hidden_size,
|
| 227 |
+
hidden_size,
|
| 228 |
+
bias=config.bias,
|
| 229 |
+
skip_bias_add=True,
|
| 230 |
+
reduce_results=self.reduce_row_parallel_results,
|
| 231 |
+
quant_config=quant_config)
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 234 |
+
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
| 235 |
+
x, bias = self.dense_h_to_4h(x)
|
| 236 |
+
if bias is not None:
|
| 237 |
+
x += bias
|
| 238 |
+
x = self.act(x)
|
| 239 |
+
x, bias = self.dense_4h_to_h(x)
|
| 240 |
+
return x, bias
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class FalconDecoderLayer(nn.Module):
|
| 244 |
+
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
config: FalconConfig,
|
| 248 |
+
cache_config: Optional[CacheConfig] = None,
|
| 249 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 250 |
+
prefix: str = "",
|
| 251 |
+
):
|
| 252 |
+
super().__init__()
|
| 253 |
+
hidden_size = config.hidden_size
|
| 254 |
+
self.num_heads = config.num_attention_heads
|
| 255 |
+
self.self_attention = FalconAttention(
|
| 256 |
+
config,
|
| 257 |
+
cache_config,
|
| 258 |
+
quant_config,
|
| 259 |
+
prefix=f"{prefix}.self_attention")
|
| 260 |
+
self.mlp = FalconMLP(config, quant_config)
|
| 261 |
+
self.config = config
|
| 262 |
+
|
| 263 |
+
if (not hasattr(config, "num_ln_in_parallel_attn")):
|
| 264 |
+
config.num_ln_in_parallel_attn = None
|
| 265 |
+
|
| 266 |
+
if (config.num_ln_in_parallel_attn is None
|
| 267 |
+
and config.new_decoder_architecture):
|
| 268 |
+
config.num_ln_in_parallel_attn = 2
|
| 269 |
+
|
| 270 |
+
if not config.parallel_attn:
|
| 271 |
+
self.post_attention_layernorm = LayerNorm(
|
| 272 |
+
hidden_size, eps=config.layer_norm_epsilon)
|
| 273 |
+
self.input_layernorm = LayerNorm(hidden_size,
|
| 274 |
+
eps=config.layer_norm_epsilon)
|
| 275 |
+
else:
|
| 276 |
+
if config.num_ln_in_parallel_attn == 2:
|
| 277 |
+
# The layer norm before self-attention
|
| 278 |
+
self.ln_attn = LayerNorm(hidden_size,
|
| 279 |
+
eps=config.layer_norm_epsilon)
|
| 280 |
+
# The layer norm before the MLP
|
| 281 |
+
self.ln_mlp = LayerNorm(hidden_size,
|
| 282 |
+
eps=config.layer_norm_epsilon)
|
| 283 |
+
else:
|
| 284 |
+
self.input_layernorm = LayerNorm(hidden_size,
|
| 285 |
+
eps=config.layer_norm_epsilon)
|
| 286 |
+
|
| 287 |
+
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
| 288 |
+
or config.parallel_attn)
|
| 289 |
+
|
| 290 |
+
def forward(
|
| 291 |
+
self,
|
| 292 |
+
positions: torch.Tensor,
|
| 293 |
+
hidden_states: torch.Tensor,
|
| 294 |
+
kv_cache: torch.Tensor,
|
| 295 |
+
attn_metadata: AttentionMetadata,
|
| 296 |
+
) -> torch.Tensor:
|
| 297 |
+
residual = hidden_states
|
| 298 |
+
|
| 299 |
+
if self.config.num_ln_in_parallel_attn == 2:
|
| 300 |
+
attention_layernorm_out = self.ln_attn(hidden_states)
|
| 301 |
+
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
| 302 |
+
else:
|
| 303 |
+
attention_layernorm_out = self.input_layernorm(hidden_states)
|
| 304 |
+
|
| 305 |
+
# Self attention.
|
| 306 |
+
attention_output, attention_bias = self.self_attention(
|
| 307 |
+
positions=positions,
|
| 308 |
+
hidden_states=attention_layernorm_out,
|
| 309 |
+
kv_cache=kv_cache,
|
| 310 |
+
attn_metadata=attn_metadata,
|
| 311 |
+
)
|
| 312 |
+
if self.reduce_row_parallel_results and attention_bias is not None:
|
| 313 |
+
attention_output += attention_bias
|
| 314 |
+
|
| 315 |
+
if not self.config.new_decoder_architecture:
|
| 316 |
+
if self.config.parallel_attn:
|
| 317 |
+
mlp_layernorm_out = attention_layernorm_out
|
| 318 |
+
else:
|
| 319 |
+
residual += attention_output
|
| 320 |
+
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
| 321 |
+
|
| 322 |
+
if (self.config.new_decoder_architecture and self.config.parallel_attn
|
| 323 |
+
and self.config.num_ln_in_parallel_attn == 1):
|
| 324 |
+
mlp_layernorm_out = attention_layernorm_out
|
| 325 |
+
|
| 326 |
+
# MLP.
|
| 327 |
+
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
|
| 328 |
+
if self.reduce_row_parallel_results and mlp_bias is not None:
|
| 329 |
+
mlp_output += mlp_bias
|
| 330 |
+
|
| 331 |
+
if not self.reduce_row_parallel_results:
|
| 332 |
+
# When MLP and Attention layers are parallel, we can use
|
| 333 |
+
# only one all-reduce operator to reduce the results from
|
| 334 |
+
# both MLP and Attention layers.
|
| 335 |
+
mlp_output += attention_output
|
| 336 |
+
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
|
| 337 |
+
if attention_bias is not None:
|
| 338 |
+
mlp_output += attention_bias
|
| 339 |
+
if mlp_bias is not None:
|
| 340 |
+
mlp_output += mlp_bias
|
| 341 |
+
|
| 342 |
+
output = mlp_output + residual
|
| 343 |
+
return output
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
@support_torch_compile
|
| 347 |
+
class FalconModel(nn.Module):
|
| 348 |
+
|
| 349 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 350 |
+
super().__init__()
|
| 351 |
+
|
| 352 |
+
config = vllm_config.model_config.hf_config
|
| 353 |
+
cache_config = vllm_config.cache_config
|
| 354 |
+
quant_config = vllm_config.quant_config
|
| 355 |
+
|
| 356 |
+
self.config = config
|
| 357 |
+
self.embed_dim = config.hidden_size
|
| 358 |
+
self.num_heads = config.num_attention_heads
|
| 359 |
+
self.use_alibi = config.alibi
|
| 360 |
+
|
| 361 |
+
# Embedding + LN Embedding
|
| 362 |
+
self.word_embeddings = VocabParallelEmbedding(
|
| 363 |
+
config.vocab_size,
|
| 364 |
+
self.embed_dim,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Transformer blocks
|
| 368 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 369 |
+
config.num_hidden_layers,
|
| 370 |
+
lambda prefix: FalconDecoderLayer(
|
| 371 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 372 |
+
prefix=f"{prefix}.h")
|
| 373 |
+
|
| 374 |
+
# Final Layer Norm
|
| 375 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 376 |
+
self.make_empty_intermediate_tensors = (
|
| 377 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 378 |
+
config.hidden_size))
|
| 379 |
+
|
| 380 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 381 |
+
return self.word_embeddings(input_ids)
|
| 382 |
+
|
| 383 |
+
def forward(
|
| 384 |
+
self,
|
| 385 |
+
input_ids: torch.Tensor,
|
| 386 |
+
positions: torch.Tensor,
|
| 387 |
+
kv_caches: List[torch.Tensor],
|
| 388 |
+
attn_metadata: AttentionMetadata,
|
| 389 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 390 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 391 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 392 |
+
if get_pp_group().is_first_rank:
|
| 393 |
+
if inputs_embeds is not None:
|
| 394 |
+
hidden_states = inputs_embeds
|
| 395 |
+
else:
|
| 396 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 397 |
+
else:
|
| 398 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 399 |
+
for i in range(self.start_layer, self.end_layer):
|
| 400 |
+
layer = self.h[i]
|
| 401 |
+
hidden_states = layer(
|
| 402 |
+
positions,
|
| 403 |
+
hidden_states,
|
| 404 |
+
kv_caches[i - self.start_layer],
|
| 405 |
+
attn_metadata,
|
| 406 |
+
)
|
| 407 |
+
if not get_pp_group().is_last_rank:
|
| 408 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 409 |
+
hidden_states = self.ln_f(hidden_states)
|
| 410 |
+
return hidden_states
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class FalconForCausalLM(nn.Module, SupportsPP):
|
| 414 |
+
packed_modules_mapping = {
|
| 415 |
+
"query_key_value": ["query_key_value"],
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 419 |
+
super().__init__()
|
| 420 |
+
config = vllm_config.model_config.hf_config
|
| 421 |
+
quant_config = vllm_config.quant_config
|
| 422 |
+
self.config = config
|
| 423 |
+
self.quant_config = quant_config
|
| 424 |
+
self.transformer = FalconModel(vllm_config=vllm_config,
|
| 425 |
+
prefix=maybe_prefix(
|
| 426 |
+
prefix, "transformer"))
|
| 427 |
+
# only Falcon-11B doesn't share lm_head weight with word embeddings
|
| 428 |
+
# and previous Falcon model doesn't have tie_word_embeddings config
|
| 429 |
+
# so we set tie_word_embeddings to True by default
|
| 430 |
+
self.tie_word_embeddings = (config.tie_word_embeddings
|
| 431 |
+
if config.tie_word_embeddings is not None
|
| 432 |
+
else True)
|
| 433 |
+
if self.tie_word_embeddings:
|
| 434 |
+
self.lm_head = self.transformer.word_embeddings
|
| 435 |
+
else:
|
| 436 |
+
self.lm_head = ParallelLMHead(
|
| 437 |
+
config.vocab_size,
|
| 438 |
+
config.hidden_size,
|
| 439 |
+
quant_config=quant_config,
|
| 440 |
+
)
|
| 441 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 442 |
+
self.sampler = get_sampler()
|
| 443 |
+
self.make_empty_intermediate_tensors = (
|
| 444 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 445 |
+
|
| 446 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 447 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 448 |
+
|
| 449 |
+
def forward(
|
| 450 |
+
self,
|
| 451 |
+
input_ids: torch.LongTensor,
|
| 452 |
+
positions: torch.Tensor,
|
| 453 |
+
kv_caches: List[torch.Tensor],
|
| 454 |
+
attn_metadata: AttentionMetadata,
|
| 455 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 456 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 457 |
+
) -> torch.Tensor:
|
| 458 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 459 |
+
attn_metadata, intermediate_tensors,
|
| 460 |
+
inputs_embeds)
|
| 461 |
+
return hidden_states
|
| 462 |
+
|
| 463 |
+
def compute_logits(
|
| 464 |
+
self,
|
| 465 |
+
hidden_states: torch.Tensor,
|
| 466 |
+
sampling_metadata: SamplingMetadata,
|
| 467 |
+
) -> Optional[torch.Tensor]:
|
| 468 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 469 |
+
sampling_metadata)
|
| 470 |
+
return logits
|
| 471 |
+
|
| 472 |
+
def sample(
|
| 473 |
+
self,
|
| 474 |
+
logits: torch.Tensor,
|
| 475 |
+
sampling_metadata: SamplingMetadata,
|
| 476 |
+
) -> Optional[SamplerOutput]:
|
| 477 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 478 |
+
return next_tokens
|
| 479 |
+
|
| 480 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 481 |
+
torch.Tensor]]) -> Set[str]:
|
| 482 |
+
total_num_heads = self.config.num_attention_heads
|
| 483 |
+
if self.config.new_decoder_architecture:
|
| 484 |
+
total_num_kv_heads = self.config.num_kv_heads
|
| 485 |
+
elif self.config.multi_query:
|
| 486 |
+
total_num_kv_heads = 1
|
| 487 |
+
else:
|
| 488 |
+
total_num_kv_heads = total_num_heads
|
| 489 |
+
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
| 490 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 491 |
+
loaded_params: Set[str] = set()
|
| 492 |
+
for name, loaded_weight in weights:
|
| 493 |
+
if name == "lm_head.weight" and self.tie_word_embeddings:
|
| 494 |
+
# Falcon uses tied embeddings except Falcon-11b.
|
| 495 |
+
continue
|
| 496 |
+
# Skip loading extra bias for GPTQ models.
|
| 497 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 498 |
+
continue
|
| 499 |
+
if is_pp_missing_parameter(name, self):
|
| 500 |
+
continue
|
| 501 |
+
param = params_dict[name]
|
| 502 |
+
if "query_key_value" in name:
|
| 503 |
+
output_dim = getattr(param, "output_dim", None)
|
| 504 |
+
loaded_weight_shape = loaded_weight.shape
|
| 505 |
+
if output_dim is not None:
|
| 506 |
+
loaded_weight = loaded_weight.view(
|
| 507 |
+
loaded_weight_shape[:output_dim] +
|
| 508 |
+
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
| 509 |
+
-1) + loaded_weight_shape[output_dim + 1:])
|
| 510 |
+
wq = loaded_weight.narrow(
|
| 511 |
+
output_dim + 1, 0,
|
| 512 |
+
num_query_heads_per_kv_head).reshape(
|
| 513 |
+
*loaded_weight_shape[:output_dim], -1,
|
| 514 |
+
*loaded_weight_shape[output_dim + 1:])
|
| 515 |
+
wk = loaded_weight.narrow(
|
| 516 |
+
output_dim + 1, num_query_heads_per_kv_head,
|
| 517 |
+
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
| 518 |
+
*loaded_weight_shape[output_dim + 1:])
|
| 519 |
+
wv = loaded_weight.narrow(
|
| 520 |
+
output_dim + 1, num_query_heads_per_kv_head + 1,
|
| 521 |
+
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
| 522 |
+
*loaded_weight_shape[output_dim + 1:])
|
| 523 |
+
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
| 524 |
+
|
| 525 |
+
weight_loader = getattr(param, "weight_loader",
|
| 526 |
+
default_weight_loader)
|
| 527 |
+
weight_loader(param, loaded_weight)
|
| 528 |
+
loaded_params.add(name)
|
| 529 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/florence2.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Iterable, List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from vllm.attention import AttentionMetadata
|
| 10 |
+
from vllm.config import VllmConfig
|
| 11 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 12 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 13 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 14 |
+
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
|
| 15 |
+
BartParallelLMHead,
|
| 16 |
+
BartScaledWordEmbedding)
|
| 17 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 18 |
+
from vllm.sequence import IntermediateTensors
|
| 19 |
+
|
| 20 |
+
from .utils import AutoWeightsLoader
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Florence2LanguageModel(nn.Module):
|
| 24 |
+
|
| 25 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
config = vllm_config.model_config.hf_config
|
| 29 |
+
cache_config = vllm_config.cache_config
|
| 30 |
+
quant_config = vllm_config.quant_config
|
| 31 |
+
|
| 32 |
+
self.config = config
|
| 33 |
+
|
| 34 |
+
self.padding_idx = config.pad_token_id
|
| 35 |
+
self.vocab_size = config.vocab_size
|
| 36 |
+
|
| 37 |
+
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
| 38 |
+
self.encoder = BartEncoder(config,
|
| 39 |
+
cache_config=cache_config,
|
| 40 |
+
quant_config=quant_config,
|
| 41 |
+
prefix=f"{prefix}.encoder")
|
| 42 |
+
self.decoder = BartDecoder(config,
|
| 43 |
+
cache_config=cache_config,
|
| 44 |
+
quant_config=quant_config,
|
| 45 |
+
prefix=f"{prefix}.decoder")
|
| 46 |
+
|
| 47 |
+
if self.config.tie_word_embeddings:
|
| 48 |
+
self.encoder.embed_tokens.weight = self.shared.weight
|
| 49 |
+
self.decoder.embed_tokens.weight = self.shared.weight
|
| 50 |
+
|
| 51 |
+
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
| 52 |
+
encoder_input_ids: torch.Tensor,
|
| 53 |
+
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
| 54 |
+
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
| 55 |
+
r"""
|
| 56 |
+
Args:
|
| 57 |
+
input_ids
|
| 58 |
+
Indices of *decoder* input sequence tokens in the vocabulary.
|
| 59 |
+
Padding will be ignored by default should you
|
| 60 |
+
provide it.
|
| 61 |
+
positions
|
| 62 |
+
Positions of *decoder* input sequence tokens.
|
| 63 |
+
encoder_input_ids
|
| 64 |
+
Indices of *encoder* input sequence tokens in the vocabulary.
|
| 65 |
+
encoder_positions:
|
| 66 |
+
Positions of *encoder* input sequence tokens.
|
| 67 |
+
kv_caches:
|
| 68 |
+
Layer-wise list of KV cache tensors
|
| 69 |
+
attn_metadata:
|
| 70 |
+
vLLM Attention metadata structure
|
| 71 |
+
Returns:
|
| 72 |
+
Model output torch.Tensor
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
encoder_hidden_states = None
|
| 76 |
+
|
| 77 |
+
if encoder_input_ids.numel() > 0:
|
| 78 |
+
# Run encoder attention if a non-zero number of encoder tokens
|
| 79 |
+
# are provided as input
|
| 80 |
+
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
| 81 |
+
positions=encoder_positions,
|
| 82 |
+
kv_caches=kv_caches,
|
| 83 |
+
attn_metadata=attn_metadata)
|
| 84 |
+
|
| 85 |
+
# decoder outputs consists of
|
| 86 |
+
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
| 87 |
+
decoder_outputs = self.decoder(
|
| 88 |
+
decoder_input_ids=input_ids,
|
| 89 |
+
decoder_positions=positions,
|
| 90 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 91 |
+
kv_caches=kv_caches,
|
| 92 |
+
attn_metadata=attn_metadata)
|
| 93 |
+
|
| 94 |
+
return decoder_outputs
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Florence2LanguageForConditionalGeneration(nn.Module):
|
| 98 |
+
|
| 99 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
config = vllm_config.model_config.hf_config
|
| 103 |
+
|
| 104 |
+
self.config = config
|
| 105 |
+
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
| 106 |
+
prefix=f"{prefix}.model")
|
| 107 |
+
embed_scale = math.sqrt(
|
| 108 |
+
config.d_model) if config.scale_embedding else 1.0
|
| 109 |
+
|
| 110 |
+
self.vocab_size = config.vocab_size
|
| 111 |
+
self.lm_head = BartParallelLMHead(self.vocab_size,
|
| 112 |
+
config.d_model,
|
| 113 |
+
embed_scale=embed_scale)
|
| 114 |
+
|
| 115 |
+
self.logits_processor = LogitsProcessor(self.vocab_size,
|
| 116 |
+
config.vocab_size)
|
| 117 |
+
self.sampler = get_sampler()
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self,
|
| 121 |
+
input_ids: torch.Tensor,
|
| 122 |
+
positions: torch.Tensor,
|
| 123 |
+
encoder_input_ids: torch.Tensor,
|
| 124 |
+
encoder_positions: torch.Tensor,
|
| 125 |
+
kv_caches: List[torch.Tensor],
|
| 126 |
+
attn_metadata: AttentionMetadata,
|
| 127 |
+
**kwargs,
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
r"""
|
| 130 |
+
Args:
|
| 131 |
+
input_ids
|
| 132 |
+
torch.Tensor of *decoder* input token ids.
|
| 133 |
+
positions
|
| 134 |
+
torch.Tensor of *decoder* position indices.
|
| 135 |
+
encoder_input_ids
|
| 136 |
+
torch.Tensor of *encoder* input token ids.
|
| 137 |
+
encoder_positions
|
| 138 |
+
torch.Tensor of *encoder* position indices
|
| 139 |
+
kv_caches:
|
| 140 |
+
Layer-wise list of KV cache tensors
|
| 141 |
+
attn_metadata:
|
| 142 |
+
vLLM Attention metadata structure
|
| 143 |
+
Returns:
|
| 144 |
+
Output torch.Tensor
|
| 145 |
+
"""
|
| 146 |
+
return self.model(input_ids, positions, encoder_input_ids,
|
| 147 |
+
encoder_positions, kv_caches, attn_metadata)
|
| 148 |
+
|
| 149 |
+
def compute_logits(
|
| 150 |
+
self,
|
| 151 |
+
hidden_states: torch.Tensor,
|
| 152 |
+
sampling_metadata: SamplingMetadata,
|
| 153 |
+
) -> Optional[torch.Tensor]:
|
| 154 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 155 |
+
sampling_metadata)
|
| 156 |
+
return logits
|
| 157 |
+
|
| 158 |
+
def sample(self, logits: torch.Tensor,
|
| 159 |
+
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
| 160 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 161 |
+
return next_tokens
|
| 162 |
+
|
| 163 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 164 |
+
torch.Tensor]]) -> Set[str]:
|
| 165 |
+
stacked_params_mapping = [
|
| 166 |
+
# (param_name, shard_name, shard_id)
|
| 167 |
+
("qkv_proj", "q_proj", "q"),
|
| 168 |
+
("qkv_proj", "k_proj", "k"),
|
| 169 |
+
("qkv_proj", "v_proj", "v"),
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
params_dict = dict(self.named_parameters())
|
| 173 |
+
loaded_params: Set[str] = set()
|
| 174 |
+
for name, loaded_weight in weights:
|
| 175 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 176 |
+
if weight_name not in name:
|
| 177 |
+
continue
|
| 178 |
+
name = name.replace(weight_name, param_name)
|
| 179 |
+
param = params_dict[name]
|
| 180 |
+
weight_loader = param.weight_loader
|
| 181 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 182 |
+
break
|
| 183 |
+
else:
|
| 184 |
+
if "final_logits_bias" in name:
|
| 185 |
+
continue
|
| 186 |
+
if self.config.tie_word_embeddings and "embed_tokens" in name:
|
| 187 |
+
continue
|
| 188 |
+
param = params_dict[name]
|
| 189 |
+
weight_loader = getattr(param, "weight_loader",
|
| 190 |
+
default_weight_loader)
|
| 191 |
+
weight_loader(param, loaded_weight)
|
| 192 |
+
loaded_params.add(name)
|
| 193 |
+
return loaded_params
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Florence2ForConditionalGeneration(nn.Module):
|
| 197 |
+
|
| 198 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 199 |
+
super().__init__()
|
| 200 |
+
config = vllm_config.model_config.hf_config
|
| 201 |
+
|
| 202 |
+
# TODO(Isotr0py): Add vision backbone
|
| 203 |
+
self.language_model = Florence2LanguageForConditionalGeneration(
|
| 204 |
+
vllm_config=vllm_config.with_hf_config(config.text_config),
|
| 205 |
+
prefix=f"{prefix}.language_model",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def sampler(self):
|
| 210 |
+
return self.language_model.sampler
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self,
|
| 214 |
+
input_ids: torch.Tensor,
|
| 215 |
+
positions: torch.Tensor,
|
| 216 |
+
kv_caches: List[torch.Tensor],
|
| 217 |
+
attn_metadata: AttentionMetadata,
|
| 218 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 219 |
+
*,
|
| 220 |
+
encoder_input_ids: torch.Tensor,
|
| 221 |
+
encoder_positions: torch.Tensor,
|
| 222 |
+
**kwargs,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
r"""
|
| 225 |
+
Args:
|
| 226 |
+
input_ids
|
| 227 |
+
torch.Tensor of *decoder* input token ids.
|
| 228 |
+
positions
|
| 229 |
+
torch.Tensor of *decoder* position indices.
|
| 230 |
+
encoder_input_ids
|
| 231 |
+
torch.Tensor of *encoder* input token ids.
|
| 232 |
+
encoder_positions
|
| 233 |
+
torch.Tensor of *encoder* position indices
|
| 234 |
+
kv_caches:
|
| 235 |
+
Layer-wise list of KV cache tensors
|
| 236 |
+
attn_metadata:
|
| 237 |
+
vLLM Attention metadata structure
|
| 238 |
+
Returns:
|
| 239 |
+
Output torch.Tensor
|
| 240 |
+
"""
|
| 241 |
+
return self.language_model(input_ids, positions, encoder_input_ids,
|
| 242 |
+
encoder_positions, kv_caches, attn_metadata)
|
| 243 |
+
|
| 244 |
+
def compute_logits(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states: torch.Tensor,
|
| 247 |
+
sampling_metadata: SamplingMetadata,
|
| 248 |
+
) -> Optional[torch.Tensor]:
|
| 249 |
+
return self.language_model.compute_logits(hidden_states,
|
| 250 |
+
sampling_metadata)
|
| 251 |
+
|
| 252 |
+
def sample(
|
| 253 |
+
self,
|
| 254 |
+
logits: torch.Tensor,
|
| 255 |
+
sampling_metadata: SamplingMetadata,
|
| 256 |
+
) -> SamplerOutput:
|
| 257 |
+
return self.language_model.sample(logits, sampling_metadata)
|
| 258 |
+
|
| 259 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 260 |
+
torch.Tensor]]) -> Set[str]:
|
| 261 |
+
skip_prefixes = [
|
| 262 |
+
'image_projection', "vision_tower", "image_proj_norm",
|
| 263 |
+
"image_pos_embed", "visual_temporal_embed"
|
| 264 |
+
]
|
| 265 |
+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
| 266 |
+
return loader.load_weights(weights)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/fuyu.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
|
| 4 |
+
# Copyright 2023 The vLLM team.
|
| 5 |
+
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
""" PyTorch Fuyu model."""
|
| 19 |
+
import math
|
| 20 |
+
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
| 21 |
+
TypedDict)
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
| 26 |
+
FuyuProcessor)
|
| 27 |
+
|
| 28 |
+
from vllm.attention import AttentionMetadata
|
| 29 |
+
from vllm.config import VllmConfig
|
| 30 |
+
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
| 31 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 32 |
+
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
| 33 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 34 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 35 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 36 |
+
NestedTensors)
|
| 37 |
+
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
| 38 |
+
MultiModalDataItems)
|
| 39 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 40 |
+
BaseProcessingInfo, PromptReplacement,
|
| 41 |
+
PromptReplacementDetails)
|
| 42 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 46 |
+
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
| 47 |
+
merge_multimodal_embeddings)
|
| 48 |
+
|
| 49 |
+
# Cannot find the following 2 numbers from hf config.
|
| 50 |
+
_IMAGE_TOKEN_ID = 71011
|
| 51 |
+
_NEWLINE_TOKEN_ID = 71019
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FuyuImagePatchInputs(TypedDict):
|
| 55 |
+
type: Literal["image_patches"]
|
| 56 |
+
flat_data: torch.Tensor
|
| 57 |
+
"""
|
| 58 |
+
Shape:
|
| 59 |
+
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
patches_per_image: List[int]
|
| 63 |
+
"""
|
| 64 |
+
List of number of total patches for each image in the batch.
|
| 65 |
+
This is used to restore the first two dimensions of `flat_data`.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class FuyuProcessingInfo(BaseProcessingInfo):
|
| 70 |
+
|
| 71 |
+
def get_hf_config(self):
|
| 72 |
+
return self.ctx.get_hf_config(FuyuConfig)
|
| 73 |
+
|
| 74 |
+
def get_hf_processor(self):
|
| 75 |
+
return self.ctx.get_hf_processor(FuyuProcessor)
|
| 76 |
+
|
| 77 |
+
def get_image_processor(self) -> FuyuImageProcessor:
|
| 78 |
+
return self.get_hf_processor().image_processor
|
| 79 |
+
|
| 80 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 81 |
+
return {"image": 1}
|
| 82 |
+
|
| 83 |
+
def get_mm_max_tokens_per_item(
|
| 84 |
+
self,
|
| 85 |
+
seq_len: int,
|
| 86 |
+
mm_counts: Mapping[str, int],
|
| 87 |
+
) -> Mapping[str, int]:
|
| 88 |
+
target_width, target_height = self.get_image_size_with_most_features()
|
| 89 |
+
|
| 90 |
+
max_ncols, max_nrows = self.get_image_feature_grid_size(
|
| 91 |
+
image_width=target_width,
|
| 92 |
+
image_height=target_height,
|
| 93 |
+
)
|
| 94 |
+
max_image_tokens = (max_ncols + 1) * max_nrows
|
| 95 |
+
|
| 96 |
+
return {"image": max_image_tokens}
|
| 97 |
+
|
| 98 |
+
def get_image_feature_grid_size(
|
| 99 |
+
self,
|
| 100 |
+
*,
|
| 101 |
+
image_width: int,
|
| 102 |
+
image_height: int,
|
| 103 |
+
) -> tuple[int, int]:
|
| 104 |
+
image_processor = self.get_image_processor()
|
| 105 |
+
target_width = image_processor.size["width"]
|
| 106 |
+
target_height = image_processor.size["height"]
|
| 107 |
+
|
| 108 |
+
if not (image_width <= target_width and image_height <= target_height):
|
| 109 |
+
height_scale_factor = target_height / image_height
|
| 110 |
+
width_scale_factor = target_width / image_width
|
| 111 |
+
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
| 112 |
+
|
| 113 |
+
image_height = int(image_height * optimal_scale_factor)
|
| 114 |
+
image_width = int(image_width * optimal_scale_factor)
|
| 115 |
+
|
| 116 |
+
ncols = math.ceil(image_width / 30)
|
| 117 |
+
nrows = math.ceil(image_height / 30)
|
| 118 |
+
return ncols, nrows
|
| 119 |
+
|
| 120 |
+
def get_image_size_with_most_features(self) -> ImageSize:
|
| 121 |
+
image_processor = self.get_image_processor()
|
| 122 |
+
return ImageSize(width=image_processor.size["width"],
|
| 123 |
+
height=image_processor.size["height"])
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
| 127 |
+
|
| 128 |
+
def get_dummy_processor_inputs(
|
| 129 |
+
self,
|
| 130 |
+
seq_len: int,
|
| 131 |
+
mm_counts: Mapping[str, int],
|
| 132 |
+
) -> ProcessorInputs:
|
| 133 |
+
target_width, target_height = \
|
| 134 |
+
self.info.get_image_size_with_most_features()
|
| 135 |
+
num_images = mm_counts.get("image", 0)
|
| 136 |
+
|
| 137 |
+
mm_data = {
|
| 138 |
+
"image":
|
| 139 |
+
self._get_dummy_images(width=target_width,
|
| 140 |
+
height=target_height,
|
| 141 |
+
num_images=num_images)
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
return ProcessorInputs(
|
| 145 |
+
prompt_text="",
|
| 146 |
+
mm_data=mm_data,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
| 151 |
+
|
| 152 |
+
def _call_hf_processor(
|
| 153 |
+
self,
|
| 154 |
+
prompt: str,
|
| 155 |
+
mm_data: Mapping[str, object],
|
| 156 |
+
mm_kwargs: Mapping[str, object],
|
| 157 |
+
) -> BatchFeature:
|
| 158 |
+
if not mm_data:
|
| 159 |
+
# Avoid warning from HF logger for text-only input
|
| 160 |
+
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
| 161 |
+
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
| 162 |
+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
| 163 |
+
|
| 164 |
+
processed_outputs = super()._call_hf_processor(
|
| 165 |
+
prompt=prompt,
|
| 166 |
+
mm_data=mm_data,
|
| 167 |
+
mm_kwargs=mm_kwargs,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
image_patches = processed_outputs.get("image_patches")
|
| 171 |
+
if image_patches is not None:
|
| 172 |
+
images = mm_data["images"]
|
| 173 |
+
assert isinstance(images, list)
|
| 174 |
+
|
| 175 |
+
# Original output: (1, num_images, Pn, Px * Py * C)
|
| 176 |
+
# New output: (num_images, Pn, Px * Py * C)
|
| 177 |
+
assert (isinstance(image_patches, list)
|
| 178 |
+
and len(image_patches) == 1)
|
| 179 |
+
assert (isinstance(image_patches[0], torch.Tensor)
|
| 180 |
+
and len(image_patches[0]) == len(images))
|
| 181 |
+
|
| 182 |
+
processed_outputs["image_patches"] = image_patches[0]
|
| 183 |
+
|
| 184 |
+
return processed_outputs
|
| 185 |
+
|
| 186 |
+
def _apply_hf_processor_tokens_only(
|
| 187 |
+
self,
|
| 188 |
+
prompt_tokens: list[int],
|
| 189 |
+
) -> list[int]:
|
| 190 |
+
# HF processor adds boa_token_id
|
| 191 |
+
tokenizer = self.info.get_tokenizer()
|
| 192 |
+
vocab = tokenizer.get_vocab()
|
| 193 |
+
|
| 194 |
+
boa_token_id = vocab["<0x04>"]
|
| 195 |
+
|
| 196 |
+
return prompt_tokens + [boa_token_id]
|
| 197 |
+
|
| 198 |
+
def _get_mm_fields_config(
|
| 199 |
+
self,
|
| 200 |
+
hf_inputs: BatchFeature,
|
| 201 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 202 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 203 |
+
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
| 204 |
+
|
| 205 |
+
def _get_prompt_replacements(
|
| 206 |
+
self,
|
| 207 |
+
mm_items: MultiModalDataItems,
|
| 208 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 209 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 210 |
+
) -> list[PromptReplacement]:
|
| 211 |
+
hf_config = self.info.get_hf_config()
|
| 212 |
+
bos_token_id = hf_config.bos_token_id
|
| 213 |
+
assert isinstance(bos_token_id, int)
|
| 214 |
+
|
| 215 |
+
tokenizer = self.info.get_tokenizer()
|
| 216 |
+
eot_token_id = tokenizer.bos_token_id
|
| 217 |
+
assert isinstance(eot_token_id, int)
|
| 218 |
+
|
| 219 |
+
def get_replacement_fuyu(item_idx: int):
|
| 220 |
+
images = mm_items.get_items("image", ImageProcessorItems)
|
| 221 |
+
image_size = images.get_image_size(item_idx)
|
| 222 |
+
|
| 223 |
+
ncols, nrows = self.info.get_image_feature_grid_size(
|
| 224 |
+
image_width=image_size.width,
|
| 225 |
+
image_height=image_size.height,
|
| 226 |
+
)
|
| 227 |
+
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
| 228 |
+
[_NEWLINE_TOKEN_ID]) * nrows
|
| 229 |
+
|
| 230 |
+
return PromptReplacementDetails(
|
| 231 |
+
full=image_tokens + [bos_token_id],
|
| 232 |
+
features=image_tokens,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return [
|
| 236 |
+
PromptReplacement(
|
| 237 |
+
modality="image",
|
| 238 |
+
target=[eot_token_id],
|
| 239 |
+
replacement=get_replacement_fuyu,
|
| 240 |
+
)
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
|
| 245 |
+
info=FuyuProcessingInfo,
|
| 246 |
+
dummy_inputs=FuyuDummyInputsBuilder)
|
| 247 |
+
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
| 248 |
+
|
| 249 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 250 |
+
super().__init__()
|
| 251 |
+
config = vllm_config.model_config.hf_config
|
| 252 |
+
quant_config = vllm_config.quant_config
|
| 253 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 254 |
+
self.config = config
|
| 255 |
+
self.multimodal_config = multimodal_config
|
| 256 |
+
|
| 257 |
+
self.padding_idx = config.pad_token_id
|
| 258 |
+
self.vocab_size = config.text_config.vocab_size
|
| 259 |
+
self.image_token_id = _IMAGE_TOKEN_ID
|
| 260 |
+
self.image_feature_size = config.patch_size**2 * config.num_channels
|
| 261 |
+
|
| 262 |
+
self.vision_embed_tokens = ColumnParallelLinear(
|
| 263 |
+
self.image_feature_size,
|
| 264 |
+
config.hidden_size,
|
| 265 |
+
quant_config=quant_config,
|
| 266 |
+
gather_output=True,
|
| 267 |
+
)
|
| 268 |
+
self.language_model = PersimmonForCausalLM(
|
| 269 |
+
vllm_config=vllm_config.with_hf_config(config.text_config),
|
| 270 |
+
prefix=maybe_prefix(prefix, "language_model"),
|
| 271 |
+
)
|
| 272 |
+
self.make_empty_intermediate_tensors = (
|
| 273 |
+
self.language_model.make_empty_intermediate_tensors)
|
| 274 |
+
|
| 275 |
+
@property
|
| 276 |
+
def sampler(self):
|
| 277 |
+
return self.language_model.sampler
|
| 278 |
+
|
| 279 |
+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
| 280 |
+
|
| 281 |
+
h = w = self.config.patch_size
|
| 282 |
+
num_channels = self.config.num_channels
|
| 283 |
+
expected_dims = num_channels * h * w
|
| 284 |
+
|
| 285 |
+
def _validate_shape(d: torch.Tensor):
|
| 286 |
+
actual_dims = d.size(-1)
|
| 287 |
+
|
| 288 |
+
if actual_dims != expected_dims:
|
| 289 |
+
expected_expr = str(expected_dims)
|
| 290 |
+
raise ValueError(
|
| 291 |
+
"The expected shape of pixel values per image per batch "
|
| 292 |
+
f" per patch is {expected_expr}. "
|
| 293 |
+
f"You supplied {tuple(d.shape)}.")
|
| 294 |
+
|
| 295 |
+
for d in data:
|
| 296 |
+
_validate_shape(d)
|
| 297 |
+
|
| 298 |
+
return data.to(self.vision_embed_tokens.weight.dtype)
|
| 299 |
+
|
| 300 |
+
def _parse_and_validate_image_input(
|
| 301 |
+
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
| 302 |
+
image_patches = kwargs.pop("image_patches", None)
|
| 303 |
+
if image_patches is not None:
|
| 304 |
+
if not isinstance(image_patches, (torch.Tensor, list)):
|
| 305 |
+
raise ValueError("Incorrect type of image patches. "
|
| 306 |
+
f"Got type: {type(image_patches)}")
|
| 307 |
+
|
| 308 |
+
image_patches_flat = flatten_bn(image_patches)
|
| 309 |
+
|
| 310 |
+
return FuyuImagePatchInputs(
|
| 311 |
+
type="image_patches",
|
| 312 |
+
flat_data=self._validate_pixel_values(
|
| 313 |
+
flatten_bn(image_patches_flat, concat=True)),
|
| 314 |
+
patches_per_image=[x.size(0) for x in image_patches_flat],
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
def _process_image_input(
|
| 320 |
+
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
|
| 321 |
+
image_patches_flat = image_input["flat_data"]
|
| 322 |
+
patches_per_image = image_input["patches_per_image"]
|
| 323 |
+
|
| 324 |
+
assert self.vision_embed_tokens is not None
|
| 325 |
+
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
| 326 |
+
image_patches_flat)
|
| 327 |
+
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
| 328 |
+
|
| 329 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 330 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 331 |
+
if image_input is None:
|
| 332 |
+
return None
|
| 333 |
+
vision_embeddings = self._process_image_input(image_input)
|
| 334 |
+
return vision_embeddings
|
| 335 |
+
|
| 336 |
+
def get_input_embeddings(
|
| 337 |
+
self,
|
| 338 |
+
input_ids: torch.Tensor,
|
| 339 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 342 |
+
if multimodal_embeddings is not None:
|
| 343 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 344 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 345 |
+
_IMAGE_TOKEN_ID)
|
| 346 |
+
return inputs_embeds
|
| 347 |
+
|
| 348 |
+
def forward(
|
| 349 |
+
self,
|
| 350 |
+
input_ids: torch.Tensor,
|
| 351 |
+
positions: torch.Tensor,
|
| 352 |
+
kv_caches: List[torch.Tensor],
|
| 353 |
+
attn_metadata: AttentionMetadata,
|
| 354 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 355 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 356 |
+
**kwargs: object,
|
| 357 |
+
):
|
| 358 |
+
if intermediate_tensors is not None:
|
| 359 |
+
inputs_embeds = None
|
| 360 |
+
|
| 361 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 362 |
+
# condition is for v0 compatibility.
|
| 363 |
+
elif inputs_embeds is None:
|
| 364 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 365 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 366 |
+
vision_embeddings)
|
| 367 |
+
input_ids = None
|
| 368 |
+
|
| 369 |
+
hidden_states = self.language_model(
|
| 370 |
+
input_ids=input_ids,
|
| 371 |
+
positions=positions,
|
| 372 |
+
kv_caches=kv_caches,
|
| 373 |
+
attn_metadata=attn_metadata,
|
| 374 |
+
intermediate_tensors=intermediate_tensors,
|
| 375 |
+
inputs_embeds=inputs_embeds,
|
| 376 |
+
)
|
| 377 |
+
return hidden_states
|
| 378 |
+
|
| 379 |
+
def compute_logits(
|
| 380 |
+
self,
|
| 381 |
+
hidden_states: torch.Tensor,
|
| 382 |
+
sampling_metadata: SamplingMetadata,
|
| 383 |
+
) -> Optional[torch.Tensor]:
|
| 384 |
+
logits = self.language_model.logits_processor(
|
| 385 |
+
self.language_model.lm_head, hidden_states, sampling_metadata)
|
| 386 |
+
return logits
|
| 387 |
+
|
| 388 |
+
def sample(
|
| 389 |
+
self,
|
| 390 |
+
logits: torch.Tensor,
|
| 391 |
+
sampling_metadata: SamplingMetadata,
|
| 392 |
+
) -> Optional[SamplerOutput]:
|
| 393 |
+
next_tokens = self.language_model.sampler(logits, sampling_metadata)
|
| 394 |
+
return next_tokens
|
| 395 |
+
|
| 396 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 397 |
+
torch.Tensor]]) -> Set[str]:
|
| 398 |
+
loader = AutoWeightsLoader(self)
|
| 399 |
+
return loader.load_weights(weights)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
# Copyright (c) Google Inc.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
| 18 |
+
from functools import cache
|
| 19 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
from transformers import GemmaConfig
|
| 24 |
+
|
| 25 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 26 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 27 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 28 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 29 |
+
from vllm.logger import init_logger
|
| 30 |
+
from vllm.model_executor.layers.activation import GeluAndMul
|
| 31 |
+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
| 32 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 33 |
+
QKVParallelLinear,
|
| 34 |
+
RowParallelLinear)
|
| 35 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 36 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 37 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 38 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 39 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 40 |
+
VocabParallelEmbedding)
|
| 41 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 42 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 46 |
+
from .utils import (is_pp_missing_parameter,
|
| 47 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 48 |
+
maybe_prefix)
|
| 49 |
+
|
| 50 |
+
logger = init_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@cache
|
| 54 |
+
def _get_gemma_act_fn(
|
| 55 |
+
hidden_act: Optional[str],
|
| 56 |
+
hidden_activation: Optional[str],
|
| 57 |
+
) -> nn.Module:
|
| 58 |
+
if hidden_activation is None:
|
| 59 |
+
if hidden_act is not None:
|
| 60 |
+
logger.warning(
|
| 61 |
+
"Gemma's activation function was incorrectly set to exact GeLU "
|
| 62 |
+
"in the config JSON file when it was initially released. "
|
| 63 |
+
"Changing the activation function to approximate GeLU "
|
| 64 |
+
"(`gelu_pytorch_tanh`). If you want to use the legacy "
|
| 65 |
+
"`%s`, edit the config JSON to set "
|
| 66 |
+
"`hidden_activation=%s` instead of `hidden_act`. "
|
| 67 |
+
"See https://github.com/huggingface/transformers/pull/29402 "
|
| 68 |
+
"for more details.", hidden_act, hidden_act)
|
| 69 |
+
return GeluAndMul(approximate="tanh")
|
| 70 |
+
elif hidden_activation == "gelu_pytorch_tanh":
|
| 71 |
+
return GeluAndMul(approximate="tanh")
|
| 72 |
+
elif hidden_activation == "gelu":
|
| 73 |
+
return GeluAndMul(approximate="none")
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Activation function {hidden_act} is not "
|
| 76 |
+
"supported for Gemma models.")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class GemmaMLP(nn.Module):
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
hidden_size: int,
|
| 84 |
+
intermediate_size: int,
|
| 85 |
+
hidden_act: Optional[str] = None,
|
| 86 |
+
hidden_activation: Optional[str] = None,
|
| 87 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 88 |
+
prefix: str = "",
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 92 |
+
hidden_size,
|
| 93 |
+
[intermediate_size] * 2,
|
| 94 |
+
bias=False,
|
| 95 |
+
quant_config=quant_config,
|
| 96 |
+
prefix=f"{prefix}.gate_up_proj",
|
| 97 |
+
)
|
| 98 |
+
self.down_proj = RowParallelLinear(
|
| 99 |
+
intermediate_size,
|
| 100 |
+
hidden_size,
|
| 101 |
+
bias=False,
|
| 102 |
+
quant_config=quant_config,
|
| 103 |
+
prefix=f"{prefix}.down_proj",
|
| 104 |
+
)
|
| 105 |
+
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 109 |
+
x = self.act_fn(gate_up)
|
| 110 |
+
x, _ = self.down_proj(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class GemmaAttention(nn.Module):
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
hidden_size: int,
|
| 119 |
+
num_heads: int,
|
| 120 |
+
num_kv_heads: int,
|
| 121 |
+
head_dim: int,
|
| 122 |
+
max_position_embeddings: int = 8192,
|
| 123 |
+
rope_theta: float = 10000,
|
| 124 |
+
cache_config: Optional[CacheConfig] = None,
|
| 125 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 126 |
+
prefix: str = "",
|
| 127 |
+
) -> None:
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.hidden_size = hidden_size
|
| 130 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 131 |
+
self.total_num_heads = num_heads
|
| 132 |
+
assert self.total_num_heads % tp_size == 0
|
| 133 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 134 |
+
self.total_num_kv_heads = num_kv_heads
|
| 135 |
+
if self.total_num_kv_heads >= tp_size:
|
| 136 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 137 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 138 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 139 |
+
else:
|
| 140 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 141 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 142 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 143 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 144 |
+
self.head_dim = head_dim
|
| 145 |
+
self.q_size = self.num_heads * self.head_dim
|
| 146 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 147 |
+
self.scaling = self.head_dim**-0.5
|
| 148 |
+
self.rope_theta = rope_theta
|
| 149 |
+
|
| 150 |
+
self.qkv_proj = QKVParallelLinear(
|
| 151 |
+
hidden_size,
|
| 152 |
+
self.head_dim,
|
| 153 |
+
self.total_num_heads,
|
| 154 |
+
self.total_num_kv_heads,
|
| 155 |
+
bias=False,
|
| 156 |
+
quant_config=quant_config,
|
| 157 |
+
prefix=f"{prefix}.qkv_proj",
|
| 158 |
+
)
|
| 159 |
+
self.o_proj = RowParallelLinear(
|
| 160 |
+
self.total_num_heads * self.head_dim,
|
| 161 |
+
hidden_size,
|
| 162 |
+
bias=False,
|
| 163 |
+
quant_config=quant_config,
|
| 164 |
+
prefix=f"{prefix}.o_proj",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self.rotary_emb = get_rope(
|
| 168 |
+
self.head_dim,
|
| 169 |
+
rotary_dim=self.head_dim,
|
| 170 |
+
max_position=max_position_embeddings,
|
| 171 |
+
base=self.rope_theta,
|
| 172 |
+
is_neox_style=True,
|
| 173 |
+
)
|
| 174 |
+
self.attn = Attention(self.num_heads,
|
| 175 |
+
self.head_dim,
|
| 176 |
+
self.scaling,
|
| 177 |
+
num_kv_heads=self.num_kv_heads,
|
| 178 |
+
cache_config=cache_config,
|
| 179 |
+
quant_config=quant_config,
|
| 180 |
+
prefix=f"{prefix}.attn")
|
| 181 |
+
|
| 182 |
+
def forward(
|
| 183 |
+
self,
|
| 184 |
+
positions: torch.Tensor,
|
| 185 |
+
hidden_states: torch.Tensor,
|
| 186 |
+
kv_cache: torch.Tensor,
|
| 187 |
+
attn_metadata: AttentionMetadata,
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 190 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 191 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 192 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 193 |
+
output, _ = self.o_proj(attn_output)
|
| 194 |
+
return output
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class GemmaDecoderLayer(nn.Module):
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
config: GemmaConfig,
|
| 202 |
+
cache_config: Optional[CacheConfig] = None,
|
| 203 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 204 |
+
prefix: str = "",
|
| 205 |
+
) -> None:
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.hidden_size = config.hidden_size
|
| 208 |
+
self.self_attn = GemmaAttention(
|
| 209 |
+
hidden_size=self.hidden_size,
|
| 210 |
+
num_heads=config.num_attention_heads,
|
| 211 |
+
num_kv_heads=config.num_key_value_heads,
|
| 212 |
+
head_dim=config.head_dim,
|
| 213 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 214 |
+
rope_theta=config.rope_theta,
|
| 215 |
+
cache_config=cache_config,
|
| 216 |
+
quant_config=quant_config,
|
| 217 |
+
prefix=f"{prefix}.self_attn",
|
| 218 |
+
)
|
| 219 |
+
self.mlp = GemmaMLP(
|
| 220 |
+
hidden_size=self.hidden_size,
|
| 221 |
+
intermediate_size=config.intermediate_size,
|
| 222 |
+
hidden_act=config.hidden_act,
|
| 223 |
+
hidden_activation=getattr(config, "hidden_activation", None),
|
| 224 |
+
quant_config=quant_config,
|
| 225 |
+
prefix=f"{prefix}.mlp",
|
| 226 |
+
)
|
| 227 |
+
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
| 228 |
+
eps=config.rms_norm_eps)
|
| 229 |
+
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
| 230 |
+
eps=config.rms_norm_eps)
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
positions: torch.Tensor,
|
| 235 |
+
hidden_states: torch.Tensor,
|
| 236 |
+
kv_cache: torch.Tensor,
|
| 237 |
+
attn_metadata: AttentionMetadata,
|
| 238 |
+
residual: Optional[torch.Tensor],
|
| 239 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 240 |
+
# Self Attention
|
| 241 |
+
if residual is None:
|
| 242 |
+
residual = hidden_states
|
| 243 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 244 |
+
else:
|
| 245 |
+
hidden_states, residual = self.input_layernorm(
|
| 246 |
+
hidden_states, residual)
|
| 247 |
+
hidden_states = self.self_attn(
|
| 248 |
+
positions=positions,
|
| 249 |
+
hidden_states=hidden_states,
|
| 250 |
+
kv_cache=kv_cache,
|
| 251 |
+
attn_metadata=attn_metadata,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Fully Connected
|
| 255 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 256 |
+
hidden_states, residual)
|
| 257 |
+
hidden_states = self.mlp(hidden_states)
|
| 258 |
+
return hidden_states, residual
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@support_torch_compile
|
| 262 |
+
class GemmaModel(nn.Module):
|
| 263 |
+
|
| 264 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 265 |
+
super().__init__()
|
| 266 |
+
|
| 267 |
+
config = vllm_config.model_config.hf_config
|
| 268 |
+
cache_config = vllm_config.cache_config
|
| 269 |
+
quant_config = vllm_config.quant_config
|
| 270 |
+
|
| 271 |
+
self.config = config
|
| 272 |
+
|
| 273 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 274 |
+
config.vocab_size,
|
| 275 |
+
config.hidden_size,
|
| 276 |
+
)
|
| 277 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 278 |
+
config.num_hidden_layers,
|
| 279 |
+
lambda prefix: GemmaDecoderLayer(
|
| 280 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 281 |
+
prefix=f"{prefix}.layers")
|
| 282 |
+
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 283 |
+
|
| 284 |
+
# Normalize the embedding by sqrt(hidden_size)
|
| 285 |
+
# The normalizer's data type should be downcasted to the model's
|
| 286 |
+
# data type such as bfloat16, not float32.
|
| 287 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
| 288 |
+
normalizer = self.config.hidden_size**0.5
|
| 289 |
+
self.register_buffer("normalizer", torch.tensor(normalizer))
|
| 290 |
+
self.make_empty_intermediate_tensors = (
|
| 291 |
+
make_empty_intermediate_tensors_factory(
|
| 292 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 293 |
+
|
| 294 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 295 |
+
return self.embed_tokens(input_ids)
|
| 296 |
+
|
| 297 |
+
def forward(
|
| 298 |
+
self,
|
| 299 |
+
input_ids: torch.Tensor,
|
| 300 |
+
positions: torch.Tensor,
|
| 301 |
+
kv_caches: List[torch.Tensor],
|
| 302 |
+
attn_metadata: AttentionMetadata,
|
| 303 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 304 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 305 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 306 |
+
if get_pp_group().is_first_rank:
|
| 307 |
+
if inputs_embeds is not None:
|
| 308 |
+
hidden_states = inputs_embeds
|
| 309 |
+
else:
|
| 310 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 311 |
+
hidden_states *= self.normalizer
|
| 312 |
+
residual = None
|
| 313 |
+
else:
|
| 314 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 315 |
+
residual = intermediate_tensors["residual"]
|
| 316 |
+
for i in range(self.start_layer, self.end_layer):
|
| 317 |
+
layer = self.layers[i]
|
| 318 |
+
hidden_states, residual = layer(
|
| 319 |
+
positions,
|
| 320 |
+
hidden_states,
|
| 321 |
+
kv_caches[i - self.start_layer],
|
| 322 |
+
attn_metadata,
|
| 323 |
+
residual,
|
| 324 |
+
)
|
| 325 |
+
if not get_pp_group().is_last_rank:
|
| 326 |
+
return IntermediateTensors({
|
| 327 |
+
"hidden_states": hidden_states,
|
| 328 |
+
"residual": residual
|
| 329 |
+
})
|
| 330 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 331 |
+
return hidden_states
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 335 |
+
packed_modules_mapping = {
|
| 336 |
+
"qkv_proj": [
|
| 337 |
+
"q_proj",
|
| 338 |
+
"k_proj",
|
| 339 |
+
"v_proj",
|
| 340 |
+
],
|
| 341 |
+
"gate_up_proj": [
|
| 342 |
+
"gate_proj",
|
| 343 |
+
"up_proj",
|
| 344 |
+
],
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
# LoRA specific attributes
|
| 348 |
+
supported_lora_modules = [
|
| 349 |
+
"qkv_proj",
|
| 350 |
+
"o_proj",
|
| 351 |
+
"gate_up_proj",
|
| 352 |
+
"down_proj",
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
# Gemma does not apply LoRA to the embedding layer.
|
| 356 |
+
embedding_modules = {}
|
| 357 |
+
embedding_padding_modules = []
|
| 358 |
+
|
| 359 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 360 |
+
super().__init__()
|
| 361 |
+
config = vllm_config.model_config.hf_config
|
| 362 |
+
quant_config = vllm_config.quant_config
|
| 363 |
+
lora_config = vllm_config.lora_config
|
| 364 |
+
|
| 365 |
+
self.config = config
|
| 366 |
+
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
| 367 |
+
assert config.tie_word_embeddings
|
| 368 |
+
self.lora_config = lora_config
|
| 369 |
+
|
| 370 |
+
self.quant_config = quant_config
|
| 371 |
+
self.model = GemmaModel(vllm_config=vllm_config,
|
| 372 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 373 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 374 |
+
self.sampler = get_sampler()
|
| 375 |
+
self.make_empty_intermediate_tensors = (
|
| 376 |
+
self.model.make_empty_intermediate_tensors)
|
| 377 |
+
|
| 378 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 379 |
+
return self.model.get_input_embeddings(input_ids)
|
| 380 |
+
|
| 381 |
+
def forward(
|
| 382 |
+
self,
|
| 383 |
+
input_ids: torch.Tensor,
|
| 384 |
+
positions: torch.Tensor,
|
| 385 |
+
kv_caches: List[torch.Tensor],
|
| 386 |
+
attn_metadata: AttentionMetadata,
|
| 387 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 388 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 389 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 390 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 391 |
+
attn_metadata, intermediate_tensors,
|
| 392 |
+
inputs_embeds)
|
| 393 |
+
return hidden_states
|
| 394 |
+
|
| 395 |
+
def compute_logits(
|
| 396 |
+
self,
|
| 397 |
+
hidden_states: torch.Tensor,
|
| 398 |
+
sampling_metadata: SamplingMetadata,
|
| 399 |
+
) -> Optional[torch.Tensor]:
|
| 400 |
+
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
| 401 |
+
sampling_metadata)
|
| 402 |
+
return logits
|
| 403 |
+
|
| 404 |
+
def sample(
|
| 405 |
+
self,
|
| 406 |
+
logits: torch.Tensor,
|
| 407 |
+
sampling_metadata: SamplingMetadata,
|
| 408 |
+
) -> Optional[SamplerOutput]:
|
| 409 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 410 |
+
return next_tokens
|
| 411 |
+
|
| 412 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 413 |
+
torch.Tensor]]) -> Set[str]:
|
| 414 |
+
stacked_params_mapping = [
|
| 415 |
+
# (param_name, shard_name, shard_id)
|
| 416 |
+
("qkv_proj", "q_proj", "q"),
|
| 417 |
+
("qkv_proj", "k_proj", "k"),
|
| 418 |
+
("qkv_proj", "v_proj", "v"),
|
| 419 |
+
("gate_up_proj", "gate_proj", 0),
|
| 420 |
+
("gate_up_proj", "up_proj", 1),
|
| 421 |
+
]
|
| 422 |
+
params_dict = dict(self.named_parameters())
|
| 423 |
+
loaded_params: Set[str] = set()
|
| 424 |
+
for name, loaded_weight in weights:
|
| 425 |
+
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
| 426 |
+
if shard_name not in name:
|
| 427 |
+
continue
|
| 428 |
+
name = name.replace(shard_name, param_name)
|
| 429 |
+
# Skip loading extra bias for GPTQ models.
|
| 430 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 431 |
+
continue
|
| 432 |
+
if is_pp_missing_parameter(name, self):
|
| 433 |
+
continue
|
| 434 |
+
param = params_dict[name]
|
| 435 |
+
weight_loader = param.weight_loader
|
| 436 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 437 |
+
break
|
| 438 |
+
else:
|
| 439 |
+
# lm_head is not used in vllm as it is tied with embed_token.
|
| 440 |
+
# To prevent errors, skip loading lm_head.weight.
|
| 441 |
+
if "lm_head.weight" in name:
|
| 442 |
+
continue
|
| 443 |
+
# Skip loading extra bias for GPTQ models.
|
| 444 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 445 |
+
continue
|
| 446 |
+
if is_pp_missing_parameter(name, self):
|
| 447 |
+
continue
|
| 448 |
+
param = params_dict[name]
|
| 449 |
+
weight_loader = getattr(param, "weight_loader",
|
| 450 |
+
default_weight_loader)
|
| 451 |
+
weight_loader(param, loaded_weight)
|
| 452 |
+
loaded_params.add(name)
|
| 453 |
+
unloaded_params = params_dict.keys() - loaded_params
|
| 454 |
+
if unloaded_params:
|
| 455 |
+
logger.warning(
|
| 456 |
+
"Some weights are not initialized from checkpoints: %s",
|
| 457 |
+
unloaded_params)
|
| 458 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/glm4_vision_encoder.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/THUDM/GLM-4
|
| 5 |
+
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
|
| 6 |
+
from argparse import Namespace
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import LayerNorm
|
| 12 |
+
|
| 13 |
+
from vllm.attention.layer import MultiHeadAttention
|
| 14 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 15 |
+
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
| 16 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 17 |
+
MergedColumnParallelLinear,
|
| 18 |
+
QKVParallelLinear,
|
| 19 |
+
ReplicatedLinear,
|
| 20 |
+
RowParallelLinear)
|
| 21 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 22 |
+
QuantizationConfig)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbedding(nn.Module):
|
| 26 |
+
|
| 27 |
+
def __init__(self, config):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.proj = nn.Conv2d(config.in_channels,
|
| 30 |
+
config.hidden_size,
|
| 31 |
+
kernel_size=config.patch_size,
|
| 32 |
+
stride=config.patch_size)
|
| 33 |
+
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
| 34 |
+
self.position_embedding = nn.Embedding(config.num_positions,
|
| 35 |
+
config.hidden_size)
|
| 36 |
+
|
| 37 |
+
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
"""
|
| 39 |
+
Parameters:
|
| 40 |
+
images : torch.Tensor
|
| 41 |
+
Input image tensor with shape (B, C, H, W)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
torch.Tensor
|
| 45 |
+
Transformed tensor with shape (B, L, D)
|
| 46 |
+
"""
|
| 47 |
+
images = images.to(device=self.proj.weight.device,
|
| 48 |
+
dtype=self.proj.weight.dtype)
|
| 49 |
+
x = self.proj(images)
|
| 50 |
+
x = x.flatten(2).transpose(1, 2)
|
| 51 |
+
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
| 52 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 53 |
+
x += self.position_embedding.weight.unsqueeze(0)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Attention(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
config,
|
| 62 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 63 |
+
prefix: str = '',
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.hidden_size = config.hidden_size
|
| 67 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 68 |
+
self.num_heads_per_rank = config.num_heads // self.tp_size
|
| 69 |
+
self.head_dim = config.hidden_size // config.num_heads
|
| 70 |
+
self.scale = self.head_dim**-0.5
|
| 71 |
+
|
| 72 |
+
self.query_key_value = QKVParallelLinear(
|
| 73 |
+
config.hidden_size,
|
| 74 |
+
self.head_dim,
|
| 75 |
+
config.num_heads,
|
| 76 |
+
quant_config=quant_config,
|
| 77 |
+
prefix=f"{prefix}.query_key_value",
|
| 78 |
+
)
|
| 79 |
+
self.dense = RowParallelLinear(
|
| 80 |
+
config.hidden_size,
|
| 81 |
+
config.hidden_size,
|
| 82 |
+
quant_config=quant_config,
|
| 83 |
+
prefix=f"{prefix}.dense",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
|
| 87 |
+
self.scale)
|
| 88 |
+
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
|
| 92 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 93 |
+
|
| 94 |
+
out = self.attn(q, k, v)
|
| 95 |
+
output, _ = self.dense(out)
|
| 96 |
+
output = self.output_dropout(output)
|
| 97 |
+
return output
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class MLP(nn.Module):
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
config,
|
| 105 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 106 |
+
prefix: str = '',
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.config = config
|
| 110 |
+
self.activation_fn = get_act_fn(config.hidden_act)
|
| 111 |
+
self.fc1 = ColumnParallelLinear(
|
| 112 |
+
config.hidden_size,
|
| 113 |
+
config.intermediate_size,
|
| 114 |
+
quant_config=quant_config,
|
| 115 |
+
prefix=f"{prefix}.fc1",
|
| 116 |
+
)
|
| 117 |
+
self.fc2 = RowParallelLinear(
|
| 118 |
+
config.intermediate_size,
|
| 119 |
+
config.hidden_size,
|
| 120 |
+
quant_config=quant_config,
|
| 121 |
+
prefix=f"{prefix}.fc2",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
x, _ = self.fc1(x)
|
| 126 |
+
x = self.activation_fn(x)
|
| 127 |
+
x, _ = self.fc2(x)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class TransformerLayer(nn.Module):
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
config,
|
| 136 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 137 |
+
prefix: str = '',
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.input_layernorm = LayerNorm(config.hidden_size,
|
| 141 |
+
eps=config.layer_norm_eps)
|
| 142 |
+
self.attention = Attention(config,
|
| 143 |
+
quant_config=quant_config,
|
| 144 |
+
prefix=f"{prefix}.attention")
|
| 145 |
+
self.mlp = MLP(config,
|
| 146 |
+
quant_config=quant_config,
|
| 147 |
+
prefix=f"{prefix}.mlp")
|
| 148 |
+
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
| 149 |
+
eps=config.layer_norm_eps)
|
| 150 |
+
|
| 151 |
+
def forward(self, hidden_states):
|
| 152 |
+
attention_input = hidden_states
|
| 153 |
+
attention_output = self.input_layernorm(
|
| 154 |
+
self.attention(attention_input))
|
| 155 |
+
hidden_states = attention_input + attention_output
|
| 156 |
+
mlp_input = hidden_states
|
| 157 |
+
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
|
| 158 |
+
output = mlp_input + mlp_output
|
| 159 |
+
return output
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Transformer(nn.Module):
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
config,
|
| 167 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 168 |
+
prefix: str = '',
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.layers = nn.ModuleList([
|
| 172 |
+
TransformerLayer(config,
|
| 173 |
+
quant_config=quant_config,
|
| 174 |
+
prefix=f"{prefix}.layers.{layer_idx}")
|
| 175 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 176 |
+
])
|
| 177 |
+
|
| 178 |
+
def forward(self, hidden_states):
|
| 179 |
+
for layer_module in self.layers:
|
| 180 |
+
hidden_states = layer_module(hidden_states)
|
| 181 |
+
return hidden_states
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class GLU(nn.Module):
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
config,
|
| 189 |
+
in_features,
|
| 190 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 191 |
+
prefix: str = '',
|
| 192 |
+
):
|
| 193 |
+
"""
|
| 194 |
+
The original implementation is the same as:
|
| 195 |
+
```python
|
| 196 |
+
self.dense_h_to_4h = ColumnParallelLinear(
|
| 197 |
+
config.hidden_size,
|
| 198 |
+
config.ffn_hidden_size,
|
| 199 |
+
bias=False,
|
| 200 |
+
quant_config=quant_config
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.gate_proj = ColumnParallelLinear(
|
| 204 |
+
config.hidden_size,
|
| 205 |
+
config.ffn_hidden_size,
|
| 206 |
+
bias=False,
|
| 207 |
+
quant_config=quant_config
|
| 208 |
+
)
|
| 209 |
+
```
|
| 210 |
+
```
|
| 211 |
+
gate_proj_output, _ = self.gate_proj(x)
|
| 212 |
+
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
|
| 213 |
+
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
|
| 217 |
+
```
|
| 218 |
+
self.merged_proj = MergedColumnParallelLinear(
|
| 219 |
+
config.hidden_size,
|
| 220 |
+
[config.ffn_hidden_size] * 2,
|
| 221 |
+
bias=False,
|
| 222 |
+
quant_config=quant_config
|
| 223 |
+
)
|
| 224 |
+
```
|
| 225 |
+
```
|
| 226 |
+
x, _ = self.merged_proj(x)
|
| 227 |
+
```
|
| 228 |
+
"""
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.linear_proj = ReplicatedLinear(in_features,
|
| 231 |
+
config.hidden_size,
|
| 232 |
+
bias=False,
|
| 233 |
+
quant_config=quant_config,
|
| 234 |
+
prefix=f"{prefix}.linear_proj")
|
| 235 |
+
self.norm1 = nn.LayerNorm(config.hidden_size)
|
| 236 |
+
self.act1 = nn.GELU()
|
| 237 |
+
self.act2 = SiluAndMul()
|
| 238 |
+
|
| 239 |
+
self.merged_proj = MergedColumnParallelLinear(
|
| 240 |
+
config.hidden_size, [config.ffn_hidden_size] * 2,
|
| 241 |
+
bias=False,
|
| 242 |
+
quant_config=quant_config,
|
| 243 |
+
prefix=f"{prefix}.merged_proj")
|
| 244 |
+
|
| 245 |
+
self.dense_4h_to_h = RowParallelLinear(
|
| 246 |
+
config.ffn_hidden_size,
|
| 247 |
+
config.hidden_size,
|
| 248 |
+
bias=False,
|
| 249 |
+
quant_config=quant_config,
|
| 250 |
+
prefix=f"{prefix}.dense_4h_to_h")
|
| 251 |
+
|
| 252 |
+
def forward(self, x):
|
| 253 |
+
x, _ = self.linear_proj(x)
|
| 254 |
+
x = self.act1(self.norm1(x))
|
| 255 |
+
x, _ = self.merged_proj(x)
|
| 256 |
+
x = self.act2(x)
|
| 257 |
+
x, _ = self.dense_4h_to_h(x)
|
| 258 |
+
return x
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class EVA2CLIPModel(nn.Module):
|
| 262 |
+
|
| 263 |
+
def __init__(
|
| 264 |
+
self,
|
| 265 |
+
config,
|
| 266 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 267 |
+
prefix: str = '',
|
| 268 |
+
):
|
| 269 |
+
super().__init__()
|
| 270 |
+
vision_config = Namespace(**config.vision_config)
|
| 271 |
+
self.patch_embedding = PatchEmbedding(vision_config)
|
| 272 |
+
self.transformer = Transformer(vision_config,
|
| 273 |
+
quant_config=quant_config,
|
| 274 |
+
prefix=f"{prefix}.transformer")
|
| 275 |
+
self.linear_proj = GLU(config,
|
| 276 |
+
in_features=config.hidden_size,
|
| 277 |
+
quant_config=quant_config,
|
| 278 |
+
prefix=f"{prefix}.linear_proj")
|
| 279 |
+
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
|
| 280 |
+
out_channels=config.hidden_size,
|
| 281 |
+
kernel_size=2,
|
| 282 |
+
stride=2)
|
| 283 |
+
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 284 |
+
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 285 |
+
self.scaling_factor = vision_config.scaling_factor
|
| 286 |
+
|
| 287 |
+
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
| 288 |
+
"""
|
| 289 |
+
Parameters:
|
| 290 |
+
images : torch.Tensor
|
| 291 |
+
Input image tensor with shape (B, C, H, W)
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
torch.Tensor
|
| 295 |
+
Transformed tensor with shape (B, L, D)
|
| 296 |
+
"""
|
| 297 |
+
x = self.patch_embedding(images)
|
| 298 |
+
x = self.transformer(x)
|
| 299 |
+
x = x[:, 1:]
|
| 300 |
+
|
| 301 |
+
b, s, h = x.shape
|
| 302 |
+
grid_size = int(s**0.5)
|
| 303 |
+
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
| 304 |
+
x = self.conv(x)
|
| 305 |
+
|
| 306 |
+
x = x.flatten(2).transpose(1, 2)
|
| 307 |
+
x = self.linear_proj(x)
|
| 308 |
+
boi = self.boi.expand(x.shape[0], -1, -1)
|
| 309 |
+
eoi = self.eoi.expand(x.shape[0], -1, -1)
|
| 310 |
+
x = torch.cat((boi, x, eoi), dim=1)
|
| 311 |
+
x = x / self.scaling_factor
|
| 312 |
+
return x
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt2.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 7 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
| 21 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
from transformers import GPT2Config
|
| 26 |
+
|
| 27 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 28 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 29 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 30 |
+
from vllm.distributed.parallel_state import (
|
| 31 |
+
get_pp_group, get_tensor_model_parallel_world_size)
|
| 32 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 33 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 34 |
+
QKVParallelLinear,
|
| 35 |
+
RowParallelLinear)
|
| 36 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 37 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 38 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 39 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 40 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 41 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 42 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsPP
|
| 46 |
+
from .utils import (is_pp_missing_parameter,
|
| 47 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 48 |
+
maybe_prefix)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class GPT2Attention(nn.Module):
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
config: GPT2Config,
|
| 56 |
+
cache_config: Optional[CacheConfig] = None,
|
| 57 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 58 |
+
prefix: str = "",
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.hidden_size = config.hidden_size
|
| 62 |
+
total_num_heads = config.num_attention_heads
|
| 63 |
+
tensor_model_parallel_world_size = (
|
| 64 |
+
get_tensor_model_parallel_world_size())
|
| 65 |
+
assert total_num_heads % tensor_model_parallel_world_size == 0
|
| 66 |
+
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
| 67 |
+
self.head_dim = self.hidden_size // total_num_heads
|
| 68 |
+
self.scale = self.head_dim**-0.5
|
| 69 |
+
|
| 70 |
+
self.c_attn = QKVParallelLinear(
|
| 71 |
+
self.hidden_size,
|
| 72 |
+
self.head_dim,
|
| 73 |
+
total_num_heads,
|
| 74 |
+
bias=True,
|
| 75 |
+
quant_config=quant_config,
|
| 76 |
+
prefix=f"{prefix}.c_attn",
|
| 77 |
+
)
|
| 78 |
+
self.c_proj = RowParallelLinear(
|
| 79 |
+
self.hidden_size,
|
| 80 |
+
self.hidden_size,
|
| 81 |
+
bias=True,
|
| 82 |
+
quant_config=quant_config,
|
| 83 |
+
prefix=f"{prefix}.c_proj",
|
| 84 |
+
)
|
| 85 |
+
self.attn = Attention(self.num_heads,
|
| 86 |
+
self.head_dim,
|
| 87 |
+
scale=self.scale,
|
| 88 |
+
cache_config=cache_config,
|
| 89 |
+
quant_config=quant_config,
|
| 90 |
+
prefix=f"{prefix}.attn")
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
hidden_states: torch.Tensor,
|
| 95 |
+
kv_cache: torch.Tensor,
|
| 96 |
+
attn_metadata: AttentionMetadata,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
qkv, _ = self.c_attn(hidden_states)
|
| 99 |
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
| 100 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 101 |
+
attn_output, _ = self.c_proj(attn_output)
|
| 102 |
+
return attn_output
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class GPT2MLP(nn.Module):
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
intermediate_size: int,
|
| 110 |
+
config: GPT2Config,
|
| 111 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 112 |
+
prefix: str = "",
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
hidden_size = config.hidden_size
|
| 116 |
+
self.c_fc = ColumnParallelLinear(
|
| 117 |
+
hidden_size,
|
| 118 |
+
intermediate_size,
|
| 119 |
+
bias=True,
|
| 120 |
+
quant_config=quant_config,
|
| 121 |
+
prefix=f"{prefix}.c_fc",
|
| 122 |
+
)
|
| 123 |
+
self.c_proj = RowParallelLinear(
|
| 124 |
+
intermediate_size,
|
| 125 |
+
hidden_size,
|
| 126 |
+
bias=True,
|
| 127 |
+
quant_config=quant_config,
|
| 128 |
+
prefix=f"{prefix}.c_proj",
|
| 129 |
+
)
|
| 130 |
+
self.act = get_act_fn(config.activation_function)
|
| 131 |
+
|
| 132 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
hidden_states, _ = self.c_fc(hidden_states)
|
| 134 |
+
hidden_states = self.act(hidden_states)
|
| 135 |
+
hidden_states, _ = self.c_proj(hidden_states)
|
| 136 |
+
return hidden_states
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class GPT2Block(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
config: GPT2Config,
|
| 144 |
+
cache_config: Optional[CacheConfig] = None,
|
| 145 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 146 |
+
prefix: str = "",
|
| 147 |
+
):
|
| 148 |
+
super().__init__()
|
| 149 |
+
hidden_size = config.hidden_size
|
| 150 |
+
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
| 151 |
+
hidden_size)
|
| 152 |
+
|
| 153 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 154 |
+
self.attn = GPT2Attention(config,
|
| 155 |
+
cache_config,
|
| 156 |
+
quant_config,
|
| 157 |
+
prefix=f"{prefix}.attn")
|
| 158 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 159 |
+
self.mlp = GPT2MLP(inner_dim,
|
| 160 |
+
config,
|
| 161 |
+
quant_config,
|
| 162 |
+
prefix=f"{prefix}.mlp")
|
| 163 |
+
|
| 164 |
+
def forward(
|
| 165 |
+
self,
|
| 166 |
+
hidden_states: torch.Tensor,
|
| 167 |
+
kv_cache: torch.Tensor,
|
| 168 |
+
attn_metadata: AttentionMetadata,
|
| 169 |
+
) -> torch.Tensor:
|
| 170 |
+
residual = hidden_states
|
| 171 |
+
hidden_states = self.ln_1(hidden_states)
|
| 172 |
+
attn_output = self.attn(
|
| 173 |
+
hidden_states=hidden_states,
|
| 174 |
+
kv_cache=kv_cache,
|
| 175 |
+
attn_metadata=attn_metadata,
|
| 176 |
+
)
|
| 177 |
+
# residual connection
|
| 178 |
+
hidden_states = attn_output + residual
|
| 179 |
+
|
| 180 |
+
residual = hidden_states
|
| 181 |
+
hidden_states = self.ln_2(hidden_states)
|
| 182 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
| 183 |
+
# residual connection
|
| 184 |
+
hidden_states = residual + feed_forward_hidden_states
|
| 185 |
+
return hidden_states
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@support_torch_compile
|
| 189 |
+
class GPT2Model(nn.Module):
|
| 190 |
+
|
| 191 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 192 |
+
super().__init__()
|
| 193 |
+
|
| 194 |
+
config = vllm_config.model_config.hf_config
|
| 195 |
+
cache_config = vllm_config.cache_config
|
| 196 |
+
quant_config = vllm_config.quant_config
|
| 197 |
+
|
| 198 |
+
self.config = config
|
| 199 |
+
assert not config.add_cross_attention
|
| 200 |
+
assert not config.scale_attn_by_inverse_layer_idx
|
| 201 |
+
assert not config.reorder_and_upcast_attn
|
| 202 |
+
self.embed_dim = config.hidden_size
|
| 203 |
+
self.wte = VocabParallelEmbedding(config.vocab_size,
|
| 204 |
+
self.embed_dim,
|
| 205 |
+
quant_config=quant_config,
|
| 206 |
+
prefix=f"{prefix}.wte")
|
| 207 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
| 208 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 209 |
+
config.num_hidden_layers,
|
| 210 |
+
lambda prefix: GPT2Block(
|
| 211 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 212 |
+
prefix=f"{prefix}.h")
|
| 213 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 214 |
+
self.make_empty_intermediate_tensors = (
|
| 215 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 216 |
+
config.n_embd))
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
return self.wte(input_ids)
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
input_ids: torch.Tensor,
|
| 224 |
+
position_ids: torch.Tensor,
|
| 225 |
+
kv_caches: List[torch.Tensor],
|
| 226 |
+
attn_metadata: AttentionMetadata,
|
| 227 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 228 |
+
inputs_embeds: Optional[torch.Tensor],
|
| 229 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 230 |
+
if get_pp_group().is_first_rank:
|
| 231 |
+
if inputs_embeds is None:
|
| 232 |
+
inputs_embeds = self.get_input_embeddings(input_ids)
|
| 233 |
+
position_embeds = self.wpe(position_ids)
|
| 234 |
+
hidden_states = inputs_embeds + position_embeds
|
| 235 |
+
else:
|
| 236 |
+
assert intermediate_tensors is not None
|
| 237 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 238 |
+
|
| 239 |
+
for i in range(self.start_layer, self.end_layer):
|
| 240 |
+
layer = self.h[i]
|
| 241 |
+
hidden_states = layer(hidden_states,
|
| 242 |
+
kv_caches[i - self.start_layer],
|
| 243 |
+
attn_metadata)
|
| 244 |
+
|
| 245 |
+
if not get_pp_group().is_last_rank:
|
| 246 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 247 |
+
|
| 248 |
+
hidden_states = self.ln_f(hidden_states)
|
| 249 |
+
return hidden_states
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class GPT2LMHeadModel(nn.Module, SupportsPP):
|
| 253 |
+
|
| 254 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 255 |
+
super().__init__()
|
| 256 |
+
config = vllm_config.model_config.hf_config
|
| 257 |
+
quant_config = vllm_config.quant_config
|
| 258 |
+
self.config = config
|
| 259 |
+
self.quant_config = quant_config
|
| 260 |
+
self.transformer = GPT2Model(vllm_config=vllm_config,
|
| 261 |
+
prefix=maybe_prefix(
|
| 262 |
+
prefix, "transformer"))
|
| 263 |
+
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
| 264 |
+
self.config.hidden_size,
|
| 265 |
+
quant_config=quant_config,
|
| 266 |
+
prefix=f"{prefix}.lm_head")
|
| 267 |
+
if self.config.tie_word_embeddings:
|
| 268 |
+
self.lm_head = self.lm_head.tie_weights(self.transformer.wte)
|
| 269 |
+
|
| 270 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 271 |
+
self.sampler = get_sampler()
|
| 272 |
+
self.make_empty_intermediate_tensors = (
|
| 273 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 274 |
+
|
| 275 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 276 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
input_ids: torch.Tensor,
|
| 281 |
+
positions: torch.Tensor,
|
| 282 |
+
kv_caches: List[torch.Tensor],
|
| 283 |
+
attn_metadata: AttentionMetadata,
|
| 284 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 285 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 286 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 287 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 288 |
+
attn_metadata, intermediate_tensors,
|
| 289 |
+
inputs_embeds)
|
| 290 |
+
return hidden_states
|
| 291 |
+
|
| 292 |
+
def compute_logits(
|
| 293 |
+
self,
|
| 294 |
+
hidden_states: torch.Tensor,
|
| 295 |
+
sampling_metadata: SamplingMetadata,
|
| 296 |
+
) -> Optional[torch.Tensor]:
|
| 297 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 298 |
+
sampling_metadata)
|
| 299 |
+
return logits
|
| 300 |
+
|
| 301 |
+
def sample(
|
| 302 |
+
self,
|
| 303 |
+
logits: torch.Tensor,
|
| 304 |
+
sampling_metadata: SamplingMetadata,
|
| 305 |
+
) -> Optional[SamplerOutput]:
|
| 306 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 307 |
+
return next_tokens
|
| 308 |
+
|
| 309 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 310 |
+
torch.Tensor]]) -> Set[str]:
|
| 311 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 312 |
+
loaded_params: Set[str] = set()
|
| 313 |
+
for name, loaded_weight in weights:
|
| 314 |
+
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
| 315 |
+
# Skip attention mask.
|
| 316 |
+
# NOTE: "c_attn.bias" should not be skipped.
|
| 317 |
+
continue
|
| 318 |
+
if not name.startswith("transformer.") and not name.startswith(
|
| 319 |
+
"lm_head"):
|
| 320 |
+
name = "transformer." + name
|
| 321 |
+
|
| 322 |
+
if is_pp_missing_parameter(name, self):
|
| 323 |
+
continue
|
| 324 |
+
|
| 325 |
+
param = params_dict[name]
|
| 326 |
+
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
| 327 |
+
# Because of this, we need to transpose the weights.
|
| 328 |
+
# Note(zhuohan): the logic below might break quantized models.
|
| 329 |
+
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
| 330 |
+
if conv1d_weight_name not in name:
|
| 331 |
+
continue
|
| 332 |
+
if not name.endswith(".weight"):
|
| 333 |
+
continue
|
| 334 |
+
loaded_weight = loaded_weight.t()
|
| 335 |
+
weight_loader = getattr(param, "weight_loader",
|
| 336 |
+
default_weight_loader)
|
| 337 |
+
weight_loader(param, loaded_weight)
|
| 338 |
+
loaded_params.add(name)
|
| 339 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_bigcode.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2023 CTranslate2, and Michael Feil
|
| 7 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 8 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
| 22 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from transformers import GPTBigCodeConfig
|
| 27 |
+
|
| 28 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 29 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 30 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 31 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 32 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 33 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 34 |
+
QKVParallelLinear,
|
| 35 |
+
RowParallelLinear)
|
| 36 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 37 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 38 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 39 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 40 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 41 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 42 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 46 |
+
from .utils import (is_pp_missing_parameter,
|
| 47 |
+
make_empty_intermediate_tensors_factory, make_layers)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class GPTBigCodeAttention(nn.Module):
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
config: GPTBigCodeConfig,
|
| 55 |
+
cache_config: Optional[CacheConfig] = None,
|
| 56 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 57 |
+
prefix: str = "",
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.hidden_size = config.hidden_size
|
| 61 |
+
total_num_heads = config.num_attention_heads
|
| 62 |
+
self.tensor_model_parallel_world_size = (
|
| 63 |
+
get_tensor_model_parallel_world_size())
|
| 64 |
+
assert total_num_heads % self.tensor_model_parallel_world_size == 0
|
| 65 |
+
self.num_heads = (total_num_heads //
|
| 66 |
+
self.tensor_model_parallel_world_size)
|
| 67 |
+
self.head_dim = self.hidden_size // total_num_heads
|
| 68 |
+
self.scale = self.head_dim**-0.5
|
| 69 |
+
|
| 70 |
+
self.multi_query = config.multi_query
|
| 71 |
+
if self.multi_query:
|
| 72 |
+
total_num_kv_heads = 1
|
| 73 |
+
self.num_kv_heads = 1
|
| 74 |
+
else:
|
| 75 |
+
total_num_kv_heads = total_num_heads
|
| 76 |
+
self.num_kv_heads = self.num_heads
|
| 77 |
+
self.kv_dim = self.head_dim * self.num_kv_heads
|
| 78 |
+
self.c_attn = QKVParallelLinear(
|
| 79 |
+
self.hidden_size,
|
| 80 |
+
self.head_dim,
|
| 81 |
+
total_num_heads,
|
| 82 |
+
total_num_kv_heads,
|
| 83 |
+
bias=True,
|
| 84 |
+
quant_config=quant_config,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.c_proj = RowParallelLinear(
|
| 88 |
+
self.hidden_size,
|
| 89 |
+
self.hidden_size,
|
| 90 |
+
bias=True,
|
| 91 |
+
quant_config=quant_config,
|
| 92 |
+
)
|
| 93 |
+
self.attn = Attention(self.num_heads,
|
| 94 |
+
self.head_dim,
|
| 95 |
+
scale=self.scale,
|
| 96 |
+
num_kv_heads=self.num_kv_heads,
|
| 97 |
+
cache_config=cache_config,
|
| 98 |
+
quant_config=quant_config,
|
| 99 |
+
prefix=f"{prefix}.attn")
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
hidden_states: torch.Tensor,
|
| 104 |
+
kv_cache: torch.Tensor,
|
| 105 |
+
attn_metadata: AttentionMetadata,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
qkv, _ = self.c_attn(hidden_states)
|
| 108 |
+
q, k, v = qkv.split(
|
| 109 |
+
[
|
| 110 |
+
self.hidden_size // self.tensor_model_parallel_world_size,
|
| 111 |
+
self.kv_dim, self.kv_dim
|
| 112 |
+
],
|
| 113 |
+
dim=-1,
|
| 114 |
+
)
|
| 115 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 116 |
+
attn_output, _ = self.c_proj(attn_output)
|
| 117 |
+
return attn_output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class GPTBigMLP(nn.Module):
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
intermediate_size: int,
|
| 125 |
+
config: GPTBigCodeConfig,
|
| 126 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
hidden_size = config.hidden_size
|
| 130 |
+
self.c_fc = ColumnParallelLinear(
|
| 131 |
+
hidden_size,
|
| 132 |
+
intermediate_size,
|
| 133 |
+
bias=True,
|
| 134 |
+
quant_config=quant_config,
|
| 135 |
+
)
|
| 136 |
+
self.c_proj = RowParallelLinear(
|
| 137 |
+
intermediate_size,
|
| 138 |
+
hidden_size,
|
| 139 |
+
bias=True,
|
| 140 |
+
quant_config=quant_config,
|
| 141 |
+
)
|
| 142 |
+
self.act = get_act_fn(config.activation_function)
|
| 143 |
+
|
| 144 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
hidden_states, _ = self.c_fc(hidden_states)
|
| 146 |
+
hidden_states = self.act(hidden_states)
|
| 147 |
+
hidden_states, _ = self.c_proj(hidden_states)
|
| 148 |
+
return hidden_states
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class GPTBigCodeBlock(nn.Module):
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
config: GPTBigCodeConfig,
|
| 156 |
+
cache_config: Optional[CacheConfig] = None,
|
| 157 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 158 |
+
prefix: str = "",
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
hidden_size = config.hidden_size
|
| 162 |
+
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
| 163 |
+
hidden_size)
|
| 164 |
+
|
| 165 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 166 |
+
self.attn = GPTBigCodeAttention(config,
|
| 167 |
+
cache_config,
|
| 168 |
+
quant_config,
|
| 169 |
+
prefix=f"{prefix}.attn")
|
| 170 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 171 |
+
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
hidden_states: torch.Tensor,
|
| 176 |
+
kv_cache: torch.Tensor,
|
| 177 |
+
attn_metadata: AttentionMetadata,
|
| 178 |
+
) -> torch.Tensor:
|
| 179 |
+
residual = hidden_states
|
| 180 |
+
hidden_states = self.ln_1(hidden_states)
|
| 181 |
+
attn_output = self.attn(
|
| 182 |
+
hidden_states=hidden_states,
|
| 183 |
+
kv_cache=kv_cache,
|
| 184 |
+
attn_metadata=attn_metadata,
|
| 185 |
+
)
|
| 186 |
+
# residual connection
|
| 187 |
+
hidden_states = attn_output + residual
|
| 188 |
+
|
| 189 |
+
residual = hidden_states
|
| 190 |
+
hidden_states = self.ln_2(hidden_states)
|
| 191 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
| 192 |
+
# residual connection
|
| 193 |
+
hidden_states = residual + feed_forward_hidden_states
|
| 194 |
+
return hidden_states
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@support_torch_compile
|
| 198 |
+
class GPTBigCodeModel(nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
config = vllm_config.model_config.hf_config
|
| 204 |
+
cache_config = vllm_config.cache_config
|
| 205 |
+
quant_config = vllm_config.quant_config
|
| 206 |
+
lora_config = vllm_config.lora_config
|
| 207 |
+
|
| 208 |
+
self.config = config
|
| 209 |
+
assert not config.add_cross_attention
|
| 210 |
+
|
| 211 |
+
self.embed_dim = config.hidden_size
|
| 212 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 213 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 214 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 215 |
+
self.wte = VocabParallelEmbedding(self.vocab_size,
|
| 216 |
+
self.embed_dim,
|
| 217 |
+
org_num_embeddings=config.vocab_size)
|
| 218 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
| 219 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 220 |
+
config.num_hidden_layers,
|
| 221 |
+
lambda prefix: GPTBigCodeBlock(
|
| 222 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 223 |
+
prefix=f"{prefix}.h",
|
| 224 |
+
)
|
| 225 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 226 |
+
self.make_empty_intermediate_tensors = (
|
| 227 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 228 |
+
config.n_embd))
|
| 229 |
+
|
| 230 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
return self.wte(input_ids)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self,
|
| 235 |
+
input_ids: torch.Tensor,
|
| 236 |
+
position_ids: torch.Tensor,
|
| 237 |
+
kv_caches: List[torch.Tensor],
|
| 238 |
+
attn_metadata: AttentionMetadata,
|
| 239 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 240 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 242 |
+
if get_pp_group().is_first_rank:
|
| 243 |
+
if inputs_embeds is None:
|
| 244 |
+
inputs_embeds = self.get_input_embeddings(input_ids)
|
| 245 |
+
hidden_states = inputs_embeds + self.wpe(position_ids)
|
| 246 |
+
else:
|
| 247 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 248 |
+
|
| 249 |
+
for i in range(self.start_layer, self.end_layer):
|
| 250 |
+
layer = self.h[i]
|
| 251 |
+
hidden_states = layer(hidden_states,
|
| 252 |
+
kv_caches[i - self.start_layer],
|
| 253 |
+
attn_metadata)
|
| 254 |
+
|
| 255 |
+
if not get_pp_group().is_last_rank:
|
| 256 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 257 |
+
hidden_states = self.ln_f(hidden_states)
|
| 258 |
+
return hidden_states
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 262 |
+
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
| 263 |
+
|
| 264 |
+
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
|
| 265 |
+
|
| 266 |
+
embedding_modules = {
|
| 267 |
+
"wte": "input_embeddings",
|
| 268 |
+
"lm_head": "output_embeddings",
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
embedding_padding_modules = []
|
| 272 |
+
|
| 273 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 274 |
+
super().__init__()
|
| 275 |
+
config = vllm_config.model_config.hf_config
|
| 276 |
+
quant_config = vllm_config.quant_config
|
| 277 |
+
lora_config = vllm_config.lora_config
|
| 278 |
+
|
| 279 |
+
self.config = config
|
| 280 |
+
self.lora_config = lora_config
|
| 281 |
+
|
| 282 |
+
self.quant_config = quant_config
|
| 283 |
+
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
|
| 284 |
+
prefix=prefix)
|
| 285 |
+
if self.config.tie_word_embeddings:
|
| 286 |
+
self.lm_head = self.transformer.wte
|
| 287 |
+
else:
|
| 288 |
+
self.lm_head = ParallelLMHead(
|
| 289 |
+
self.transformer.vocab_size,
|
| 290 |
+
self.transformer.embed_dim,
|
| 291 |
+
org_num_embeddings=self.config.vocab_size)
|
| 292 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 293 |
+
if lora_config:
|
| 294 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 295 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 296 |
+
config.vocab_size)
|
| 297 |
+
self.sampler = get_sampler()
|
| 298 |
+
self.make_empty_intermediate_tensors = (
|
| 299 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 300 |
+
|
| 301 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 302 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
input_ids: torch.Tensor,
|
| 307 |
+
positions: torch.Tensor,
|
| 308 |
+
kv_caches: List[torch.Tensor],
|
| 309 |
+
attn_metadata: AttentionMetadata,
|
| 310 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 311 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 312 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 313 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 314 |
+
attn_metadata, intermediate_tensors,
|
| 315 |
+
inputs_embeds)
|
| 316 |
+
return hidden_states
|
| 317 |
+
|
| 318 |
+
def compute_logits(
|
| 319 |
+
self,
|
| 320 |
+
hidden_states: torch.Tensor,
|
| 321 |
+
sampling_metadata: SamplingMetadata,
|
| 322 |
+
) -> Optional[torch.Tensor]:
|
| 323 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 324 |
+
sampling_metadata)
|
| 325 |
+
return logits
|
| 326 |
+
|
| 327 |
+
def sample(
|
| 328 |
+
self,
|
| 329 |
+
logits: torch.Tensor,
|
| 330 |
+
sampling_metadata: SamplingMetadata,
|
| 331 |
+
) -> Optional[SamplerOutput]:
|
| 332 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 333 |
+
return next_tokens
|
| 334 |
+
|
| 335 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 336 |
+
torch.Tensor]]) -> Set[str]:
|
| 337 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 338 |
+
loaded_params: Set[str] = set()
|
| 339 |
+
for name, loaded_weight in weights:
|
| 340 |
+
if "lm_head.weight" in name:
|
| 341 |
+
continue
|
| 342 |
+
if ".attn.bias" in name:
|
| 343 |
+
# Skip attention mask.
|
| 344 |
+
# NOTE: "c_attn.bias" should not be skipped.
|
| 345 |
+
continue
|
| 346 |
+
if is_pp_missing_parameter(name, self):
|
| 347 |
+
continue
|
| 348 |
+
param = params_dict[name]
|
| 349 |
+
weight_loader = getattr(param, "weight_loader",
|
| 350 |
+
default_weight_loader)
|
| 351 |
+
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
| 352 |
+
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
| 353 |
+
weight_loader(param, loaded_weight, 'q')
|
| 354 |
+
weight_loader(param, loaded_weight, 'k')
|
| 355 |
+
weight_loader(param, loaded_weight, 'v')
|
| 356 |
+
else:
|
| 357 |
+
weight_loader(param, loaded_weight)
|
| 358 |
+
loaded_params.add(name)
|
| 359 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/granitemoe.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 9 |
+
# and OPT implementations in this library. It has been modified from its
|
| 10 |
+
# original forms to accommodate minor architectural differences compared
|
| 11 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
"""Inference-only GraniteMoe model."""
|
| 25 |
+
from typing import Iterable, List, Optional, Set, Tuple
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import nn
|
| 29 |
+
from transformers.models.granitemoe import GraniteMoeConfig
|
| 30 |
+
|
| 31 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 32 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 33 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 34 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 35 |
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
| 36 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 37 |
+
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
| 38 |
+
ReplicatedLinear,
|
| 39 |
+
RowParallelLinear)
|
| 40 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 41 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 42 |
+
QuantizationConfig)
|
| 43 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 44 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 45 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 46 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 47 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 48 |
+
from vllm.sequence import IntermediateTensors
|
| 49 |
+
|
| 50 |
+
from . import mixtral
|
| 51 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 52 |
+
from .utils import make_layers, maybe_prefix
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class GraniteMoeMoE(nn.Module):
|
| 56 |
+
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
|
| 57 |
+
expert across all ranks.
|
| 58 |
+
Each expert's weights are sharded across all ranks and a fused MoE
|
| 59 |
+
kernel is used for the forward pass, and finally we reduce the outputs
|
| 60 |
+
across ranks.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
num_experts: int,
|
| 65 |
+
top_k: int,
|
| 66 |
+
hidden_size: int,
|
| 67 |
+
intermediate_size: int,
|
| 68 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 69 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 70 |
+
tp_size: Optional[int] = None,
|
| 71 |
+
prefix: str = ""):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.hidden_size = hidden_size
|
| 74 |
+
|
| 75 |
+
# Gate always runs at half / full precision for now.
|
| 76 |
+
self.gate = ReplicatedLinear(hidden_size,
|
| 77 |
+
num_experts,
|
| 78 |
+
bias=False,
|
| 79 |
+
params_dtype=params_dtype,
|
| 80 |
+
quant_config=None,
|
| 81 |
+
prefix=f"{prefix}.gate")
|
| 82 |
+
|
| 83 |
+
self.experts = FusedMoE(num_experts=num_experts,
|
| 84 |
+
top_k=top_k,
|
| 85 |
+
hidden_size=hidden_size,
|
| 86 |
+
intermediate_size=intermediate_size,
|
| 87 |
+
params_dtype=params_dtype,
|
| 88 |
+
reduce_results=True,
|
| 89 |
+
renormalize=True,
|
| 90 |
+
quant_config=quant_config,
|
| 91 |
+
tp_size=tp_size,
|
| 92 |
+
prefix=f"{prefix}.experts")
|
| 93 |
+
|
| 94 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
| 96 |
+
orig_shape = hidden_states.shape
|
| 97 |
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
| 98 |
+
# router_logits: (num_tokens, n_experts)
|
| 99 |
+
router_logits, _ = self.gate(hidden_states)
|
| 100 |
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
| 101 |
+
return final_hidden_states.view(orig_shape)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class GraniteMoeAttention(nn.Module):
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
hidden_size: int,
|
| 109 |
+
num_heads: int,
|
| 110 |
+
num_kv_heads: int,
|
| 111 |
+
max_position: int = 4096 * 32,
|
| 112 |
+
rope_theta: float = 10000,
|
| 113 |
+
cache_config: Optional[CacheConfig] = None,
|
| 114 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 115 |
+
attention_multiplier: Optional[float] = None,
|
| 116 |
+
prefix: str = "",
|
| 117 |
+
) -> None:
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.hidden_size = hidden_size
|
| 120 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 121 |
+
self.total_num_heads = num_heads
|
| 122 |
+
assert self.total_num_heads % tp_size == 0
|
| 123 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 124 |
+
self.total_num_kv_heads = num_kv_heads
|
| 125 |
+
if self.total_num_kv_heads >= tp_size:
|
| 126 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 127 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 128 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 129 |
+
else:
|
| 130 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 131 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 132 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 133 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 134 |
+
self.head_dim = hidden_size // self.total_num_heads
|
| 135 |
+
self.q_size = self.num_heads * self.head_dim
|
| 136 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 137 |
+
self.scaling = (attention_multiplier if attention_multiplier
|
| 138 |
+
is not None else self.head_dim**-1)
|
| 139 |
+
self.rope_theta = rope_theta
|
| 140 |
+
|
| 141 |
+
self.qkv_proj = QKVParallelLinear(
|
| 142 |
+
hidden_size,
|
| 143 |
+
self.head_dim,
|
| 144 |
+
self.total_num_heads,
|
| 145 |
+
self.total_num_kv_heads,
|
| 146 |
+
bias=False,
|
| 147 |
+
quant_config=quant_config,
|
| 148 |
+
prefix=f"{prefix}.qkv_proj",
|
| 149 |
+
)
|
| 150 |
+
self.o_proj = RowParallelLinear(
|
| 151 |
+
self.total_num_heads * self.head_dim,
|
| 152 |
+
hidden_size,
|
| 153 |
+
bias=False,
|
| 154 |
+
quant_config=quant_config,
|
| 155 |
+
prefix=f"{prefix}.o_proj",
|
| 156 |
+
)
|
| 157 |
+
self.rotary_emb = get_rope(
|
| 158 |
+
self.head_dim,
|
| 159 |
+
rotary_dim=self.head_dim,
|
| 160 |
+
max_position=max_position,
|
| 161 |
+
base=int(self.rope_theta),
|
| 162 |
+
is_neox_style=True,
|
| 163 |
+
)
|
| 164 |
+
self.attn = Attention(self.num_heads,
|
| 165 |
+
self.head_dim,
|
| 166 |
+
self.scaling,
|
| 167 |
+
num_kv_heads=self.num_kv_heads,
|
| 168 |
+
cache_config=cache_config,
|
| 169 |
+
quant_config=quant_config,
|
| 170 |
+
prefix=f"{prefix}.attn")
|
| 171 |
+
|
| 172 |
+
def forward(
|
| 173 |
+
self,
|
| 174 |
+
positions: torch.Tensor,
|
| 175 |
+
hidden_states: torch.Tensor,
|
| 176 |
+
kv_cache: torch.Tensor,
|
| 177 |
+
attn_metadata: AttentionMetadata,
|
| 178 |
+
) -> torch.Tensor:
|
| 179 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 180 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 181 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 182 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 183 |
+
output, _ = self.o_proj(attn_output)
|
| 184 |
+
return output
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class GraniteMoeDecoderLayer(nn.Module):
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
config: GraniteMoeConfig,
|
| 192 |
+
cache_config: Optional[CacheConfig] = None,
|
| 193 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 194 |
+
prefix: str = "",
|
| 195 |
+
) -> None:
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.hidden_size = config.hidden_size
|
| 198 |
+
# Requires transformers > 4.32.0
|
| 199 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 200 |
+
self.self_attn = GraniteMoeAttention(
|
| 201 |
+
hidden_size=self.hidden_size,
|
| 202 |
+
num_heads=config.num_attention_heads,
|
| 203 |
+
max_position=config.max_position_embeddings,
|
| 204 |
+
num_kv_heads=config.num_key_value_heads,
|
| 205 |
+
rope_theta=rope_theta,
|
| 206 |
+
cache_config=cache_config,
|
| 207 |
+
quant_config=quant_config,
|
| 208 |
+
prefix=f"{prefix}.self_attn",
|
| 209 |
+
attention_multiplier=config.attention_multiplier)
|
| 210 |
+
self.block_sparse_moe = GraniteMoeMoE(
|
| 211 |
+
num_experts=config.num_local_experts,
|
| 212 |
+
top_k=config.num_experts_per_tok,
|
| 213 |
+
hidden_size=config.hidden_size,
|
| 214 |
+
intermediate_size=config.intermediate_size,
|
| 215 |
+
quant_config=quant_config,
|
| 216 |
+
prefix=f"{prefix}.block_sparse_moe")
|
| 217 |
+
|
| 218 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 219 |
+
eps=config.rms_norm_eps)
|
| 220 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 221 |
+
eps=config.rms_norm_eps)
|
| 222 |
+
|
| 223 |
+
self.residual_multiplier = config.residual_multiplier
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
positions: torch.Tensor,
|
| 228 |
+
hidden_states: torch.Tensor,
|
| 229 |
+
kv_cache: torch.Tensor,
|
| 230 |
+
attn_metadata: AttentionMetadata,
|
| 231 |
+
) -> torch.Tensor:
|
| 232 |
+
# Self Attention
|
| 233 |
+
residual = hidden_states
|
| 234 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 235 |
+
hidden_states = self.self_attn(
|
| 236 |
+
positions=positions,
|
| 237 |
+
hidden_states=hidden_states,
|
| 238 |
+
kv_cache=kv_cache,
|
| 239 |
+
attn_metadata=attn_metadata,
|
| 240 |
+
)
|
| 241 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 242 |
+
residual = hidden_states
|
| 243 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 244 |
+
hidden_states = self.block_sparse_moe(hidden_states)
|
| 245 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 246 |
+
|
| 247 |
+
return hidden_states
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@support_torch_compile
|
| 251 |
+
class GraniteMoeModel(nn.Module):
|
| 252 |
+
|
| 253 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
config = vllm_config.model_config.hf_config
|
| 257 |
+
cache_config = vllm_config.cache_config
|
| 258 |
+
quant_config = vllm_config.quant_config
|
| 259 |
+
lora_config = vllm_config.lora_config
|
| 260 |
+
|
| 261 |
+
self.padding_idx = config.pad_token_id
|
| 262 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 263 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 264 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 265 |
+
self.org_vocab_size = config.vocab_size
|
| 266 |
+
|
| 267 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 268 |
+
self.vocab_size,
|
| 269 |
+
config.hidden_size,
|
| 270 |
+
org_num_embeddings=config.vocab_size,
|
| 271 |
+
)
|
| 272 |
+
self.embedding_multiplier = config.embedding_multiplier
|
| 273 |
+
|
| 274 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 275 |
+
config.num_hidden_layers,
|
| 276 |
+
lambda prefix: GraniteMoeDecoderLayer(
|
| 277 |
+
config, cache_config, quant_config=quant_config, prefix=prefix
|
| 278 |
+
),
|
| 279 |
+
prefix=f"{prefix}.layers")
|
| 280 |
+
|
| 281 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 282 |
+
|
| 283 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 284 |
+
return self.embed_tokens(input_ids)
|
| 285 |
+
|
| 286 |
+
def forward(
|
| 287 |
+
self,
|
| 288 |
+
input_ids: torch.Tensor,
|
| 289 |
+
positions: torch.Tensor,
|
| 290 |
+
kv_caches: List[torch.Tensor],
|
| 291 |
+
attn_metadata: AttentionMetadata,
|
| 292 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 293 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 294 |
+
) -> torch.Tensor:
|
| 295 |
+
if get_pp_group().is_first_rank:
|
| 296 |
+
if inputs_embeds is not None:
|
| 297 |
+
hidden_states = inputs_embeds
|
| 298 |
+
else:
|
| 299 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 300 |
+
hidden_states *= self.embedding_multiplier
|
| 301 |
+
residual = None
|
| 302 |
+
else:
|
| 303 |
+
assert intermediate_tensors is not None
|
| 304 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 305 |
+
residual = intermediate_tensors["residual"]
|
| 306 |
+
for i in range(self.start_layer, self.end_layer):
|
| 307 |
+
layer = self.layers[i]
|
| 308 |
+
hidden_states = layer(positions, hidden_states,
|
| 309 |
+
kv_caches[i - self.start_layer],
|
| 310 |
+
attn_metadata)
|
| 311 |
+
if not get_pp_group().is_last_rank:
|
| 312 |
+
return IntermediateTensors({
|
| 313 |
+
"hidden_states": hidden_states,
|
| 314 |
+
"residual": residual
|
| 315 |
+
})
|
| 316 |
+
hidden_states = self.norm(hidden_states)
|
| 317 |
+
return hidden_states
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 321 |
+
fall_back_to_pt_during_load = False
|
| 322 |
+
|
| 323 |
+
packed_modules_mapping = {
|
| 324 |
+
"qkv_proj": [
|
| 325 |
+
"q_proj",
|
| 326 |
+
"k_proj",
|
| 327 |
+
"v_proj",
|
| 328 |
+
],
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# LoRA specific attributes
|
| 332 |
+
supported_lora_modules = [
|
| 333 |
+
"qkv_proj",
|
| 334 |
+
"o_proj",
|
| 335 |
+
"embed_tokens",
|
| 336 |
+
"lm_head",
|
| 337 |
+
"layer",
|
| 338 |
+
]
|
| 339 |
+
embedding_modules = {
|
| 340 |
+
"embed_tokens": "input_embeddings",
|
| 341 |
+
"lm_head": "output_embeddings",
|
| 342 |
+
}
|
| 343 |
+
embedding_padding_modules = ["lm_head"]
|
| 344 |
+
|
| 345 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 346 |
+
super().__init__()
|
| 347 |
+
config = vllm_config.model_config.hf_config
|
| 348 |
+
quant_config = vllm_config.quant_config
|
| 349 |
+
lora_config = vllm_config.lora_config
|
| 350 |
+
|
| 351 |
+
self.config = config
|
| 352 |
+
self.lora_config = lora_config
|
| 353 |
+
self.quant_config = quant_config # Required by MixtralForCausalLM
|
| 354 |
+
|
| 355 |
+
self.model = GraniteMoeModel(vllm_config=vllm_config,
|
| 356 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 357 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 358 |
+
if lora_config:
|
| 359 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 360 |
+
self.lm_head = ParallelLMHead(
|
| 361 |
+
self.unpadded_vocab_size,
|
| 362 |
+
config.hidden_size,
|
| 363 |
+
org_num_embeddings=config.vocab_size,
|
| 364 |
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
| 365 |
+
# We need bigger padding if using lora for kernel
|
| 366 |
+
# compatibility
|
| 367 |
+
if not lora_config else lora_config.lora_vocab_padding_size,
|
| 368 |
+
quant_config=quant_config,
|
| 369 |
+
)
|
| 370 |
+
if config.tie_word_embeddings:
|
| 371 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 372 |
+
|
| 373 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 374 |
+
config.vocab_size,
|
| 375 |
+
scale=1 /
|
| 376 |
+
self.config.logits_scaling)
|
| 377 |
+
|
| 378 |
+
self.sampler = get_sampler()
|
| 379 |
+
|
| 380 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 381 |
+
return self.model.get_input_embeddings(input_ids)
|
| 382 |
+
|
| 383 |
+
def forward(
|
| 384 |
+
self,
|
| 385 |
+
input_ids: torch.Tensor,
|
| 386 |
+
positions: torch.Tensor,
|
| 387 |
+
kv_caches: List[torch.Tensor],
|
| 388 |
+
attn_metadata: AttentionMetadata,
|
| 389 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 390 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 391 |
+
) -> torch.Tensor:
|
| 392 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 393 |
+
attn_metadata, intermediate_tensors,
|
| 394 |
+
inputs_embeds)
|
| 395 |
+
return hidden_states
|
| 396 |
+
|
| 397 |
+
def compute_logits(
|
| 398 |
+
self, hidden_states: torch.Tensor,
|
| 399 |
+
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
|
| 400 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 401 |
+
sampling_metadata)
|
| 402 |
+
return logits
|
| 403 |
+
|
| 404 |
+
def make_empty_intermediate_tensors(
|
| 405 |
+
self, batch_size: int, dtype: torch.dtype,
|
| 406 |
+
device: torch.device) -> IntermediateTensors:
|
| 407 |
+
return IntermediateTensors({
|
| 408 |
+
"hidden_states":
|
| 409 |
+
torch.zeros((batch_size, self.config.hidden_size),
|
| 410 |
+
dtype=dtype,
|
| 411 |
+
device=device),
|
| 412 |
+
"residual":
|
| 413 |
+
torch.zeros((batch_size, self.config.hidden_size),
|
| 414 |
+
dtype=dtype,
|
| 415 |
+
device=device),
|
| 416 |
+
})
|
| 417 |
+
|
| 418 |
+
def sample(
|
| 419 |
+
self,
|
| 420 |
+
logits: Optional[torch.Tensor],
|
| 421 |
+
sampling_metadata: SamplingMetadata,
|
| 422 |
+
) -> Optional[SamplerOutput]:
|
| 423 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 424 |
+
return next_tokens
|
| 425 |
+
|
| 426 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 427 |
+
torch.Tensor]]) -> Set[str]:
|
| 428 |
+
new_weights = {}
|
| 429 |
+
for n, p in weights:
|
| 430 |
+
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
| 431 |
+
for e in range(p.size(0)):
|
| 432 |
+
w1_name = n.replace(
|
| 433 |
+
'.block_sparse_moe.input_linear.weight',
|
| 434 |
+
f".block_sparse_moe.experts.{e}.w1.weight")
|
| 435 |
+
w3_name = n.replace(
|
| 436 |
+
'.block_sparse_moe.input_linear.weight',
|
| 437 |
+
f".block_sparse_moe.experts.{e}.w3.weight")
|
| 438 |
+
w1_param, w3_param = p[e].chunk(2, dim=0)
|
| 439 |
+
assert w1_name not in new_weights
|
| 440 |
+
assert w3_name not in new_weights
|
| 441 |
+
new_weights[w1_name] = w1_param
|
| 442 |
+
new_weights[w3_name] = w3_param
|
| 443 |
+
elif n.endswith('.block_sparse_moe.output_linear.weight'):
|
| 444 |
+
for e in range(p.size(0)):
|
| 445 |
+
w2_name = n.replace(
|
| 446 |
+
'.block_sparse_moe.output_linear.weight',
|
| 447 |
+
f".block_sparse_moe.experts.{e}.w2.weight")
|
| 448 |
+
w2_param = p[e]
|
| 449 |
+
assert w2_name not in new_weights
|
| 450 |
+
new_weights[w2_name] = w2_param
|
| 451 |
+
elif n.endswith('.block_sparse_moe.router.layer.weight'):
|
| 452 |
+
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
|
| 453 |
+
".block_sparse_moe.gate.weight")
|
| 454 |
+
assert gate_name not in new_weights
|
| 455 |
+
new_weights[gate_name] = p
|
| 456 |
+
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
|
| 457 |
+
pass
|
| 458 |
+
else:
|
| 459 |
+
new_weights[n] = p
|
| 460 |
+
return mixtral.MixtralForCausalLM.load_weights(self,
|
| 461 |
+
new_weights.items())
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/h2ovl.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
|
| 4 |
+
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# H2OVL-Mississippi
|
| 7 |
+
# Copyright (c) 2024 H2O.AI
|
| 8 |
+
# Licensed under Apache 2.0 License [see LICENSE for details]
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
from typing import Mapping, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from transformers import PretrainedConfig
|
| 15 |
+
|
| 16 |
+
from vllm.logger import init_logger
|
| 17 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 18 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 19 |
+
from vllm.multimodal.inputs import MultiModalKwargs
|
| 20 |
+
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
| 21 |
+
MultiModalDataItems)
|
| 22 |
+
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
| 23 |
+
PromptReplacementDetails)
|
| 24 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
| 25 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 26 |
+
|
| 27 |
+
from .intern_vit import InternVisionModel
|
| 28 |
+
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
|
| 29 |
+
BaseInternVLProcessingInfo, BaseInternVLProcessor,
|
| 30 |
+
InternVLChatModel, InternVLDummyInputsBuilder,
|
| 31 |
+
InternVLMultiModalProcessor, build_transform,
|
| 32 |
+
find_closest_aspect_ratio, get_internvl_target_ratios)
|
| 33 |
+
|
| 34 |
+
logger = init_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def resolve_h2ovl_min_max_num(
|
| 38 |
+
*,
|
| 39 |
+
min_dynamic_patch: int,
|
| 40 |
+
max_dynamic_patch: int,
|
| 41 |
+
dynamic_image_size: bool,
|
| 42 |
+
use_thumbnail: bool,
|
| 43 |
+
) -> tuple[int, int]:
|
| 44 |
+
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
| 45 |
+
|
| 46 |
+
if use_thumbnail and max_dynamic_patch != 1:
|
| 47 |
+
max_dynamic_patch += 1
|
| 48 |
+
|
| 49 |
+
return min_dynamic_patch, max_dynamic_patch
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_h2ovl_target_ratios(
|
| 53 |
+
min_num: int,
|
| 54 |
+
max_num: int,
|
| 55 |
+
*,
|
| 56 |
+
prior_aspect_ratio: Optional[tuple[int, int]],
|
| 57 |
+
) -> list[tuple[int, int]]:
|
| 58 |
+
target_ratios = get_internvl_target_ratios(min_num, max_num)
|
| 59 |
+
|
| 60 |
+
# if prior_aspect_ratio is provided, filter the target ratios
|
| 61 |
+
if prior_aspect_ratio is not None:
|
| 62 |
+
target_ratios = [
|
| 63 |
+
ratio for ratio in target_ratios if prior_aspect_ratio[0] %
|
| 64 |
+
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
return target_ratios
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# modified to include blocks generated in second pass
|
| 71 |
+
def calculate_h2ovl_targets(
|
| 72 |
+
*,
|
| 73 |
+
orig_width: int,
|
| 74 |
+
orig_height: int,
|
| 75 |
+
target_ratios: list[tuple[int, int]],
|
| 76 |
+
image_size: int,
|
| 77 |
+
use_thumbnail: bool,
|
| 78 |
+
) -> tuple[int, int, int, tuple[int, int]]:
|
| 79 |
+
aspect_ratio = orig_width / orig_height
|
| 80 |
+
|
| 81 |
+
# find the closest aspect ratio to the target
|
| 82 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 83 |
+
aspect_ratio,
|
| 84 |
+
target_ratios,
|
| 85 |
+
width=orig_width,
|
| 86 |
+
height=orig_height,
|
| 87 |
+
image_size=image_size,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# calculate the target width and height
|
| 91 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 92 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 93 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 94 |
+
|
| 95 |
+
# add thumbnail image if num_blocks != 1
|
| 96 |
+
if use_thumbnail and blocks != 1:
|
| 97 |
+
blocks += 1
|
| 98 |
+
|
| 99 |
+
return blocks, target_width, target_height, target_aspect_ratio
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
| 103 |
+
# refactored to handle prior_aspect_ratio
|
| 104 |
+
def dynamic_preprocess_h2ovl(
|
| 105 |
+
image: Image.Image,
|
| 106 |
+
*,
|
| 107 |
+
target_ratios: list[tuple[int, int]],
|
| 108 |
+
image_size: int,
|
| 109 |
+
use_thumbnail: bool,
|
| 110 |
+
) -> tuple[list[Image.Image], tuple[int, int]]:
|
| 111 |
+
orig_width, orig_height = image.size
|
| 112 |
+
|
| 113 |
+
# calculate the number of blocks without thumbnail
|
| 114 |
+
(
|
| 115 |
+
blocks,
|
| 116 |
+
target_width,
|
| 117 |
+
target_height,
|
| 118 |
+
target_aspect_ratio,
|
| 119 |
+
) = calculate_h2ovl_targets(
|
| 120 |
+
orig_width=orig_width,
|
| 121 |
+
orig_height=orig_height,
|
| 122 |
+
target_ratios=target_ratios,
|
| 123 |
+
image_size=image_size,
|
| 124 |
+
use_thumbnail=False,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# resize the image
|
| 128 |
+
resized_img = image.resize((target_width, target_height))
|
| 129 |
+
processed_images = []
|
| 130 |
+
for i in range(blocks):
|
| 131 |
+
box = (
|
| 132 |
+
(i % (target_width // image_size)) * image_size,
|
| 133 |
+
(i // (target_width // image_size)) * image_size,
|
| 134 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 135 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 136 |
+
)
|
| 137 |
+
# split the image
|
| 138 |
+
split_img = resized_img.crop(box)
|
| 139 |
+
processed_images.append(split_img)
|
| 140 |
+
|
| 141 |
+
assert len(processed_images) == blocks
|
| 142 |
+
|
| 143 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 144 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 145 |
+
processed_images.append(thumbnail_img)
|
| 146 |
+
|
| 147 |
+
return processed_images, target_aspect_ratio
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _preprocess_image(
|
| 151 |
+
image: Image.Image,
|
| 152 |
+
*,
|
| 153 |
+
input_size: int,
|
| 154 |
+
min_num: int,
|
| 155 |
+
max_num: int,
|
| 156 |
+
use_thumbnail: bool,
|
| 157 |
+
prior_aspect_ratio: Optional[tuple[int, int]],
|
| 158 |
+
) -> tuple[torch.Tensor, tuple[int, int]]:
|
| 159 |
+
target_ratios = get_h2ovl_target_ratios(
|
| 160 |
+
min_num,
|
| 161 |
+
max_num,
|
| 162 |
+
prior_aspect_ratio=prior_aspect_ratio,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
transform = build_transform(input_size=input_size)
|
| 166 |
+
images, target_aspect_ratio = dynamic_preprocess_h2ovl(
|
| 167 |
+
image,
|
| 168 |
+
image_size=input_size,
|
| 169 |
+
use_thumbnail=use_thumbnail,
|
| 170 |
+
target_ratios=target_ratios,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
pixel_values = torch.stack([transform(image) for image in images])
|
| 174 |
+
return pixel_values, target_aspect_ratio
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# refactored to use the _preprocess_image function
|
| 178 |
+
def image_to_pixel_values_h2ovl(
|
| 179 |
+
image: Image.Image,
|
| 180 |
+
*,
|
| 181 |
+
input_size: int,
|
| 182 |
+
min_num: int,
|
| 183 |
+
max_num: int,
|
| 184 |
+
use_thumbnail: bool,
|
| 185 |
+
use_msac: bool,
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
# when MSAC is turned on, we need to process the image twice
|
| 188 |
+
if use_msac:
|
| 189 |
+
# first pass
|
| 190 |
+
pixel_values1, aspect_ratio1 = _preprocess_image(
|
| 191 |
+
image,
|
| 192 |
+
input_size=input_size,
|
| 193 |
+
min_num=min_num,
|
| 194 |
+
max_num=max_num,
|
| 195 |
+
use_thumbnail=True,
|
| 196 |
+
prior_aspect_ratio=None,
|
| 197 |
+
)
|
| 198 |
+
# second pass
|
| 199 |
+
pixel_values2, _ = _preprocess_image(
|
| 200 |
+
image,
|
| 201 |
+
input_size=input_size,
|
| 202 |
+
min_num=3, # Hardcoded value
|
| 203 |
+
max_num=max_num,
|
| 204 |
+
use_thumbnail=True,
|
| 205 |
+
prior_aspect_ratio=aspect_ratio1,
|
| 206 |
+
)
|
| 207 |
+
# combine pixel values
|
| 208 |
+
pixel_values = torch.cat(
|
| 209 |
+
[pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
pixel_values, _ = _preprocess_image(
|
| 213 |
+
image,
|
| 214 |
+
input_size=input_size,
|
| 215 |
+
min_num=min_num,
|
| 216 |
+
max_num=max_num,
|
| 217 |
+
use_thumbnail=use_thumbnail,
|
| 218 |
+
prior_aspect_ratio=None,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return pixel_values
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class H2OVLProcessor(BaseInternVLProcessor):
|
| 225 |
+
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
config: PretrainedConfig,
|
| 229 |
+
tokenizer: AnyTokenizer,
|
| 230 |
+
*,
|
| 231 |
+
max_dynamic_patch: Optional[int] = None,
|
| 232 |
+
dynamic_image_size: Optional[bool] = None,
|
| 233 |
+
use_msac: Optional[bool] = None,
|
| 234 |
+
) -> None:
|
| 235 |
+
super().__init__(
|
| 236 |
+
config,
|
| 237 |
+
tokenizer,
|
| 238 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 239 |
+
dynamic_image_size=dynamic_image_size,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if use_msac is None:
|
| 243 |
+
use_msac = config.use_msac
|
| 244 |
+
assert isinstance(use_msac, bool)
|
| 245 |
+
|
| 246 |
+
self.use_msac = use_msac
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def image_token_id(self) -> int:
|
| 250 |
+
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
| 251 |
+
|
| 252 |
+
def get_image_repl_features(
|
| 253 |
+
self,
|
| 254 |
+
feature_size: int,
|
| 255 |
+
num_patches: Optional[int],
|
| 256 |
+
) -> str:
|
| 257 |
+
return IMG_CONTEXT * feature_size
|
| 258 |
+
|
| 259 |
+
def get_image_repl_full(
|
| 260 |
+
self,
|
| 261 |
+
feature_size: int,
|
| 262 |
+
num_patches: Optional[int],
|
| 263 |
+
) -> str:
|
| 264 |
+
features = self.get_image_repl_features(feature_size, num_patches)
|
| 265 |
+
return IMG_START + features + IMG_END
|
| 266 |
+
|
| 267 |
+
def resolve_min_max_num(
|
| 268 |
+
self,
|
| 269 |
+
*,
|
| 270 |
+
max_dynamic_patch: Optional[int] = None,
|
| 271 |
+
dynamic_image_size: Optional[bool] = None,
|
| 272 |
+
use_thumbnail: Optional[bool] = None,
|
| 273 |
+
) -> tuple[int, int]:
|
| 274 |
+
min_dynamic_patch = self.min_dynamic_patch
|
| 275 |
+
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
|
| 276 |
+
is None else max_dynamic_patch)
|
| 277 |
+
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
|
| 278 |
+
is None else dynamic_image_size)
|
| 279 |
+
use_thumbnail = (self.use_thumbnail
|
| 280 |
+
if use_thumbnail is None else use_thumbnail)
|
| 281 |
+
|
| 282 |
+
return resolve_h2ovl_min_max_num(
|
| 283 |
+
min_dynamic_patch=min_dynamic_patch,
|
| 284 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 285 |
+
dynamic_image_size=dynamic_image_size,
|
| 286 |
+
use_thumbnail=use_thumbnail,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def resolve_target_ratios(
|
| 290 |
+
self,
|
| 291 |
+
*,
|
| 292 |
+
max_dynamic_patch: Optional[int] = None,
|
| 293 |
+
dynamic_image_size: Optional[bool] = None,
|
| 294 |
+
use_thumbnail: Optional[bool] = None,
|
| 295 |
+
prior_aspect_ratio: Optional[tuple[int, int]] = None,
|
| 296 |
+
) -> list[tuple[int, int]]:
|
| 297 |
+
min_num, max_num = self.resolve_min_max_num(
|
| 298 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 299 |
+
dynamic_image_size=dynamic_image_size,
|
| 300 |
+
use_thumbnail=use_thumbnail,
|
| 301 |
+
)
|
| 302 |
+
if prior_aspect_ratio: # hardcoded value for second pass of use_msac
|
| 303 |
+
min_num = 3
|
| 304 |
+
|
| 305 |
+
return get_h2ovl_target_ratios(
|
| 306 |
+
min_num,
|
| 307 |
+
max_num,
|
| 308 |
+
prior_aspect_ratio=prior_aspect_ratio,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def get_num_image_tokens(
|
| 312 |
+
self,
|
| 313 |
+
*,
|
| 314 |
+
image_width: int,
|
| 315 |
+
image_height: int,
|
| 316 |
+
use_msac: Optional[bool] = None,
|
| 317 |
+
) -> int:
|
| 318 |
+
use_msac = (self.use_msac if use_msac is None else use_msac)
|
| 319 |
+
|
| 320 |
+
use_thumbnail = self.use_thumbnail
|
| 321 |
+
|
| 322 |
+
if use_msac:
|
| 323 |
+
target_ratios_1 = self.resolve_target_ratios(
|
| 324 |
+
use_thumbnail=False, # Applied in calculate_targets
|
| 325 |
+
)
|
| 326 |
+
num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
|
| 327 |
+
orig_width=image_width,
|
| 328 |
+
orig_height=image_height,
|
| 329 |
+
image_size=self.image_size,
|
| 330 |
+
target_ratios=target_ratios_1,
|
| 331 |
+
use_thumbnail=True,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
target_ratios_2 = self.resolve_target_ratios(
|
| 335 |
+
use_thumbnail=False, # Applied in calculate_targets
|
| 336 |
+
prior_aspect_ratio=aspect_ratio_1,
|
| 337 |
+
)
|
| 338 |
+
num_patches_2, _, _, _ = calculate_h2ovl_targets(
|
| 339 |
+
orig_width=image_width,
|
| 340 |
+
orig_height=image_height,
|
| 341 |
+
image_size=self.image_size,
|
| 342 |
+
target_ratios=target_ratios_2,
|
| 343 |
+
use_thumbnail=True,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
num_patches = num_patches_1 + num_patches_2 - 1
|
| 347 |
+
else:
|
| 348 |
+
target_ratios = self.resolve_target_ratios(
|
| 349 |
+
use_thumbnail=False, # Applied in calculate_targets
|
| 350 |
+
)
|
| 351 |
+
num_patches, _, _, _ = calculate_h2ovl_targets(
|
| 352 |
+
orig_width=image_width,
|
| 353 |
+
orig_height=image_height,
|
| 354 |
+
image_size=self.image_size,
|
| 355 |
+
target_ratios=target_ratios,
|
| 356 |
+
use_thumbnail=use_thumbnail,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
return num_patches * self.num_image_token
|
| 360 |
+
|
| 361 |
+
def _images_to_pixel_values_lst(
|
| 362 |
+
self,
|
| 363 |
+
images: list[Image.Image],
|
| 364 |
+
max_dynamic_patch: Optional[int] = None,
|
| 365 |
+
dynamic_image_size: Optional[bool] = None,
|
| 366 |
+
) -> list[torch.Tensor]:
|
| 367 |
+
use_msac = self.use_msac if len(images) == 1 else False
|
| 368 |
+
|
| 369 |
+
min_num, max_num = self.resolve_min_max_num(
|
| 370 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 371 |
+
dynamic_image_size=dynamic_image_size,
|
| 372 |
+
use_thumbnail=False, # Applied in image_to_pixel_values
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return [
|
| 376 |
+
image_to_pixel_values_h2ovl(
|
| 377 |
+
image,
|
| 378 |
+
input_size=self.image_size,
|
| 379 |
+
min_num=min_num,
|
| 380 |
+
max_num=max_num,
|
| 381 |
+
use_thumbnail=self.use_thumbnail,
|
| 382 |
+
use_msac=use_msac,
|
| 383 |
+
) for image in images
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
|
| 388 |
+
|
| 389 |
+
def get_hf_processor(
|
| 390 |
+
self,
|
| 391 |
+
*,
|
| 392 |
+
max_dynamic_patch: Optional[int] = None,
|
| 393 |
+
dynamic_image_size: Optional[bool] = None,
|
| 394 |
+
) -> H2OVLProcessor:
|
| 395 |
+
return H2OVLProcessor(
|
| 396 |
+
self.get_hf_config(),
|
| 397 |
+
self.get_tokenizer(),
|
| 398 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 399 |
+
dynamic_image_size=dynamic_image_size,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
def get_mm_max_tokens_per_item(
|
| 403 |
+
self,
|
| 404 |
+
seq_len: int,
|
| 405 |
+
mm_counts: Mapping[str, int],
|
| 406 |
+
) -> Mapping[str, int]:
|
| 407 |
+
max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
|
| 408 |
+
if mm_counts.get("image", 0) <= 1:
|
| 409 |
+
max_tokens_per_image = max_tokens_one_image
|
| 410 |
+
else:
|
| 411 |
+
max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
|
| 412 |
+
|
| 413 |
+
return {"image": max_tokens_per_image}
|
| 414 |
+
|
| 415 |
+
def get_num_image_tokens(
|
| 416 |
+
self,
|
| 417 |
+
*,
|
| 418 |
+
image_width: int,
|
| 419 |
+
image_height: int,
|
| 420 |
+
processor: Optional[H2OVLProcessor],
|
| 421 |
+
use_msac: Optional[bool] = None,
|
| 422 |
+
) -> int:
|
| 423 |
+
if processor is None:
|
| 424 |
+
processor = self.get_hf_processor()
|
| 425 |
+
|
| 426 |
+
return processor.get_num_image_tokens(
|
| 427 |
+
image_width=image_width,
|
| 428 |
+
image_height=image_height,
|
| 429 |
+
use_msac=use_msac,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
|
| 433 |
+
target_width, target_height = self.get_image_size_with_most_features()
|
| 434 |
+
|
| 435 |
+
return self.get_num_image_tokens(
|
| 436 |
+
image_width=target_width,
|
| 437 |
+
image_height=target_height,
|
| 438 |
+
processor=None,
|
| 439 |
+
use_msac=use_msac,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
| 444 |
+
):
|
| 445 |
+
|
| 446 |
+
def __init__(self,
|
| 447 |
+
info: H2OVLProcessingInfo,
|
| 448 |
+
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
|
| 449 |
+
*,
|
| 450 |
+
cache: Optional[ProcessingCache] = None,
|
| 451 |
+
enable_sanity_checks: bool = True) -> None:
|
| 452 |
+
super().__init__(
|
| 453 |
+
info,
|
| 454 |
+
dummy_inputs,
|
| 455 |
+
cache=cache,
|
| 456 |
+
enable_sanity_checks=enable_sanity_checks,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
if self.cache is not None:
|
| 460 |
+
# The processor output depends on the number of images passed,
|
| 461 |
+
# making it incompatible with processing cache which is supposed
|
| 462 |
+
# to be invariant of how many images are passed per prompt
|
| 463 |
+
self.cache = None
|
| 464 |
+
logger.warning_once(
|
| 465 |
+
f"{type(self).__name__} does not support processing cache.")
|
| 466 |
+
|
| 467 |
+
def _get_prompt_replacements(
|
| 468 |
+
self,
|
| 469 |
+
mm_items: MultiModalDataItems,
|
| 470 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 471 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 472 |
+
) -> list[PromptReplacement]:
|
| 473 |
+
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 474 |
+
|
| 475 |
+
if "image_num_patches" in out_mm_kwargs:
|
| 476 |
+
image_num_patches = out_mm_kwargs["image_num_patches"]
|
| 477 |
+
assert isinstance(image_num_patches, torch.Tensor)
|
| 478 |
+
image_num_patches = image_num_patches.tolist()
|
| 479 |
+
elif "image_embeds" in out_mm_kwargs:
|
| 480 |
+
# TODO: Use image size information in dictionary embedding inputs
|
| 481 |
+
# to compute num_patches (similar to Qwen2-VL)
|
| 482 |
+
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
| 483 |
+
else:
|
| 484 |
+
image_num_patches = []
|
| 485 |
+
|
| 486 |
+
num_images = len(image_num_patches)
|
| 487 |
+
|
| 488 |
+
def get_replacement_internvl(item_idx: int):
|
| 489 |
+
images = mm_items.get_items(
|
| 490 |
+
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
| 491 |
+
|
| 492 |
+
if isinstance(images, ImageEmbeddingItems):
|
| 493 |
+
feature_size = images.get_feature_size(item_idx)
|
| 494 |
+
else:
|
| 495 |
+
image_size = images.get_image_size(item_idx)
|
| 496 |
+
feature_size = self.info.get_num_image_tokens(
|
| 497 |
+
image_width=image_size.width,
|
| 498 |
+
image_height=image_size.height,
|
| 499 |
+
processor=hf_processor,
|
| 500 |
+
use_msac=None if num_images == 1 else False,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
num_patches = image_num_patches[item_idx]
|
| 504 |
+
if num_patches is not None:
|
| 505 |
+
assert isinstance(num_patches, int)
|
| 506 |
+
|
| 507 |
+
return PromptReplacementDetails(
|
| 508 |
+
full=hf_processor.get_image_repl_full(feature_size,
|
| 509 |
+
num_patches),
|
| 510 |
+
features=hf_processor.get_image_repl_features(
|
| 511 |
+
feature_size, num_patches),
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
return [
|
| 515 |
+
PromptReplacement(
|
| 516 |
+
modality="image",
|
| 517 |
+
target="<image>",
|
| 518 |
+
replacement=get_replacement_internvl,
|
| 519 |
+
)
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@MULTIMODAL_REGISTRY.register_processor(
|
| 524 |
+
H2OVLMultiModalProcessor,
|
| 525 |
+
info=H2OVLProcessingInfo,
|
| 526 |
+
dummy_inputs=InternVLDummyInputsBuilder)
|
| 527 |
+
class H2OVLChatModel(InternVLChatModel):
|
| 528 |
+
|
| 529 |
+
def _init_vision_model(
|
| 530 |
+
self,
|
| 531 |
+
config: PretrainedConfig,
|
| 532 |
+
quant_config: Optional[QuantizationConfig],
|
| 533 |
+
*,
|
| 534 |
+
is_mono: bool,
|
| 535 |
+
prefix: str,
|
| 536 |
+
):
|
| 537 |
+
if not is_mono:
|
| 538 |
+
vision_feature_layer = config.select_layer
|
| 539 |
+
if vision_feature_layer < 0:
|
| 540 |
+
num_hidden_layers = (config.vision_config.num_hidden_layers +
|
| 541 |
+
vision_feature_layer + 1)
|
| 542 |
+
else:
|
| 543 |
+
num_hidden_layers = vision_feature_layer + 1
|
| 544 |
+
|
| 545 |
+
return InternVisionModel(
|
| 546 |
+
config.vision_config,
|
| 547 |
+
quant_config=quant_config,
|
| 548 |
+
num_hidden_layers_override=num_hidden_layers,
|
| 549 |
+
prefix=prefix,
|
| 550 |
+
)
|
| 551 |
+
else:
|
| 552 |
+
msg = "Monolith mode is not applicable to H2OVL"
|
| 553 |
+
raise NotImplementedError(msg)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/idefics3.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
| 20 |
+
Tuple, TypedDict, Union)
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
| 26 |
+
Idefics3Processor)
|
| 27 |
+
|
| 28 |
+
from vllm.attention import AttentionMetadata
|
| 29 |
+
from vllm.config import VllmConfig
|
| 30 |
+
from vllm.logger import init_logger
|
| 31 |
+
from vllm.model_executor.layers.linear import ReplicatedLinear
|
| 32 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 33 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 34 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 35 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
| 36 |
+
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
| 37 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 38 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
| 39 |
+
from vllm.multimodal.inputs import NestedTensors
|
| 40 |
+
from vllm.multimodal.parse import ImageProcessorItems
|
| 41 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 42 |
+
BaseProcessingInfo,
|
| 43 |
+
MultiModalDataItems,
|
| 44 |
+
MultiModalFieldConfig,
|
| 45 |
+
PromptReplacement)
|
| 46 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 47 |
+
from vllm.sequence import IntermediateTensors
|
| 48 |
+
|
| 49 |
+
# yapf: disable
|
| 50 |
+
from .idefics2_vision_model import (
|
| 51 |
+
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
| 52 |
+
# yapf: enable
|
| 53 |
+
from .interfaces import SupportsLoRA, SupportsMultiModal
|
| 54 |
+
from .llama import LlamaModel
|
| 55 |
+
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
| 56 |
+
merge_multimodal_embeddings)
|
| 57 |
+
|
| 58 |
+
logger = init_logger(__name__)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Idefics3ImagePixelInputs(TypedDict):
|
| 62 |
+
type: Literal["pixel_values"]
|
| 63 |
+
data: torch.Tensor
|
| 64 |
+
"""
|
| 65 |
+
Shape: `(batch_size * num_images * num_patches,
|
| 66 |
+
num_channels, height, width)`
|
| 67 |
+
"""
|
| 68 |
+
pixel_attention_mask: Optional[torch.BoolTensor]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Idefics3ImageEmbeddingInputs(TypedDict):
|
| 72 |
+
type: Literal["image_embeds"]
|
| 73 |
+
data: torch.Tensor
|
| 74 |
+
"""
|
| 75 |
+
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
| 76 |
+
`hidden_size` must match the hidden size of language model backbone.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Idefics3ProcessingInfo(BaseProcessingInfo):
|
| 84 |
+
|
| 85 |
+
def get_hf_processor(
|
| 86 |
+
self,
|
| 87 |
+
*,
|
| 88 |
+
size: Optional[Dict[str, int]] = None) -> Idefics3Processor:
|
| 89 |
+
if size is not None:
|
| 90 |
+
return self.ctx.get_hf_processor(Idefics3Processor, size=size)
|
| 91 |
+
|
| 92 |
+
return self.ctx.get_hf_processor(Idefics3Processor)
|
| 93 |
+
|
| 94 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 95 |
+
return {"image": None}
|
| 96 |
+
|
| 97 |
+
def get_mm_max_tokens_per_item(
|
| 98 |
+
self,
|
| 99 |
+
seq_len: int,
|
| 100 |
+
mm_counts: Mapping[str, int],
|
| 101 |
+
) -> Mapping[str, int]:
|
| 102 |
+
hf_processor = self.get_hf_processor()
|
| 103 |
+
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
| 104 |
+
grid_w, grid_h = self._get_image_feature_grid_size(
|
| 105 |
+
image_width=image_processor.size['longest_edge'],
|
| 106 |
+
image_height=image_processor.size['longest_edge'],
|
| 107 |
+
)
|
| 108 |
+
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
|
| 109 |
+
# Calculate Non-image-token length
|
| 110 |
+
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
|
| 111 |
+
# but not for Idefic3, so we need to tokenize them to get actual length.
|
| 112 |
+
tokenizer = self.get_tokenizer()
|
| 113 |
+
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
|
| 114 |
+
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
|
| 115 |
+
# linebreak and <fake_token_around_image> always cost 1 token
|
| 116 |
+
fake_token_len = lb_len = 1
|
| 117 |
+
non_image_token = (grid_w * grid_h) * (
|
| 118 |
+
tile_token_len + fake_token_len) + glob_token_len + (
|
| 119 |
+
grid_h + 1) * lb_len + fake_token_len
|
| 120 |
+
return {"image": num_image_token + non_image_token}
|
| 121 |
+
|
| 122 |
+
def _resize_output_size(self,
|
| 123 |
+
*,
|
| 124 |
+
height: int,
|
| 125 |
+
width: int,
|
| 126 |
+
max_len: Optional[int] = None,
|
| 127 |
+
min_len: Optional[int] = 1,
|
| 128 |
+
max_size: Optional[int] = None) -> tuple[int, int]:
|
| 129 |
+
# Set default value for max_len if not provided
|
| 130 |
+
max_len = max(height, width) if max_len is None else max_len
|
| 131 |
+
aspect_ratio = width / height
|
| 132 |
+
|
| 133 |
+
# Handle the maximum size constraint
|
| 134 |
+
if max_size is not None:
|
| 135 |
+
max_len = min(max_len, max_size)
|
| 136 |
+
|
| 137 |
+
# Adjust dimensions according to the aspect ratio
|
| 138 |
+
if width >= height:
|
| 139 |
+
width = max_len
|
| 140 |
+
height = int(width / aspect_ratio)
|
| 141 |
+
else:
|
| 142 |
+
height = max_len
|
| 143 |
+
width = int(height * aspect_ratio)
|
| 144 |
+
|
| 145 |
+
# Ensure both width and height are even (if needed)
|
| 146 |
+
height += height % 2
|
| 147 |
+
width += width % 2
|
| 148 |
+
|
| 149 |
+
# Ensure dimensions are not smaller than the minimum length
|
| 150 |
+
height = max(height, min_len)
|
| 151 |
+
width = max(width, min_len)
|
| 152 |
+
|
| 153 |
+
return height, width
|
| 154 |
+
|
| 155 |
+
def _get_resize_output_image_size(
|
| 156 |
+
self,
|
| 157 |
+
*,
|
| 158 |
+
image_width: int,
|
| 159 |
+
image_height: int,
|
| 160 |
+
resolution_max_side: int,
|
| 161 |
+
) -> tuple[int, int]:
|
| 162 |
+
hf_processor = self.get_hf_processor()
|
| 163 |
+
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
| 164 |
+
max_image_size = image_processor.size['longest_edge']
|
| 165 |
+
if resolution_max_side > max_image_size:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
"`resolution_max_side` cannot be larger than `max_image_size`")
|
| 168 |
+
|
| 169 |
+
height, width = image_height, image_width
|
| 170 |
+
|
| 171 |
+
# Find the output size, when rescaling the longest edge to max_len and
|
| 172 |
+
# preserving the aspect ratio
|
| 173 |
+
height, width = self._resize_output_size(height=height,
|
| 174 |
+
width=width,
|
| 175 |
+
max_len=resolution_max_side)
|
| 176 |
+
return height, width
|
| 177 |
+
|
| 178 |
+
def _get_image_feature_grid_size(
|
| 179 |
+
self,
|
| 180 |
+
*,
|
| 181 |
+
image_width: int,
|
| 182 |
+
image_height: int,
|
| 183 |
+
size: Optional[dict[str, object]] = None,
|
| 184 |
+
) -> tuple[int, int]:
|
| 185 |
+
hf_processor = self.get_hf_processor(size=size)
|
| 186 |
+
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
| 187 |
+
max_image_size = image_processor.max_image_size['longest_edge']
|
| 188 |
+
size = image_processor.size['longest_edge']
|
| 189 |
+
assert size % max_image_size == 0, (
|
| 190 |
+
"`longest_edge` in image_processor's `size` must be divisible by "
|
| 191 |
+
"`longest_edge` in `max_image_size`, this may be caused by "
|
| 192 |
+
"incorrect mm_kwargs override.")
|
| 193 |
+
|
| 194 |
+
resized_height, resized_width = self._get_resize_output_image_size(
|
| 195 |
+
image_width=image_width,
|
| 196 |
+
image_height=image_height,
|
| 197 |
+
resolution_max_side=size,
|
| 198 |
+
)
|
| 199 |
+
if resized_height > max_image_size or resized_width > max_image_size:
|
| 200 |
+
grid_h = math.ceil(resized_height / max_image_size)
|
| 201 |
+
grid_w = math.ceil(resized_width / max_image_size)
|
| 202 |
+
else:
|
| 203 |
+
grid_h = grid_w = 0
|
| 204 |
+
return grid_w, grid_h
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
| 208 |
+
):
|
| 209 |
+
|
| 210 |
+
def get_dummy_processor_inputs(
|
| 211 |
+
self,
|
| 212 |
+
seq_len: int,
|
| 213 |
+
mm_counts: Mapping[str, int],
|
| 214 |
+
) -> ProcessorInputs:
|
| 215 |
+
num_images = mm_counts.get("image", 0)
|
| 216 |
+
hf_processor = self.info.get_hf_processor()
|
| 217 |
+
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
| 218 |
+
longest_edge = image_processor.max_image_size['longest_edge']
|
| 219 |
+
image_token: str = hf_processor.image_token.content
|
| 220 |
+
|
| 221 |
+
mm_data = {
|
| 222 |
+
"image":
|
| 223 |
+
self._get_dummy_images(width=longest_edge,
|
| 224 |
+
height=longest_edge,
|
| 225 |
+
num_images=num_images)
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
return ProcessorInputs(
|
| 229 |
+
prompt_text=image_token * num_images,
|
| 230 |
+
mm_data=mm_data,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class Idefics3MultimodalProcessor(
|
| 235 |
+
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
|
| 236 |
+
|
| 237 |
+
def _call_hf_processor(
|
| 238 |
+
self,
|
| 239 |
+
prompt: str,
|
| 240 |
+
mm_data: Mapping[str, object],
|
| 241 |
+
mm_kwargs: Mapping[str, object],
|
| 242 |
+
) -> BatchFeature:
|
| 243 |
+
if mm_data:
|
| 244 |
+
processed_outputs = super()._call_hf_processor(
|
| 245 |
+
prompt, mm_data, mm_kwargs)
|
| 246 |
+
image_grids = [
|
| 247 |
+
self.info._get_image_feature_grid_size(
|
| 248 |
+
image_width=img.width,
|
| 249 |
+
image_height=img.height,
|
| 250 |
+
**mm_kwargs,
|
| 251 |
+
) for img in mm_data["images"]
|
| 252 |
+
]
|
| 253 |
+
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
|
| 254 |
+
for key in ("pixel_values", "pixel_attention_mask"):
|
| 255 |
+
data = processed_outputs.pop(key)
|
| 256 |
+
data = data.flatten(0, 1).split(image_patches)
|
| 257 |
+
processed_outputs[key] = data
|
| 258 |
+
else:
|
| 259 |
+
tokenizer = self.info.get_tokenizer()
|
| 260 |
+
processed_outputs = tokenizer(prompt,
|
| 261 |
+
add_special_tokens=True,
|
| 262 |
+
return_tensors="pt")
|
| 263 |
+
return processed_outputs
|
| 264 |
+
|
| 265 |
+
def _get_mm_fields_config(
|
| 266 |
+
self,
|
| 267 |
+
hf_inputs: BatchFeature,
|
| 268 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 269 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 270 |
+
return dict(
|
| 271 |
+
pixel_values=MultiModalFieldConfig.batched("image"),
|
| 272 |
+
pixel_attention_mask=MultiModalFieldConfig.batched("image"),
|
| 273 |
+
image_embeds=MultiModalFieldConfig.batched("image"),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def _get_prompt_replacements(
|
| 277 |
+
self,
|
| 278 |
+
mm_items: MultiModalDataItems,
|
| 279 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 280 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 281 |
+
) -> list[PromptReplacement]:
|
| 282 |
+
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 283 |
+
|
| 284 |
+
image_token = hf_processor.image_token.content
|
| 285 |
+
fake_image_token = hf_processor.fake_image_token.content
|
| 286 |
+
global_img_token = hf_processor.global_image_tag
|
| 287 |
+
image_seq_len = hf_processor.image_seq_len
|
| 288 |
+
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
| 289 |
+
|
| 290 |
+
p_img = image_token * image_seq_len
|
| 291 |
+
global_img_placeholder = fake_image_token + global_img_token + p_img
|
| 292 |
+
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
|
| 293 |
+
|
| 294 |
+
def get_replacement_idefics3(item_idx: int) -> str:
|
| 295 |
+
images = mm_items.get_items("image", ImageProcessorItems)
|
| 296 |
+
|
| 297 |
+
image_size = images.get_image_size(item_idx)
|
| 298 |
+
grid_w, grid_h = self.info._get_image_feature_grid_size(
|
| 299 |
+
image_width=image_size.width,
|
| 300 |
+
image_height=image_size.height,
|
| 301 |
+
**hf_processor_mm_kwargs,
|
| 302 |
+
)
|
| 303 |
+
if grid_w == 0 and grid_h == 0:
|
| 304 |
+
image_placeholder = global_img_placeholder
|
| 305 |
+
else:
|
| 306 |
+
tiles_placeholder = list[str]()
|
| 307 |
+
for i in range(grid_h):
|
| 308 |
+
for j in range(grid_w):
|
| 309 |
+
placeholder_per_tile = tile_img_placeholder.format(
|
| 310 |
+
n_h=i + 1, n_w=j + 1)
|
| 311 |
+
tiles_placeholder.append(placeholder_per_tile)
|
| 312 |
+
# Add line break if it is the last tile in the row
|
| 313 |
+
if j == grid_w - 1:
|
| 314 |
+
tiles_placeholder.append("\n")
|
| 315 |
+
|
| 316 |
+
image_placeholder = "".join(
|
| 317 |
+
[*tiles_placeholder, "\n", global_img_placeholder])
|
| 318 |
+
return image_placeholder + fake_image_token
|
| 319 |
+
|
| 320 |
+
return [
|
| 321 |
+
PromptReplacement(
|
| 322 |
+
modality="image",
|
| 323 |
+
target=image_token,
|
| 324 |
+
replacement=get_replacement_idefics3,
|
| 325 |
+
)
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Idefics3SimpleMLP(nn.Module):
|
| 330 |
+
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
config: Idefics3Config,
|
| 334 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 335 |
+
prefix: str = "",
|
| 336 |
+
):
|
| 337 |
+
super().__init__()
|
| 338 |
+
input_size = config.vision_config.hidden_size * (config.scale_factor**
|
| 339 |
+
2)
|
| 340 |
+
output_size = config.text_config.hidden_size
|
| 341 |
+
self.proj = ReplicatedLinear(
|
| 342 |
+
input_size,
|
| 343 |
+
output_size,
|
| 344 |
+
bias=False,
|
| 345 |
+
quant_config=quant_config,
|
| 346 |
+
prefix=maybe_prefix(prefix, "proj"),
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 350 |
+
out, _ = self.proj(x)
|
| 351 |
+
return out
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class Idefics3Connector(nn.Module):
|
| 355 |
+
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
config: Idefics3Config,
|
| 359 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 360 |
+
prefix: str = "",
|
| 361 |
+
):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.scale_factor = config.scale_factor
|
| 364 |
+
self.modality_projection = Idefics3SimpleMLP(
|
| 365 |
+
config,
|
| 366 |
+
quant_config,
|
| 367 |
+
prefix=maybe_prefix(prefix, "modality_projection"),
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def pixel_shuffle(self,
|
| 371 |
+
x: torch.Tensor,
|
| 372 |
+
scale_factor: int = 2) -> torch.Tensor:
|
| 373 |
+
bsz, seq, embed_dim = x.size()
|
| 374 |
+
height = width = int(seq**0.5)
|
| 375 |
+
x = x.view(bsz, height, width, embed_dim)
|
| 376 |
+
x = x.view(bsz, height, int(width / scale_factor),
|
| 377 |
+
embed_dim * scale_factor)
|
| 378 |
+
x = x.permute(0, 2, 1, 3)
|
| 379 |
+
x = x.reshape(
|
| 380 |
+
bsz,
|
| 381 |
+
int(width / scale_factor),
|
| 382 |
+
int(height / scale_factor),
|
| 383 |
+
embed_dim * (scale_factor**2),
|
| 384 |
+
)
|
| 385 |
+
x = x.permute(0, 2, 1, 3)
|
| 386 |
+
x = x.reshape(bsz, int(seq / (scale_factor**2)),
|
| 387 |
+
embed_dim * (scale_factor**2))
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 391 |
+
image_hidden_states = self.pixel_shuffle(image_hidden_states,
|
| 392 |
+
self.scale_factor)
|
| 393 |
+
image_hidden_states = self.modality_projection(image_hidden_states)
|
| 394 |
+
return image_hidden_states
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class Idefics3Model(nn.Module):
|
| 398 |
+
|
| 399 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 400 |
+
super().__init__()
|
| 401 |
+
|
| 402 |
+
config: Idefics3Config = vllm_config.model_config.hf_config
|
| 403 |
+
quant_config = vllm_config.quant_config
|
| 404 |
+
|
| 405 |
+
self.config = config
|
| 406 |
+
self.padding_idx = self.config.text_config.pad_token_id
|
| 407 |
+
self.vocab_size = self.config.text_config.vocab_size
|
| 408 |
+
self.vision_model = Idefics3VisionTransformer(
|
| 409 |
+
config.vision_config,
|
| 410 |
+
quant_config=quant_config,
|
| 411 |
+
prefix=maybe_prefix(prefix, "vision_model"))
|
| 412 |
+
self.connector = Idefics3Connector(
|
| 413 |
+
config,
|
| 414 |
+
quant_config,
|
| 415 |
+
prefix=maybe_prefix(prefix, "connector"),
|
| 416 |
+
)
|
| 417 |
+
self.text_model = LlamaModel(
|
| 418 |
+
vllm_config=vllm_config.with_hf_config(config.text_config),
|
| 419 |
+
prefix=maybe_prefix(prefix, "text_model"),
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
self.image_seq_len = int(
|
| 423 |
+
((config.vision_config.image_size //
|
| 424 |
+
config.vision_config.patch_size)**2) / (config.scale_factor**2))
|
| 425 |
+
self.image_token_id = self.config.image_token_id
|
| 426 |
+
|
| 427 |
+
def _validate_pixel_values(
|
| 428 |
+
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
| 429 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 430 |
+
|
| 431 |
+
h = w = self.config.vision_config.image_size
|
| 432 |
+
expected_dims = (3, h, w)
|
| 433 |
+
|
| 434 |
+
def _validate_shape(d: torch.Tensor):
|
| 435 |
+
actual_dims = tuple(d.shape[1:])
|
| 436 |
+
|
| 437 |
+
if actual_dims != expected_dims:
|
| 438 |
+
expected_expr = ("num_patches", *map(str, expected_dims))
|
| 439 |
+
raise ValueError(
|
| 440 |
+
"The expected shape of pixel values per image per batch "
|
| 441 |
+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
| 442 |
+
|
| 443 |
+
for d in data:
|
| 444 |
+
_validate_shape(d)
|
| 445 |
+
|
| 446 |
+
return data
|
| 447 |
+
|
| 448 |
+
def _parse_and_validate_image_input(
|
| 449 |
+
self, **kwargs: object) -> Optional[ImageInputs]:
|
| 450 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 451 |
+
image_embeds = kwargs.pop("image_embeds", None)
|
| 452 |
+
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
|
| 453 |
+
|
| 454 |
+
if pixel_values is None and image_embeds is None:
|
| 455 |
+
return None
|
| 456 |
+
|
| 457 |
+
if image_embeds is not None:
|
| 458 |
+
if not isinstance(image_embeds, (torch.Tensor, list)):
|
| 459 |
+
raise ValueError("Incorrect type of image embeddings. "
|
| 460 |
+
f"Got type: {type(image_embeds)}")
|
| 461 |
+
|
| 462 |
+
return Idefics3ImageEmbeddingInputs(
|
| 463 |
+
type="image_embeds",
|
| 464 |
+
data=flatten_bn(image_embeds, concat=True),
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
if pixel_values is not None:
|
| 468 |
+
if not isinstance(pixel_values, (torch.Tensor, list)):
|
| 469 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 470 |
+
f"Got type: {type(pixel_values)}")
|
| 471 |
+
|
| 472 |
+
if isinstance(pixel_values, list):
|
| 473 |
+
pixel_values = torch.cat(pixel_values, dim=1)
|
| 474 |
+
pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
|
| 475 |
+
else:
|
| 476 |
+
pixel_values = flatten_bn(pixel_values)
|
| 477 |
+
pixel_attention_mask = flatten_bn(pixel_attention_mask)
|
| 478 |
+
|
| 479 |
+
return Idefics3ImagePixelInputs(
|
| 480 |
+
type="pixel_values",
|
| 481 |
+
data=self._validate_pixel_values(pixel_values),
|
| 482 |
+
pixel_attention_mask=pixel_attention_mask)
|
| 483 |
+
|
| 484 |
+
raise AssertionError("This line should be unreachable.")
|
| 485 |
+
|
| 486 |
+
def _image_pixels_to_features(
|
| 487 |
+
self,
|
| 488 |
+
pixel_values: torch.Tensor,
|
| 489 |
+
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
| 490 |
+
) -> NestedTensors:
|
| 491 |
+
# NOTE: we skip the step to select the vision feature layer since
|
| 492 |
+
# this is already done inside the vision tower
|
| 493 |
+
num_patches = [x.size(0) for x in pixel_values]
|
| 494 |
+
pixel_values = pixel_values.to(
|
| 495 |
+
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
|
| 496 |
+
) # fp16 compatibility
|
| 497 |
+
|
| 498 |
+
# Remove padding images - padding images are full 0.
|
| 499 |
+
nb_values_per_image = pixel_values.shape[1:].numel()
|
| 500 |
+
real_images_inds = (pixel_values == 0.0).sum(
|
| 501 |
+
dim=(-1, -2, -3)) != nb_values_per_image
|
| 502 |
+
pixel_values = pixel_values[real_images_inds].contiguous()
|
| 503 |
+
|
| 504 |
+
# Handle the vision attention mask
|
| 505 |
+
if pixel_attention_mask is None:
|
| 506 |
+
pixel_attention_mask = torch.ones(
|
| 507 |
+
size=(pixel_values.size(0), pixel_values.size(2),
|
| 508 |
+
pixel_values.size(3)),
|
| 509 |
+
dtype=torch.bool,
|
| 510 |
+
device=pixel_values.device,
|
| 511 |
+
)
|
| 512 |
+
else:
|
| 513 |
+
# Remove padding images from the mask
|
| 514 |
+
pixel_attention_mask = pixel_attention_mask[
|
| 515 |
+
real_images_inds].contiguous()
|
| 516 |
+
|
| 517 |
+
patch_size = self.config.vision_config.patch_size
|
| 518 |
+
patches_subgrid = pixel_attention_mask.unfold(dimension=1,
|
| 519 |
+
size=patch_size,
|
| 520 |
+
step=patch_size)
|
| 521 |
+
patches_subgrid = patches_subgrid.unfold(dimension=2,
|
| 522 |
+
size=patch_size,
|
| 523 |
+
step=patch_size)
|
| 524 |
+
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
| 525 |
+
|
| 526 |
+
# Get sequence from the vision encoder
|
| 527 |
+
image_hidden_states = self.vision_model(
|
| 528 |
+
pixel_values=pixel_values,
|
| 529 |
+
patch_attention_mask=patch_attention_mask,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
return image_hidden_states.split(num_patches)
|
| 533 |
+
|
| 534 |
+
def _process_image_pixels(
|
| 535 |
+
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
|
| 536 |
+
assert self.vision_model is not None
|
| 537 |
+
|
| 538 |
+
pixel_values = inputs["data"]
|
| 539 |
+
pixel_attention_mask = inputs["pixel_attention_mask"]
|
| 540 |
+
|
| 541 |
+
return self._image_pixels_to_features(pixel_values,
|
| 542 |
+
pixel_attention_mask)
|
| 543 |
+
|
| 544 |
+
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
|
| 545 |
+
if image_input["type"] == "image_embeds":
|
| 546 |
+
return image_input["data"]
|
| 547 |
+
|
| 548 |
+
assert self.vision_model is not None
|
| 549 |
+
image_features = self._process_image_pixels(image_input)
|
| 550 |
+
num_patches = [x.size(0) for x in image_features]
|
| 551 |
+
image_features = torch.cat(image_features)
|
| 552 |
+
return self.connector(image_features).split(num_patches)
|
| 553 |
+
|
| 554 |
+
def get_input_embeddings(
|
| 555 |
+
self,
|
| 556 |
+
input_ids: torch.Tensor,
|
| 557 |
+
) -> torch.Tensor:
|
| 558 |
+
return self.text_model.get_input_embeddings(input_ids)
|
| 559 |
+
|
| 560 |
+
def forward(
|
| 561 |
+
self,
|
| 562 |
+
input_ids: torch.Tensor,
|
| 563 |
+
positions: torch.Tensor,
|
| 564 |
+
kv_caches: List[torch.Tensor],
|
| 565 |
+
attn_metadata: AttentionMetadata,
|
| 566 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 567 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 568 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 569 |
+
|
| 570 |
+
hidden_states = self.text_model(
|
| 571 |
+
input_ids,
|
| 572 |
+
positions,
|
| 573 |
+
kv_caches,
|
| 574 |
+
attn_metadata,
|
| 575 |
+
intermediate_tensors,
|
| 576 |
+
inputs_embeds=inputs_embeds,
|
| 577 |
+
)
|
| 578 |
+
return hidden_states
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
@MULTIMODAL_REGISTRY.register_processor(
|
| 582 |
+
Idefics3MultimodalProcessor,
|
| 583 |
+
info=Idefics3ProcessingInfo,
|
| 584 |
+
dummy_inputs=Idefics3DummyInputsBuilder)
|
| 585 |
+
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
| 586 |
+
SupportsLoRA):
|
| 587 |
+
packed_modules_mapping = {
|
| 588 |
+
"qkv_proj": [
|
| 589 |
+
"q_proj",
|
| 590 |
+
"k_proj",
|
| 591 |
+
"v_proj",
|
| 592 |
+
],
|
| 593 |
+
"gate_up_proj": [
|
| 594 |
+
"gate_proj",
|
| 595 |
+
"up_proj",
|
| 596 |
+
],
|
| 597 |
+
}
|
| 598 |
+
# LoRA specific attributes
|
| 599 |
+
supported_lora_modules = [
|
| 600 |
+
# vision_model
|
| 601 |
+
"fc1",
|
| 602 |
+
"fc2",
|
| 603 |
+
"out_proj",
|
| 604 |
+
# text_model
|
| 605 |
+
"qkv_proj", # same name with vision encoder
|
| 606 |
+
"o_proj",
|
| 607 |
+
"gate_up_proj",
|
| 608 |
+
"down_proj",
|
| 609 |
+
]
|
| 610 |
+
|
| 611 |
+
embedding_modules = {}
|
| 612 |
+
embedding_padding_modules = []
|
| 613 |
+
|
| 614 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 615 |
+
super().__init__()
|
| 616 |
+
|
| 617 |
+
config = vllm_config.model_config.hf_config
|
| 618 |
+
quant_config = vllm_config.quant_config
|
| 619 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 620 |
+
|
| 621 |
+
self.config = config
|
| 622 |
+
self.multimodal_config = multimodal_config
|
| 623 |
+
|
| 624 |
+
self.model = Idefics3Model(vllm_config=vllm_config,
|
| 625 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 626 |
+
self.image_token_id = self.config.image_token_id
|
| 627 |
+
|
| 628 |
+
self.lm_head = ParallelLMHead(
|
| 629 |
+
config.text_config.vocab_size,
|
| 630 |
+
config.text_config.hidden_size,
|
| 631 |
+
quant_config=quant_config,
|
| 632 |
+
)
|
| 633 |
+
if self.config.text_config.tie_word_embeddings:
|
| 634 |
+
self.lm_head.weight = self.model.text_model.wte.weight
|
| 635 |
+
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
| 636 |
+
self.sampler = get_sampler()
|
| 637 |
+
|
| 638 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 639 |
+
image_input = self.model._parse_and_validate_image_input(**kwargs)
|
| 640 |
+
if image_input is None:
|
| 641 |
+
return None
|
| 642 |
+
vision_embeddings = self.model._process_image_input(image_input)
|
| 643 |
+
return vision_embeddings
|
| 644 |
+
|
| 645 |
+
def get_input_embeddings(
|
| 646 |
+
self,
|
| 647 |
+
input_ids: torch.Tensor,
|
| 648 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 649 |
+
) -> torch.Tensor:
|
| 650 |
+
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
| 651 |
+
if multimodal_embeddings is not None:
|
| 652 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 653 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 654 |
+
self.config.image_token_id)
|
| 655 |
+
return inputs_embeds
|
| 656 |
+
|
| 657 |
+
def forward(
|
| 658 |
+
self,
|
| 659 |
+
input_ids: torch.Tensor,
|
| 660 |
+
positions: torch.Tensor,
|
| 661 |
+
kv_caches: List[torch.Tensor],
|
| 662 |
+
attn_metadata: AttentionMetadata,
|
| 663 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 664 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 665 |
+
**kwargs: object,
|
| 666 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 667 |
+
if intermediate_tensors is not None:
|
| 668 |
+
inputs_embeds = None
|
| 669 |
+
|
| 670 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 671 |
+
# condition is for v0 compatibility.
|
| 672 |
+
elif inputs_embeds is None:
|
| 673 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 674 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 675 |
+
vision_embeddings)
|
| 676 |
+
input_ids = None
|
| 677 |
+
|
| 678 |
+
hidden_states = self.model.text_model(input_ids,
|
| 679 |
+
positions,
|
| 680 |
+
kv_caches,
|
| 681 |
+
attn_metadata,
|
| 682 |
+
intermediate_tensors,
|
| 683 |
+
inputs_embeds=inputs_embeds)
|
| 684 |
+
|
| 685 |
+
return hidden_states
|
| 686 |
+
|
| 687 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
| 688 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
| 689 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 690 |
+
sampling_metadata)
|
| 691 |
+
return logits
|
| 692 |
+
|
| 693 |
+
def sample(
|
| 694 |
+
self,
|
| 695 |
+
logits: torch.Tensor,
|
| 696 |
+
sampling_metadata: SamplingMetadata,
|
| 697 |
+
) -> Optional[SamplerOutput]:
|
| 698 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 699 |
+
return next_tokens
|
| 700 |
+
|
| 701 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 702 |
+
torch.Tensor]]) -> Set[str]:
|
| 703 |
+
loader = AutoWeightsLoader(self)
|
| 704 |
+
return loader.load_weights(weights)
|
| 705 |
+
|
| 706 |
+
def get_mm_mapping(self) -> MultiModelKeys:
|
| 707 |
+
"""
|
| 708 |
+
Get the module prefix in multimodal models
|
| 709 |
+
"""
|
| 710 |
+
return MultiModelKeys.from_string_field(
|
| 711 |
+
language_model="model.text_model",
|
| 712 |
+
connector="model.connector",
|
| 713 |
+
tower_model="model.vision_model")
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/internlm2.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from transformers import PretrainedConfig
|
| 9 |
+
|
| 10 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 11 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 12 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 13 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 14 |
+
get_tensor_model_parallel_world_size,
|
| 15 |
+
split_tensor_along_last_dim,
|
| 16 |
+
tensor_model_parallel_all_gather)
|
| 17 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 18 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 19 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 20 |
+
QKVParallelLinear,
|
| 21 |
+
RowParallelLinear)
|
| 22 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 23 |
+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
| 24 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 25 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 26 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 27 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 28 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 29 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 30 |
+
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
| 31 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 32 |
+
from vllm.sequence import IntermediateTensors, PoolerOutput
|
| 33 |
+
|
| 34 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 35 |
+
from .utils import (is_pp_missing_parameter,
|
| 36 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 37 |
+
maybe_prefix)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class InternLM2MLP(nn.Module):
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
hidden_size: int,
|
| 45 |
+
intermediate_size: int,
|
| 46 |
+
hidden_act: str,
|
| 47 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 48 |
+
prefix: str = "",
|
| 49 |
+
) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 52 |
+
hidden_size,
|
| 53 |
+
[intermediate_size] * 2,
|
| 54 |
+
bias=False,
|
| 55 |
+
quant_config=quant_config,
|
| 56 |
+
prefix=f"{prefix}.gate_up_proj",
|
| 57 |
+
)
|
| 58 |
+
self.w2 = RowParallelLinear(
|
| 59 |
+
intermediate_size,
|
| 60 |
+
hidden_size,
|
| 61 |
+
bias=False,
|
| 62 |
+
quant_config=quant_config,
|
| 63 |
+
prefix=f"{prefix}.w2",
|
| 64 |
+
)
|
| 65 |
+
if hidden_act != "silu":
|
| 66 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 67 |
+
"Only silu is supported for now.")
|
| 68 |
+
self.act_fn = SiluAndMul()
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 72 |
+
x = self.act_fn(gate_up)
|
| 73 |
+
x, _ = self.w2(x)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class InternLM2Attention(nn.Module):
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
hidden_size: int,
|
| 82 |
+
num_heads: int,
|
| 83 |
+
num_kv_heads: int,
|
| 84 |
+
rope_theta: float = 10000,
|
| 85 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 86 |
+
max_position_embeddings: int = 8192,
|
| 87 |
+
cache_config: Optional[CacheConfig] = None,
|
| 88 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 89 |
+
prefix: str = "",
|
| 90 |
+
) -> None:
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.hidden_size = hidden_size
|
| 93 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 94 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 95 |
+
self.total_num_heads = num_heads
|
| 96 |
+
assert self.total_num_heads % self.tp_size == 0
|
| 97 |
+
self.num_heads = self.total_num_heads // self.tp_size
|
| 98 |
+
self.total_num_kv_heads = num_kv_heads
|
| 99 |
+
if self.total_num_kv_heads >= self.tp_size:
|
| 100 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 101 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 102 |
+
assert self.total_num_kv_heads % self.tp_size == 0
|
| 103 |
+
else:
|
| 104 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 105 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 106 |
+
assert self.tp_size % self.total_num_kv_heads == 0
|
| 107 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
| 108 |
+
self.head_dim = hidden_size // self.total_num_heads
|
| 109 |
+
self.q_size = self.num_heads * self.head_dim
|
| 110 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 111 |
+
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
|
| 112 |
+
self.scaling = self.head_dim**-0.5
|
| 113 |
+
self.rope_theta = rope_theta
|
| 114 |
+
self.max_position_embeddings = max_position_embeddings
|
| 115 |
+
|
| 116 |
+
self.wqkv = QKVParallelLinear(
|
| 117 |
+
hidden_size,
|
| 118 |
+
self.head_dim,
|
| 119 |
+
self.total_num_heads,
|
| 120 |
+
self.total_num_kv_heads,
|
| 121 |
+
bias=False,
|
| 122 |
+
quant_config=quant_config,
|
| 123 |
+
prefix=f"{prefix}.wqkv",
|
| 124 |
+
)
|
| 125 |
+
self.wo = RowParallelLinear(
|
| 126 |
+
self.total_num_heads * self.head_dim,
|
| 127 |
+
hidden_size,
|
| 128 |
+
bias=False,
|
| 129 |
+
quant_config=quant_config,
|
| 130 |
+
prefix=f"{prefix}.wo",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.rotary_emb = get_rope(
|
| 134 |
+
self.head_dim,
|
| 135 |
+
rotary_dim=self.head_dim,
|
| 136 |
+
max_position=max_position_embeddings,
|
| 137 |
+
base=rope_theta,
|
| 138 |
+
rope_scaling=rope_scaling,
|
| 139 |
+
)
|
| 140 |
+
self.attn = Attention(
|
| 141 |
+
self.num_heads,
|
| 142 |
+
self.head_dim,
|
| 143 |
+
self.scaling,
|
| 144 |
+
num_kv_heads=self.num_kv_heads,
|
| 145 |
+
cache_config=cache_config,
|
| 146 |
+
quant_config=quant_config,
|
| 147 |
+
prefix=f"{prefix}.attn",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def split_qkv(self, qkv: torch.Tensor):
|
| 151 |
+
seq_len = qkv.shape[0]
|
| 152 |
+
if self.tp_size > 1:
|
| 153 |
+
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
|
| 154 |
+
qkv = tensor_model_parallel_all_gather(qkv)
|
| 155 |
+
qkv = torch.split(qkv, qkv_map, dim=-1)
|
| 156 |
+
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
|
| 157 |
+
qkv = torch.cat(qkv, dim=-1)
|
| 158 |
+
|
| 159 |
+
qkv = qkv.view(seq_len, self.total_num_kv_heads,
|
| 160 |
+
self.key_value_groups + 2, self.head_dim)
|
| 161 |
+
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
|
| 162 |
+
q = q.reshape(seq_len, self.q_size * self.tp_size)
|
| 163 |
+
k = k.reshape(seq_len, self.kv_size * self.tp_size)
|
| 164 |
+
v = v.reshape(seq_len, self.kv_size * self.tp_size)
|
| 165 |
+
|
| 166 |
+
if self.tp_size > 1:
|
| 167 |
+
splitter = partial(split_tensor_along_last_dim,
|
| 168 |
+
num_partitions=self.tp_size)
|
| 169 |
+
q = splitter(q)[self.tp_rank]
|
| 170 |
+
k = splitter(k)[self.tp_rank]
|
| 171 |
+
v = splitter(v)[self.tp_rank]
|
| 172 |
+
return q, k, v
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
positions: torch.Tensor,
|
| 177 |
+
hidden_states: torch.Tensor,
|
| 178 |
+
kv_cache: torch.Tensor,
|
| 179 |
+
attn_metadata: AttentionMetadata,
|
| 180 |
+
) -> torch.Tensor:
|
| 181 |
+
qkv, _ = self.wqkv(hidden_states)
|
| 182 |
+
q, k, v = self.split_qkv(qkv)
|
| 183 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 184 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 185 |
+
output, _ = self.wo(attn_output)
|
| 186 |
+
return output
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class InternLMDecoderLayer(nn.Module):
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
config: PretrainedConfig,
|
| 194 |
+
cache_config: Optional[CacheConfig] = None,
|
| 195 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 196 |
+
prefix: str = "",
|
| 197 |
+
) -> None:
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.hidden_size = config.hidden_size
|
| 200 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 201 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 202 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 203 |
+
8192)
|
| 204 |
+
self.attention = InternLM2Attention(
|
| 205 |
+
hidden_size=self.hidden_size,
|
| 206 |
+
num_heads=config.num_attention_heads,
|
| 207 |
+
num_kv_heads=config.num_key_value_heads,
|
| 208 |
+
rope_theta=rope_theta,
|
| 209 |
+
rope_scaling=rope_scaling,
|
| 210 |
+
max_position_embeddings=max_position_embeddings,
|
| 211 |
+
cache_config=cache_config,
|
| 212 |
+
quant_config=quant_config,
|
| 213 |
+
prefix=f"{prefix}.attention",
|
| 214 |
+
)
|
| 215 |
+
self.feed_forward = InternLM2MLP(
|
| 216 |
+
hidden_size=self.hidden_size,
|
| 217 |
+
intermediate_size=config.intermediate_size,
|
| 218 |
+
hidden_act=config.hidden_act,
|
| 219 |
+
quant_config=quant_config,
|
| 220 |
+
prefix=f"{prefix}.feed_forward",
|
| 221 |
+
)
|
| 222 |
+
self.attention_norm = RMSNorm(config.hidden_size,
|
| 223 |
+
eps=config.rms_norm_eps)
|
| 224 |
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
positions: torch.Tensor,
|
| 229 |
+
hidden_states: torch.Tensor,
|
| 230 |
+
kv_cache: torch.Tensor,
|
| 231 |
+
attn_metadata: AttentionMetadata,
|
| 232 |
+
residual: Optional[torch.Tensor],
|
| 233 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 234 |
+
# Self Attention
|
| 235 |
+
if residual is None:
|
| 236 |
+
residual = hidden_states
|
| 237 |
+
hidden_states = self.attention_norm(hidden_states)
|
| 238 |
+
else:
|
| 239 |
+
hidden_states, residual = self.attention_norm(
|
| 240 |
+
hidden_states, residual)
|
| 241 |
+
hidden_states = self.attention(
|
| 242 |
+
positions=positions,
|
| 243 |
+
hidden_states=hidden_states,
|
| 244 |
+
kv_cache=kv_cache,
|
| 245 |
+
attn_metadata=attn_metadata,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Fully Connected
|
| 249 |
+
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
| 250 |
+
hidden_states = self.feed_forward(hidden_states)
|
| 251 |
+
return hidden_states, residual
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@support_torch_compile
|
| 255 |
+
class InternLM2Model(nn.Module):
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
*,
|
| 260 |
+
vllm_config: VllmConfig,
|
| 261 |
+
prefix: str = "",
|
| 262 |
+
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
|
| 263 |
+
super().__init__()
|
| 264 |
+
|
| 265 |
+
config = vllm_config.model_config.hf_config
|
| 266 |
+
cache_config = vllm_config.cache_config
|
| 267 |
+
quant_config = vllm_config.quant_config
|
| 268 |
+
|
| 269 |
+
self.config = config
|
| 270 |
+
self.padding_idx = config.pad_token_id
|
| 271 |
+
self.vocab_size = config.vocab_size
|
| 272 |
+
self.tok_embeddings = VocabParallelEmbedding(
|
| 273 |
+
config.vocab_size,
|
| 274 |
+
config.hidden_size,
|
| 275 |
+
)
|
| 276 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 277 |
+
config.num_hidden_layers,
|
| 278 |
+
lambda prefix: layer_type(
|
| 279 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 280 |
+
prefix=f"{prefix}.layers")
|
| 281 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 282 |
+
self.make_empty_intermediate_tensors = (
|
| 283 |
+
make_empty_intermediate_tensors_factory(
|
| 284 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 285 |
+
|
| 286 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 287 |
+
return self.tok_embeddings(input_ids)
|
| 288 |
+
|
| 289 |
+
def forward(
|
| 290 |
+
self,
|
| 291 |
+
input_ids: torch.Tensor,
|
| 292 |
+
positions: torch.Tensor,
|
| 293 |
+
kv_caches: List[torch.Tensor],
|
| 294 |
+
attn_metadata: AttentionMetadata,
|
| 295 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 296 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 297 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 298 |
+
if get_pp_group().is_first_rank:
|
| 299 |
+
if inputs_embeds is not None:
|
| 300 |
+
hidden_states = inputs_embeds
|
| 301 |
+
else:
|
| 302 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 303 |
+
residual = None
|
| 304 |
+
else:
|
| 305 |
+
assert intermediate_tensors is not None
|
| 306 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 307 |
+
residual = intermediate_tensors["residual"]
|
| 308 |
+
for i in range(self.start_layer, self.end_layer):
|
| 309 |
+
layer = self.layers[i]
|
| 310 |
+
hidden_states, residual = layer(
|
| 311 |
+
positions,
|
| 312 |
+
hidden_states,
|
| 313 |
+
kv_caches[i - self.start_layer],
|
| 314 |
+
attn_metadata,
|
| 315 |
+
residual,
|
| 316 |
+
)
|
| 317 |
+
if not get_pp_group().is_last_rank:
|
| 318 |
+
return IntermediateTensors({
|
| 319 |
+
"hidden_states": hidden_states,
|
| 320 |
+
"residual": residual
|
| 321 |
+
})
|
| 322 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 323 |
+
return hidden_states
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
| 327 |
+
packed_modules_mapping = {
|
| 328 |
+
"wqkv": ["wqkv"],
|
| 329 |
+
"gate_up_proj": ["w1", "w3"],
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
# LoRA specific attributes
|
| 333 |
+
supported_lora_modules = [
|
| 334 |
+
"wqkv",
|
| 335 |
+
"wo",
|
| 336 |
+
"gate_up_proj",
|
| 337 |
+
"w2",
|
| 338 |
+
]
|
| 339 |
+
embedding_modules = {}
|
| 340 |
+
embedding_padding_modules = []
|
| 341 |
+
|
| 342 |
+
def __init__(self,
|
| 343 |
+
*,
|
| 344 |
+
vllm_config: VllmConfig,
|
| 345 |
+
prefix: str = "",
|
| 346 |
+
model_type: Type[InternLM2Model] = InternLM2Model):
|
| 347 |
+
super().__init__()
|
| 348 |
+
config = vllm_config.model_config.hf_config
|
| 349 |
+
quant_config = vllm_config.quant_config
|
| 350 |
+
lora_config = vllm_config.lora_config
|
| 351 |
+
|
| 352 |
+
self.config = config
|
| 353 |
+
self.quant_config = quant_config
|
| 354 |
+
self.lora_config = lora_config
|
| 355 |
+
|
| 356 |
+
self.model = model_type(vllm_config=vllm_config,
|
| 357 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 358 |
+
self.output = ParallelLMHead(config.vocab_size,
|
| 359 |
+
config.hidden_size,
|
| 360 |
+
quant_config=quant_config,
|
| 361 |
+
prefix=maybe_prefix(prefix, "output"))
|
| 362 |
+
if self.config.tie_word_embeddings:
|
| 363 |
+
self.output.weight = self.model.tok_embeddings.weight
|
| 364 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 365 |
+
self.sampler = get_sampler()
|
| 366 |
+
self.make_empty_intermediate_tensors = (
|
| 367 |
+
self.model.make_empty_intermediate_tensors)
|
| 368 |
+
|
| 369 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 370 |
+
return self.model.get_input_embeddings(input_ids)
|
| 371 |
+
|
| 372 |
+
def forward(
|
| 373 |
+
self,
|
| 374 |
+
input_ids: torch.Tensor,
|
| 375 |
+
positions: torch.Tensor,
|
| 376 |
+
kv_caches: List[torch.Tensor],
|
| 377 |
+
attn_metadata: AttentionMetadata,
|
| 378 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 379 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 380 |
+
) -> torch.Tensor:
|
| 381 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 382 |
+
attn_metadata, intermediate_tensors,
|
| 383 |
+
inputs_embeds)
|
| 384 |
+
return hidden_states
|
| 385 |
+
|
| 386 |
+
def compute_logits(
|
| 387 |
+
self,
|
| 388 |
+
hidden_states: torch.Tensor,
|
| 389 |
+
sampling_metadata: SamplingMetadata,
|
| 390 |
+
) -> Optional[torch.Tensor]:
|
| 391 |
+
logits = self.logits_processor(self.output, hidden_states,
|
| 392 |
+
sampling_metadata)
|
| 393 |
+
return logits
|
| 394 |
+
|
| 395 |
+
def sample(
|
| 396 |
+
self,
|
| 397 |
+
logits: torch.Tensor,
|
| 398 |
+
sampling_metadata: SamplingMetadata,
|
| 399 |
+
) -> Optional[SamplerOutput]:
|
| 400 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 401 |
+
return next_tokens
|
| 402 |
+
|
| 403 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 404 |
+
torch.Tensor]]) -> Set[str]:
|
| 405 |
+
stacked_params_mapping = [
|
| 406 |
+
# (param_name, shard_name, shard_id)
|
| 407 |
+
("gate_up_proj", "w1", 0),
|
| 408 |
+
("gate_up_proj", "w3", 1),
|
| 409 |
+
]
|
| 410 |
+
params_dict = dict(self.named_parameters())
|
| 411 |
+
loaded_params: Set[str] = set()
|
| 412 |
+
for name, loaded_weight in weights:
|
| 413 |
+
if "rotary_emb.inv_freq" in name:
|
| 414 |
+
continue
|
| 415 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 416 |
+
if weight_name not in name:
|
| 417 |
+
continue
|
| 418 |
+
name = name.replace(weight_name, param_name)
|
| 419 |
+
# Skip loading extra bias for GPTQ models.
|
| 420 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 421 |
+
continue
|
| 422 |
+
if is_pp_missing_parameter(name, self):
|
| 423 |
+
continue
|
| 424 |
+
param = params_dict[name]
|
| 425 |
+
weight_loader = param.weight_loader
|
| 426 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 427 |
+
break
|
| 428 |
+
else:
|
| 429 |
+
# Skip loading extra bias for GPTQ models.
|
| 430 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 431 |
+
continue
|
| 432 |
+
if is_pp_missing_parameter(name, self):
|
| 433 |
+
continue
|
| 434 |
+
param = params_dict[name]
|
| 435 |
+
weight_loader = getattr(param, "weight_loader",
|
| 436 |
+
default_weight_loader)
|
| 437 |
+
weight_loader(param, loaded_weight)
|
| 438 |
+
loaded_params.add(name)
|
| 439 |
+
return loaded_params
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
| 443 |
+
|
| 444 |
+
def __init__(
|
| 445 |
+
self,
|
| 446 |
+
*,
|
| 447 |
+
vllm_config: VllmConfig,
|
| 448 |
+
prefix: str = "",
|
| 449 |
+
model_type: Type[InternLM2Model] = InternLM2Model,
|
| 450 |
+
):
|
| 451 |
+
super().__init__(vllm_config=vllm_config,
|
| 452 |
+
prefix=prefix,
|
| 453 |
+
model_type=model_type)
|
| 454 |
+
|
| 455 |
+
for attr in ("output", "logits_processor", "sampler"):
|
| 456 |
+
delattr(self, attr)
|
| 457 |
+
|
| 458 |
+
config = vllm_config.model_config.hf_config
|
| 459 |
+
self.v_head = RowParallelLinear(
|
| 460 |
+
config.hidden_size,
|
| 461 |
+
1,
|
| 462 |
+
bias=False,
|
| 463 |
+
input_is_parallel=False,
|
| 464 |
+
prefix=maybe_prefix(prefix, "v_head"),
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
pooler_config = vllm_config.model_config.pooler_config
|
| 468 |
+
self._pooler = Pooler.from_config_with_defaults(
|
| 469 |
+
pooler_config,
|
| 470 |
+
pooling_type=PoolingType.ALL,
|
| 471 |
+
normalize=False,
|
| 472 |
+
softmax=False,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
def forward(
|
| 476 |
+
self,
|
| 477 |
+
input_ids: torch.Tensor,
|
| 478 |
+
positions: torch.Tensor,
|
| 479 |
+
kv_caches: List[torch.Tensor],
|
| 480 |
+
attn_metadata: AttentionMetadata,
|
| 481 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 482 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 483 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 484 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 485 |
+
attn_metadata, intermediate_tensors,
|
| 486 |
+
inputs_embeds)
|
| 487 |
+
logits, _ = self.v_head(hidden_states)
|
| 488 |
+
return logits
|
| 489 |
+
|
| 490 |
+
def pooler(
|
| 491 |
+
self,
|
| 492 |
+
hidden_states: torch.Tensor,
|
| 493 |
+
pooling_metadata: PoolingMetadata,
|
| 494 |
+
) -> Optional[PoolerOutput]:
|
| 495 |
+
return self._pooler(hidden_states, pooling_metadata)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/internvl.py
ADDED
|
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# InternVL
|
| 6 |
+
# Copyright (c) 2023 OpenGVLab
|
| 7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from functools import cached_property
|
| 11 |
+
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
| 12 |
+
TypedDict, TypeVar, Union)
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torchvision.transforms as T
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from transformers import BatchFeature, PretrainedConfig, TensorType
|
| 19 |
+
|
| 20 |
+
from vllm.attention import AttentionMetadata
|
| 21 |
+
from vllm.config import VllmConfig
|
| 22 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 23 |
+
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
| 24 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 25 |
+
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
| 26 |
+
InternVisionPatchModel)
|
| 27 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 28 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 29 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 30 |
+
NestedTensors)
|
| 31 |
+
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
| 32 |
+
ImageSize, MultiModalDataItems)
|
| 33 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 34 |
+
BaseProcessingInfo, PromptReplacement,
|
| 35 |
+
PromptReplacementDetails)
|
| 36 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 37 |
+
from vllm.sequence import IntermediateTensors
|
| 38 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 39 |
+
|
| 40 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 41 |
+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
| 42 |
+
maybe_prefix, merge_multimodal_embeddings)
|
| 43 |
+
|
| 44 |
+
IMG_START = '<img>'
|
| 45 |
+
IMG_END = '</img>'
|
| 46 |
+
IMG_CONTEXT = '<IMG_CONTEXT>'
|
| 47 |
+
|
| 48 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 49 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class InternVLImagePixelInputs(TypedDict):
|
| 53 |
+
type: Literal["pixel_values"]
|
| 54 |
+
data: torch.Tensor
|
| 55 |
+
"""
|
| 56 |
+
Shape:
|
| 57 |
+
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
|
| 58 |
+
"""
|
| 59 |
+
patches_per_image: List[int]
|
| 60 |
+
"""
|
| 61 |
+
List of number of total patches for each image in the batch.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class InternVLImageEmbeddingInputs(TypedDict):
|
| 66 |
+
type: Literal["image_embeds"]
|
| 67 |
+
data: NestedTensors
|
| 68 |
+
"""
|
| 69 |
+
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
|
| 70 |
+
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
|
| 71 |
+
|
| 72 |
+
`hidden_size` must match the hidden size of language model backbone.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
InternVLImageInputs = Union[InternVLImagePixelInputs,
|
| 77 |
+
InternVLImageEmbeddingInputs]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
| 81 |
+
def build_transform(input_size: int):
|
| 82 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
| 83 |
+
return T.Compose([
|
| 84 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 85 |
+
T.Resize((input_size, input_size),
|
| 86 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
| 87 |
+
T.ToTensor(),
|
| 88 |
+
T.Normalize(mean=MEAN, std=STD)
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
| 93 |
+
def find_closest_aspect_ratio(
|
| 94 |
+
aspect_ratio: float,
|
| 95 |
+
target_ratios: list[tuple[int, int]],
|
| 96 |
+
*,
|
| 97 |
+
width: int,
|
| 98 |
+
height: int,
|
| 99 |
+
image_size: int,
|
| 100 |
+
) -> tuple[int, int]:
|
| 101 |
+
best_ratio_diff = float('inf')
|
| 102 |
+
best_ratio = (1, 1)
|
| 103 |
+
area = width * height
|
| 104 |
+
for ratio in target_ratios:
|
| 105 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 106 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 107 |
+
if ratio_diff < best_ratio_diff:
|
| 108 |
+
best_ratio_diff = ratio_diff
|
| 109 |
+
best_ratio = ratio
|
| 110 |
+
elif ratio_diff == best_ratio_diff:
|
| 111 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 112 |
+
best_ratio = ratio
|
| 113 |
+
return best_ratio
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def resolve_internvl_min_max_num(
|
| 117 |
+
*,
|
| 118 |
+
min_dynamic_patch: int,
|
| 119 |
+
max_dynamic_patch: int,
|
| 120 |
+
dynamic_image_size: bool,
|
| 121 |
+
use_thumbnail: bool,
|
| 122 |
+
) -> tuple[int, int]:
|
| 123 |
+
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
| 124 |
+
|
| 125 |
+
if use_thumbnail and max_dynamic_patch != 1:
|
| 126 |
+
max_dynamic_patch += 1
|
| 127 |
+
|
| 128 |
+
return min_dynamic_patch, max_dynamic_patch
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_internvl_target_ratios(
|
| 132 |
+
min_num: int,
|
| 133 |
+
max_num: int,
|
| 134 |
+
) -> list[tuple[int, int]]:
|
| 135 |
+
target_ratios = {(i, j)
|
| 136 |
+
for n in range(min_num, max_num + 1)
|
| 137 |
+
for i in range(1, n + 1)
|
| 138 |
+
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
| 139 |
+
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def calculate_internvl_targets(
|
| 143 |
+
*,
|
| 144 |
+
orig_width: int,
|
| 145 |
+
orig_height: int,
|
| 146 |
+
target_ratios: list[tuple[int, int]],
|
| 147 |
+
image_size: int,
|
| 148 |
+
use_thumbnail: bool,
|
| 149 |
+
) -> tuple[int, int, int]:
|
| 150 |
+
aspect_ratio = orig_width / orig_height
|
| 151 |
+
|
| 152 |
+
# find the closest aspect ratio to the target
|
| 153 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 154 |
+
aspect_ratio,
|
| 155 |
+
target_ratios,
|
| 156 |
+
width=orig_width,
|
| 157 |
+
height=orig_height,
|
| 158 |
+
image_size=image_size,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# calculate the target width and height
|
| 162 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 163 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 164 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 165 |
+
|
| 166 |
+
# add thumbnail image if num_blocks != 1
|
| 167 |
+
if use_thumbnail and blocks != 1:
|
| 168 |
+
blocks += 1
|
| 169 |
+
|
| 170 |
+
return blocks, target_width, target_height
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
| 174 |
+
def dynamic_preprocess_internvl(
|
| 175 |
+
image: Image.Image,
|
| 176 |
+
*,
|
| 177 |
+
target_ratios: list[tuple[int, int]],
|
| 178 |
+
image_size: int,
|
| 179 |
+
use_thumbnail: bool,
|
| 180 |
+
) -> list[Image.Image]:
|
| 181 |
+
orig_width, orig_height = image.size
|
| 182 |
+
|
| 183 |
+
# calculate the number of blocks without thumbnail
|
| 184 |
+
blocks, target_width, target_height = calculate_internvl_targets(
|
| 185 |
+
orig_width=orig_width,
|
| 186 |
+
orig_height=orig_height,
|
| 187 |
+
target_ratios=target_ratios,
|
| 188 |
+
image_size=image_size,
|
| 189 |
+
use_thumbnail=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# resize the image
|
| 193 |
+
resized_img = image.resize((target_width, target_height))
|
| 194 |
+
processed_images = []
|
| 195 |
+
for i in range(blocks):
|
| 196 |
+
box = ((i % (target_width // image_size)) * image_size,
|
| 197 |
+
(i // (target_width // image_size)) * image_size,
|
| 198 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 199 |
+
((i // (target_width // image_size)) + 1) * image_size)
|
| 200 |
+
# split the image
|
| 201 |
+
split_img = resized_img.crop(box)
|
| 202 |
+
processed_images.append(split_img)
|
| 203 |
+
|
| 204 |
+
assert len(processed_images) == blocks
|
| 205 |
+
|
| 206 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 207 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 208 |
+
processed_images.append(thumbnail_img)
|
| 209 |
+
|
| 210 |
+
return processed_images
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
| 214 |
+
def image_to_pixel_values_internvl(
|
| 215 |
+
image: Image.Image,
|
| 216 |
+
*,
|
| 217 |
+
input_size: int,
|
| 218 |
+
min_num: int,
|
| 219 |
+
max_num: int,
|
| 220 |
+
use_thumbnail: bool,
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
target_ratios = get_internvl_target_ratios(min_num, max_num)
|
| 223 |
+
|
| 224 |
+
transform = build_transform(input_size=input_size)
|
| 225 |
+
images = dynamic_preprocess_internvl(
|
| 226 |
+
image,
|
| 227 |
+
target_ratios=target_ratios,
|
| 228 |
+
image_size=input_size,
|
| 229 |
+
use_thumbnail=use_thumbnail,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
pixel_values = torch.stack([transform(image) for image in images])
|
| 233 |
+
return pixel_values
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class BaseInternVLProcessor(ABC):
|
| 237 |
+
"""
|
| 238 |
+
This model doesn't define its own HF processor,
|
| 239 |
+
so we implement our own one here.
|
| 240 |
+
|
| 241 |
+
The code to insert image tokens is based on:
|
| 242 |
+
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
config: PretrainedConfig,
|
| 248 |
+
tokenizer: AnyTokenizer,
|
| 249 |
+
*,
|
| 250 |
+
max_dynamic_patch: Optional[int] = None,
|
| 251 |
+
dynamic_image_size: Optional[bool] = None,
|
| 252 |
+
) -> None:
|
| 253 |
+
super().__init__()
|
| 254 |
+
|
| 255 |
+
self.config = config
|
| 256 |
+
self.tokenizer = tokenizer
|
| 257 |
+
|
| 258 |
+
image_size: int = config.vision_config.image_size
|
| 259 |
+
patch_size: int = config.vision_config.patch_size
|
| 260 |
+
|
| 261 |
+
if dynamic_image_size is None:
|
| 262 |
+
dynamic_image_size = config.dynamic_image_size
|
| 263 |
+
assert isinstance(dynamic_image_size, bool)
|
| 264 |
+
|
| 265 |
+
if max_dynamic_patch is None:
|
| 266 |
+
max_dynamic_patch = config.max_dynamic_patch
|
| 267 |
+
assert isinstance(max_dynamic_patch, int)
|
| 268 |
+
|
| 269 |
+
self.num_image_token = int(
|
| 270 |
+
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
| 271 |
+
self.image_size = image_size
|
| 272 |
+
self.min_dynamic_patch: int = config.min_dynamic_patch
|
| 273 |
+
self.max_dynamic_patch = max_dynamic_patch
|
| 274 |
+
self.dynamic_image_size = dynamic_image_size
|
| 275 |
+
self.use_thumbnail: bool = config.use_thumbnail
|
| 276 |
+
|
| 277 |
+
@property
|
| 278 |
+
@abstractmethod
|
| 279 |
+
def image_token_id(self) -> int:
|
| 280 |
+
raise NotImplementedError
|
| 281 |
+
|
| 282 |
+
@abstractmethod
|
| 283 |
+
def get_image_repl_features(
|
| 284 |
+
self,
|
| 285 |
+
feature_size: int,
|
| 286 |
+
num_patches: Optional[int],
|
| 287 |
+
) -> str:
|
| 288 |
+
raise NotImplementedError
|
| 289 |
+
|
| 290 |
+
@abstractmethod
|
| 291 |
+
def get_image_repl_full(
|
| 292 |
+
self,
|
| 293 |
+
feature_size: int,
|
| 294 |
+
num_patches: Optional[int],
|
| 295 |
+
) -> str:
|
| 296 |
+
raise NotImplementedError
|
| 297 |
+
|
| 298 |
+
def resolve_min_max_num(
|
| 299 |
+
self,
|
| 300 |
+
*,
|
| 301 |
+
max_dynamic_patch: Optional[int] = None,
|
| 302 |
+
dynamic_image_size: Optional[bool] = None,
|
| 303 |
+
use_thumbnail: Optional[bool] = None,
|
| 304 |
+
) -> tuple[int, int]:
|
| 305 |
+
min_dynamic_patch = self.min_dynamic_patch
|
| 306 |
+
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
|
| 307 |
+
is None else max_dynamic_patch)
|
| 308 |
+
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
|
| 309 |
+
is None else dynamic_image_size)
|
| 310 |
+
use_thumbnail = (self.use_thumbnail
|
| 311 |
+
if use_thumbnail is None else use_thumbnail)
|
| 312 |
+
|
| 313 |
+
return resolve_internvl_min_max_num(
|
| 314 |
+
min_dynamic_patch=min_dynamic_patch,
|
| 315 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 316 |
+
dynamic_image_size=dynamic_image_size,
|
| 317 |
+
use_thumbnail=use_thumbnail,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def resolve_target_ratios(
|
| 321 |
+
self,
|
| 322 |
+
*,
|
| 323 |
+
max_dynamic_patch: Optional[int] = None,
|
| 324 |
+
dynamic_image_size: Optional[bool] = None,
|
| 325 |
+
use_thumbnail: Optional[bool] = None,
|
| 326 |
+
) -> list[tuple[int, int]]:
|
| 327 |
+
min_num, max_num = self.resolve_min_max_num(
|
| 328 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 329 |
+
dynamic_image_size=dynamic_image_size,
|
| 330 |
+
use_thumbnail=use_thumbnail,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return get_internvl_target_ratios(min_num, max_num)
|
| 334 |
+
|
| 335 |
+
def get_num_image_tokens(
|
| 336 |
+
self,
|
| 337 |
+
*,
|
| 338 |
+
image_width: int,
|
| 339 |
+
image_height: int,
|
| 340 |
+
) -> int:
|
| 341 |
+
target_ratios = self.resolve_target_ratios(
|
| 342 |
+
use_thumbnail=False, # Applied in calculate_targets
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
num_patches, _, _ = calculate_internvl_targets(
|
| 346 |
+
orig_width=image_width,
|
| 347 |
+
orig_height=image_height,
|
| 348 |
+
image_size=self.image_size,
|
| 349 |
+
target_ratios=target_ratios,
|
| 350 |
+
use_thumbnail=self.use_thumbnail,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return num_patches * self.num_image_token
|
| 354 |
+
|
| 355 |
+
def _images_to_pixel_values_lst(
|
| 356 |
+
self,
|
| 357 |
+
images: list[Image.Image],
|
| 358 |
+
max_dynamic_patch: Optional[int] = None,
|
| 359 |
+
dynamic_image_size: Optional[bool] = None,
|
| 360 |
+
) -> list[torch.Tensor]:
|
| 361 |
+
min_num, max_num = self.resolve_min_max_num(
|
| 362 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 363 |
+
dynamic_image_size=dynamic_image_size,
|
| 364 |
+
use_thumbnail=False, # Applied in image_to_pixel_values
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return [
|
| 368 |
+
image_to_pixel_values_internvl(
|
| 369 |
+
image,
|
| 370 |
+
input_size=self.image_size,
|
| 371 |
+
min_num=min_num,
|
| 372 |
+
max_num=max_num,
|
| 373 |
+
use_thumbnail=self.use_thumbnail,
|
| 374 |
+
) for image in images
|
| 375 |
+
]
|
| 376 |
+
|
| 377 |
+
def __call__(
|
| 378 |
+
self,
|
| 379 |
+
text: Optional[Union[str, list[str]]] = None,
|
| 380 |
+
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
| 381 |
+
max_dynamic_patch: Optional[int] = None,
|
| 382 |
+
dynamic_image_size: Optional[bool] = None,
|
| 383 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 384 |
+
) -> BatchFeature:
|
| 385 |
+
if text is None:
|
| 386 |
+
text = []
|
| 387 |
+
if not isinstance(text, list):
|
| 388 |
+
text = [text]
|
| 389 |
+
if images is None:
|
| 390 |
+
images = []
|
| 391 |
+
if not isinstance(images, list):
|
| 392 |
+
images = [images]
|
| 393 |
+
|
| 394 |
+
if len(images) == 0:
|
| 395 |
+
image_inputs = {}
|
| 396 |
+
else:
|
| 397 |
+
pixel_values_lst = self._images_to_pixel_values_lst(
|
| 398 |
+
images,
|
| 399 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 400 |
+
dynamic_image_size=dynamic_image_size,
|
| 401 |
+
)
|
| 402 |
+
image_inputs = {
|
| 403 |
+
"pixel_values_flat": torch.cat(pixel_values_lst),
|
| 404 |
+
"image_num_patches": list(map(len, pixel_values_lst)),
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
for pixel_values in pixel_values_lst:
|
| 408 |
+
num_patches = pixel_values.shape[0]
|
| 409 |
+
feature_size = num_patches * self.num_image_token
|
| 410 |
+
|
| 411 |
+
image_repl = self.get_image_repl_full(feature_size,
|
| 412 |
+
num_patches)
|
| 413 |
+
text = [t.replace('<image>', image_repl, 1) for t in text]
|
| 414 |
+
|
| 415 |
+
text_inputs = self.tokenizer(text)
|
| 416 |
+
|
| 417 |
+
return BatchFeature(
|
| 418 |
+
{
|
| 419 |
+
**text_inputs,
|
| 420 |
+
**image_inputs,
|
| 421 |
+
},
|
| 422 |
+
tensor_type=return_tensors,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class InternVLProcessor(BaseInternVLProcessor):
|
| 427 |
+
|
| 428 |
+
@property
|
| 429 |
+
def image_token_id(self) -> int:
|
| 430 |
+
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
| 431 |
+
|
| 432 |
+
def get_image_repl_features(
|
| 433 |
+
self,
|
| 434 |
+
feature_size: int,
|
| 435 |
+
num_patches: Optional[int],
|
| 436 |
+
) -> str:
|
| 437 |
+
return IMG_CONTEXT * feature_size
|
| 438 |
+
|
| 439 |
+
def get_image_repl_full(
|
| 440 |
+
self,
|
| 441 |
+
feature_size: int,
|
| 442 |
+
num_patches: Optional[int],
|
| 443 |
+
) -> str:
|
| 444 |
+
features = self.get_image_repl_features(feature_size, num_patches)
|
| 445 |
+
return IMG_START + features + IMG_END
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
| 449 |
+
|
| 450 |
+
@abstractmethod
|
| 451 |
+
def get_hf_processor(
|
| 452 |
+
self,
|
| 453 |
+
*,
|
| 454 |
+
max_dynamic_patch: Optional[int] = None,
|
| 455 |
+
dynamic_image_size: Optional[bool] = None,
|
| 456 |
+
) -> BaseInternVLProcessor:
|
| 457 |
+
raise NotImplementedError
|
| 458 |
+
|
| 459 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 460 |
+
return {"image": None}
|
| 461 |
+
|
| 462 |
+
def get_mm_max_tokens_per_item(
|
| 463 |
+
self,
|
| 464 |
+
seq_len: int,
|
| 465 |
+
mm_counts: Mapping[str, int],
|
| 466 |
+
) -> Mapping[str, int]:
|
| 467 |
+
return {"image": self.get_max_image_tokens()}
|
| 468 |
+
|
| 469 |
+
def get_num_image_tokens(
|
| 470 |
+
self,
|
| 471 |
+
*,
|
| 472 |
+
image_width: int,
|
| 473 |
+
image_height: int,
|
| 474 |
+
processor: Optional[BaseInternVLProcessor],
|
| 475 |
+
) -> int:
|
| 476 |
+
if processor is None:
|
| 477 |
+
processor = self.get_hf_processor()
|
| 478 |
+
|
| 479 |
+
return processor.get_num_image_tokens(
|
| 480 |
+
image_width=image_width,
|
| 481 |
+
image_height=image_height,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
def get_max_image_tokens(self) -> int:
|
| 485 |
+
target_width, target_height = self.get_image_size_with_most_features()
|
| 486 |
+
|
| 487 |
+
return self.get_num_image_tokens(
|
| 488 |
+
image_width=target_width,
|
| 489 |
+
image_height=target_height,
|
| 490 |
+
processor=None,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def get_image_size_with_most_features(self) -> ImageSize:
|
| 494 |
+
processor = self.get_hf_processor()
|
| 495 |
+
|
| 496 |
+
base_size = processor.image_size
|
| 497 |
+
target_ratios = processor.resolve_target_ratios()
|
| 498 |
+
|
| 499 |
+
largest_feature_size, largest_feature_pinpoint = 0, None
|
| 500 |
+
for wr, hr in target_ratios:
|
| 501 |
+
width, height = base_size * wr, base_size * hr
|
| 502 |
+
|
| 503 |
+
feat_size = self.get_num_image_tokens(
|
| 504 |
+
image_width=width,
|
| 505 |
+
image_height=height,
|
| 506 |
+
processor=processor,
|
| 507 |
+
)
|
| 508 |
+
if feat_size > largest_feature_size:
|
| 509 |
+
largest_feature_size = feat_size
|
| 510 |
+
largest_feature_pinpoint = ImageSize(width=width,
|
| 511 |
+
height=height)
|
| 512 |
+
|
| 513 |
+
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
| 514 |
+
raise ValueError("Cannot have a largest feature size of 0!")
|
| 515 |
+
|
| 516 |
+
return largest_feature_pinpoint
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
| 523 |
+
|
| 524 |
+
def get_dummy_processor_inputs(
|
| 525 |
+
self,
|
| 526 |
+
seq_len: int,
|
| 527 |
+
mm_counts: Mapping[str, int],
|
| 528 |
+
) -> ProcessorInputs:
|
| 529 |
+
target_width, target_height = \
|
| 530 |
+
self.info.get_image_size_with_most_features()
|
| 531 |
+
num_images = mm_counts.get("image", 0)
|
| 532 |
+
|
| 533 |
+
mm_data = {
|
| 534 |
+
"image":
|
| 535 |
+
self._get_dummy_images(width=target_width,
|
| 536 |
+
height=target_height,
|
| 537 |
+
num_images=num_images)
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
return ProcessorInputs(
|
| 541 |
+
prompt_text="<image>" * num_images,
|
| 542 |
+
mm_data=mm_data,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
| 547 |
+
|
| 548 |
+
def _call_hf_processor(
|
| 549 |
+
self,
|
| 550 |
+
prompt: str,
|
| 551 |
+
mm_data: Mapping[str, object],
|
| 552 |
+
mm_kwargs: Mapping[str, object],
|
| 553 |
+
) -> BatchFeature:
|
| 554 |
+
processed_outputs = super()._call_hf_processor(
|
| 555 |
+
prompt=prompt,
|
| 556 |
+
mm_data=mm_data,
|
| 557 |
+
mm_kwargs=mm_kwargs,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
|
| 561 |
+
image_data = mm_data.get("images", [])
|
| 562 |
+
assert isinstance(image_data, list)
|
| 563 |
+
|
| 564 |
+
# Since there may be extra tokens in the feature placeholders,
|
| 565 |
+
# we need to pass the image token ID to the model to select the
|
| 566 |
+
# tokens to merge from the vision encoder outputs
|
| 567 |
+
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
| 568 |
+
|
| 569 |
+
return processed_outputs
|
| 570 |
+
|
| 571 |
+
def _get_mm_fields_config(
|
| 572 |
+
self,
|
| 573 |
+
hf_inputs: BatchFeature,
|
| 574 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 575 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 576 |
+
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
| 577 |
+
num_images = len(image_num_patches)
|
| 578 |
+
|
| 579 |
+
return dict(
|
| 580 |
+
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
| 581 |
+
"image", image_num_patches),
|
| 582 |
+
image_num_patches=MultiModalFieldConfig.batched("image"),
|
| 583 |
+
image_embeds=MultiModalFieldConfig.batched("image"),
|
| 584 |
+
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
def _get_prompt_replacements(
|
| 588 |
+
self,
|
| 589 |
+
mm_items: MultiModalDataItems,
|
| 590 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 591 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 592 |
+
) -> list[PromptReplacement]:
|
| 593 |
+
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 594 |
+
|
| 595 |
+
if "image_num_patches" in out_mm_kwargs:
|
| 596 |
+
image_num_patches = out_mm_kwargs["image_num_patches"]
|
| 597 |
+
assert isinstance(image_num_patches, torch.Tensor)
|
| 598 |
+
image_num_patches = image_num_patches.tolist()
|
| 599 |
+
elif "image_embeds" in out_mm_kwargs:
|
| 600 |
+
# TODO: Use image size information in dictionary embedding inputs
|
| 601 |
+
# to compute num_patches (similar to Qwen2-VL)
|
| 602 |
+
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
| 603 |
+
else:
|
| 604 |
+
image_num_patches = []
|
| 605 |
+
|
| 606 |
+
def get_replacement_internvl(item_idx: int):
|
| 607 |
+
images = mm_items.get_items(
|
| 608 |
+
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
| 609 |
+
|
| 610 |
+
if isinstance(images, ImageEmbeddingItems):
|
| 611 |
+
feature_size = images.get_feature_size(item_idx)
|
| 612 |
+
else:
|
| 613 |
+
image_size = images.get_image_size(item_idx)
|
| 614 |
+
feature_size = self.info.get_num_image_tokens(
|
| 615 |
+
image_width=image_size.width,
|
| 616 |
+
image_height=image_size.height,
|
| 617 |
+
processor=hf_processor,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
num_patches = image_num_patches[item_idx]
|
| 621 |
+
if num_patches is not None:
|
| 622 |
+
assert isinstance(num_patches, int)
|
| 623 |
+
|
| 624 |
+
return PromptReplacementDetails(
|
| 625 |
+
full=hf_processor.get_image_repl_full(feature_size,
|
| 626 |
+
num_patches),
|
| 627 |
+
features=hf_processor.get_image_repl_features(
|
| 628 |
+
feature_size, num_patches),
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
return [
|
| 632 |
+
PromptReplacement(
|
| 633 |
+
modality="image",
|
| 634 |
+
target="<image>",
|
| 635 |
+
replacement=get_replacement_internvl,
|
| 636 |
+
)
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
|
| 641 |
+
|
| 642 |
+
def get_hf_processor(
|
| 643 |
+
self,
|
| 644 |
+
*,
|
| 645 |
+
max_dynamic_patch: Optional[int] = None,
|
| 646 |
+
dynamic_image_size: Optional[bool] = None,
|
| 647 |
+
) -> InternVLProcessor:
|
| 648 |
+
return InternVLProcessor(
|
| 649 |
+
self.get_hf_config(),
|
| 650 |
+
self.get_tokenizer(),
|
| 651 |
+
max_dynamic_patch=max_dynamic_patch,
|
| 652 |
+
dynamic_image_size=dynamic_image_size,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
@MULTIMODAL_REGISTRY.register_processor(
|
| 657 |
+
InternVLMultiModalProcessor,
|
| 658 |
+
info=InternVLProcessingInfo,
|
| 659 |
+
dummy_inputs=InternVLDummyInputsBuilder)
|
| 660 |
+
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
| 661 |
+
|
| 662 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
| 663 |
+
super().__init__()
|
| 664 |
+
|
| 665 |
+
config = vllm_config.model_config.hf_config
|
| 666 |
+
quant_config = vllm_config.quant_config
|
| 667 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 668 |
+
|
| 669 |
+
self.config = config
|
| 670 |
+
self.multimodal_config = multimodal_config
|
| 671 |
+
self._patch_quant_config(config, quant_config)
|
| 672 |
+
|
| 673 |
+
image_size = config.force_image_size or config.vision_config.image_size
|
| 674 |
+
patch_size = config.vision_config.patch_size
|
| 675 |
+
self.patch_size = patch_size
|
| 676 |
+
self.num_image_token = int(
|
| 677 |
+
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
| 678 |
+
self.downsample_ratio = config.downsample_ratio
|
| 679 |
+
self.ps_version = config.ps_version
|
| 680 |
+
|
| 681 |
+
self.llm_arch_name = config.text_config.architectures[0]
|
| 682 |
+
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
|
| 683 |
+
self.vision_model = self._init_vision_model(
|
| 684 |
+
config,
|
| 685 |
+
quant_config=quant_config,
|
| 686 |
+
is_mono=self.is_mono,
|
| 687 |
+
prefix=maybe_prefix(prefix, "vision_model"),
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
self.language_model = init_vllm_registered_model(
|
| 691 |
+
vllm_config=vllm_config,
|
| 692 |
+
hf_config=config.text_config,
|
| 693 |
+
prefix=maybe_prefix(prefix, "language_model"),
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
self.mlp1 = self._init_mlp1(config)
|
| 697 |
+
|
| 698 |
+
self.img_context_token_id = None
|
| 699 |
+
self.visual_token_mask = None
|
| 700 |
+
self.make_empty_intermediate_tensors = (
|
| 701 |
+
self.language_model.make_empty_intermediate_tensors)
|
| 702 |
+
|
| 703 |
+
def _patch_quant_config(self, config: PretrainedConfig,
|
| 704 |
+
quant_config: QuantizationConfig):
|
| 705 |
+
# the awq models from OpenGVLab missing `modules_to_not_convert`
|
| 706 |
+
# patch the quant_config to add `modules_to_not_convert` back
|
| 707 |
+
if isinstance(quant_config, AWQConfig):
|
| 708 |
+
text_config = config.text_config
|
| 709 |
+
llm_quant_config = getattr(text_config, "quantization_config",
|
| 710 |
+
None)
|
| 711 |
+
if (not quant_config.modules_to_not_convert) and \
|
| 712 |
+
(llm_quant_config is not None):
|
| 713 |
+
quant_config.modules_to_not_convert.append("vision_model")
|
| 714 |
+
|
| 715 |
+
@cached_property
|
| 716 |
+
def sampler(self):
|
| 717 |
+
if hasattr(self.language_model, "sampler"):
|
| 718 |
+
return self.language_model.sampler
|
| 719 |
+
|
| 720 |
+
return get_sampler()
|
| 721 |
+
|
| 722 |
+
def _init_vision_model(
|
| 723 |
+
self,
|
| 724 |
+
config: PretrainedConfig,
|
| 725 |
+
quant_config: Optional[QuantizationConfig],
|
| 726 |
+
*,
|
| 727 |
+
is_mono: bool,
|
| 728 |
+
prefix: str,
|
| 729 |
+
):
|
| 730 |
+
if not is_mono:
|
| 731 |
+
vision_feature_layer = config.select_layer
|
| 732 |
+
if vision_feature_layer < 0:
|
| 733 |
+
num_hidden_layers = config.vision_config.num_hidden_layers \
|
| 734 |
+
+ vision_feature_layer + 1
|
| 735 |
+
else:
|
| 736 |
+
num_hidden_layers = vision_feature_layer + 1
|
| 737 |
+
|
| 738 |
+
return InternVisionModel(
|
| 739 |
+
config.vision_config,
|
| 740 |
+
quant_config=quant_config,
|
| 741 |
+
num_hidden_layers_override=num_hidden_layers,
|
| 742 |
+
prefix=prefix,
|
| 743 |
+
)
|
| 744 |
+
else:
|
| 745 |
+
return InternVisionPatchModel(config.vision_config)
|
| 746 |
+
|
| 747 |
+
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
| 748 |
+
vit_hidden_size = config.vision_config.hidden_size
|
| 749 |
+
llm_hidden_size = config.text_config.hidden_size
|
| 750 |
+
|
| 751 |
+
return nn.Sequential(
|
| 752 |
+
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
|
| 753 |
+
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
| 754 |
+
llm_hidden_size),
|
| 755 |
+
nn.GELU(),
|
| 756 |
+
nn.Linear(llm_hidden_size, llm_hidden_size),
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
| 760 |
+
n, w, h, c = x.size()
|
| 761 |
+
# N, W, H, C --> N, W, H * scale, C // scale
|
| 762 |
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
| 763 |
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
| 764 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 765 |
+
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
| 766 |
+
int(c / (scale_factor * scale_factor)))
|
| 767 |
+
if self.ps_version == 'v1':
|
| 768 |
+
pass
|
| 769 |
+
else:
|
| 770 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 771 |
+
return x
|
| 772 |
+
|
| 773 |
+
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 774 |
+
vit_embeds = self.vision_model(pixel_values=pixel_values)
|
| 775 |
+
vit_embeds = vit_embeds[:, 1:, :]
|
| 776 |
+
|
| 777 |
+
h = w = int(vit_embeds.shape[1]**0.5)
|
| 778 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 779 |
+
vit_embeds = self.pixel_shuffle(vit_embeds,
|
| 780 |
+
scale_factor=self.downsample_ratio)
|
| 781 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
|
| 782 |
+
vit_embeds.shape[-1])
|
| 783 |
+
vit_embeds = self.mlp1(vit_embeds)
|
| 784 |
+
return vit_embeds
|
| 785 |
+
|
| 786 |
+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
| 787 |
+
|
| 788 |
+
h = w = self.config.vision_config.image_size
|
| 789 |
+
expected_dims = (3, h, w)
|
| 790 |
+
|
| 791 |
+
def _validate_shape(d: torch.Tensor):
|
| 792 |
+
actual_dims = tuple(d.shape)
|
| 793 |
+
|
| 794 |
+
if actual_dims != expected_dims:
|
| 795 |
+
expected_expr = str(expected_dims)
|
| 796 |
+
raise ValueError(
|
| 797 |
+
"The expected shape of pixel values per image per batch "
|
| 798 |
+
f" per patch is {expected_expr}. "
|
| 799 |
+
f"You supplied {tuple(d.shape)}.")
|
| 800 |
+
|
| 801 |
+
for d in data:
|
| 802 |
+
_validate_shape(d)
|
| 803 |
+
|
| 804 |
+
return data
|
| 805 |
+
|
| 806 |
+
def _parse_and_validate_image_input(
|
| 807 |
+
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
| 808 |
+
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
| 809 |
+
image_num_patches = kwargs.pop("image_num_patches", None)
|
| 810 |
+
image_embeds = kwargs.pop("image_embeds", None)
|
| 811 |
+
|
| 812 |
+
if pixel_values_flat is None and image_embeds is None:
|
| 813 |
+
return None
|
| 814 |
+
|
| 815 |
+
if image_embeds is not None:
|
| 816 |
+
if not isinstance(image_embeds, torch.Tensor):
|
| 817 |
+
raise ValueError("Incorrect type of image embeddings. "
|
| 818 |
+
f"Got type: {type(image_embeds)}")
|
| 819 |
+
|
| 820 |
+
return InternVLImageEmbeddingInputs(
|
| 821 |
+
type="image_embeds",
|
| 822 |
+
data=flatten_bn(image_embeds),
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
image_token_id = kwargs["image_token_id"]
|
| 826 |
+
assert isinstance(image_token_id, torch.Tensor)
|
| 827 |
+
self.img_context_token_id = image_token_id.flatten().unique().item()
|
| 828 |
+
|
| 829 |
+
if pixel_values_flat is not None:
|
| 830 |
+
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
| 831 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 832 |
+
f"Got type: {type(pixel_values_flat)}")
|
| 833 |
+
|
| 834 |
+
assert isinstance(image_num_patches, (torch.Tensor, list))
|
| 835 |
+
|
| 836 |
+
return InternVLImagePixelInputs(
|
| 837 |
+
type="pixel_values",
|
| 838 |
+
data=self._validate_pixel_values(
|
| 839 |
+
flatten_bn(pixel_values_flat, concat=True)),
|
| 840 |
+
patches_per_image=flatten_bn(image_num_patches,
|
| 841 |
+
concat=True).tolist())
|
| 842 |
+
|
| 843 |
+
raise AssertionError("This line should be unreachable.")
|
| 844 |
+
|
| 845 |
+
def _process_image_input(
|
| 846 |
+
self,
|
| 847 |
+
image_input: InternVLImageInputs,
|
| 848 |
+
) -> tuple[torch.Tensor, ...]:
|
| 849 |
+
if image_input["type"] == "image_embeds":
|
| 850 |
+
return image_input["data"]
|
| 851 |
+
|
| 852 |
+
assert self.vision_model is not None
|
| 853 |
+
|
| 854 |
+
image_embeds = self.extract_feature(image_input["data"])
|
| 855 |
+
|
| 856 |
+
patches_per_image = image_input["patches_per_image"]
|
| 857 |
+
|
| 858 |
+
# Only one image in the current batch
|
| 859 |
+
if len(patches_per_image) == 1:
|
| 860 |
+
image_embeds = image_embeds.view(
|
| 861 |
+
-1, self.config.text_config.hidden_size).unsqueeze(0)
|
| 862 |
+
return image_embeds
|
| 863 |
+
|
| 864 |
+
# NOTE: Image embeddings are split into separate tensors for each image
|
| 865 |
+
# by the size of each embedding.
|
| 866 |
+
feature_size = image_embeds.shape[1]
|
| 867 |
+
image_embeds = image_embeds.view(-1,
|
| 868 |
+
self.config.text_config.hidden_size)
|
| 869 |
+
image_feature_sizes = [
|
| 870 |
+
num_patches * feature_size for num_patches in patches_per_image
|
| 871 |
+
]
|
| 872 |
+
image_embeds = image_embeds.split(image_feature_sizes)
|
| 873 |
+
return image_embeds
|
| 874 |
+
|
| 875 |
+
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
| 876 |
+
if self.is_mono:
|
| 877 |
+
self.visual_token_mask = (
|
| 878 |
+
input_ids == self.img_context_token_id).reshape(-1, 1)
|
| 879 |
+
else:
|
| 880 |
+
self.visual_token_mask = None
|
| 881 |
+
|
| 882 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 883 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 884 |
+
if image_input is None:
|
| 885 |
+
return None
|
| 886 |
+
vision_embeddings = self._process_image_input(image_input)
|
| 887 |
+
return vision_embeddings
|
| 888 |
+
|
| 889 |
+
def get_input_embeddings(
|
| 890 |
+
self,
|
| 891 |
+
input_ids: torch.Tensor,
|
| 892 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 893 |
+
) -> torch.Tensor:
|
| 894 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 895 |
+
if multimodal_embeddings is not None:
|
| 896 |
+
assert self.img_context_token_id is not None
|
| 897 |
+
self._set_visual_token_mask(input_ids)
|
| 898 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 899 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 900 |
+
self.img_context_token_id)
|
| 901 |
+
return inputs_embeds
|
| 902 |
+
|
| 903 |
+
def forward(
|
| 904 |
+
self,
|
| 905 |
+
input_ids: torch.Tensor,
|
| 906 |
+
positions: torch.Tensor,
|
| 907 |
+
kv_caches: List[torch.Tensor],
|
| 908 |
+
attn_metadata: AttentionMetadata,
|
| 909 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 910 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 911 |
+
**kwargs: object,
|
| 912 |
+
) -> Union[SamplerOutput, IntermediateTensors]:
|
| 913 |
+
|
| 914 |
+
if intermediate_tensors is not None:
|
| 915 |
+
input_ids = None
|
| 916 |
+
inputs_embeds = None
|
| 917 |
+
|
| 918 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 919 |
+
# condition is for v0 compatibility.
|
| 920 |
+
elif inputs_embeds is None:
|
| 921 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 922 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 923 |
+
vision_embeddings)
|
| 924 |
+
input_ids = None
|
| 925 |
+
|
| 926 |
+
forward_kwargs = {
|
| 927 |
+
"input_ids": input_ids,
|
| 928 |
+
"positions": positions,
|
| 929 |
+
"kv_caches": kv_caches,
|
| 930 |
+
"attn_metadata": attn_metadata,
|
| 931 |
+
"intermediate_tensors": intermediate_tensors,
|
| 932 |
+
"inputs_embeds": inputs_embeds,
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
# Only required if the model is mono-architecture
|
| 936 |
+
if self.visual_token_mask is not None:
|
| 937 |
+
forward_kwargs.update(
|
| 938 |
+
{"visual_token_mask": self.visual_token_mask})
|
| 939 |
+
self.visual_token_mask = None
|
| 940 |
+
|
| 941 |
+
hidden_states = self.language_model.model(**forward_kwargs)
|
| 942 |
+
return hidden_states
|
| 943 |
+
|
| 944 |
+
def compute_logits(
|
| 945 |
+
self,
|
| 946 |
+
hidden_states: torch.Tensor,
|
| 947 |
+
sampling_metadata: SamplingMetadata,
|
| 948 |
+
) -> Optional[torch.Tensor]:
|
| 949 |
+
return self.language_model.compute_logits(hidden_states,
|
| 950 |
+
sampling_metadata)
|
| 951 |
+
|
| 952 |
+
def sample(
|
| 953 |
+
self,
|
| 954 |
+
logits: torch.Tensor,
|
| 955 |
+
sampling_metadata: SamplingMetadata,
|
| 956 |
+
) -> Optional[SamplerOutput]:
|
| 957 |
+
return self.language_model.sample(logits, sampling_metadata)
|
| 958 |
+
|
| 959 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 960 |
+
torch.Tensor]]) -> Set[str]:
|
| 961 |
+
loader = AutoWeightsLoader(self)
|
| 962 |
+
return loader.load_weights(weights)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/jamba.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Inference-only Jamba model."""
|
| 3 |
+
from typing import Iterable, List, Optional, Set, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import JambaConfig
|
| 8 |
+
|
| 9 |
+
from vllm.attention.backends.abstract import AttentionMetadata
|
| 10 |
+
from vllm.attention.layer import Attention
|
| 11 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 12 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 13 |
+
from vllm.distributed.parallel_state import get_pp_group
|
| 14 |
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
| 15 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 16 |
+
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
| 17 |
+
ReplicatedLinear,
|
| 18 |
+
RowParallelLinear)
|
| 19 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 20 |
+
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
| 21 |
+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
| 22 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 23 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 24 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 25 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 26 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 27 |
+
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
| 28 |
+
MambaCacheParams)
|
| 29 |
+
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
| 30 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 31 |
+
from vllm.sequence import IntermediateTensors, PoolerOutput
|
| 32 |
+
from vllm.utils import LayerBlockType
|
| 33 |
+
|
| 34 |
+
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
| 35 |
+
from .utils import (is_pp_missing_parameter,
|
| 36 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 37 |
+
maybe_prefix)
|
| 38 |
+
|
| 39 |
+
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class JambaMoE(nn.Module):
|
| 43 |
+
|
| 44 |
+
def __init__(self,
|
| 45 |
+
config: JambaConfig,
|
| 46 |
+
num_experts: Optional[int] = None,
|
| 47 |
+
top_k: Optional[int] = None,
|
| 48 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 49 |
+
tp_size: Optional[int] = None,
|
| 50 |
+
quant_config: Optional[QuantizationConfig] = None):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.num_total_experts = num_experts or config.num_experts
|
| 53 |
+
self.top_k = top_k or config.num_experts_per_tok
|
| 54 |
+
self.hidden_size = config.hidden_size
|
| 55 |
+
self.intermediate_size = config.intermediate_size
|
| 56 |
+
|
| 57 |
+
if self.num_total_experts > 1:
|
| 58 |
+
self.router = ReplicatedLinear(self.hidden_size,
|
| 59 |
+
self.num_total_experts,
|
| 60 |
+
bias=False,
|
| 61 |
+
quant_config=None,
|
| 62 |
+
params_dtype=params_dtype)
|
| 63 |
+
|
| 64 |
+
self.experts = FusedMoE(self.num_total_experts,
|
| 65 |
+
self.top_k,
|
| 66 |
+
self.hidden_size,
|
| 67 |
+
self.intermediate_size,
|
| 68 |
+
tp_size=tp_size,
|
| 69 |
+
params_dtype=params_dtype,
|
| 70 |
+
reduce_results=True,
|
| 71 |
+
renormalize=False,
|
| 72 |
+
use_grouped_topk=False,
|
| 73 |
+
quant_config=quant_config)
|
| 74 |
+
|
| 75 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
orig_shape = hidden_states.shape
|
| 77 |
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
| 78 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 79 |
+
if self.num_total_experts > 1:
|
| 80 |
+
router_logits, _ = self.router(hidden_states)
|
| 81 |
+
else:
|
| 82 |
+
router_logits = torch.ones((hidden_states.shape[0], 1),
|
| 83 |
+
device=hidden_states.device,
|
| 84 |
+
dtype=hidden_states.dtype)
|
| 85 |
+
hidden_states = self.experts(hidden_states, router_logits)
|
| 86 |
+
return hidden_states.view(orig_shape)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class JambaMLP(JambaMoE):
|
| 90 |
+
|
| 91 |
+
def __init__(self,
|
| 92 |
+
config: JambaConfig,
|
| 93 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 94 |
+
tp_size: Optional[int] = None,
|
| 95 |
+
quant_config: Optional[QuantizationConfig] = None):
|
| 96 |
+
super().__init__(config,
|
| 97 |
+
num_experts=1,
|
| 98 |
+
top_k=1,
|
| 99 |
+
params_dtype=params_dtype,
|
| 100 |
+
tp_size=tp_size,
|
| 101 |
+
quant_config=quant_config)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class JambaMambaDecoderLayer(nn.Module):
|
| 105 |
+
|
| 106 |
+
def __init__(self,
|
| 107 |
+
config: JambaConfig,
|
| 108 |
+
layer_idx: int,
|
| 109 |
+
cache_config: Optional[CacheConfig] = None,
|
| 110 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 111 |
+
is_lora_enabled: Optional[bool] = False,
|
| 112 |
+
**kwargs) -> None:
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.config = config
|
| 115 |
+
self.is_lora_enabled = is_lora_enabled
|
| 116 |
+
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
| 117 |
+
ssm_state_size = config.mamba_d_state,
|
| 118 |
+
conv_kernel_size = config.mamba_d_conv,
|
| 119 |
+
intermediate_size = config.mamba_expand *\
|
| 120 |
+
config.hidden_size,
|
| 121 |
+
time_step_rank = config.mamba_dt_rank,
|
| 122 |
+
use_conv_bias = config.mamba_conv_bias,
|
| 123 |
+
use_bias = config.mamba_proj_bias,
|
| 124 |
+
use_rms_norm=True,
|
| 125 |
+
rms_norm_eps=config.rms_norm_eps,
|
| 126 |
+
activation=config.hidden_act,
|
| 127 |
+
is_lora_enabled = self.is_lora_enabled
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
num_experts = config.layers_num_experts[layer_idx]
|
| 131 |
+
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
| 132 |
+
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
| 133 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 134 |
+
eps=config.rms_norm_eps)
|
| 135 |
+
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
| 136 |
+
eps=config.rms_norm_eps)
|
| 137 |
+
|
| 138 |
+
def forward(
|
| 139 |
+
self,
|
| 140 |
+
hidden_states: torch.Tensor,
|
| 141 |
+
attn_metadata: AttentionMetadata,
|
| 142 |
+
residual: Optional[torch.Tensor],
|
| 143 |
+
mamba_cache_params: MambaCacheParams,
|
| 144 |
+
**kwargs,
|
| 145 |
+
):
|
| 146 |
+
if residual is None:
|
| 147 |
+
residual = hidden_states
|
| 148 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 149 |
+
else:
|
| 150 |
+
hidden_states, residual = self.input_layernorm(
|
| 151 |
+
hidden_states, residual)
|
| 152 |
+
|
| 153 |
+
hidden_states = self.mamba(hidden_states, attn_metadata,
|
| 154 |
+
mamba_cache_params)
|
| 155 |
+
# Fully Connected
|
| 156 |
+
hidden_states, residual = self.pre_ff_layernorm(
|
| 157 |
+
hidden_states, residual)
|
| 158 |
+
hidden_states = self.feed_forward(hidden_states)
|
| 159 |
+
return hidden_states, residual
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class JambaAttentionDecoderLayer(nn.Module):
|
| 163 |
+
|
| 164 |
+
def __init__(self,
|
| 165 |
+
config: JambaConfig,
|
| 166 |
+
layer_idx: int,
|
| 167 |
+
cache_config: Optional[CacheConfig] = None,
|
| 168 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 169 |
+
prefix: str = "",
|
| 170 |
+
**kwargs) -> None:
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.hidden_size = config.hidden_size
|
| 173 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 174 |
+
self.total_num_heads = config.num_attention_heads
|
| 175 |
+
assert self.total_num_heads % tp_size == 0
|
| 176 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 177 |
+
self.total_num_kv_heads = config.num_key_value_heads
|
| 178 |
+
if self.total_num_kv_heads >= tp_size:
|
| 179 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 180 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 181 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 182 |
+
else:
|
| 183 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 184 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 185 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 186 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 187 |
+
self.head_dim = config.hidden_size // self.total_num_heads
|
| 188 |
+
self.q_size = self.num_heads * self.head_dim
|
| 189 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 190 |
+
self.scaling = self.head_dim**-0.5
|
| 191 |
+
|
| 192 |
+
self.qkv_proj = QKVParallelLinear(
|
| 193 |
+
config.hidden_size,
|
| 194 |
+
self.head_dim,
|
| 195 |
+
self.total_num_heads,
|
| 196 |
+
self.total_num_kv_heads,
|
| 197 |
+
bias=False,
|
| 198 |
+
quant_config=quant_config,
|
| 199 |
+
)
|
| 200 |
+
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
| 201 |
+
config.hidden_size,
|
| 202 |
+
bias=False,
|
| 203 |
+
quant_config=quant_config)
|
| 204 |
+
|
| 205 |
+
self.attn = Attention(
|
| 206 |
+
self.num_heads,
|
| 207 |
+
self.head_dim,
|
| 208 |
+
self.scaling,
|
| 209 |
+
num_kv_heads=self.num_kv_heads,
|
| 210 |
+
cache_config=cache_config,
|
| 211 |
+
prefix=f"{prefix}.attn",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
num_experts = config.layers_num_experts[layer_idx]
|
| 215 |
+
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
| 216 |
+
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
| 217 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 218 |
+
eps=config.rms_norm_eps)
|
| 219 |
+
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
| 220 |
+
eps=config.rms_norm_eps)
|
| 221 |
+
|
| 222 |
+
def self_attention(
|
| 223 |
+
self,
|
| 224 |
+
positions: torch.Tensor,
|
| 225 |
+
hidden_states: torch.Tensor,
|
| 226 |
+
kv_cache: torch.Tensor,
|
| 227 |
+
attn_metadata: AttentionMetadata,
|
| 228 |
+
**kwargs,
|
| 229 |
+
) -> torch.Tensor:
|
| 230 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 231 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 232 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 233 |
+
output, _ = self.o_proj(attn_output)
|
| 234 |
+
return output
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
positions: torch.Tensor,
|
| 239 |
+
hidden_states: torch.Tensor,
|
| 240 |
+
kv_cache: torch.Tensor,
|
| 241 |
+
attn_metadata: AttentionMetadata,
|
| 242 |
+
residual: Optional[torch.Tensor],
|
| 243 |
+
**kwargs,
|
| 244 |
+
):
|
| 245 |
+
if residual is None:
|
| 246 |
+
residual = hidden_states
|
| 247 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 248 |
+
else:
|
| 249 |
+
hidden_states, residual = self.input_layernorm(
|
| 250 |
+
hidden_states, residual)
|
| 251 |
+
|
| 252 |
+
hidden_states = self.self_attention(
|
| 253 |
+
positions=positions,
|
| 254 |
+
hidden_states=hidden_states,
|
| 255 |
+
kv_cache=kv_cache,
|
| 256 |
+
attn_metadata=attn_metadata,
|
| 257 |
+
)
|
| 258 |
+
# Fully Connected
|
| 259 |
+
hidden_states, residual = self.pre_ff_layernorm(
|
| 260 |
+
hidden_states, residual)
|
| 261 |
+
hidden_states = self.feed_forward(hidden_states)
|
| 262 |
+
return hidden_states, residual
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
ALL_DECODER_LAYER_TYPES = {
|
| 266 |
+
"attention": JambaAttentionDecoderLayer,
|
| 267 |
+
"mamba": JambaMambaDecoderLayer
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class JambaModel(nn.Module):
|
| 272 |
+
|
| 273 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 274 |
+
super().__init__()
|
| 275 |
+
|
| 276 |
+
config = vllm_config.model_config.hf_config
|
| 277 |
+
cache_config = vllm_config.cache_config
|
| 278 |
+
quant_config = vllm_config.quant_config
|
| 279 |
+
lora_config = vllm_config.lora_config
|
| 280 |
+
|
| 281 |
+
self.config = config
|
| 282 |
+
self.padding_idx = config.pad_token_id
|
| 283 |
+
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
| 284 |
+
(lora_config.max_loras or 1)) if lora_config else 0)
|
| 285 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 286 |
+
self.org_vocab_size = config.vocab_size
|
| 287 |
+
|
| 288 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 289 |
+
self.vocab_size,
|
| 290 |
+
config.hidden_size,
|
| 291 |
+
org_num_embeddings=config.vocab_size,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
| 295 |
+
|
| 296 |
+
def get_layer(prefix: str):
|
| 297 |
+
layer_idx = int(prefix.rsplit(".", 1)[1])
|
| 298 |
+
layer_class = ALL_DECODER_LAYER_TYPES[
|
| 299 |
+
config.layers_block_type[layer_idx]]
|
| 300 |
+
return layer_class(config,
|
| 301 |
+
layer_idx,
|
| 302 |
+
cache_config,
|
| 303 |
+
quant_config=quant_config,
|
| 304 |
+
prefix=prefix,
|
| 305 |
+
**extra_kwargs)
|
| 306 |
+
|
| 307 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 308 |
+
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
| 309 |
+
self.make_empty_intermediate_tensors = (
|
| 310 |
+
make_empty_intermediate_tensors_factory(
|
| 311 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 312 |
+
|
| 313 |
+
self.final_layernorm = RMSNorm(config.hidden_size,
|
| 314 |
+
eps=config.rms_norm_eps)
|
| 315 |
+
|
| 316 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 317 |
+
return self.embed_tokens(input_ids)
|
| 318 |
+
|
| 319 |
+
def forward(
|
| 320 |
+
self,
|
| 321 |
+
input_ids: torch.Tensor,
|
| 322 |
+
positions: torch.Tensor,
|
| 323 |
+
kv_caches: List[torch.Tensor],
|
| 324 |
+
attn_metadata: AttentionMetadata,
|
| 325 |
+
mamba_cache_params: MambaCacheParams,
|
| 326 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 327 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 328 |
+
) -> torch.Tensor:
|
| 329 |
+
if get_pp_group().is_first_rank:
|
| 330 |
+
if inputs_embeds is not None:
|
| 331 |
+
hidden_states = inputs_embeds
|
| 332 |
+
else:
|
| 333 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 334 |
+
residual = None
|
| 335 |
+
else:
|
| 336 |
+
assert intermediate_tensors is not None
|
| 337 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 338 |
+
residual = intermediate_tensors["residual"]
|
| 339 |
+
|
| 340 |
+
kv_cache_index = 0
|
| 341 |
+
mamba_cache_index = 0
|
| 342 |
+
for i in range(self.start_layer, self.end_layer):
|
| 343 |
+
layer = self.layers[i]
|
| 344 |
+
kv_cache = None
|
| 345 |
+
layer_mamba_cache_params = None
|
| 346 |
+
if isinstance(layer, JambaAttentionDecoderLayer):
|
| 347 |
+
kv_cache = kv_caches[kv_cache_index]
|
| 348 |
+
kv_cache_index += 1
|
| 349 |
+
if isinstance(layer, JambaMambaDecoderLayer):
|
| 350 |
+
current_state_layer = mamba_cache_index
|
| 351 |
+
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
| 352 |
+
current_state_layer)
|
| 353 |
+
mamba_cache_index += 1
|
| 354 |
+
|
| 355 |
+
hidden_states, residual = layer(
|
| 356 |
+
positions=positions,
|
| 357 |
+
hidden_states=hidden_states,
|
| 358 |
+
kv_cache=kv_cache,
|
| 359 |
+
attn_metadata=attn_metadata,
|
| 360 |
+
residual=residual,
|
| 361 |
+
mamba_cache_params=layer_mamba_cache_params)
|
| 362 |
+
if not get_pp_group().is_last_rank:
|
| 363 |
+
return IntermediateTensors({
|
| 364 |
+
"hidden_states": hidden_states,
|
| 365 |
+
"residual": residual
|
| 366 |
+
})
|
| 367 |
+
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
| 368 |
+
return hidden_states
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
| 372 |
+
IsHybrid):
|
| 373 |
+
packed_modules_mapping = {
|
| 374 |
+
"qkv_proj": [
|
| 375 |
+
"q_proj",
|
| 376 |
+
"k_proj",
|
| 377 |
+
"v_proj",
|
| 378 |
+
],
|
| 379 |
+
"in_proj": ["in_proj"],
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
# LoRA specific attributes
|
| 383 |
+
supported_lora_modules = [
|
| 384 |
+
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
|
| 385 |
+
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
|
| 386 |
+
]
|
| 387 |
+
embedding_modules = {
|
| 388 |
+
"embed_tokens": "input_embeddings",
|
| 389 |
+
"lm_head": "output_embeddings",
|
| 390 |
+
}
|
| 391 |
+
embedding_padding_modules = ["lm_head"]
|
| 392 |
+
|
| 393 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 394 |
+
config = vllm_config.model_config.hf_config
|
| 395 |
+
cache_config = vllm_config.cache_config
|
| 396 |
+
lora_config = vllm_config.lora_config
|
| 397 |
+
scheduler_config = vllm_config.scheduler_config
|
| 398 |
+
assert not cache_config.enable_prefix_caching, \
|
| 399 |
+
"Jamba currently does not support prefix caching"
|
| 400 |
+
|
| 401 |
+
super().__init__()
|
| 402 |
+
self.config = config
|
| 403 |
+
self.vllm_config = vllm_config
|
| 404 |
+
self.model_config = vllm_config.model_config
|
| 405 |
+
self.scheduler_config = scheduler_config
|
| 406 |
+
self.model = JambaModel(vllm_config=vllm_config,
|
| 407 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 408 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 409 |
+
if lora_config:
|
| 410 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 411 |
+
self.lm_head = ParallelLMHead(
|
| 412 |
+
self.unpadded_vocab_size,
|
| 413 |
+
config.hidden_size,
|
| 414 |
+
org_num_embeddings=config.vocab_size,
|
| 415 |
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
| 416 |
+
# We need bigger padding if using lora for kernel
|
| 417 |
+
# compatibility
|
| 418 |
+
if not lora_config else lora_config.lora_vocab_padding_size,
|
| 419 |
+
)
|
| 420 |
+
# Used to track and store by the Mamba cache between steps.
|
| 421 |
+
self.mamba_cache: Optional[MambaCacheManager] = None
|
| 422 |
+
|
| 423 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 424 |
+
config.vocab_size)
|
| 425 |
+
self.sampler = get_sampler()
|
| 426 |
+
|
| 427 |
+
self.make_empty_intermediate_tensors = (
|
| 428 |
+
self.model.make_empty_intermediate_tensors)
|
| 429 |
+
if self.scheduler_config is not None and \
|
| 430 |
+
not self.model_config.enforce_eager:
|
| 431 |
+
if self.scheduler_config.max_num_seqs > \
|
| 432 |
+
vllm_config.compilation_config.max_capture_size:
|
| 433 |
+
self.max_batch_size = \
|
| 434 |
+
vllm_config.compilation_config.max_capture_size
|
| 435 |
+
else:
|
| 436 |
+
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
| 437 |
+
self.scheduler_config.max_num_seqs)
|
| 438 |
+
else:
|
| 439 |
+
self.max_batch_size = 8192 + 2
|
| 440 |
+
|
| 441 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
return self.model.get_input_embeddings(input_ids)
|
| 443 |
+
|
| 444 |
+
def forward(self,
|
| 445 |
+
input_ids: torch.Tensor,
|
| 446 |
+
positions: torch.Tensor,
|
| 447 |
+
kv_caches: List[KVCache],
|
| 448 |
+
attn_metadata: AttentionMetadata,
|
| 449 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 450 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 451 |
+
**kwargs):
|
| 452 |
+
if self.mamba_cache is None:
|
| 453 |
+
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
| 454 |
+
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
| 455 |
+
self.mamba_cache = MambaCacheManager(
|
| 456 |
+
self.lm_head.weight.dtype, num_mamba_layers,
|
| 457 |
+
self.max_batch_size, *self._get_mamba_cache_shape())
|
| 458 |
+
(
|
| 459 |
+
mamba_cache_tensors,
|
| 460 |
+
state_indices_tensor,
|
| 461 |
+
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
| 462 |
+
**kwargs)
|
| 463 |
+
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
| 464 |
+
mamba_cache_tensors[1],
|
| 465 |
+
state_indices_tensor)
|
| 466 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 467 |
+
attn_metadata, mamba_cache_params,
|
| 468 |
+
intermediate_tensors, inputs_embeds)
|
| 469 |
+
return hidden_states
|
| 470 |
+
|
| 471 |
+
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
| 472 |
+
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
| 473 |
+
input_buffers, **kwargs)
|
| 474 |
+
|
| 475 |
+
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
| 476 |
+
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
| 477 |
+
|
| 478 |
+
def _get_mamba_cache_shape(
|
| 479 |
+
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
| 480 |
+
world_size = get_tensor_model_parallel_world_size()
|
| 481 |
+
hidden_size = self.config.hidden_size
|
| 482 |
+
conv_state_shape = (
|
| 483 |
+
self.config.mamba_expand * hidden_size // world_size,
|
| 484 |
+
self.config.mamba_d_conv - 1,
|
| 485 |
+
)
|
| 486 |
+
temporal_state_shape = (
|
| 487 |
+
self.config.mamba_expand * hidden_size // world_size,
|
| 488 |
+
self.config.mamba_d_state,
|
| 489 |
+
)
|
| 490 |
+
return conv_state_shape, temporal_state_shape
|
| 491 |
+
|
| 492 |
+
def compute_logits(
|
| 493 |
+
self,
|
| 494 |
+
hidden_states: torch.Tensor,
|
| 495 |
+
sampling_metadata: SamplingMetadata,
|
| 496 |
+
) -> Optional[torch.Tensor]:
|
| 497 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 498 |
+
sampling_metadata)
|
| 499 |
+
return logits
|
| 500 |
+
|
| 501 |
+
def sample(
|
| 502 |
+
self,
|
| 503 |
+
logits: Optional[torch.Tensor],
|
| 504 |
+
sampling_metadata: SamplingMetadata,
|
| 505 |
+
) -> Optional[SamplerOutput]:
|
| 506 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 507 |
+
return next_tokens
|
| 508 |
+
|
| 509 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 510 |
+
torch.Tensor]]) -> Set[str]:
|
| 511 |
+
stacked_params_mapping = [
|
| 512 |
+
# (param_name, shard_name, shard_id)
|
| 513 |
+
("qkv_proj", "q_proj", "q"),
|
| 514 |
+
("qkv_proj", "k_proj", "k"),
|
| 515 |
+
("qkv_proj", "v_proj", "v"),
|
| 516 |
+
]
|
| 517 |
+
|
| 518 |
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
| 519 |
+
# (param_name, weight_name, expert_id, shard_id)
|
| 520 |
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
| 521 |
+
ckpt_gate_proj_name="gate_proj",
|
| 522 |
+
ckpt_down_proj_name="down_proj",
|
| 523 |
+
ckpt_up_proj_name="up_proj",
|
| 524 |
+
num_experts=self.config.num_experts)
|
| 525 |
+
|
| 526 |
+
params_dict = dict(self.named_parameters())
|
| 527 |
+
loaded_params: Set[str] = set()
|
| 528 |
+
for name, loaded_weight in weights:
|
| 529 |
+
if "rotary_emb.inv_freq" in name:
|
| 530 |
+
continue
|
| 531 |
+
|
| 532 |
+
if "A_log" in name:
|
| 533 |
+
name = name.replace("A_log", "A")
|
| 534 |
+
|
| 535 |
+
if ".self_attn." in name:
|
| 536 |
+
name = name.replace(".self_attn", "")
|
| 537 |
+
|
| 538 |
+
if "feed_forward" in name and not _is_moe_layer(name):
|
| 539 |
+
## map MLP layers to expert with ID=0
|
| 540 |
+
name = name.replace("feed_forward", "feed_forward.experts.0")
|
| 541 |
+
|
| 542 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 543 |
+
if weight_name not in name:
|
| 544 |
+
continue
|
| 545 |
+
if 'experts' in name:
|
| 546 |
+
continue
|
| 547 |
+
name = name.replace(weight_name, param_name)
|
| 548 |
+
# Skip loading extra bias for GPTQ models.
|
| 549 |
+
|
| 550 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 551 |
+
continue
|
| 552 |
+
# Skip layers on other devices.
|
| 553 |
+
if is_pp_missing_parameter(name, self):
|
| 554 |
+
continue
|
| 555 |
+
param = params_dict[name]
|
| 556 |
+
weight_loader = param.weight_loader
|
| 557 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 558 |
+
break
|
| 559 |
+
else:
|
| 560 |
+
for (
|
| 561 |
+
param_name,
|
| 562 |
+
weight_name,
|
| 563 |
+
expert_id,
|
| 564 |
+
shard_id,
|
| 565 |
+
) in expert_params_mapping:
|
| 566 |
+
if weight_name not in name:
|
| 567 |
+
continue
|
| 568 |
+
|
| 569 |
+
if is_pp_missing_parameter(name, self):
|
| 570 |
+
continue
|
| 571 |
+
name = name.replace(weight_name, param_name)
|
| 572 |
+
param = params_dict[name]
|
| 573 |
+
weight_loader = param.weight_loader
|
| 574 |
+
weight_loader(param,
|
| 575 |
+
loaded_weight,
|
| 576 |
+
name,
|
| 577 |
+
shard_id=shard_id,
|
| 578 |
+
expert_id=expert_id)
|
| 579 |
+
break
|
| 580 |
+
else:
|
| 581 |
+
# Skip loading extra bias for GPTQ models.
|
| 582 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 583 |
+
continue
|
| 584 |
+
if is_pp_missing_parameter(name, self):
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
param = params_dict[name]
|
| 588 |
+
weight_loader = getattr(param, "weight_loader",
|
| 589 |
+
default_weight_loader)
|
| 590 |
+
weight_loader(param, loaded_weight)
|
| 591 |
+
loaded_params.add(name)
|
| 592 |
+
return loaded_params
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def _is_moe_layer(name: str):
|
| 596 |
+
return any(
|
| 597 |
+
[experts_name in name for experts_name in [
|
| 598 |
+
"experts",
|
| 599 |
+
"router",
|
| 600 |
+
]])
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class JambaForSequenceClassification(JambaForCausalLM):
|
| 604 |
+
|
| 605 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 606 |
+
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
| 607 |
+
config = vllm_config.model_config.hf_config
|
| 608 |
+
num_labels: int = config.num_labels
|
| 609 |
+
score_bias: bool = getattr(config, 'score_bias', False)
|
| 610 |
+
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
|
| 611 |
+
|
| 612 |
+
pooler_config = vllm_config.model_config.pooler_config
|
| 613 |
+
self._pooler = Pooler.from_config_with_defaults(
|
| 614 |
+
pooler_config,
|
| 615 |
+
pooling_type=PoolingType.LAST,
|
| 616 |
+
normalize=False,
|
| 617 |
+
softmax=False)
|
| 618 |
+
|
| 619 |
+
def pooler(
|
| 620 |
+
self,
|
| 621 |
+
hidden_states: torch.Tensor,
|
| 622 |
+
pooling_metadata: PoolingMetadata,
|
| 623 |
+
) -> Optional[PoolerOutput]:
|
| 624 |
+
hidden_states = hidden_states.float()
|
| 625 |
+
logits = self.score(hidden_states)
|
| 626 |
+
return self._pooler(logits, pooling_metadata)
|
| 627 |
+
|
| 628 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 629 |
+
# TODO: The reward weights themselves have float32 accuracy data, we
|
| 630 |
+
# would like to load them in fp32 to get that extra precision.
|
| 631 |
+
super().load_weights(weights)
|
| 632 |
+
self.score = self.score.float()
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/llama.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 9 |
+
# and OPT implementations in this library. It has been modified from its
|
| 10 |
+
# original forms to accommodate minor architectural differences compared
|
| 11 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
| 25 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import nn
|
| 29 |
+
from transformers import LlamaConfig
|
| 30 |
+
|
| 31 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 32 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 33 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 34 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 35 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 36 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 37 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 38 |
+
QKVParallelLinear,
|
| 39 |
+
RowParallelLinear)
|
| 40 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 41 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 42 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 43 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 44 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 45 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 46 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 47 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 48 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 49 |
+
from vllm.sequence import IntermediateTensors
|
| 50 |
+
|
| 51 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 52 |
+
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
| 53 |
+
is_pp_missing_parameter,
|
| 54 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 55 |
+
maybe_prefix)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class LlamaMLP(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
hidden_size: int,
|
| 63 |
+
intermediate_size: int,
|
| 64 |
+
hidden_act: str,
|
| 65 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 66 |
+
bias: bool = False,
|
| 67 |
+
prefix: str = "",
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 71 |
+
input_size=hidden_size,
|
| 72 |
+
output_sizes=[intermediate_size] * 2,
|
| 73 |
+
bias=bias,
|
| 74 |
+
quant_config=quant_config,
|
| 75 |
+
prefix=f"{prefix}.gate_up_proj",
|
| 76 |
+
)
|
| 77 |
+
self.down_proj = RowParallelLinear(
|
| 78 |
+
input_size=intermediate_size,
|
| 79 |
+
output_size=hidden_size,
|
| 80 |
+
bias=bias,
|
| 81 |
+
quant_config=quant_config,
|
| 82 |
+
prefix=f"{prefix}.down_proj",
|
| 83 |
+
)
|
| 84 |
+
if hidden_act != "silu":
|
| 85 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 86 |
+
"Only silu is supported for now.")
|
| 87 |
+
self.act_fn = SiluAndMul()
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x, _ = self.gate_up_proj(x)
|
| 91 |
+
x = self.act_fn(x)
|
| 92 |
+
x, _ = self.down_proj(x)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class LlamaAttention(nn.Module):
|
| 97 |
+
|
| 98 |
+
def __init__(self,
|
| 99 |
+
config: LlamaConfig,
|
| 100 |
+
hidden_size: int,
|
| 101 |
+
num_heads: int,
|
| 102 |
+
num_kv_heads: int,
|
| 103 |
+
rope_theta: float = 10000,
|
| 104 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 105 |
+
max_position_embeddings: int = 8192,
|
| 106 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 107 |
+
bias: bool = False,
|
| 108 |
+
bias_o_proj: bool = False,
|
| 109 |
+
cache_config: Optional[CacheConfig] = None,
|
| 110 |
+
prefix: str = "") -> None:
|
| 111 |
+
super().__init__()
|
| 112 |
+
layer_idx = extract_layer_index(prefix)
|
| 113 |
+
self.hidden_size = hidden_size
|
| 114 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 115 |
+
self.total_num_heads = num_heads
|
| 116 |
+
assert self.total_num_heads % tp_size == 0
|
| 117 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 118 |
+
self.total_num_kv_heads = num_kv_heads
|
| 119 |
+
if self.total_num_kv_heads >= tp_size:
|
| 120 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 121 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 122 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 123 |
+
else:
|
| 124 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 125 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 126 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 127 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 128 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 129 |
+
self.head_dim = getattr(config, "head_dim",
|
| 130 |
+
self.hidden_size // self.total_num_heads)
|
| 131 |
+
self.q_size = self.num_heads * self.head_dim
|
| 132 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 133 |
+
self.scaling = self.head_dim**-0.5
|
| 134 |
+
self.rope_theta = rope_theta
|
| 135 |
+
self.max_position_embeddings = max_position_embeddings
|
| 136 |
+
|
| 137 |
+
self.qkv_proj = QKVParallelLinear(
|
| 138 |
+
hidden_size=hidden_size,
|
| 139 |
+
head_size=self.head_dim,
|
| 140 |
+
total_num_heads=self.total_num_heads,
|
| 141 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 142 |
+
bias=bias,
|
| 143 |
+
quant_config=quant_config,
|
| 144 |
+
prefix=f"{prefix}.qkv_proj",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self.o_proj = RowParallelLinear(
|
| 148 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 149 |
+
output_size=hidden_size,
|
| 150 |
+
bias=bias_o_proj,
|
| 151 |
+
quant_config=quant_config,
|
| 152 |
+
prefix=f"{prefix}.o_proj",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
is_neox_style = True
|
| 156 |
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
| 157 |
+
if is_gguf and config.model_type == "llama":
|
| 158 |
+
is_neox_style = False
|
| 159 |
+
|
| 160 |
+
self.rotary_emb = get_rope(
|
| 161 |
+
self.head_dim,
|
| 162 |
+
rotary_dim=self.head_dim,
|
| 163 |
+
max_position=max_position_embeddings,
|
| 164 |
+
base=rope_theta,
|
| 165 |
+
rope_scaling=rope_scaling,
|
| 166 |
+
is_neox_style=is_neox_style,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if hasattr(config, "interleaved_sliding_window"):
|
| 170 |
+
interleaved_sliding_window = config.interleaved_sliding_window
|
| 171 |
+
if isinstance(interleaved_sliding_window, int):
|
| 172 |
+
sliding_window = interleaved_sliding_window
|
| 173 |
+
elif isinstance(interleaved_sliding_window, list):
|
| 174 |
+
sw_idx = layer_idx % len(interleaved_sliding_window)
|
| 175 |
+
sliding_window = interleaved_sliding_window[sw_idx]
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"{type(interleaved_sliding_window)} is not supported.")
|
| 179 |
+
else:
|
| 180 |
+
sliding_window = None
|
| 181 |
+
|
| 182 |
+
self.attn = Attention(
|
| 183 |
+
self.num_heads,
|
| 184 |
+
self.head_dim,
|
| 185 |
+
self.scaling,
|
| 186 |
+
num_kv_heads=self.num_kv_heads,
|
| 187 |
+
cache_config=cache_config,
|
| 188 |
+
quant_config=quant_config,
|
| 189 |
+
per_layer_sliding_window=sliding_window,
|
| 190 |
+
prefix=f"{prefix}.attn",
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
positions: torch.Tensor,
|
| 196 |
+
hidden_states: torch.Tensor,
|
| 197 |
+
kv_cache: torch.Tensor,
|
| 198 |
+
attn_metadata: AttentionMetadata,
|
| 199 |
+
) -> torch.Tensor:
|
| 200 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 201 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 202 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 203 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 204 |
+
output, _ = self.o_proj(attn_output)
|
| 205 |
+
return output
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class LlamaDecoderLayer(nn.Module):
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
config: LlamaConfig,
|
| 213 |
+
cache_config: Optional[CacheConfig] = None,
|
| 214 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 215 |
+
prefix: str = "",
|
| 216 |
+
) -> None:
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.hidden_size = config.hidden_size
|
| 219 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 220 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 221 |
+
if rope_scaling is not None and getattr(
|
| 222 |
+
config, "original_max_position_embeddings", None):
|
| 223 |
+
rope_scaling["original_max_position_embeddings"] = (
|
| 224 |
+
config.original_max_position_embeddings)
|
| 225 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 226 |
+
8192)
|
| 227 |
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
| 228 |
+
# Support internlm/internlm-7b with bias
|
| 229 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 230 |
+
config, "bias", False)
|
| 231 |
+
bias_o_proj = attention_bias
|
| 232 |
+
# support internlm/internlm3-8b with qkv_bias
|
| 233 |
+
if hasattr(config, 'qkv_bias'):
|
| 234 |
+
attention_bias = config.qkv_bias
|
| 235 |
+
|
| 236 |
+
self.self_attn = LlamaAttention(
|
| 237 |
+
config=config,
|
| 238 |
+
hidden_size=self.hidden_size,
|
| 239 |
+
num_heads=config.num_attention_heads,
|
| 240 |
+
num_kv_heads=getattr(config, "num_key_value_heads",
|
| 241 |
+
config.num_attention_heads),
|
| 242 |
+
rope_theta=rope_theta,
|
| 243 |
+
rope_scaling=rope_scaling,
|
| 244 |
+
max_position_embeddings=max_position_embeddings,
|
| 245 |
+
quant_config=quant_config,
|
| 246 |
+
bias=attention_bias,
|
| 247 |
+
bias_o_proj=bias_o_proj,
|
| 248 |
+
cache_config=cache_config,
|
| 249 |
+
prefix=f"{prefix}.self_attn",
|
| 250 |
+
)
|
| 251 |
+
self.mlp = LlamaMLP(
|
| 252 |
+
hidden_size=self.hidden_size,
|
| 253 |
+
intermediate_size=config.intermediate_size,
|
| 254 |
+
hidden_act=config.hidden_act,
|
| 255 |
+
quant_config=quant_config,
|
| 256 |
+
bias=getattr(config, "mlp_bias", False),
|
| 257 |
+
prefix=f"{prefix}.mlp",
|
| 258 |
+
)
|
| 259 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 260 |
+
eps=config.rms_norm_eps)
|
| 261 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 262 |
+
eps=config.rms_norm_eps)
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
positions: torch.Tensor,
|
| 267 |
+
hidden_states: torch.Tensor,
|
| 268 |
+
kv_cache: torch.Tensor,
|
| 269 |
+
attn_metadata: AttentionMetadata,
|
| 270 |
+
residual: Optional[torch.Tensor],
|
| 271 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 272 |
+
# Self Attention
|
| 273 |
+
if residual is None:
|
| 274 |
+
residual = hidden_states
|
| 275 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 276 |
+
else:
|
| 277 |
+
hidden_states, residual = self.input_layernorm(
|
| 278 |
+
hidden_states, residual)
|
| 279 |
+
hidden_states = self.self_attn(positions=positions,
|
| 280 |
+
hidden_states=hidden_states,
|
| 281 |
+
kv_cache=kv_cache,
|
| 282 |
+
attn_metadata=attn_metadata)
|
| 283 |
+
|
| 284 |
+
# Fully Connected
|
| 285 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 286 |
+
hidden_states, residual)
|
| 287 |
+
hidden_states = self.mlp(hidden_states)
|
| 288 |
+
return hidden_states, residual
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@support_torch_compile
|
| 292 |
+
class LlamaModel(nn.Module):
|
| 293 |
+
|
| 294 |
+
def __init__(self,
|
| 295 |
+
*,
|
| 296 |
+
vllm_config: VllmConfig,
|
| 297 |
+
prefix: str = "",
|
| 298 |
+
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
config = vllm_config.model_config.hf_config
|
| 302 |
+
cache_config = vllm_config.cache_config
|
| 303 |
+
quant_config = vllm_config.quant_config
|
| 304 |
+
lora_config = vllm_config.lora_config
|
| 305 |
+
|
| 306 |
+
self.config = config
|
| 307 |
+
self.quant_config = quant_config
|
| 308 |
+
self.padding_idx = config.pad_token_id
|
| 309 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 310 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 311 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 312 |
+
self.org_vocab_size = config.vocab_size
|
| 313 |
+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
| 314 |
+
and get_pp_group().is_last_rank):
|
| 315 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 316 |
+
self.vocab_size,
|
| 317 |
+
config.hidden_size,
|
| 318 |
+
org_num_embeddings=config.vocab_size,
|
| 319 |
+
quant_config=quant_config,
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
self.embed_tokens = PPMissingLayer()
|
| 323 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 324 |
+
config.num_hidden_layers,
|
| 325 |
+
lambda prefix: layer_type(config=config,
|
| 326 |
+
cache_config=cache_config,
|
| 327 |
+
quant_config=quant_config,
|
| 328 |
+
prefix=prefix),
|
| 329 |
+
prefix=f"{prefix}.layers",
|
| 330 |
+
)
|
| 331 |
+
if get_pp_group().is_last_rank:
|
| 332 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 333 |
+
else:
|
| 334 |
+
self.norm = PPMissingLayer()
|
| 335 |
+
|
| 336 |
+
self.make_empty_intermediate_tensors = (
|
| 337 |
+
make_empty_intermediate_tensors_factory(
|
| 338 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 339 |
+
|
| 340 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 341 |
+
return self.embed_tokens(input_ids)
|
| 342 |
+
|
| 343 |
+
def forward(
|
| 344 |
+
self,
|
| 345 |
+
input_ids: Optional[torch.Tensor],
|
| 346 |
+
positions: torch.Tensor,
|
| 347 |
+
kv_caches: List[torch.Tensor],
|
| 348 |
+
attn_metadata: AttentionMetadata,
|
| 349 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 350 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 351 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 352 |
+
if get_pp_group().is_first_rank:
|
| 353 |
+
if inputs_embeds is not None:
|
| 354 |
+
hidden_states = inputs_embeds
|
| 355 |
+
else:
|
| 356 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 357 |
+
residual = None
|
| 358 |
+
else:
|
| 359 |
+
assert intermediate_tensors is not None
|
| 360 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 361 |
+
residual = intermediate_tensors["residual"]
|
| 362 |
+
|
| 363 |
+
for i in range(self.start_layer, self.end_layer):
|
| 364 |
+
layer = self.layers[i]
|
| 365 |
+
hidden_states, residual = layer(positions, hidden_states,
|
| 366 |
+
kv_caches[i - self.start_layer],
|
| 367 |
+
attn_metadata, residual)
|
| 368 |
+
|
| 369 |
+
if not get_pp_group().is_last_rank:
|
| 370 |
+
return IntermediateTensors({
|
| 371 |
+
"hidden_states": hidden_states,
|
| 372 |
+
"residual": residual
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 376 |
+
return hidden_states
|
| 377 |
+
|
| 378 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 379 |
+
torch.Tensor]]) -> Set[str]:
|
| 380 |
+
stacked_params_mapping = [
|
| 381 |
+
# (param_name, shard_name, shard_id)
|
| 382 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 383 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 384 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 385 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 386 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 387 |
+
]
|
| 388 |
+
params_dict = dict(self.named_parameters())
|
| 389 |
+
loaded_params: Set[str] = set()
|
| 390 |
+
for name, loaded_weight in weights:
|
| 391 |
+
if "rotary_emb.inv_freq" in name:
|
| 392 |
+
continue
|
| 393 |
+
if ("rotary_emb.cos_cached" in name
|
| 394 |
+
or "rotary_emb.sin_cached" in name):
|
| 395 |
+
# Models trained using ColossalAI may include these tensors in
|
| 396 |
+
# the checkpoint. Skip them.
|
| 397 |
+
continue
|
| 398 |
+
if (self.quant_config is not None and
|
| 399 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 400 |
+
# Loading kv cache quantization scales
|
| 401 |
+
param = params_dict[scale_name]
|
| 402 |
+
weight_loader = getattr(param, "weight_loader",
|
| 403 |
+
default_weight_loader)
|
| 404 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 405 |
+
loaded_weight[0])
|
| 406 |
+
weight_loader(param, loaded_weight)
|
| 407 |
+
loaded_params.add(scale_name)
|
| 408 |
+
continue
|
| 409 |
+
if "scale" in name:
|
| 410 |
+
# Remapping the name of FP8 kv-scale.
|
| 411 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 412 |
+
if name is None:
|
| 413 |
+
continue
|
| 414 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 415 |
+
if weight_name not in name:
|
| 416 |
+
continue
|
| 417 |
+
name = name.replace(weight_name, param_name)
|
| 418 |
+
# Skip loading extra bias for GPTQ models.
|
| 419 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
if is_pp_missing_parameter(name, self):
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
param = params_dict[name]
|
| 426 |
+
weight_loader = param.weight_loader
|
| 427 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 428 |
+
break
|
| 429 |
+
else:
|
| 430 |
+
# Skip loading extra bias for GPTQ models.
|
| 431 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 432 |
+
continue
|
| 433 |
+
|
| 434 |
+
if is_pp_missing_parameter(name, self):
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
param = params_dict[name]
|
| 438 |
+
weight_loader = getattr(param, "weight_loader",
|
| 439 |
+
default_weight_loader)
|
| 440 |
+
weight_loader(param, loaded_weight)
|
| 441 |
+
loaded_params.add(name)
|
| 442 |
+
return loaded_params
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 446 |
+
packed_modules_mapping = {
|
| 447 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
| 448 |
+
"gate_up_proj": ["gate_proj", "up_proj"]
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
# LoRA specific attributes
|
| 452 |
+
supported_lora_modules = [
|
| 453 |
+
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
| 454 |
+
"lm_head"
|
| 455 |
+
]
|
| 456 |
+
embedding_modules = {
|
| 457 |
+
"embed_tokens": "input_embeddings",
|
| 458 |
+
"lm_head": "output_embeddings"
|
| 459 |
+
}
|
| 460 |
+
embedding_padding_modules = ["lm_head"]
|
| 461 |
+
|
| 462 |
+
# Mistral/Llama models can also be loaded with --load-format mistral
|
| 463 |
+
# from consolidated.safetensors checkpoints
|
| 464 |
+
mistral_mapping = {
|
| 465 |
+
"layers": "model.layers",
|
| 466 |
+
"attention": "self_attn",
|
| 467 |
+
"wq": "q_proj",
|
| 468 |
+
"wk": "k_proj",
|
| 469 |
+
"wv": "v_proj",
|
| 470 |
+
"wo": "o_proj",
|
| 471 |
+
"attention_norm": "input_layernorm",
|
| 472 |
+
"feed_forward": "mlp",
|
| 473 |
+
"w1": "gate_proj",
|
| 474 |
+
"w2": "down_proj",
|
| 475 |
+
"w3": "up_proj",
|
| 476 |
+
"ffn_norm": "post_attention_layernorm",
|
| 477 |
+
"tok_embeddings": "model.embed_tokens",
|
| 478 |
+
"output": "lm_head",
|
| 479 |
+
"norm": "model.norm"
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 483 |
+
super().__init__()
|
| 484 |
+
config = vllm_config.model_config.hf_config
|
| 485 |
+
quant_config = vllm_config.quant_config
|
| 486 |
+
lora_config = vllm_config.lora_config
|
| 487 |
+
self.config = config
|
| 488 |
+
self.lora_config = lora_config
|
| 489 |
+
|
| 490 |
+
self.model = self._init_model(vllm_config=vllm_config,
|
| 491 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 492 |
+
|
| 493 |
+
if get_pp_group().is_last_rank:
|
| 494 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 495 |
+
if lora_config:
|
| 496 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 497 |
+
self.lm_head = ParallelLMHead(
|
| 498 |
+
self.unpadded_vocab_size,
|
| 499 |
+
config.hidden_size,
|
| 500 |
+
org_num_embeddings=config.vocab_size,
|
| 501 |
+
padding_size=(
|
| 502 |
+
DEFAULT_VOCAB_PADDING_SIZE
|
| 503 |
+
# We need bigger padding if using lora for kernel
|
| 504 |
+
# compatibility
|
| 505 |
+
if not lora_config else
|
| 506 |
+
lora_config.lora_vocab_padding_size),
|
| 507 |
+
quant_config=quant_config,
|
| 508 |
+
prefix=maybe_prefix(prefix, "lm_head"),
|
| 509 |
+
)
|
| 510 |
+
if config.tie_word_embeddings:
|
| 511 |
+
self.lm_head = self.lm_head.tie_weights(
|
| 512 |
+
self.model.embed_tokens)
|
| 513 |
+
|
| 514 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 515 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 516 |
+
config.vocab_size,
|
| 517 |
+
logit_scale)
|
| 518 |
+
else:
|
| 519 |
+
self.lm_head = PPMissingLayer()
|
| 520 |
+
|
| 521 |
+
self.sampler = get_sampler()
|
| 522 |
+
|
| 523 |
+
self.make_empty_intermediate_tensors = (
|
| 524 |
+
self.model.make_empty_intermediate_tensors)
|
| 525 |
+
|
| 526 |
+
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
| 527 |
+
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
|
| 528 |
+
|
| 529 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 530 |
+
return self.model.get_input_embeddings(input_ids)
|
| 531 |
+
|
| 532 |
+
def forward(
|
| 533 |
+
self,
|
| 534 |
+
input_ids: torch.Tensor,
|
| 535 |
+
positions: torch.Tensor,
|
| 536 |
+
kv_caches: List[torch.Tensor],
|
| 537 |
+
attn_metadata: AttentionMetadata,
|
| 538 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 539 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 540 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 541 |
+
model_output = self.model(input_ids, positions, kv_caches,
|
| 542 |
+
attn_metadata, intermediate_tensors,
|
| 543 |
+
inputs_embeds)
|
| 544 |
+
return model_output
|
| 545 |
+
|
| 546 |
+
def compute_logits(
|
| 547 |
+
self,
|
| 548 |
+
hidden_states: torch.Tensor,
|
| 549 |
+
sampling_metadata: SamplingMetadata,
|
| 550 |
+
) -> Optional[torch.Tensor]:
|
| 551 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 552 |
+
sampling_metadata)
|
| 553 |
+
return logits
|
| 554 |
+
|
| 555 |
+
def sample(self, logits: torch.Tensor,
|
| 556 |
+
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
| 557 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 558 |
+
return next_tokens
|
| 559 |
+
|
| 560 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 561 |
+
torch.Tensor]]) -> Set[str]:
|
| 562 |
+
loader = AutoWeightsLoader(
|
| 563 |
+
self,
|
| 564 |
+
skip_prefixes=(["lm_head."]
|
| 565 |
+
if self.config.tie_word_embeddings else None),
|
| 566 |
+
)
|
| 567 |
+
return loader.load_weights(
|
| 568 |
+
self.maybe_remap_mistral(name, loaded_weight)
|
| 569 |
+
for name, loaded_weight in weights)
|
| 570 |
+
|
| 571 |
+
# This function is used to remap the mistral format as
|
| 572 |
+
# used by Mistral and Llama <=2
|
| 573 |
+
def maybe_remap_mistral(
|
| 574 |
+
self,
|
| 575 |
+
name: str,
|
| 576 |
+
loaded_weight: torch.Tensor,
|
| 577 |
+
) -> Tuple[str, torch.Tensor]:
|
| 578 |
+
|
| 579 |
+
def permute(w: torch.Tensor, n_heads: int):
|
| 580 |
+
attn_in = self.config.head_dim * n_heads
|
| 581 |
+
attn_out = self.config.hidden_size
|
| 582 |
+
|
| 583 |
+
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
| 584 |
+
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
| 585 |
+
|
| 586 |
+
mapping = self.mistral_mapping
|
| 587 |
+
modules = name.split(".")
|
| 588 |
+
|
| 589 |
+
# rotary embeds should be sliced
|
| 590 |
+
if "wk" in modules:
|
| 591 |
+
loaded_weight = permute(loaded_weight,
|
| 592 |
+
self.config.num_key_value_heads)
|
| 593 |
+
elif "wq" in modules:
|
| 594 |
+
loaded_weight = permute(loaded_weight,
|
| 595 |
+
self.config.num_attention_heads)
|
| 596 |
+
|
| 597 |
+
for item in modules:
|
| 598 |
+
if item in mapping and mapping[item] not in name:
|
| 599 |
+
name = name.replace(item, mapping[item])
|
| 600 |
+
|
| 601 |
+
return name, loaded_weight
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/llava.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from functools import cached_property
|
| 5 |
+
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
| 6 |
+
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from packaging.version import Version
|
| 11 |
+
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
| 12 |
+
PixtralVisionConfig, PretrainedConfig,
|
| 13 |
+
SiglipVisionConfig)
|
| 14 |
+
from transformers import __version__ as TRANSFORMERS_VERSION
|
| 15 |
+
from transformers.models.llava import LlavaProcessor
|
| 16 |
+
from transformers.models.pixtral import PixtralProcessor
|
| 17 |
+
|
| 18 |
+
from vllm.attention import AttentionMetadata
|
| 19 |
+
from vllm.config import VllmConfig
|
| 20 |
+
from vllm.inputs import InputProcessingContext
|
| 21 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 22 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 23 |
+
RowParallelLinear)
|
| 24 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 25 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 26 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 27 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 28 |
+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
| 29 |
+
MultiModalInputs, MultiModalKwargs,
|
| 30 |
+
NestedTensors)
|
| 31 |
+
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
| 32 |
+
ImageSize, MultiModalDataItems)
|
| 33 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 34 |
+
BaseProcessingInfo, ProcessingCache,
|
| 35 |
+
PromptReplacement)
|
| 36 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 37 |
+
from vllm.sequence import IntermediateTensors
|
| 38 |
+
|
| 39 |
+
from .clip import CLIPVisionModel
|
| 40 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 41 |
+
from .pixtral import (PixtralHFVisionModel,
|
| 42 |
+
get_pixtral_hf_image_feature_grid_size)
|
| 43 |
+
from .siglip import SiglipVisionModel
|
| 44 |
+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
| 45 |
+
maybe_prefix, merge_multimodal_embeddings)
|
| 46 |
+
from .vision import get_vision_encoder_info
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LlavaImagePixelInputs(TypedDict):
|
| 50 |
+
type: Literal["pixel_values"]
|
| 51 |
+
data: Union[torch.Tensor, List[torch.Tensor]]
|
| 52 |
+
"""
|
| 53 |
+
Shape: `(batch_size * num_images, num_channels, height, width)`
|
| 54 |
+
|
| 55 |
+
Note that `height` or `width` may be different per batch and image,
|
| 56 |
+
in which case the data is passed as a list instead of a batched tensor.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LlavaImageEmbeddingInputs(TypedDict):
|
| 61 |
+
type: Literal["image_embeds"]
|
| 62 |
+
data: torch.Tensor
|
| 63 |
+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
| 64 |
+
|
| 65 |
+
`hidden_size` must match the hidden size of language model backbone.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LlavaMultiModalProjector(nn.Module):
|
| 73 |
+
|
| 74 |
+
def __init__(self,
|
| 75 |
+
vision_hidden_size: int,
|
| 76 |
+
text_hidden_size: int,
|
| 77 |
+
projector_hidden_act: str,
|
| 78 |
+
multimodal_projector_bias: bool,
|
| 79 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 80 |
+
prefix: str = ""):
|
| 81 |
+
super().__init__()
|
| 82 |
+
|
| 83 |
+
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
|
| 84 |
+
text_hidden_size,
|
| 85 |
+
bias=multimodal_projector_bias,
|
| 86 |
+
quant_config=quant_config,
|
| 87 |
+
prefix=f"{prefix}.linear_1")
|
| 88 |
+
self.act = get_act_fn(projector_hidden_act)
|
| 89 |
+
self.linear_2 = RowParallelLinear(text_hidden_size,
|
| 90 |
+
text_hidden_size,
|
| 91 |
+
bias=multimodal_projector_bias,
|
| 92 |
+
quant_config=quant_config,
|
| 93 |
+
prefix=f"{prefix}.linear_2")
|
| 94 |
+
|
| 95 |
+
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
| 96 |
+
hidden_states, _ = self.linear_1(image_features)
|
| 97 |
+
hidden_states = self.act(hidden_states)
|
| 98 |
+
hidden_states, _ = self.linear_2(hidden_states)
|
| 99 |
+
return hidden_states
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class LlavaLikeConfig(Protocol):
|
| 103 |
+
vision_config: Final[PretrainedConfig]
|
| 104 |
+
image_token_index: Final[int]
|
| 105 |
+
vision_feature_select_strategy: Final[str]
|
| 106 |
+
vision_feature_layer: Final[Union[int, list[int]]]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class LlavaLikeProcessor(Protocol):
|
| 110 |
+
image_token: Final[str]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class BaseLlavaProcessingInfo(BaseProcessingInfo):
|
| 114 |
+
|
| 115 |
+
def get_hf_config(self) -> LlavaLikeConfig:
|
| 116 |
+
return self.ctx.get_hf_config(LlavaConfig)
|
| 117 |
+
|
| 118 |
+
def get_vision_encoder_info(self):
|
| 119 |
+
return get_vision_encoder_info(self.get_hf_config())
|
| 120 |
+
|
| 121 |
+
@abstractmethod
|
| 122 |
+
def get_hf_processor(self) -> LlavaLikeProcessor:
|
| 123 |
+
raise NotImplementedError
|
| 124 |
+
|
| 125 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 126 |
+
return {"image": None}
|
| 127 |
+
|
| 128 |
+
def get_mm_max_tokens_per_item(
|
| 129 |
+
self,
|
| 130 |
+
seq_len: int,
|
| 131 |
+
mm_counts: Mapping[str, int],
|
| 132 |
+
) -> Mapping[str, int]:
|
| 133 |
+
return {"image": self.get_max_image_tokens()}
|
| 134 |
+
|
| 135 |
+
def _apply_feature_select_strategy(
|
| 136 |
+
self,
|
| 137 |
+
strategy: str,
|
| 138 |
+
encoder_num_image_tokens: int,
|
| 139 |
+
) -> int:
|
| 140 |
+
if strategy == "default":
|
| 141 |
+
return encoder_num_image_tokens - 1
|
| 142 |
+
if strategy == "full":
|
| 143 |
+
return encoder_num_image_tokens
|
| 144 |
+
|
| 145 |
+
msg = f"Unexpected feature select strategy: {strategy!r}"
|
| 146 |
+
raise NotImplementedError(msg)
|
| 147 |
+
|
| 148 |
+
def get_num_image_tokens(
|
| 149 |
+
self,
|
| 150 |
+
*,
|
| 151 |
+
image_width: int,
|
| 152 |
+
image_height: int,
|
| 153 |
+
) -> int:
|
| 154 |
+
hf_config = self.get_hf_config()
|
| 155 |
+
vision_encoder_info = self.get_vision_encoder_info()
|
| 156 |
+
|
| 157 |
+
return self._apply_feature_select_strategy(
|
| 158 |
+
hf_config.vision_feature_select_strategy,
|
| 159 |
+
vision_encoder_info.get_num_image_tokens(
|
| 160 |
+
image_width=image_width,
|
| 161 |
+
image_height=image_height,
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def get_image_size_with_most_features(self) -> ImageSize:
|
| 166 |
+
vision_encoder_info = self.get_vision_encoder_info()
|
| 167 |
+
width = height = vision_encoder_info.get_image_size()
|
| 168 |
+
return ImageSize(width=width, height=height)
|
| 169 |
+
|
| 170 |
+
def get_max_image_tokens(self) -> int:
|
| 171 |
+
target_width, target_height = self.get_image_size_with_most_features()
|
| 172 |
+
|
| 173 |
+
return self.get_num_image_tokens(
|
| 174 |
+
image_width=target_width,
|
| 175 |
+
image_height=target_height,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
| 183 |
+
|
| 184 |
+
def get_dummy_processor_inputs(
|
| 185 |
+
self,
|
| 186 |
+
seq_len: int,
|
| 187 |
+
mm_counts: Mapping[str, int],
|
| 188 |
+
) -> ProcessorInputs:
|
| 189 |
+
num_images = mm_counts.get("image", 0)
|
| 190 |
+
|
| 191 |
+
processor = self.info.get_hf_processor()
|
| 192 |
+
image_token = processor.image_token
|
| 193 |
+
target_width, target_height = \
|
| 194 |
+
self.info.get_image_size_with_most_features()
|
| 195 |
+
|
| 196 |
+
mm_data = {
|
| 197 |
+
"image":
|
| 198 |
+
self._get_dummy_images(width=target_width,
|
| 199 |
+
height=target_height,
|
| 200 |
+
num_images=num_images)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
return ProcessorInputs(
|
| 204 |
+
prompt_text=image_token * num_images,
|
| 205 |
+
mm_data=mm_data,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
|
| 210 |
+
|
| 211 |
+
def get_hf_processor(self):
|
| 212 |
+
return self.ctx.get_hf_processor(LlavaProcessor)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
| 216 |
+
|
| 217 |
+
# Copied from BaseMultiModalProcessor
|
| 218 |
+
@abstractmethod
|
| 219 |
+
def _get_mm_fields_config(
|
| 220 |
+
self,
|
| 221 |
+
hf_inputs: BatchFeature,
|
| 222 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 223 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 224 |
+
raise NotImplementedError
|
| 225 |
+
|
| 226 |
+
def _get_prompt_replacements(
|
| 227 |
+
self,
|
| 228 |
+
mm_items: MultiModalDataItems,
|
| 229 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 230 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 231 |
+
) -> list[PromptReplacement]:
|
| 232 |
+
hf_config = self.info.get_hf_config()
|
| 233 |
+
image_token_id = hf_config.image_token_index
|
| 234 |
+
|
| 235 |
+
def get_replacement(item_idx: int):
|
| 236 |
+
images = mm_items.get_items(
|
| 237 |
+
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
| 238 |
+
|
| 239 |
+
if isinstance(images, ImageEmbeddingItems):
|
| 240 |
+
num_image_tokens = images.get_feature_size(item_idx)
|
| 241 |
+
else:
|
| 242 |
+
image_size = images.get_image_size(item_idx)
|
| 243 |
+
num_image_tokens = self.info.get_num_image_tokens(
|
| 244 |
+
image_width=image_size.width,
|
| 245 |
+
image_height=image_size.height,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return [image_token_id] * num_image_tokens
|
| 249 |
+
|
| 250 |
+
return [
|
| 251 |
+
PromptReplacement(
|
| 252 |
+
modality="image",
|
| 253 |
+
target=[image_token_id],
|
| 254 |
+
replacement=get_replacement,
|
| 255 |
+
),
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class LlavaMultiModalProcessor(
|
| 260 |
+
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
|
| 261 |
+
|
| 262 |
+
def _get_mm_fields_config(
|
| 263 |
+
self,
|
| 264 |
+
hf_inputs: BatchFeature,
|
| 265 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 266 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 267 |
+
return dict(
|
| 268 |
+
pixel_values=MultiModalFieldConfig.batched("image"),
|
| 269 |
+
image_embeds=MultiModalFieldConfig.batched("image"),
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
|
| 274 |
+
|
| 275 |
+
def get_hf_processor(self):
|
| 276 |
+
return self.ctx.get_hf_processor(PixtralProcessor)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class PixtralHFMultiModalProcessor(
|
| 280 |
+
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
|
| 281 |
+
|
| 282 |
+
def _call_hf_processor(
|
| 283 |
+
self,
|
| 284 |
+
prompt: str,
|
| 285 |
+
mm_data: Mapping[str, object],
|
| 286 |
+
mm_kwargs: Mapping[str, object],
|
| 287 |
+
) -> BatchFeature:
|
| 288 |
+
processed_outputs = super()._call_hf_processor(
|
| 289 |
+
prompt=prompt,
|
| 290 |
+
mm_data=mm_data,
|
| 291 |
+
mm_kwargs=mm_kwargs,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
pixel_values = processed_outputs.get("pixel_values")
|
| 295 |
+
if pixel_values is not None:
|
| 296 |
+
# Before/after https://github.com/huggingface/transformers/pull/35122
|
| 297 |
+
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
|
| 298 |
+
images = mm_data["images"]
|
| 299 |
+
assert isinstance(images, list)
|
| 300 |
+
|
| 301 |
+
# Original output: (1, num_images, C, H, W)
|
| 302 |
+
# New output: (num_images, C, H, W)
|
| 303 |
+
assert (isinstance(pixel_values, list)
|
| 304 |
+
and len(pixel_values) == 1)
|
| 305 |
+
assert (isinstance(pixel_values[0], list)
|
| 306 |
+
and len(pixel_values[0]) == len(images))
|
| 307 |
+
|
| 308 |
+
processed_outputs["pixel_values"] = pixel_values[0]
|
| 309 |
+
else:
|
| 310 |
+
# Avoid padding since we need the output for each image to be
|
| 311 |
+
# independent of other images for the cache to work correctly
|
| 312 |
+
image_sizes = processed_outputs["image_sizes"]
|
| 313 |
+
assert len(pixel_values) == len(image_sizes)
|
| 314 |
+
|
| 315 |
+
processed_outputs["pixel_values"] = [
|
| 316 |
+
p[:, :h, :w]
|
| 317 |
+
for p, (h, w) in zip(pixel_values, image_sizes)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
return processed_outputs
|
| 321 |
+
|
| 322 |
+
def _get_mm_fields_config(
|
| 323 |
+
self,
|
| 324 |
+
hf_inputs: BatchFeature,
|
| 325 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 326 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 327 |
+
return dict(
|
| 328 |
+
pixel_values=MultiModalFieldConfig.batched("image"),
|
| 329 |
+
image_embeds=MultiModalFieldConfig.batched("image"),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def _get_prompt_replacements(
|
| 333 |
+
self,
|
| 334 |
+
mm_items: MultiModalDataItems,
|
| 335 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 336 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 337 |
+
) -> list[PromptReplacement]:
|
| 338 |
+
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 339 |
+
hf_config = self.info.get_hf_config()
|
| 340 |
+
tokenizer = self.info.get_tokenizer()
|
| 341 |
+
vocab = tokenizer.get_vocab()
|
| 342 |
+
|
| 343 |
+
image_break_id = vocab[processor.image_break_token]
|
| 344 |
+
image_token_id = hf_config.image_token_index
|
| 345 |
+
image_end_id = vocab[processor.image_end_token]
|
| 346 |
+
|
| 347 |
+
vision_config = hf_config.vision_config
|
| 348 |
+
assert isinstance(vision_config, PixtralVisionConfig)
|
| 349 |
+
|
| 350 |
+
def get_replacement(item_idx: int):
|
| 351 |
+
images = mm_items.get_items("image", ImageProcessorItems)
|
| 352 |
+
image_size = images.get_image_size(item_idx)
|
| 353 |
+
|
| 354 |
+
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
|
| 355 |
+
vision_config,
|
| 356 |
+
image_width=image_size.width,
|
| 357 |
+
image_height=image_size.height,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
| 361 |
+
tokens[-1] = image_end_id
|
| 362 |
+
|
| 363 |
+
return tokens
|
| 364 |
+
|
| 365 |
+
return [
|
| 366 |
+
PromptReplacement(
|
| 367 |
+
modality="image",
|
| 368 |
+
target=[image_token_id],
|
| 369 |
+
replacement=get_replacement,
|
| 370 |
+
),
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _build_llava_or_pixtral_hf_info(
|
| 375 |
+
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
|
| 376 |
+
hf_config = ctx.get_hf_config(LlavaConfig)
|
| 377 |
+
|
| 378 |
+
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
| 379 |
+
return PixtralHFProcessingInfo(ctx)
|
| 380 |
+
|
| 381 |
+
return LlavaProcessingInfo(ctx)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _build_llava_or_pixtral_hf_processor(
|
| 385 |
+
info: _I,
|
| 386 |
+
dummy_inputs: BaseDummyInputsBuilder[_I],
|
| 387 |
+
*,
|
| 388 |
+
cache: Optional[ProcessingCache] = None,
|
| 389 |
+
enable_sanity_checks: bool = True,
|
| 390 |
+
) -> BaseMultiModalProcessor:
|
| 391 |
+
if isinstance(info, PixtralHFProcessingInfo):
|
| 392 |
+
return PixtralHFMultiModalProcessor(
|
| 393 |
+
info,
|
| 394 |
+
dummy_inputs, # type: ignore
|
| 395 |
+
cache=cache,
|
| 396 |
+
enable_sanity_checks=enable_sanity_checks,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
if isinstance(info, LlavaProcessingInfo):
|
| 400 |
+
return LlavaMultiModalProcessor(
|
| 401 |
+
info,
|
| 402 |
+
dummy_inputs, # type: ignore
|
| 403 |
+
cache=cache,
|
| 404 |
+
enable_sanity_checks=enable_sanity_checks,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
raise NotImplementedError(type(info))
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
| 411 |
+
"""Determine the number of hidden layers to initialize up to in the
|
| 412 |
+
visual encoder.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
hf_config: Model config with vision feature layer(s).
|
| 416 |
+
"""
|
| 417 |
+
feature_layers = hf_config.vision_feature_layer
|
| 418 |
+
num_hidden_layers = hf_config.vision_config.num_hidden_layers
|
| 419 |
+
# If we have one feature layer, initialize up to that layer
|
| 420 |
+
if isinstance(feature_layers, int):
|
| 421 |
+
return _get_layer_index(feature_layers, num_hidden_layers)
|
| 422 |
+
# If we have multiple feature layers, initialize up to the deepest one
|
| 423 |
+
elif isinstance(feature_layers, (list, tuple)):
|
| 424 |
+
return max(
|
| 425 |
+
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
|
| 426 |
+
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
|
| 427 |
+
" is not supported")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
| 431 |
+
"""Given an signed vision feature layer, get the number of hidden layers
|
| 432 |
+
needed to leverage it.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
feature_layer_index: Index of a required layer in the visual encoder.
|
| 436 |
+
num_hidden_layers: The total number of hidden layers in the visual
|
| 437 |
+
encoder.
|
| 438 |
+
"""
|
| 439 |
+
if feature_layer_index < 0:
|
| 440 |
+
return num_hidden_layers + feature_layer_index + 1
|
| 441 |
+
return feature_layer_index + 1
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def init_vision_tower_for_llava(
|
| 445 |
+
hf_config: LlavaLikeConfig,
|
| 446 |
+
quant_config: Optional[QuantizationConfig],
|
| 447 |
+
*,
|
| 448 |
+
require_post_norm: Optional[bool] = None,
|
| 449 |
+
prefix: str = "",
|
| 450 |
+
):
|
| 451 |
+
vision_config = hf_config.vision_config
|
| 452 |
+
|
| 453 |
+
# Initialize the vision tower only up to the deepest required feature layer
|
| 454 |
+
num_hidden_layers = _get_num_hidden_layers(hf_config)
|
| 455 |
+
|
| 456 |
+
if isinstance(vision_config, CLIPVisionConfig):
|
| 457 |
+
return CLIPVisionModel(
|
| 458 |
+
vision_config,
|
| 459 |
+
quant_config=quant_config,
|
| 460 |
+
num_hidden_layers_override=num_hidden_layers,
|
| 461 |
+
require_post_norm=require_post_norm,
|
| 462 |
+
prefix=prefix,
|
| 463 |
+
)
|
| 464 |
+
elif isinstance(vision_config, SiglipVisionConfig):
|
| 465 |
+
return SiglipVisionModel(
|
| 466 |
+
vision_config,
|
| 467 |
+
quant_config=quant_config,
|
| 468 |
+
num_hidden_layers_override=num_hidden_layers,
|
| 469 |
+
require_post_norm=require_post_norm,
|
| 470 |
+
prefix=prefix,
|
| 471 |
+
)
|
| 472 |
+
elif isinstance(vision_config, PixtralVisionConfig):
|
| 473 |
+
return PixtralHFVisionModel(
|
| 474 |
+
vision_config,
|
| 475 |
+
quant_config=quant_config,
|
| 476 |
+
num_hidden_layers_override=num_hidden_layers,
|
| 477 |
+
require_post_norm=require_post_norm,
|
| 478 |
+
prefix=prefix,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
msg = f"Unsupported vision config: {type(vision_config)}"
|
| 482 |
+
raise NotImplementedError(msg)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
|
| 486 |
+
info=_build_llava_or_pixtral_hf_info,
|
| 487 |
+
dummy_inputs=LlavaDummyInputsBuilder)
|
| 488 |
+
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
| 489 |
+
|
| 490 |
+
packed_modules_mapping = {
|
| 491 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
| 492 |
+
"gate_up_proj": ["gate_proj", "up_proj"]
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
| 496 |
+
super().__init__()
|
| 497 |
+
|
| 498 |
+
config = vllm_config.model_config.hf_config
|
| 499 |
+
quant_config = vllm_config.quant_config
|
| 500 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 501 |
+
|
| 502 |
+
self.config = config
|
| 503 |
+
self.multimodal_config = multimodal_config
|
| 504 |
+
|
| 505 |
+
# NOTE: These are special cases for Pixtral-12B in the HF-format
|
| 506 |
+
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
|
| 507 |
+
if (config.text_config.architectures is None
|
| 508 |
+
and config.text_config.model_type == "mistral"):
|
| 509 |
+
config.text_config.architectures = ["MistralForCausalLM"]
|
| 510 |
+
if (config.projector_hidden_act is None
|
| 511 |
+
and config.vision_config.hidden_act == "gelu"):
|
| 512 |
+
config.projector_hidden_act = "gelu"
|
| 513 |
+
|
| 514 |
+
# TODO: Optionally initializes this for supporting embeddings.
|
| 515 |
+
self.vision_tower = init_vision_tower_for_llava(
|
| 516 |
+
config,
|
| 517 |
+
quant_config,
|
| 518 |
+
require_post_norm=False,
|
| 519 |
+
prefix=maybe_prefix(prefix, "vision_tower"))
|
| 520 |
+
self.multi_modal_projector = LlavaMultiModalProjector(
|
| 521 |
+
vision_hidden_size=config.vision_config.hidden_size,
|
| 522 |
+
text_hidden_size=config.text_config.hidden_size,
|
| 523 |
+
projector_hidden_act=config.projector_hidden_act,
|
| 524 |
+
multimodal_projector_bias=config.multimodal_projector_bias,
|
| 525 |
+
quant_config=quant_config,
|
| 526 |
+
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
| 527 |
+
|
| 528 |
+
self.language_model = init_vllm_registered_model(
|
| 529 |
+
vllm_config=vllm_config,
|
| 530 |
+
hf_config=config.text_config,
|
| 531 |
+
prefix=maybe_prefix(prefix, "language_model"),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
self.make_empty_intermediate_tensors = (
|
| 535 |
+
self.language_model.make_empty_intermediate_tensors)
|
| 536 |
+
|
| 537 |
+
@cached_property
|
| 538 |
+
def sampler(self):
|
| 539 |
+
if hasattr(self.language_model, "sampler"):
|
| 540 |
+
return self.language_model.sampler
|
| 541 |
+
|
| 542 |
+
return get_sampler()
|
| 543 |
+
|
| 544 |
+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
| 545 |
+
h = w = self.config.vision_config.image_size
|
| 546 |
+
expected_dims = (3, h, w)
|
| 547 |
+
actual_dims = tuple(data.shape[1:])
|
| 548 |
+
|
| 549 |
+
if actual_dims != expected_dims:
|
| 550 |
+
expected_expr = ("batch_size", *map(str, expected_dims))
|
| 551 |
+
raise ValueError(
|
| 552 |
+
f"The expected shape of pixel values is {expected_expr}. "
|
| 553 |
+
f"You supplied {tuple(data.shape)}.")
|
| 554 |
+
|
| 555 |
+
return data
|
| 556 |
+
|
| 557 |
+
def _parse_and_validate_image_input(
|
| 558 |
+
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
| 559 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 560 |
+
image_embeds = kwargs.pop("image_embeds", None)
|
| 561 |
+
|
| 562 |
+
if pixel_values is None and image_embeds is None:
|
| 563 |
+
return None
|
| 564 |
+
|
| 565 |
+
if pixel_values is not None:
|
| 566 |
+
if not isinstance(pixel_values, (torch.Tensor, list)):
|
| 567 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 568 |
+
f"Got type: {type(pixel_values)}")
|
| 569 |
+
|
| 570 |
+
if self.config.vision_config.model_type == "pixtral":
|
| 571 |
+
return LlavaImagePixelInputs(
|
| 572 |
+
type="pixel_values",
|
| 573 |
+
data=flatten_bn(pixel_values),
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
return LlavaImagePixelInputs(
|
| 577 |
+
type="pixel_values",
|
| 578 |
+
data=self._validate_pixel_values(
|
| 579 |
+
flatten_bn(pixel_values, concat=True)),
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
if image_embeds is not None:
|
| 583 |
+
if not isinstance(image_embeds, (torch.Tensor, list)):
|
| 584 |
+
raise ValueError("Incorrect type of image embeddings. "
|
| 585 |
+
f"Got type: {type(image_embeds)}")
|
| 586 |
+
|
| 587 |
+
return LlavaImageEmbeddingInputs(
|
| 588 |
+
type="image_embeds",
|
| 589 |
+
data=flatten_bn(image_embeds, concat=True),
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
raise AssertionError("This line should be unreachable.")
|
| 593 |
+
|
| 594 |
+
def _select_image_features(self, image_features: torch.Tensor, *,
|
| 595 |
+
strategy: str) -> torch.Tensor:
|
| 596 |
+
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
| 597 |
+
if strategy == "default":
|
| 598 |
+
return image_features[:, 1:]
|
| 599 |
+
elif strategy == "full":
|
| 600 |
+
return image_features
|
| 601 |
+
|
| 602 |
+
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
| 603 |
+
|
| 604 |
+
def _image_pixels_to_features(
|
| 605 |
+
self,
|
| 606 |
+
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
| 607 |
+
PixtralHFVisionModel],
|
| 608 |
+
pixel_values: torch.Tensor,
|
| 609 |
+
) -> torch.Tensor:
|
| 610 |
+
|
| 611 |
+
# NOTE: we skip the step to select the vision feature layer since
|
| 612 |
+
# this is already done inside the vision tower
|
| 613 |
+
image_features = vision_tower(pixel_values)
|
| 614 |
+
|
| 615 |
+
return self._select_image_features(
|
| 616 |
+
image_features,
|
| 617 |
+
strategy=self.config.vision_feature_select_strategy,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
def _process_image_pixels(self,
|
| 621 |
+
inputs: LlavaImagePixelInputs) -> torch.Tensor:
|
| 622 |
+
assert self.vision_tower is not None
|
| 623 |
+
|
| 624 |
+
pixel_values = inputs["data"]
|
| 625 |
+
|
| 626 |
+
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
| 627 |
+
|
| 628 |
+
def _process_image_input(self,
|
| 629 |
+
image_input: LlavaImageInputs) -> torch.Tensor:
|
| 630 |
+
|
| 631 |
+
if image_input["type"] == "image_embeds":
|
| 632 |
+
return image_input["data"]
|
| 633 |
+
|
| 634 |
+
assert self.vision_tower is not None
|
| 635 |
+
image_features = self._process_image_pixels(image_input)
|
| 636 |
+
return self.multi_modal_projector(image_features)
|
| 637 |
+
|
| 638 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 639 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 640 |
+
if image_input is None:
|
| 641 |
+
return None
|
| 642 |
+
vision_embeddings = self._process_image_input(image_input)
|
| 643 |
+
return vision_embeddings
|
| 644 |
+
|
| 645 |
+
def get_input_embeddings(
|
| 646 |
+
self,
|
| 647 |
+
input_ids: torch.Tensor,
|
| 648 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 649 |
+
) -> torch.Tensor:
|
| 650 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 651 |
+
if multimodal_embeddings is not None:
|
| 652 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 653 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 654 |
+
self.config.image_token_index)
|
| 655 |
+
return inputs_embeds
|
| 656 |
+
|
| 657 |
+
def forward(
|
| 658 |
+
self,
|
| 659 |
+
input_ids: torch.Tensor,
|
| 660 |
+
positions: torch.Tensor,
|
| 661 |
+
kv_caches: List[torch.Tensor],
|
| 662 |
+
attn_metadata: AttentionMetadata,
|
| 663 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 664 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 665 |
+
**kwargs: object,
|
| 666 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 667 |
+
"""Run forward pass for LLaVA-1.5.
|
| 668 |
+
|
| 669 |
+
One key thing to understand is the `input_ids` already accounts for the
|
| 670 |
+
positions of the to-be-inserted image embeddings.
|
| 671 |
+
|
| 672 |
+
Concretely, consider a text prompt:
|
| 673 |
+
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
|
| 674 |
+
|
| 675 |
+
Tokenizer outputs:
|
| 676 |
+
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
|
| 677 |
+
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
|
| 678 |
+
|
| 679 |
+
To reserve space in KV cache, we have to insert placeholder tokens
|
| 680 |
+
before they are inputted to the model, so the input processor prepends
|
| 681 |
+
additional image tokens (denoted as `32000`), resulting in:
|
| 682 |
+
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
|
| 683 |
+
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
|
| 684 |
+
29901]`.
|
| 685 |
+
|
| 686 |
+
We insert 575 tokens so that including the original image token in the
|
| 687 |
+
input, there are a total of 576 (24 * 24) image tokens, which
|
| 688 |
+
corresponds to the number of image tokens inputted to the language
|
| 689 |
+
model, i.e. the number of image tokens outputted by the visual encoder.
|
| 690 |
+
|
| 691 |
+
This way, the `positions` and `attn_metadata` are consistent
|
| 692 |
+
with the `input_ids`.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
| 696 |
+
batch.
|
| 697 |
+
pixel_values: The pixels in each input image.
|
| 698 |
+
|
| 699 |
+
See also:
|
| 700 |
+
:class:`LlavaImageInputs`
|
| 701 |
+
"""
|
| 702 |
+
if intermediate_tensors is not None:
|
| 703 |
+
inputs_embeds = None
|
| 704 |
+
|
| 705 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 706 |
+
# condition is for v0 compatibility.
|
| 707 |
+
elif inputs_embeds is None:
|
| 708 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 709 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 710 |
+
vision_embeddings)
|
| 711 |
+
input_ids = None
|
| 712 |
+
|
| 713 |
+
hidden_states = self.language_model.model(input_ids,
|
| 714 |
+
positions,
|
| 715 |
+
kv_caches,
|
| 716 |
+
attn_metadata,
|
| 717 |
+
intermediate_tensors,
|
| 718 |
+
inputs_embeds=inputs_embeds)
|
| 719 |
+
|
| 720 |
+
return hidden_states
|
| 721 |
+
|
| 722 |
+
def compute_logits(
|
| 723 |
+
self,
|
| 724 |
+
hidden_states: torch.Tensor,
|
| 725 |
+
sampling_metadata: SamplingMetadata,
|
| 726 |
+
) -> Optional[torch.Tensor]:
|
| 727 |
+
return self.language_model.compute_logits(hidden_states,
|
| 728 |
+
sampling_metadata)
|
| 729 |
+
|
| 730 |
+
def sample(
|
| 731 |
+
self,
|
| 732 |
+
logits: torch.Tensor,
|
| 733 |
+
sampling_metadata: SamplingMetadata,
|
| 734 |
+
) -> Optional[SamplerOutput]:
|
| 735 |
+
return self.language_model.sample(logits, sampling_metadata)
|
| 736 |
+
|
| 737 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 738 |
+
torch.Tensor]]) -> Set[str]:
|
| 739 |
+
loader = AutoWeightsLoader(self)
|
| 740 |
+
return loader.load_weights(weights)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class MantisProcessingInfo(LlavaProcessingInfo):
|
| 744 |
+
|
| 745 |
+
def get_hf_processor(self):
|
| 746 |
+
hf_config = self.get_hf_config()
|
| 747 |
+
vision_info = self.get_vision_encoder_info()
|
| 748 |
+
|
| 749 |
+
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
|
| 750 |
+
# BUG: num_additional_image_tokens = 0 but treated as 1,
|
| 751 |
+
# so we set vision_feature_select_strategy to None to offset this
|
| 752 |
+
vision_feature_select_strategy = None
|
| 753 |
+
else:
|
| 754 |
+
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
|
| 755 |
+
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
|
| 756 |
+
|
| 757 |
+
return self.ctx.get_hf_processor(
|
| 758 |
+
LlavaProcessor,
|
| 759 |
+
patch_size=vision_info.get_patch_size(),
|
| 760 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
| 765 |
+
|
| 766 |
+
def apply(
|
| 767 |
+
self,
|
| 768 |
+
prompt: Union[str, list[int]],
|
| 769 |
+
mm_data: MultiModalDataDict,
|
| 770 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 771 |
+
) -> MultiModalInputs:
|
| 772 |
+
hf_config = self.info.get_hf_config()
|
| 773 |
+
image_token_id = hf_config.image_token_index
|
| 774 |
+
|
| 775 |
+
# Assume that it doesn't depend on the image size
|
| 776 |
+
num_image_tokens = self.info.get_num_image_tokens(
|
| 777 |
+
image_width=-1,
|
| 778 |
+
image_height=-1,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
| 782 |
+
|
| 783 |
+
mm_items = self._to_mm_items(mm_data)
|
| 784 |
+
mm_item_counts = mm_items.get_all_counts()
|
| 785 |
+
mm_kwargs = result["mm_kwargs"]
|
| 786 |
+
|
| 787 |
+
# We reimplement the functionality of MLlavaProcessor from
|
| 788 |
+
# https://github.com/TIGER-AI-Lab/Mantis.git
|
| 789 |
+
def get_replacement_mantis(item_idx: int):
|
| 790 |
+
return "".join([
|
| 791 |
+
f"(image {item_idx+1}: <Image>", # 7 tokens
|
| 792 |
+
"<image>" * num_image_tokens,
|
| 793 |
+
"</Image>)", # 3 tokens
|
| 794 |
+
])
|
| 795 |
+
|
| 796 |
+
mantis_mm_repls = self._bind_and_group_repls([
|
| 797 |
+
PromptReplacement(
|
| 798 |
+
modality="image",
|
| 799 |
+
target=[image_token_id] * num_image_tokens,
|
| 800 |
+
replacement=get_replacement_mantis,
|
| 801 |
+
)
|
| 802 |
+
])
|
| 803 |
+
|
| 804 |
+
prompt_ids, prompt, _ = self._apply_prompt_replacements(
|
| 805 |
+
result["prompt_token_ids"],
|
| 806 |
+
mantis_mm_repls,
|
| 807 |
+
mm_item_counts,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
unbound_orig_repls = self._get_prompt_replacements(
|
| 811 |
+
mm_items,
|
| 812 |
+
hf_processor_mm_kwargs,
|
| 813 |
+
mm_kwargs,
|
| 814 |
+
)
|
| 815 |
+
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
|
| 816 |
+
|
| 817 |
+
mm_placeholders = self._find_mm_placeholders(
|
| 818 |
+
orig_repls,
|
| 819 |
+
prompt_ids,
|
| 820 |
+
mm_item_counts,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
| 824 |
+
|
| 825 |
+
mm_placeholder_ranges = {
|
| 826 |
+
modality: [item.to_range() for item in placeholders]
|
| 827 |
+
for modality, placeholders in mm_placeholders.items()
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
return MultiModalInputs(
|
| 831 |
+
type="multimodal",
|
| 832 |
+
prompt=prompt,
|
| 833 |
+
prompt_token_ids=prompt_ids,
|
| 834 |
+
mm_kwargs=mm_kwargs,
|
| 835 |
+
mm_placeholders=mm_placeholder_ranges,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
# To use this model, please use
|
| 840 |
+
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
| 841 |
+
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
| 842 |
+
info=MantisProcessingInfo,
|
| 843 |
+
dummy_inputs=LlavaDummyInputsBuilder)
|
| 844 |
+
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
| 845 |
+
pass
|