Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__init__.py +24 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/baichuan.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/blip.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/fuyu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gemma.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gritlm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/idefics2_vision_model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/idefics3.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/interfaces_base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/internlm2_ve.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/internvl.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/mamba.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/mixtral_quant.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/nemotron.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/nvlm_d.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/olmo2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/orion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/phimoe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/qwen2_audio.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/qwen2_moe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/qwen2_vl.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/registry.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/ultravox.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/vision.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/whisper.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/adapters.py +250 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/aria.py +663 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/baichuan.py +493 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/blip.py +334 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/clip.py +545 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/commandr.py +488 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/dbrx.py +496 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/decilm.py +124 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek_v2.py +817 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek_vl2.py +650 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/exaone.py +578 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/fairseq2_llama.py +153 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma2.py +463 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/glm.py +22 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_j.py +358 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_neox.py +352 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/granite.py +520 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/gritlm.py +250 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/idefics2_vision_model.py +346 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/interfaces.py +443 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/interfaces_base.py +175 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/intern_vit.py +476 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/internlm2_ve.py +156 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/jais.py +397 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/llava_next_video.py +500 -0
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
| 4 |
+
SupportsPP, has_inner_state, supports_lora,
|
| 5 |
+
supports_multimodal, supports_pp)
|
| 6 |
+
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
| 7 |
+
is_pooling_model, is_text_generation_model)
|
| 8 |
+
from .registry import ModelRegistry
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"ModelRegistry",
|
| 12 |
+
"VllmModelForPooling",
|
| 13 |
+
"is_pooling_model",
|
| 14 |
+
"VllmModelForTextGeneration",
|
| 15 |
+
"is_text_generation_model",
|
| 16 |
+
"HasInnerState",
|
| 17 |
+
"has_inner_state",
|
| 18 |
+
"SupportsLoRA",
|
| 19 |
+
"supports_lora",
|
| 20 |
+
"SupportsMultiModal",
|
| 21 |
+
"supports_multimodal",
|
| 22 |
+
"SupportsPP",
|
| 23 |
+
"supports_pp",
|
| 24 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/baichuan.cpython-311.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/blip.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/fuyu.cpython-311.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gemma.cpython-311.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gritlm.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/idefics2_vision_model.cpython-311.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/idefics3.cpython-311.pyc
ADDED
|
Binary file (33.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/interfaces_base.cpython-311.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/internlm2_ve.cpython-311.pyc
ADDED
|
Binary file (7.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/internvl.cpython-311.pyc
ADDED
|
Binary file (42.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/mamba.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/mixtral_quant.cpython-311.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/nemotron.cpython-311.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/nvlm_d.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/olmo2.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/orion.cpython-311.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/phimoe.cpython-311.pyc
ADDED
|
Binary file (26.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/qwen2_audio.cpython-311.pyc
ADDED
|
Binary file (22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/qwen2_moe.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/qwen2_vl.cpython-311.pyc
ADDED
|
Binary file (65.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/ultravox.cpython-311.pyc
ADDED
|
Binary file (28.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/vision.cpython-311.pyc
ADDED
|
Binary file (7.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/whisper.cpython-311.pyc
ADDED
|
Binary file (36.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/adapters.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from vllm.model_executor.layers.pooler import PoolingType
|
| 13 |
+
|
| 14 |
+
_T = TypeVar("_T", bound=type[nn.Module])
|
| 15 |
+
|
| 16 |
+
_GENERATE_SUFFIXES = [
|
| 17 |
+
"ForCausalLM",
|
| 18 |
+
"ForConditionalGeneration",
|
| 19 |
+
"ChatModel",
|
| 20 |
+
"LMHeadModel",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
| 25 |
+
model_name = orig_model_name
|
| 26 |
+
|
| 27 |
+
for generate_suffix in _GENERATE_SUFFIXES:
|
| 28 |
+
model_name = model_name.removesuffix(generate_suffix)
|
| 29 |
+
|
| 30 |
+
return model_name + pooling_suffix
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _create_pooling_model_cls(
|
| 34 |
+
orig_cls: _T,
|
| 35 |
+
*,
|
| 36 |
+
default_pooling_type: "PoolingType",
|
| 37 |
+
default_normalize: bool,
|
| 38 |
+
default_softmax: bool,
|
| 39 |
+
) -> _T:
|
| 40 |
+
# Lazy import
|
| 41 |
+
from vllm.config import VllmConfig
|
| 42 |
+
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
|
| 43 |
+
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
| 44 |
+
|
| 45 |
+
from .utils import AutoWeightsLoader, WeightsMapper
|
| 46 |
+
|
| 47 |
+
class ModelForPooling(orig_cls, VllmModelForPooling):
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
*,
|
| 52 |
+
vllm_config: "VllmConfig",
|
| 53 |
+
prefix: str = "",
|
| 54 |
+
**kwargs: Any,
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
| 57 |
+
|
| 58 |
+
# These are not used in pooling models
|
| 59 |
+
for attr in ("lm_head", "logits_processor"):
|
| 60 |
+
if hasattr(self, attr):
|
| 61 |
+
delattr(self, attr)
|
| 62 |
+
|
| 63 |
+
pooler_config = vllm_config.model_config.pooler_config
|
| 64 |
+
assert pooler_config is not None
|
| 65 |
+
|
| 66 |
+
# If the model already defines a pooler instance, don't overwrite it
|
| 67 |
+
if not getattr(self, "_pooler", None):
|
| 68 |
+
self._pooler = Pooler.from_config_with_defaults(
|
| 69 |
+
pooler_config,
|
| 70 |
+
pooling_type=default_pooling_type,
|
| 71 |
+
normalize=default_normalize,
|
| 72 |
+
softmax=default_softmax,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def pooler(
|
| 76 |
+
self,
|
| 77 |
+
hidden_states: torch.Tensor,
|
| 78 |
+
pooling_metadata: PoolingMetadata,
|
| 79 |
+
) -> PoolerOutput:
|
| 80 |
+
return self._pooler(hidden_states, pooling_metadata)
|
| 81 |
+
|
| 82 |
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
| 83 |
+
# TODO: Support uninitialized params tracking
|
| 84 |
+
|
| 85 |
+
# We have deleted this attribute, so don't load it
|
| 86 |
+
weights = ((name, data) for name, data in weights
|
| 87 |
+
if not name.startswith("lm_head."))
|
| 88 |
+
|
| 89 |
+
# If `*ForCausalLM` defines `load_weights` on the inner model
|
| 90 |
+
# and there are no other inner modules with parameters,
|
| 91 |
+
# we support loading from both `*Model` and `*ForCausalLM`
|
| 92 |
+
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
|
| 93 |
+
# Whether only `self.model` contains parameters
|
| 94 |
+
model_is_only_param = all(
|
| 95 |
+
name == "model" or next(child.parameters(), None) is None
|
| 96 |
+
for name, child in self.named_children())
|
| 97 |
+
|
| 98 |
+
if model_is_only_param:
|
| 99 |
+
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
| 100 |
+
weights = mapper.apply(weights)
|
| 101 |
+
|
| 102 |
+
self.model.load_weights(weights)
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
# For most other models
|
| 106 |
+
if hasattr(orig_cls, "load_weights"):
|
| 107 |
+
orig_cls.load_weights(self, weights) # type: ignore
|
| 108 |
+
# Fallback
|
| 109 |
+
else:
|
| 110 |
+
loader = AutoWeightsLoader(self)
|
| 111 |
+
loader.load_weights(weights)
|
| 112 |
+
|
| 113 |
+
return ModelForPooling # type: ignore
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def as_embedding_model(cls: _T) -> _T:
|
| 117 |
+
"""
|
| 118 |
+
Subclass an existing vLLM model to support embeddings.
|
| 119 |
+
|
| 120 |
+
By default, the embeddings of the whole prompt are extracted from the
|
| 121 |
+
normalized hidden state corresponding to the last token.
|
| 122 |
+
|
| 123 |
+
Note:
|
| 124 |
+
We assume that no extra layers are added to the original model;
|
| 125 |
+
please implement your own model if this is not the case.
|
| 126 |
+
"""
|
| 127 |
+
# Avoid modifying existing embedding models
|
| 128 |
+
if is_pooling_model(cls):
|
| 129 |
+
return cls
|
| 130 |
+
|
| 131 |
+
# Lazy import
|
| 132 |
+
from vllm.model_executor.layers.pooler import PoolingType
|
| 133 |
+
|
| 134 |
+
ModelForEmbedding = _create_pooling_model_cls(
|
| 135 |
+
cls,
|
| 136 |
+
default_pooling_type=PoolingType.LAST,
|
| 137 |
+
default_normalize=True,
|
| 138 |
+
default_softmax=False,
|
| 139 |
+
)
|
| 140 |
+
ModelForEmbedding.__name__ = \
|
| 141 |
+
_get_pooling_model_name(cls.__name__, "ForEmbedding")
|
| 142 |
+
|
| 143 |
+
return ModelForEmbedding # type: ignore
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def as_classification_model(cls: _T) -> _T:
|
| 147 |
+
"""
|
| 148 |
+
Subclass an existing vLLM model to support classification.
|
| 149 |
+
|
| 150 |
+
By default, the class probabilities are extracted from the softmaxed
|
| 151 |
+
hidden state corresponding to the last token.
|
| 152 |
+
|
| 153 |
+
Note:
|
| 154 |
+
We assume that the classification head is a single linear layer
|
| 155 |
+
stored as the attribute `score` of the top-level model;
|
| 156 |
+
please implement your own model if this is not the case.
|
| 157 |
+
"""
|
| 158 |
+
# Avoid modifying existing classification models
|
| 159 |
+
if is_pooling_model(cls):
|
| 160 |
+
return cls
|
| 161 |
+
|
| 162 |
+
# Lazy import
|
| 163 |
+
from vllm.attention import AttentionMetadata
|
| 164 |
+
from vllm.config import VllmConfig
|
| 165 |
+
from vllm.model_executor.layers.linear import RowParallelLinear
|
| 166 |
+
from vllm.model_executor.layers.pooler import PoolingType
|
| 167 |
+
from vllm.sequence import IntermediateTensors
|
| 168 |
+
|
| 169 |
+
from .utils import maybe_prefix
|
| 170 |
+
|
| 171 |
+
ModelForPooling = _create_pooling_model_cls(
|
| 172 |
+
cls,
|
| 173 |
+
default_pooling_type=PoolingType.LAST,
|
| 174 |
+
default_normalize=False,
|
| 175 |
+
default_softmax=True,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
class ModelForClassification(ModelForPooling):
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
*,
|
| 183 |
+
vllm_config: "VllmConfig",
|
| 184 |
+
prefix: str = "",
|
| 185 |
+
**kwargs: Any,
|
| 186 |
+
) -> None:
|
| 187 |
+
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
| 188 |
+
|
| 189 |
+
config = vllm_config.model_config.hf_config
|
| 190 |
+
quant_config = vllm_config.quant_config
|
| 191 |
+
|
| 192 |
+
self.score = RowParallelLinear(config.hidden_size,
|
| 193 |
+
config.num_labels,
|
| 194 |
+
quant_config=quant_config,
|
| 195 |
+
input_is_parallel=False,
|
| 196 |
+
bias=False,
|
| 197 |
+
prefix=maybe_prefix(
|
| 198 |
+
prefix, "score"))
|
| 199 |
+
|
| 200 |
+
def forward(
|
| 201 |
+
self,
|
| 202 |
+
input_ids: torch.Tensor,
|
| 203 |
+
positions: torch.Tensor,
|
| 204 |
+
kv_caches: list[torch.Tensor],
|
| 205 |
+
attn_metadata: AttentionMetadata,
|
| 206 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 207 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 208 |
+
) -> torch.Tensor:
|
| 209 |
+
hidden_states = super().forward(input_ids, positions, kv_caches,
|
| 210 |
+
attn_metadata,
|
| 211 |
+
intermediate_tensors,
|
| 212 |
+
inputs_embeds)
|
| 213 |
+
logits, _ = self.score(hidden_states)
|
| 214 |
+
return logits
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
ModelForClassification.__name__ = \
|
| 218 |
+
_get_pooling_model_name(cls.__name__, "ForClassification")
|
| 219 |
+
|
| 220 |
+
return ModelForClassification # type: ignore
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def as_reward_model(cls: _T) -> _T:
|
| 224 |
+
"""
|
| 225 |
+
Subclass an existing vLLM model to support reward modeling.
|
| 226 |
+
|
| 227 |
+
By default, we return the hidden states of each token directly.
|
| 228 |
+
|
| 229 |
+
Note:
|
| 230 |
+
We assume that no extra layers are added to the original model;
|
| 231 |
+
please implement your own model if this is not the case.
|
| 232 |
+
"""
|
| 233 |
+
# Avoid modifying existing reward models
|
| 234 |
+
if is_pooling_model(cls):
|
| 235 |
+
return cls
|
| 236 |
+
|
| 237 |
+
# Lazy import
|
| 238 |
+
from vllm.model_executor.layers.pooler import PoolingType
|
| 239 |
+
|
| 240 |
+
ModelForReward = _create_pooling_model_cls(
|
| 241 |
+
cls,
|
| 242 |
+
default_pooling_type=PoolingType.ALL,
|
| 243 |
+
default_normalize=False,
|
| 244 |
+
default_softmax=False,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
ModelForReward.__name__ = \
|
| 248 |
+
_get_pooling_model_name(cls.__name__, "ForReward")
|
| 249 |
+
|
| 250 |
+
return ModelForReward # type: ignore
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/aria.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
| 4 |
+
Union)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
| 9 |
+
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
| 10 |
+
from transformers.models.aria.processing_aria import AriaProcessor
|
| 11 |
+
|
| 12 |
+
from vllm.attention import AttentionMetadata
|
| 13 |
+
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
| 14 |
+
from vllm.distributed import get_tensor_model_parallel_rank
|
| 15 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 16 |
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
| 17 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 18 |
+
RowParallelLinear)
|
| 19 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 20 |
+
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
| 21 |
+
SamplingMetadata, get_sampler)
|
| 22 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
| 23 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 24 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 25 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 26 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 27 |
+
NestedTensors)
|
| 28 |
+
from vllm.multimodal.parse import MultiModalDataItems
|
| 29 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 30 |
+
BaseProcessingInfo, PromptReplacement)
|
| 31 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 32 |
+
from vllm.sequence import IntermediateTensors
|
| 33 |
+
|
| 34 |
+
# yapf: disable
|
| 35 |
+
from .idefics2_vision_model import Idefics2VisionConfig
|
| 36 |
+
from .idefics2_vision_model import (
|
| 37 |
+
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
| 38 |
+
# yapf: enable
|
| 39 |
+
from .interfaces import SupportsMultiModal
|
| 40 |
+
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
| 41 |
+
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
| 42 |
+
is_pp_missing_parameter, maybe_prefix,
|
| 43 |
+
merge_multimodal_embeddings)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AriaImagePixelInputs(TypedDict):
|
| 47 |
+
pixel_values: torch.Tensor
|
| 48 |
+
pixel_mask: Optional[torch.Tensor]
|
| 49 |
+
"""
|
| 50 |
+
Shape:
|
| 51 |
+
pixel_values: `(batch_size * num_images, num_channels, height, width)`
|
| 52 |
+
pixel_mask: `(batch_size * num_images, height, width)`
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AriaVisionTransformer(Idefics3VisionTransformer):
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
config: Idefics2VisionConfig,
|
| 61 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 62 |
+
prefix: str = "",
|
| 63 |
+
) -> None:
|
| 64 |
+
super().__init__(config, quant_config, prefix)
|
| 65 |
+
# Unlike Idefics3VisionTransformer which uses LayerNorm after the
|
| 66 |
+
# final layer, Aria omits this normalization, so we replace it with an
|
| 67 |
+
# Identity layer
|
| 68 |
+
self.post_layernorm = nn.Identity()
|
| 69 |
+
|
| 70 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 71 |
+
torch.Tensor]]) -> Set[str]:
|
| 72 |
+
stacked_params_mapping = [
|
| 73 |
+
# (param_name, shard_name, shard_id)
|
| 74 |
+
("qkv_proj", "q_proj", "q"),
|
| 75 |
+
("qkv_proj", "k_proj", "k"),
|
| 76 |
+
("qkv_proj", "v_proj", "v"),
|
| 77 |
+
]
|
| 78 |
+
params_dict = dict(self.named_parameters())
|
| 79 |
+
loaded_params: Set[str] = set()
|
| 80 |
+
for name, loaded_weight in weights:
|
| 81 |
+
|
| 82 |
+
# NOTE: post_layernorm is not used in Aria
|
| 83 |
+
if "post_layernorm" in name:
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 87 |
+
if weight_name not in name:
|
| 88 |
+
continue
|
| 89 |
+
name = name.replace(weight_name, param_name)
|
| 90 |
+
param = params_dict[name]
|
| 91 |
+
weight_loader = param.weight_loader
|
| 92 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 93 |
+
break
|
| 94 |
+
else:
|
| 95 |
+
param = params_dict[name]
|
| 96 |
+
weight_loader = getattr(param, "weight_loader",
|
| 97 |
+
default_weight_loader)
|
| 98 |
+
weight_loader(param, loaded_weight)
|
| 99 |
+
loaded_params.add(name)
|
| 100 |
+
return loaded_params
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class AriaProjectorMLP(nn.Module):
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
in_features: int,
|
| 108 |
+
hidden_features: int,
|
| 109 |
+
output_dim: int,
|
| 110 |
+
) -> None:
|
| 111 |
+
super().__init__()
|
| 112 |
+
|
| 113 |
+
self.linear_in = ColumnParallelLinear(in_features,
|
| 114 |
+
hidden_features,
|
| 115 |
+
bias=False)
|
| 116 |
+
self.linear_out = RowParallelLinear(hidden_features,
|
| 117 |
+
output_dim,
|
| 118 |
+
bias=False)
|
| 119 |
+
self.act = get_act_fn("gelu_new")
|
| 120 |
+
|
| 121 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
hidden_states, _ = self.linear_in(hidden_states)
|
| 123 |
+
hidden_states = self.act(hidden_states)
|
| 124 |
+
hidden_states, _ = self.linear_out(hidden_states)
|
| 125 |
+
return hidden_states
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class AriaProjector(nn.Module):
|
| 129 |
+
"""
|
| 130 |
+
A projection module with one cross attention layer and one FFN layer, which
|
| 131 |
+
projects ViT's outputs into MoE's inputs.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
patch_to_query_dict (dict): Maps patch numbers to their corresponding
|
| 135 |
+
query numbers,
|
| 136 |
+
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
|
| 137 |
+
based on image resolution.
|
| 138 |
+
embed_dim (int): Embedding dimension.
|
| 139 |
+
num_heads (int): Number of attention heads.
|
| 140 |
+
kv_dim (int): Dimension of key and value.
|
| 141 |
+
ff_dim (int): Hidden dimension of the feed-forward network.
|
| 142 |
+
output_dim (int): Output dimension.
|
| 143 |
+
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
|
| 144 |
+
|
| 145 |
+
Outputs:
|
| 146 |
+
A tensor with the shape of (batch_size, query_number, output_dim)
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, config: AriaConfig) -> None:
|
| 150 |
+
super().__init__()
|
| 151 |
+
|
| 152 |
+
self.patch_to_query_dict = config.projector_patch_to_query_dict
|
| 153 |
+
self.in_features = config.vision_config.hidden_size
|
| 154 |
+
self.num_heads = config.vision_config.num_attention_heads
|
| 155 |
+
self.kv_dim = config.vision_config.hidden_size
|
| 156 |
+
self.hidden_features = config.text_config.hidden_size
|
| 157 |
+
self.output_dim = config.text_config.hidden_size
|
| 158 |
+
|
| 159 |
+
self.query = nn.Parameter(
|
| 160 |
+
torch.empty(config.max_value_projector_patch_to_query_dict,
|
| 161 |
+
self.in_features))
|
| 162 |
+
|
| 163 |
+
self.cross_attn = AriaCrossAttention(config)
|
| 164 |
+
|
| 165 |
+
self.layer_norm = nn.LayerNorm(self.in_features)
|
| 166 |
+
self.feed_forward = AriaProjectorMLP(self.in_features,
|
| 167 |
+
self.hidden_features,
|
| 168 |
+
self.output_dim)
|
| 169 |
+
|
| 170 |
+
def forward(
|
| 171 |
+
self,
|
| 172 |
+
x: torch.Tensor,
|
| 173 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
batch_size, num_patches = x.shape[0], x.shape[1]
|
| 176 |
+
|
| 177 |
+
if num_patches not in self.patch_to_query_dict:
|
| 178 |
+
raise KeyError(f"Number of patches {num_patches} not found in "
|
| 179 |
+
"patch_to_query_dict amongst possible values "
|
| 180 |
+
f"{self.patch_to_query_dict.keys()}.")
|
| 181 |
+
|
| 182 |
+
query_num = self.patch_to_query_dict[num_patches]
|
| 183 |
+
|
| 184 |
+
queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
|
| 185 |
+
|
| 186 |
+
if attn_mask is not None:
|
| 187 |
+
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
|
| 188 |
+
attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
|
| 189 |
+
|
| 190 |
+
attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)
|
| 191 |
+
|
| 192 |
+
out = self.feed_forward(self.layer_norm(attention_out))
|
| 193 |
+
|
| 194 |
+
return out
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class AriaFusedMoE(FusedMoE):
|
| 198 |
+
|
| 199 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
| 200 |
+
shard_id: str) -> None:
|
| 201 |
+
# Override the weight_loader to handle the expert weights in the Aria
|
| 202 |
+
# model, which are already packed with experts, and merge the gate and
|
| 203 |
+
# up weights for each expert.
|
| 204 |
+
# Note: Loading expert weights with quantization is not supported
|
| 205 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 206 |
+
if shard_id == 'w13':
|
| 207 |
+
# the shape of loaded_weight is
|
| 208 |
+
# (num_experts, hidden_size, 2 * moe_intermediate_size)
|
| 209 |
+
if self.tp_size > 1:
|
| 210 |
+
up, gate = loaded_weight.chunk(2, dim=-1)
|
| 211 |
+
up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
|
| 212 |
+
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
|
| 213 |
+
up_and_gate = torch.cat([up_current_rank, gate_current_rank],
|
| 214 |
+
dim=-1).transpose(1, 2)
|
| 215 |
+
param.data.copy_(up_and_gate)
|
| 216 |
+
else:
|
| 217 |
+
param.data.copy_(loaded_weight.transpose(1, 2))
|
| 218 |
+
elif shard_id == 'w2':
|
| 219 |
+
# the shape of loaded_weight is
|
| 220 |
+
# (num_experts, moe_intermediate_size, hidden_size)
|
| 221 |
+
if self.tp_size > 1:
|
| 222 |
+
down_current_rank = loaded_weight.chunk(self.tp_size,
|
| 223 |
+
dim=1)[tp_rank]
|
| 224 |
+
param.data.copy_(down_current_rank.transpose(1, 2))
|
| 225 |
+
else:
|
| 226 |
+
param.data.copy_(loaded_weight.transpose(1, 2))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class AriaTextMoELayer(nn.Module):
|
| 230 |
+
"""
|
| 231 |
+
Mixture of Experts (MoE) Layer for the AriaMoE model.
|
| 232 |
+
|
| 233 |
+
This layer implements the MoE mechanism, which routes input tokens to
|
| 234 |
+
different experts based on a routing algorithm, processes them through the
|
| 235 |
+
experts, and then combines the outputs.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
config: AriaTextConfig,
|
| 241 |
+
quant_config: Optional[QuantizationConfig],
|
| 242 |
+
) -> None:
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.config = config
|
| 245 |
+
|
| 246 |
+
self.router_weight = nn.Parameter(
|
| 247 |
+
torch.empty(
|
| 248 |
+
(self.config.moe_num_experts, self.config.hidden_size)))
|
| 249 |
+
|
| 250 |
+
self.experts = AriaFusedMoE(
|
| 251 |
+
num_experts=config.moe_num_experts,
|
| 252 |
+
top_k=config.moe_topk,
|
| 253 |
+
hidden_size=config.hidden_size,
|
| 254 |
+
intermediate_size=config.intermediate_size,
|
| 255 |
+
quant_config=quant_config,
|
| 256 |
+
reduce_results=True,
|
| 257 |
+
)
|
| 258 |
+
self.shared_experts = LlamaMLP(
|
| 259 |
+
config.hidden_size,
|
| 260 |
+
config.intermediate_size * config.moe_num_shared_experts,
|
| 261 |
+
"silu",
|
| 262 |
+
quant_config=quant_config,
|
| 263 |
+
bias=config.mlp_bias,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 267 |
+
"""
|
| 268 |
+
Forward pass of the MoE Layer.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
hidden_states (torch.Tensor): Input tensor of shape (batch_size,
|
| 272 |
+
sequence_length, hidden_size).
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
torch.Tensor: Output tensor after passing through the MoE layer.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
router_output = torch.nn.functional.linear(hidden_states,
|
| 279 |
+
self.router_weight)
|
| 280 |
+
|
| 281 |
+
hidden_states_copy = hidden_states.clone()
|
| 282 |
+
# NOTE: hidden_states will be modified inplace by `FusedMoE`
|
| 283 |
+
sparse_expert_output = self.experts(hidden_states, router_output)
|
| 284 |
+
shared_expert_output = self.shared_experts(hidden_states_copy)
|
| 285 |
+
|
| 286 |
+
return sparse_expert_output + shared_expert_output
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class AriaTextDecoderLayer(LlamaDecoderLayer):
|
| 290 |
+
"""
|
| 291 |
+
Custom Decoder Layer for the AriaMoE model which modifies the standard
|
| 292 |
+
`LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
|
| 293 |
+
Experts (MoE) Layer.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
config: AriaTextConfig,
|
| 299 |
+
cache_config: Optional[CacheConfig] = None,
|
| 300 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 301 |
+
prefix: str = "",
|
| 302 |
+
) -> None:
|
| 303 |
+
super().__init__(config, cache_config, quant_config, prefix)
|
| 304 |
+
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class AriaTextModel(LlamaModel):
|
| 308 |
+
"""
|
| 309 |
+
Custom LlamaModel for the AriaMoE model which modifies the standard
|
| 310 |
+
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 314 |
+
super().__init__(vllm_config=vllm_config,
|
| 315 |
+
prefix=prefix,
|
| 316 |
+
layer_type=AriaTextDecoderLayer)
|
| 317 |
+
|
| 318 |
+
# Adapted from LlamaModel.load_weights with the modification of adding
|
| 319 |
+
# the expert weights mapping to `stacked_params_mapping`
|
| 320 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 321 |
+
torch.Tensor]]) -> Set[str]:
|
| 322 |
+
stacked_params_mapping = [
|
| 323 |
+
# (param_name, shard_name, shard_id)
|
| 324 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 325 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 326 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 327 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 328 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 329 |
+
("experts.w13_weight", "experts.fc1.weight", 'w13'),
|
| 330 |
+
("experts.w2_weight", "experts.fc2.weight", 'w2'),
|
| 331 |
+
]
|
| 332 |
+
params_dict = dict(self.named_parameters())
|
| 333 |
+
loaded_params: Set[str] = set()
|
| 334 |
+
for name, loaded_weight in weights:
|
| 335 |
+
if "rotary_emb.inv_freq" in name:
|
| 336 |
+
continue
|
| 337 |
+
if ("rotary_emb.cos_cached" in name
|
| 338 |
+
or "rotary_emb.sin_cached" in name):
|
| 339 |
+
# Models trained using ColossalAI may include these tensors in
|
| 340 |
+
# the checkpoint. Skip them.
|
| 341 |
+
continue
|
| 342 |
+
if (self.quant_config is not None and
|
| 343 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 344 |
+
# Loading kv cache quantization scales
|
| 345 |
+
param = params_dict[scale_name]
|
| 346 |
+
weight_loader = getattr(param, "weight_loader",
|
| 347 |
+
default_weight_loader)
|
| 348 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 349 |
+
loaded_weight[0])
|
| 350 |
+
weight_loader(param, loaded_weight)
|
| 351 |
+
loaded_params.add(scale_name)
|
| 352 |
+
continue
|
| 353 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 354 |
+
if weight_name not in name:
|
| 355 |
+
continue
|
| 356 |
+
name = name.replace(weight_name, param_name)
|
| 357 |
+
# Skip loading extra bias for GPTQ models.
|
| 358 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
if is_pp_missing_parameter(name, self):
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
param = params_dict[name]
|
| 365 |
+
weight_loader = param.weight_loader
|
| 366 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 367 |
+
break
|
| 368 |
+
else:
|
| 369 |
+
# Skip loading extra bias for GPTQ models.
|
| 370 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 371 |
+
continue
|
| 372 |
+
# Remapping the name of FP8 kv-scale.
|
| 373 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 374 |
+
if name is None:
|
| 375 |
+
continue
|
| 376 |
+
|
| 377 |
+
if is_pp_missing_parameter(name, self):
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
param = params_dict[name]
|
| 381 |
+
weight_loader = getattr(param, "weight_loader",
|
| 382 |
+
default_weight_loader)
|
| 383 |
+
weight_loader(param, loaded_weight)
|
| 384 |
+
loaded_params.add(name)
|
| 385 |
+
return loaded_params
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class AriaProcessingInfo(BaseProcessingInfo):
|
| 389 |
+
|
| 390 |
+
def get_hf_config(self):
|
| 391 |
+
return self.ctx.get_hf_config(AriaConfig)
|
| 392 |
+
|
| 393 |
+
def get_vision_config(self):
|
| 394 |
+
return self.get_hf_config().vision_config
|
| 395 |
+
|
| 396 |
+
def get_hf_processor(self):
|
| 397 |
+
return self.ctx.get_hf_processor(AriaProcessor)
|
| 398 |
+
|
| 399 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 400 |
+
return {"image": None}
|
| 401 |
+
|
| 402 |
+
def get_mm_max_tokens_per_item(
|
| 403 |
+
self,
|
| 404 |
+
seq_len: int,
|
| 405 |
+
mm_counts: Mapping[str, int],
|
| 406 |
+
) -> Mapping[str, int]:
|
| 407 |
+
return {"image": self.get_num_image_tokens()}
|
| 408 |
+
|
| 409 |
+
def get_num_image_tokens(self) -> int:
|
| 410 |
+
hf_config = self.get_hf_config()
|
| 411 |
+
return max(hf_config.projector_patch_to_query_dict.values())
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
|
| 415 |
+
|
| 416 |
+
def get_dummy_processor_inputs(
|
| 417 |
+
self,
|
| 418 |
+
seq_len: int,
|
| 419 |
+
mm_counts: Mapping[str, int],
|
| 420 |
+
) -> ProcessorInputs:
|
| 421 |
+
vision_config = self.info.get_vision_config()
|
| 422 |
+
|
| 423 |
+
max_image_size = vision_config.image_size
|
| 424 |
+
num_images = mm_counts.get("image", 0)
|
| 425 |
+
|
| 426 |
+
mm_data = {
|
| 427 |
+
"image":
|
| 428 |
+
self._get_dummy_images(width=max_image_size,
|
| 429 |
+
height=max_image_size,
|
| 430 |
+
num_images=num_images)
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
hf_processor = self.info.get_hf_processor()
|
| 434 |
+
image_token: str = hf_processor.tokenizer.image_token # type: ignore
|
| 435 |
+
|
| 436 |
+
return ProcessorInputs(
|
| 437 |
+
prompt_text=image_token * num_images,
|
| 438 |
+
mm_data=mm_data,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
|
| 443 |
+
|
| 444 |
+
def _get_mm_fields_config(
|
| 445 |
+
self,
|
| 446 |
+
hf_inputs: BatchFeature,
|
| 447 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 448 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 449 |
+
return dict(
|
| 450 |
+
pixel_values=MultiModalFieldConfig.batched("image"),
|
| 451 |
+
pixel_mask=MultiModalFieldConfig.batched("image"),
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
def _get_prompt_replacements(
|
| 455 |
+
self,
|
| 456 |
+
mm_items: MultiModalDataItems,
|
| 457 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 458 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 459 |
+
) -> list[PromptReplacement]:
|
| 460 |
+
hf_config = self.info.get_hf_config()
|
| 461 |
+
image_token_id = hf_config.image_token_index
|
| 462 |
+
|
| 463 |
+
num_image_tokens = self.info.get_num_image_tokens()
|
| 464 |
+
|
| 465 |
+
return [
|
| 466 |
+
PromptReplacement(
|
| 467 |
+
modality="image",
|
| 468 |
+
target=[image_token_id],
|
| 469 |
+
replacement=[image_token_id] * num_image_tokens,
|
| 470 |
+
)
|
| 471 |
+
]
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
|
| 475 |
+
info=AriaProcessingInfo,
|
| 476 |
+
dummy_inputs=AriaDummyInputsBuilder)
|
| 477 |
+
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
| 478 |
+
"""
|
| 479 |
+
Aria model for conditional generation tasks.
|
| 480 |
+
|
| 481 |
+
This model combines a vision tower, a multi-modal projector, and a language
|
| 482 |
+
model to perform tasks that involve both image and text inputs.
|
| 483 |
+
"""
|
| 484 |
+
hf_to_vllm_mapper = WeightsMapper(
|
| 485 |
+
orig_to_new_prefix={
|
| 486 |
+
"language_model.model": "language_model",
|
| 487 |
+
"language_model.lm_head": "lm_head",
|
| 488 |
+
},
|
| 489 |
+
orig_to_new_suffix={
|
| 490 |
+
"router.weight": "router_weight",
|
| 491 |
+
},
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
def __init__(
|
| 495 |
+
self,
|
| 496 |
+
vllm_config: VllmConfig,
|
| 497 |
+
prefix: str = "",
|
| 498 |
+
):
|
| 499 |
+
super().__init__()
|
| 500 |
+
config = vllm_config.model_config.hf_config
|
| 501 |
+
quant_config = vllm_config.quant_config
|
| 502 |
+
|
| 503 |
+
self.config = config
|
| 504 |
+
self.vision_tower = AriaVisionTransformer(
|
| 505 |
+
config.vision_config,
|
| 506 |
+
quant_config,
|
| 507 |
+
prefix=f"{prefix}.vision_tower",
|
| 508 |
+
)
|
| 509 |
+
self.multi_modal_projector = AriaProjector(config)
|
| 510 |
+
self.vocab_size = config.text_config.vocab_size
|
| 511 |
+
self.language_model = AriaTextModel(
|
| 512 |
+
vllm_config=vllm_config.with_hf_config(config.text_config),
|
| 513 |
+
prefix=maybe_prefix(prefix, "language_model.model"),
|
| 514 |
+
)
|
| 515 |
+
self.pad_token_id = (self.config.pad_token_id
|
| 516 |
+
if self.config.pad_token_id is not None else -1)
|
| 517 |
+
self.unpadded_vocab_size = config.text_config.vocab_size
|
| 518 |
+
self.lm_head = ParallelLMHead(
|
| 519 |
+
self.unpadded_vocab_size,
|
| 520 |
+
config.text_config.hidden_size,
|
| 521 |
+
org_num_embeddings=self.language_model.org_vocab_size,
|
| 522 |
+
quant_config=quant_config,
|
| 523 |
+
)
|
| 524 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 525 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 526 |
+
self.vocab_size, logit_scale)
|
| 527 |
+
self.sampler = get_sampler()
|
| 528 |
+
|
| 529 |
+
def _validate_image_sizes(
|
| 530 |
+
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 531 |
+
if not all(img.shape == images[0].shape for img in images):
|
| 532 |
+
raise ValueError("All images must be the same size")
|
| 533 |
+
return images
|
| 534 |
+
|
| 535 |
+
def _parse_and_validate_image_input(
|
| 536 |
+
self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
|
| 537 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 538 |
+
pixel_mask = kwargs.pop("pixel_mask", None)
|
| 539 |
+
|
| 540 |
+
if pixel_values is None:
|
| 541 |
+
return None
|
| 542 |
+
|
| 543 |
+
if not isinstance(pixel_values, (torch.Tensor, list)):
|
| 544 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 545 |
+
f"Got type: {type(pixel_values)}")
|
| 546 |
+
|
| 547 |
+
pixel_values = self._validate_image_sizes(pixel_values)
|
| 548 |
+
pixel_values = flatten_bn(pixel_values, concat=True)
|
| 549 |
+
|
| 550 |
+
if pixel_mask is not None:
|
| 551 |
+
if not isinstance(pixel_mask, (torch.Tensor, list)):
|
| 552 |
+
raise ValueError("Incorrect type of pixel mask. "
|
| 553 |
+
f"Got type: {type(pixel_mask)}")
|
| 554 |
+
|
| 555 |
+
pixel_mask = flatten_bn(pixel_mask, concat=True)
|
| 556 |
+
|
| 557 |
+
return AriaImagePixelInputs(
|
| 558 |
+
pixel_values=pixel_values,
|
| 559 |
+
pixel_mask=pixel_mask,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def _create_patch_attention_mask(
|
| 563 |
+
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 564 |
+
if pixel_mask is None:
|
| 565 |
+
return None
|
| 566 |
+
|
| 567 |
+
patches_subgrid = pixel_mask.unfold(
|
| 568 |
+
dimension=1,
|
| 569 |
+
size=self.vision_tower.config.patch_size,
|
| 570 |
+
step=self.vision_tower.config.patch_size,
|
| 571 |
+
).unfold(
|
| 572 |
+
dimension=2,
|
| 573 |
+
size=self.vision_tower.config.patch_size,
|
| 574 |
+
step=self.vision_tower.config.patch_size,
|
| 575 |
+
)
|
| 576 |
+
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
| 577 |
+
|
| 578 |
+
def _process_image_input(
|
| 579 |
+
self, image_input: AriaImagePixelInputs
|
| 580 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 581 |
+
assert self.vision_tower is not None
|
| 582 |
+
|
| 583 |
+
pixel_values = image_input['pixel_values']
|
| 584 |
+
pixel_mask = image_input['pixel_mask']
|
| 585 |
+
|
| 586 |
+
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
| 587 |
+
|
| 588 |
+
image_outputs = self.vision_tower(
|
| 589 |
+
pixel_values=pixel_values,
|
| 590 |
+
patch_attention_mask=patch_attention_mask,
|
| 591 |
+
)
|
| 592 |
+
image_attn_mask = None
|
| 593 |
+
if patch_attention_mask is not None:
|
| 594 |
+
flattened_mask = patch_attention_mask.flatten(1)
|
| 595 |
+
image_attn_mask = torch.logical_not(flattened_mask)
|
| 596 |
+
|
| 597 |
+
return self.multi_modal_projector(image_outputs, image_attn_mask)
|
| 598 |
+
|
| 599 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 600 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 601 |
+
if image_input is None:
|
| 602 |
+
return None
|
| 603 |
+
multimodal_embeddings = self._process_image_input(image_input)
|
| 604 |
+
return multimodal_embeddings
|
| 605 |
+
|
| 606 |
+
def get_input_embeddings(
|
| 607 |
+
self,
|
| 608 |
+
input_ids: torch.Tensor,
|
| 609 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 610 |
+
) -> torch.Tensor:
|
| 611 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 612 |
+
if multimodal_embeddings is not None:
|
| 613 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 614 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 615 |
+
self.config.image_token_index)
|
| 616 |
+
return inputs_embeds
|
| 617 |
+
|
| 618 |
+
def forward(
|
| 619 |
+
self,
|
| 620 |
+
input_ids: torch.Tensor,
|
| 621 |
+
positions: torch.Tensor,
|
| 622 |
+
kv_caches: List[torch.Tensor],
|
| 623 |
+
attn_metadata: AttentionMetadata,
|
| 624 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 625 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 626 |
+
**kwargs: object,
|
| 627 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 628 |
+
if inputs_embeds is None:
|
| 629 |
+
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 630 |
+
# always pass the input via `inputs_embeds`
|
| 631 |
+
# to make sure the computation graph is consistent
|
| 632 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 633 |
+
multimodal_embeddings)
|
| 634 |
+
input_ids = None
|
| 635 |
+
|
| 636 |
+
hidden_states = self.language_model(
|
| 637 |
+
input_ids,
|
| 638 |
+
positions,
|
| 639 |
+
kv_caches,
|
| 640 |
+
attn_metadata,
|
| 641 |
+
intermediate_tensors,
|
| 642 |
+
inputs_embeds=inputs_embeds,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
return hidden_states
|
| 646 |
+
|
| 647 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
| 648 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
| 649 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 650 |
+
sampling_metadata)
|
| 651 |
+
return logits
|
| 652 |
+
|
| 653 |
+
def sample(
|
| 654 |
+
self,
|
| 655 |
+
logits: torch.Tensor,
|
| 656 |
+
sampling_metadata: SamplingMetadata,
|
| 657 |
+
) -> Optional[SamplerOutput]:
|
| 658 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 659 |
+
return next_tokens
|
| 660 |
+
|
| 661 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 662 |
+
loader = AutoWeightsLoader(self)
|
| 663 |
+
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/baichuan.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
| 22 |
+
import math
|
| 23 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn
|
| 27 |
+
from transformers import PretrainedConfig
|
| 28 |
+
|
| 29 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 30 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 31 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 32 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 33 |
+
get_tensor_model_parallel_world_size)
|
| 34 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 35 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 36 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 37 |
+
QKVParallelLinear,
|
| 38 |
+
RowParallelLinear)
|
| 39 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 40 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 41 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 42 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 43 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 44 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 45 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 46 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 47 |
+
from vllm.sequence import IntermediateTensors
|
| 48 |
+
|
| 49 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 50 |
+
from .utils import (is_pp_missing_parameter,
|
| 51 |
+
make_empty_intermediate_tensors_factory, make_layers)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
| 55 |
+
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
| 56 |
+
base = torch.tensor(
|
| 57 |
+
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
| 58 |
+
dtype=torch.float32,
|
| 59 |
+
)
|
| 60 |
+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
| 61 |
+
slopes = torch.pow(base, powers)
|
| 62 |
+
|
| 63 |
+
if closest_power_of_2 != total_num_heads:
|
| 64 |
+
extra_base = torch.tensor(
|
| 65 |
+
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
| 66 |
+
dtype=torch.float32,
|
| 67 |
+
)
|
| 68 |
+
num_remaining_heads = min(closest_power_of_2,
|
| 69 |
+
total_num_heads - closest_power_of_2)
|
| 70 |
+
extra_powers = torch.arange(start=1,
|
| 71 |
+
end=1 + 2 * num_remaining_heads,
|
| 72 |
+
step=2,
|
| 73 |
+
dtype=torch.int32)
|
| 74 |
+
slopes = torch.cat(
|
| 75 |
+
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 76 |
+
return slopes
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class BaiChuanMLP(nn.Module):
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
hidden_size: int,
|
| 84 |
+
intermediate_size: int,
|
| 85 |
+
hidden_act: str,
|
| 86 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 90 |
+
hidden_size, [intermediate_size] * 2,
|
| 91 |
+
bias=False,
|
| 92 |
+
quant_config=quant_config)
|
| 93 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 94 |
+
hidden_size,
|
| 95 |
+
bias=False,
|
| 96 |
+
quant_config=quant_config)
|
| 97 |
+
if hidden_act != "silu":
|
| 98 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 99 |
+
"Only silu is supported for now.")
|
| 100 |
+
self.act_fn = SiluAndMul()
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 104 |
+
x = self.act_fn(gate_up)
|
| 105 |
+
x, _ = self.down_proj(x)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class BaiChuanAttention(nn.Module):
|
| 110 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
hidden_size: int,
|
| 115 |
+
num_heads: int,
|
| 116 |
+
position_embedding: str,
|
| 117 |
+
rope_theta: float = 10000,
|
| 118 |
+
max_position_embeddings: int = 8192,
|
| 119 |
+
cache_config: Optional[CacheConfig] = None,
|
| 120 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 121 |
+
prefix: str = "",
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.hidden_size = hidden_size
|
| 125 |
+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
| 126 |
+
)
|
| 127 |
+
self.total_num_heads = num_heads
|
| 128 |
+
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
| 129 |
+
self.num_heads = (self.total_num_heads //
|
| 130 |
+
tensor_model_parallel_world_size)
|
| 131 |
+
self.head_dim = hidden_size // self.total_num_heads
|
| 132 |
+
self.postion_embedding = position_embedding
|
| 133 |
+
self.rope_theta = rope_theta
|
| 134 |
+
self.max_position_embeddings = max_position_embeddings
|
| 135 |
+
|
| 136 |
+
# pylint: disable=invalid-name
|
| 137 |
+
self.W_pack = QKVParallelLinear(
|
| 138 |
+
hidden_size,
|
| 139 |
+
self.head_dim,
|
| 140 |
+
self.total_num_heads,
|
| 141 |
+
self.total_num_heads,
|
| 142 |
+
bias=False,
|
| 143 |
+
quant_config=quant_config,
|
| 144 |
+
)
|
| 145 |
+
self.o_proj = RowParallelLinear(
|
| 146 |
+
self.total_num_heads * self.head_dim,
|
| 147 |
+
hidden_size,
|
| 148 |
+
bias=False,
|
| 149 |
+
quant_config=quant_config,
|
| 150 |
+
)
|
| 151 |
+
# Create the alibi slopes and slice them.
|
| 152 |
+
if self.postion_embedding == "ALIBI":
|
| 153 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 154 |
+
head_start = tp_rank * self.num_heads
|
| 155 |
+
head_end = (tp_rank + 1) * self.num_heads
|
| 156 |
+
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
| 157 |
+
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
| 158 |
+
|
| 159 |
+
scaling = self.head_dim**-0.5
|
| 160 |
+
self.attn = Attention(self.num_heads,
|
| 161 |
+
self.head_dim,
|
| 162 |
+
scaling,
|
| 163 |
+
alibi_slopes=alibi_slopes,
|
| 164 |
+
quant_config=quant_config,
|
| 165 |
+
prefix=f"{prefix}.attn")
|
| 166 |
+
else:
|
| 167 |
+
self.rotary_emb = get_rope(
|
| 168 |
+
self.head_dim,
|
| 169 |
+
rotary_dim=self.head_dim,
|
| 170 |
+
max_position=self.max_position_embeddings,
|
| 171 |
+
base=self.rope_theta,
|
| 172 |
+
)
|
| 173 |
+
self.scaling = self.head_dim**-0.5
|
| 174 |
+
self.attn = Attention(self.num_heads,
|
| 175 |
+
self.head_dim,
|
| 176 |
+
self.scaling,
|
| 177 |
+
cache_config=cache_config,
|
| 178 |
+
quant_config=quant_config,
|
| 179 |
+
prefix=f"{prefix}.attn")
|
| 180 |
+
|
| 181 |
+
def forward(
|
| 182 |
+
self,
|
| 183 |
+
positions: torch.Tensor,
|
| 184 |
+
hidden_states: torch.Tensor,
|
| 185 |
+
kv_cache: torch.Tensor,
|
| 186 |
+
attn_metadata: AttentionMetadata,
|
| 187 |
+
) -> torch.Tensor:
|
| 188 |
+
qkv, _ = self.W_pack(hidden_states)
|
| 189 |
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
| 190 |
+
if self.postion_embedding != "ALIBI":
|
| 191 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 192 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 193 |
+
output, _ = self.o_proj(attn_output)
|
| 194 |
+
return output
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class BaiChuanDecoderLayer(nn.Module):
|
| 198 |
+
|
| 199 |
+
def __init__(self,
|
| 200 |
+
config: PretrainedConfig,
|
| 201 |
+
position_embedding: str,
|
| 202 |
+
cache_config: Optional[CacheConfig] = None,
|
| 203 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 204 |
+
prefix: str = ""):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.hidden_size = config.hidden_size
|
| 207 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 208 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 209 |
+
8192)
|
| 210 |
+
self.self_attn = BaiChuanAttention(
|
| 211 |
+
hidden_size=self.hidden_size,
|
| 212 |
+
num_heads=config.num_attention_heads,
|
| 213 |
+
position_embedding=position_embedding,
|
| 214 |
+
rope_theta=rope_theta,
|
| 215 |
+
max_position_embeddings=max_position_embeddings,
|
| 216 |
+
cache_config=cache_config,
|
| 217 |
+
quant_config=quant_config,
|
| 218 |
+
prefix=f"{prefix}.self_attn",
|
| 219 |
+
)
|
| 220 |
+
self.mlp = BaiChuanMLP(
|
| 221 |
+
hidden_size=self.hidden_size,
|
| 222 |
+
intermediate_size=config.intermediate_size,
|
| 223 |
+
hidden_act=config.hidden_act,
|
| 224 |
+
quant_config=quant_config,
|
| 225 |
+
)
|
| 226 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 227 |
+
eps=config.rms_norm_eps)
|
| 228 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 229 |
+
eps=config.rms_norm_eps)
|
| 230 |
+
|
| 231 |
+
def forward(
|
| 232 |
+
self,
|
| 233 |
+
positions: torch.Tensor,
|
| 234 |
+
hidden_states: torch.Tensor,
|
| 235 |
+
kv_cache: torch.Tensor,
|
| 236 |
+
attn_metadata: AttentionMetadata,
|
| 237 |
+
residual: Optional[torch.Tensor],
|
| 238 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 239 |
+
# Self Attention
|
| 240 |
+
if residual is None:
|
| 241 |
+
residual = hidden_states
|
| 242 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 243 |
+
else:
|
| 244 |
+
hidden_states, residual = self.input_layernorm(
|
| 245 |
+
hidden_states, residual)
|
| 246 |
+
hidden_states = self.self_attn(
|
| 247 |
+
positions=positions,
|
| 248 |
+
hidden_states=hidden_states,
|
| 249 |
+
kv_cache=kv_cache,
|
| 250 |
+
attn_metadata=attn_metadata,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Fully Connected
|
| 254 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 255 |
+
hidden_states, residual)
|
| 256 |
+
hidden_states = self.mlp(hidden_states)
|
| 257 |
+
return hidden_states, residual
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@support_torch_compile
|
| 261 |
+
class BaiChuanModel(nn.Module):
|
| 262 |
+
|
| 263 |
+
def __init__(
|
| 264 |
+
self,
|
| 265 |
+
vllm_config: VllmConfig,
|
| 266 |
+
prefix: str = "",
|
| 267 |
+
position_embedding: str = "ROPE",
|
| 268 |
+
) -> None:
|
| 269 |
+
super().__init__()
|
| 270 |
+
|
| 271 |
+
config = vllm_config.model_config.hf_config
|
| 272 |
+
cache_config = vllm_config.cache_config
|
| 273 |
+
quant_config = vllm_config.quant_config
|
| 274 |
+
|
| 275 |
+
self.config = config
|
| 276 |
+
self.padding_idx = config.pad_token_id
|
| 277 |
+
self.vocab_size = config.vocab_size
|
| 278 |
+
|
| 279 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 280 |
+
config.vocab_size,
|
| 281 |
+
config.hidden_size,
|
| 282 |
+
)
|
| 283 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 284 |
+
config.num_hidden_layers,
|
| 285 |
+
lambda prefix: BaiChuanDecoderLayer(config,
|
| 286 |
+
position_embedding,
|
| 287 |
+
cache_config,
|
| 288 |
+
quant_config,
|
| 289 |
+
prefix=prefix),
|
| 290 |
+
prefix=f"{prefix}.layers",
|
| 291 |
+
)
|
| 292 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 293 |
+
self.make_empty_intermediate_tensors = (
|
| 294 |
+
make_empty_intermediate_tensors_factory(
|
| 295 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 296 |
+
|
| 297 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 298 |
+
return self.embed_tokens(input_ids)
|
| 299 |
+
|
| 300 |
+
def forward(
|
| 301 |
+
self,
|
| 302 |
+
input_ids: torch.Tensor,
|
| 303 |
+
positions: torch.Tensor,
|
| 304 |
+
kv_caches: List[torch.Tensor],
|
| 305 |
+
attn_metadata: AttentionMetadata,
|
| 306 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 307 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 308 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 309 |
+
if get_pp_group().is_first_rank:
|
| 310 |
+
if inputs_embeds is not None:
|
| 311 |
+
hidden_states = inputs_embeds
|
| 312 |
+
else:
|
| 313 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 314 |
+
residual = None
|
| 315 |
+
else:
|
| 316 |
+
assert intermediate_tensors is not None
|
| 317 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 318 |
+
residual = intermediate_tensors["residual"]
|
| 319 |
+
for i in range(self.start_layer, self.end_layer):
|
| 320 |
+
layer = self.layers[i]
|
| 321 |
+
hidden_states, residual = layer(
|
| 322 |
+
positions,
|
| 323 |
+
hidden_states,
|
| 324 |
+
kv_caches[i - self.start_layer],
|
| 325 |
+
attn_metadata,
|
| 326 |
+
residual,
|
| 327 |
+
)
|
| 328 |
+
if not get_pp_group().is_last_rank:
|
| 329 |
+
return IntermediateTensors({
|
| 330 |
+
"hidden_states": hidden_states,
|
| 331 |
+
"residual": residual,
|
| 332 |
+
})
|
| 333 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 334 |
+
return hidden_states
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 338 |
+
packed_modules_mapping = {
|
| 339 |
+
"W_pack": ["W_pack"],
|
| 340 |
+
"gate_up_proj": [
|
| 341 |
+
"gate_proj",
|
| 342 |
+
"up_proj",
|
| 343 |
+
],
|
| 344 |
+
}
|
| 345 |
+
# LoRA specific attributes
|
| 346 |
+
supported_lora_modules = [
|
| 347 |
+
"W_pack",
|
| 348 |
+
"o_proj",
|
| 349 |
+
"gate_up_proj",
|
| 350 |
+
"down_proj",
|
| 351 |
+
]
|
| 352 |
+
embedding_modules = {}
|
| 353 |
+
embedding_padding_modules = []
|
| 354 |
+
|
| 355 |
+
def __init__(
|
| 356 |
+
self,
|
| 357 |
+
*,
|
| 358 |
+
vllm_config: VllmConfig,
|
| 359 |
+
prefix: str = "",
|
| 360 |
+
position_embedding: str = "ROPE",
|
| 361 |
+
):
|
| 362 |
+
super().__init__()
|
| 363 |
+
config = vllm_config.model_config.hf_config
|
| 364 |
+
quant_config = vllm_config.quant_config
|
| 365 |
+
lora_config = vllm_config.lora_config
|
| 366 |
+
self.config = config
|
| 367 |
+
self.lora_config = lora_config
|
| 368 |
+
|
| 369 |
+
self.quant_config = quant_config
|
| 370 |
+
self.model = BaiChuanModel(vllm_config=vllm_config,
|
| 371 |
+
prefix=prefix,
|
| 372 |
+
position_embedding=position_embedding)
|
| 373 |
+
self.lm_head = ParallelLMHead(config.vocab_size,
|
| 374 |
+
config.hidden_size,
|
| 375 |
+
quant_config=quant_config)
|
| 376 |
+
if self.config.tie_word_embeddings:
|
| 377 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 378 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 379 |
+
self.sampler = get_sampler()
|
| 380 |
+
self.make_empty_intermediate_tensors = (
|
| 381 |
+
self.model.make_empty_intermediate_tensors)
|
| 382 |
+
|
| 383 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 384 |
+
return self.model.get_input_embeddings(input_ids)
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
input_ids: torch.Tensor,
|
| 389 |
+
positions: torch.Tensor,
|
| 390 |
+
kv_caches: List[torch.Tensor],
|
| 391 |
+
attn_metadata: AttentionMetadata,
|
| 392 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 393 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 394 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 395 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 396 |
+
attn_metadata, intermediate_tensors,
|
| 397 |
+
inputs_embeds)
|
| 398 |
+
return hidden_states
|
| 399 |
+
|
| 400 |
+
def compute_logits(
|
| 401 |
+
self,
|
| 402 |
+
hidden_states: torch.Tensor,
|
| 403 |
+
sampling_metadata: SamplingMetadata,
|
| 404 |
+
) -> Optional[torch.Tensor]:
|
| 405 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 406 |
+
sampling_metadata)
|
| 407 |
+
return logits
|
| 408 |
+
|
| 409 |
+
def sample(
|
| 410 |
+
self,
|
| 411 |
+
logits: torch.Tensor,
|
| 412 |
+
sampling_metadata: SamplingMetadata,
|
| 413 |
+
) -> Optional[SamplerOutput]:
|
| 414 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 415 |
+
return next_tokens
|
| 416 |
+
|
| 417 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 418 |
+
torch.Tensor]]) -> Set[str]:
|
| 419 |
+
stacked_params_mapping = [
|
| 420 |
+
# (param_name, shard_name, shard_id)
|
| 421 |
+
("gate_up_proj", "gate_proj", 0),
|
| 422 |
+
("gate_up_proj", "up_proj", 1),
|
| 423 |
+
]
|
| 424 |
+
params_dict = dict(self.named_parameters())
|
| 425 |
+
loaded_params: Set[str] = set()
|
| 426 |
+
for name, loaded_weight in weights:
|
| 427 |
+
if "rotary_emb.inv_freq" in name:
|
| 428 |
+
continue
|
| 429 |
+
if name == "lm_head.weight":
|
| 430 |
+
# Unlike Baichuan, Baichuan2 normalizes the head weights.
|
| 431 |
+
# Refer to:
|
| 432 |
+
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
| 433 |
+
# Distinguish between Baichuan and Baichuan2 by checking the
|
| 434 |
+
# vocab size. This is suggested by
|
| 435 |
+
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
| 436 |
+
is_baichuan2 = self.config.vocab_size == 125696
|
| 437 |
+
if is_baichuan2:
|
| 438 |
+
loaded_weight = torch.nn.functional.normalize(
|
| 439 |
+
loaded_weight)
|
| 440 |
+
|
| 441 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 442 |
+
if weight_name not in name:
|
| 443 |
+
continue
|
| 444 |
+
name = name.replace(weight_name, param_name)
|
| 445 |
+
# Skip loading extra bias for GPTQ models.
|
| 446 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 447 |
+
continue
|
| 448 |
+
if is_pp_missing_parameter(name, self):
|
| 449 |
+
continue
|
| 450 |
+
param = params_dict[name]
|
| 451 |
+
weight_loader = param.weight_loader
|
| 452 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 453 |
+
break
|
| 454 |
+
else:
|
| 455 |
+
# Skip loading extra bias for GPTQ models.
|
| 456 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 457 |
+
continue
|
| 458 |
+
if is_pp_missing_parameter(name, self):
|
| 459 |
+
continue
|
| 460 |
+
param = params_dict[name]
|
| 461 |
+
weight_loader = getattr(param, "weight_loader",
|
| 462 |
+
default_weight_loader)
|
| 463 |
+
weight_loader(param, loaded_weight)
|
| 464 |
+
loaded_params.add(name)
|
| 465 |
+
return loaded_params
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
| 469 |
+
"""Baichuan 13B and Baichuan2 7B/13B.
|
| 470 |
+
NOTE: the class name has a lower case 'c'.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 474 |
+
config = vllm_config.model_config.hf_config
|
| 475 |
+
if config.hidden_size == 4096: # baichuan2 7b
|
| 476 |
+
super().__init__(vllm_config=vllm_config,
|
| 477 |
+
prefix=prefix,
|
| 478 |
+
position_embedding="ROPE")
|
| 479 |
+
else: # baichuan 13b, baichuan2 13b
|
| 480 |
+
super().__init__(vllm_config=vllm_config,
|
| 481 |
+
prefix=prefix,
|
| 482 |
+
position_embedding="ALIBI")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
| 486 |
+
"""Baichuan 7B.
|
| 487 |
+
NOTE: the class name has an upper case 'C'.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 491 |
+
super().__init__(vllm_config=vllm_config,
|
| 492 |
+
prefix=prefix,
|
| 493 |
+
position_embedding="ROPE")
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/blip.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Minimal implementation of BlipVisionModel intended to be only used
|
| 3 |
+
within a vision language model."""
|
| 4 |
+
from typing import Iterable, Optional, Set, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import Blip2VisionConfig, BlipVisionConfig
|
| 9 |
+
|
| 10 |
+
from vllm.attention.layer import MultiHeadAttention
|
| 11 |
+
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
| 12 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 13 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 14 |
+
QKVParallelLinear,
|
| 15 |
+
RowParallelLinear)
|
| 16 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 17 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
| 21 |
+
assert image_size % patch_size == 0
|
| 22 |
+
return image_size // patch_size
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
|
| 26 |
+
grid_length = get_blip_patch_grid_length(image_size=image_size,
|
| 27 |
+
patch_size=patch_size)
|
| 28 |
+
return grid_length * grid_length
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
|
| 32 |
+
class BlipVisionEmbeddings(nn.Module):
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.config = config
|
| 38 |
+
self.embed_dim = config.hidden_size
|
| 39 |
+
self.image_size = config.image_size
|
| 40 |
+
self.patch_size = config.patch_size
|
| 41 |
+
|
| 42 |
+
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
| 43 |
+
|
| 44 |
+
self.patch_embedding = nn.Conv2d(
|
| 45 |
+
in_channels=3,
|
| 46 |
+
out_channels=self.embed_dim,
|
| 47 |
+
kernel_size=self.patch_size,
|
| 48 |
+
stride=self.patch_size,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.num_patches = get_blip_num_patches(image_size=self.image_size,
|
| 52 |
+
patch_size=self.patch_size)
|
| 53 |
+
self.num_positions = self.num_patches + 1
|
| 54 |
+
|
| 55 |
+
self.position_embedding = nn.Parameter(
|
| 56 |
+
torch.randn(1, self.num_positions, self.embed_dim))
|
| 57 |
+
|
| 58 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
batch_size = pixel_values.shape[0]
|
| 60 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 61 |
+
patch_embeds = self.patch_embedding(pixel_values.to(
|
| 62 |
+
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
| 63 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 64 |
+
|
| 65 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
| 66 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 67 |
+
|
| 68 |
+
position_embeds = self.position_embedding.to(target_dtype)
|
| 69 |
+
embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
|
| 70 |
+
|
| 71 |
+
return embeddings
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class BlipAttention(nn.Module):
|
| 75 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
config: Union[BlipVisionConfig, Blip2VisionConfig],
|
| 80 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 81 |
+
prefix: str = "",
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.config = config
|
| 85 |
+
self.embed_dim = config.hidden_size
|
| 86 |
+
self.num_heads = config.num_attention_heads
|
| 87 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 88 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
"embed_dim must be divisible by num_heads "
|
| 91 |
+
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 92 |
+
f" {self.num_heads}).")
|
| 93 |
+
self.scale = self.head_dim**-0.5
|
| 94 |
+
self.dropout = config.attention_dropout
|
| 95 |
+
|
| 96 |
+
self.qkv = QKVParallelLinear(
|
| 97 |
+
self.embed_dim,
|
| 98 |
+
self.head_dim,
|
| 99 |
+
self.num_heads,
|
| 100 |
+
bias=config.qkv_bias,
|
| 101 |
+
quant_config=quant_config,
|
| 102 |
+
prefix=f"{prefix}.qkv",
|
| 103 |
+
)
|
| 104 |
+
self.projection = RowParallelLinear(
|
| 105 |
+
self.embed_dim,
|
| 106 |
+
self.embed_dim,
|
| 107 |
+
quant_config=quant_config,
|
| 108 |
+
prefix=f"{prefix}.projection",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 112 |
+
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
| 113 |
+
|
| 114 |
+
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
| 115 |
+
self.head_dim, self.scale)
|
| 116 |
+
|
| 117 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 118 |
+
return tensor.view(bsz, seq_len, self.num_heads,
|
| 119 |
+
self.head_dim).transpose(1, 2).contiguous()
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
hidden_states: torch.Tensor,
|
| 124 |
+
):
|
| 125 |
+
"""Input shape: Batch x Time x Channel"""
|
| 126 |
+
|
| 127 |
+
qkv_states, _ = self.qkv(hidden_states)
|
| 128 |
+
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
| 129 |
+
out = self.attn(query_states, key_states, value_states)
|
| 130 |
+
attn_output, _ = self.projection(out)
|
| 131 |
+
|
| 132 |
+
return attn_output, None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class BlipMLP(nn.Module):
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
config: BlipVisionConfig,
|
| 140 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 141 |
+
prefix: str = "",
|
| 142 |
+
) -> None:
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.config = config
|
| 146 |
+
|
| 147 |
+
self.activation_fn = get_act_fn(config.hidden_act)
|
| 148 |
+
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
| 149 |
+
config.intermediate_size,
|
| 150 |
+
bias=True,
|
| 151 |
+
quant_config=quant_config,
|
| 152 |
+
prefix=f"{prefix}.fc1")
|
| 153 |
+
self.fc2 = RowParallelLinear(config.intermediate_size,
|
| 154 |
+
config.hidden_size,
|
| 155 |
+
bias=True,
|
| 156 |
+
quant_config=quant_config,
|
| 157 |
+
prefix=f"{prefix}.fc2")
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
hidden_states, _ = self.fc1(hidden_states)
|
| 161 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 162 |
+
hidden_states, _ = self.fc2(hidden_states)
|
| 163 |
+
|
| 164 |
+
return hidden_states
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class BlipEncoderLayer(nn.Module):
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
config: BlipVisionConfig,
|
| 172 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 173 |
+
prefix: str = "",
|
| 174 |
+
) -> None:
|
| 175 |
+
super().__init__()
|
| 176 |
+
|
| 177 |
+
# fallback to sdpa attention if tp unavailable
|
| 178 |
+
self.self_attn = BlipAttention(
|
| 179 |
+
config,
|
| 180 |
+
quant_config=quant_config,
|
| 181 |
+
prefix=f"{prefix}.self_attn",
|
| 182 |
+
)
|
| 183 |
+
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
| 184 |
+
eps=config.layer_norm_eps)
|
| 185 |
+
self.mlp = BlipMLP(config,
|
| 186 |
+
quant_config=quant_config,
|
| 187 |
+
prefix=f"{prefix}.mlp")
|
| 188 |
+
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
| 189 |
+
eps=config.layer_norm_eps)
|
| 190 |
+
|
| 191 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
residual = hidden_states
|
| 193 |
+
|
| 194 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 195 |
+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
| 196 |
+
hidden_states = residual + hidden_states
|
| 197 |
+
|
| 198 |
+
residual = hidden_states
|
| 199 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 200 |
+
hidden_states = self.mlp(hidden_states)
|
| 201 |
+
hidden_states = residual + hidden_states
|
| 202 |
+
|
| 203 |
+
return hidden_states
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class BlipEncoder(nn.Module):
|
| 207 |
+
"""
|
| 208 |
+
Transformer encoder consisting of `config.num_hidden_layers` self
|
| 209 |
+
attention layers. Each layer is a [`BlipEncoderLayer`].
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
config: BlipConfig
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
config: BlipVisionConfig,
|
| 218 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 219 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 220 |
+
prefix: str = "",
|
| 221 |
+
) -> None:
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.config = config
|
| 225 |
+
|
| 226 |
+
if num_hidden_layers_override is None:
|
| 227 |
+
num_hidden_layers = config.num_hidden_layers
|
| 228 |
+
else:
|
| 229 |
+
num_hidden_layers = num_hidden_layers_override
|
| 230 |
+
|
| 231 |
+
self.layers = nn.ModuleList([
|
| 232 |
+
BlipEncoderLayer(config=config,
|
| 233 |
+
quant_config=quant_config,
|
| 234 |
+
prefix=f"{prefix}.layers.{layer_idx}")
|
| 235 |
+
for layer_idx in range(num_hidden_layers)
|
| 236 |
+
])
|
| 237 |
+
|
| 238 |
+
def forward(self, inputs_embeds: torch.Tensor):
|
| 239 |
+
hidden_states = inputs_embeds
|
| 240 |
+
for encoder_layer in self.layers:
|
| 241 |
+
hidden_states = encoder_layer(hidden_states)
|
| 242 |
+
|
| 243 |
+
return hidden_states
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class BlipVisionModel(nn.Module):
|
| 247 |
+
config_class = BlipVisionConfig
|
| 248 |
+
main_input_name = "pixel_values"
|
| 249 |
+
|
| 250 |
+
def __init__(
|
| 251 |
+
self,
|
| 252 |
+
config: BlipVisionConfig,
|
| 253 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 254 |
+
*,
|
| 255 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 256 |
+
require_post_norm: Optional[bool] = None,
|
| 257 |
+
prefix: str = "",
|
| 258 |
+
) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.config = config
|
| 261 |
+
|
| 262 |
+
self.embeddings = BlipVisionEmbeddings(config)
|
| 263 |
+
self.encoder = BlipEncoder(
|
| 264 |
+
config=config,
|
| 265 |
+
quant_config=quant_config,
|
| 266 |
+
num_hidden_layers_override=num_hidden_layers_override,
|
| 267 |
+
prefix=f"{prefix}.encoder",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
num_hidden_layers = config.num_hidden_layers
|
| 271 |
+
if len(self.encoder.layers) > config.num_hidden_layers:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"The original encoder only has {num_hidden_layers} "
|
| 274 |
+
f"layers, but you requested {len(self.encoder.layers)} layers."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# If possible, skip post_layernorm to conserve memory
|
| 278 |
+
if require_post_norm is None:
|
| 279 |
+
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
| 280 |
+
|
| 281 |
+
if require_post_norm:
|
| 282 |
+
self.post_layernorm = nn.LayerNorm(config.hidden_size,
|
| 283 |
+
eps=config.layer_norm_eps)
|
| 284 |
+
else:
|
| 285 |
+
self.post_layernorm = None
|
| 286 |
+
|
| 287 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 288 |
+
hidden_states = self.embeddings(pixel_values)
|
| 289 |
+
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
| 290 |
+
|
| 291 |
+
if self.post_layernorm is None:
|
| 292 |
+
return hidden_states
|
| 293 |
+
|
| 294 |
+
return self.post_layernorm(hidden_states)
|
| 295 |
+
|
| 296 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 297 |
+
torch.Tensor]]) -> Set[str]:
|
| 298 |
+
stacked_params_mapping = [
|
| 299 |
+
# (param_name, shard_name, shard_id)
|
| 300 |
+
("qkv_proj", "q_proj", "q"),
|
| 301 |
+
("qkv_proj", "k_proj", "k"),
|
| 302 |
+
("qkv_proj", "v_proj", "v"),
|
| 303 |
+
]
|
| 304 |
+
params_dict = dict(self.named_parameters())
|
| 305 |
+
loaded_params: Set[str] = set()
|
| 306 |
+
layer_count = len(self.encoder.layers)
|
| 307 |
+
|
| 308 |
+
for name, loaded_weight in weights:
|
| 309 |
+
# post_layernorm is not needed in BlipVisionModel
|
| 310 |
+
if (name.startswith("post_layernorm")
|
| 311 |
+
and self.post_layernorm is None):
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
# omit layers when num_hidden_layers_override is set
|
| 315 |
+
if name.startswith("encoder.layers"):
|
| 316 |
+
layer_idx = int(name.split(".")[2])
|
| 317 |
+
if layer_idx >= layer_count:
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 321 |
+
if weight_name not in name:
|
| 322 |
+
continue
|
| 323 |
+
name = name.replace(weight_name, param_name)
|
| 324 |
+
param = params_dict[name]
|
| 325 |
+
weight_loader = param.weight_loader
|
| 326 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 327 |
+
break
|
| 328 |
+
else:
|
| 329 |
+
param = params_dict[name]
|
| 330 |
+
weight_loader = getattr(param, "weight_loader",
|
| 331 |
+
default_weight_loader)
|
| 332 |
+
weight_loader(param, loaded_weight)
|
| 333 |
+
loaded_params.add(name)
|
| 334 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/clip.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Minimal implementation of CLIPVisionModel intended to be only used
|
| 3 |
+
within a vision language model."""
|
| 4 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from transformers import CLIPVisionConfig
|
| 11 |
+
|
| 12 |
+
from vllm.attention.layer import MultiHeadAttention
|
| 13 |
+
from vllm.config import ModelConfig
|
| 14 |
+
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
| 15 |
+
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
| 16 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 17 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 18 |
+
QKVParallelLinear,
|
| 19 |
+
RowParallelLinear)
|
| 20 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 21 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 22 |
+
from vllm.multimodal.utils import (cached_get_tokenizer,
|
| 23 |
+
consecutive_placeholder_ranges,
|
| 24 |
+
repeat_and_pad_placeholder_tokens)
|
| 25 |
+
from vllm.sequence import SequenceData
|
| 26 |
+
|
| 27 |
+
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
| 31 |
+
assert image_size % patch_size == 0
|
| 32 |
+
return image_size // patch_size
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
|
| 36 |
+
grid_length = get_clip_patch_grid_length(image_size=image_size,
|
| 37 |
+
patch_size=patch_size)
|
| 38 |
+
return grid_length * grid_length
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
| 42 |
+
return get_clip_num_patches(image_size=hf_config.image_size,
|
| 43 |
+
patch_size=hf_config.patch_size) + 1
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
|
| 47 |
+
return get_clip_image_feature_size(hf_config)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
|
| 51 |
+
seq_len: int,
|
| 52 |
+
num_images: int,
|
| 53 |
+
*,
|
| 54 |
+
image_token_id: int,
|
| 55 |
+
image_feature_size_override: Optional[int] = None,
|
| 56 |
+
mm_key: str = "image"):
|
| 57 |
+
if image_feature_size_override is None:
|
| 58 |
+
image_feature_size = get_clip_image_feature_size(hf_config)
|
| 59 |
+
else:
|
| 60 |
+
image_feature_size = image_feature_size_override
|
| 61 |
+
|
| 62 |
+
return SequenceData.from_prompt_token_counts(
|
| 63 |
+
(image_token_id, image_feature_size * num_images),
|
| 64 |
+
(0, seq_len - image_feature_size * num_images),
|
| 65 |
+
), {
|
| 66 |
+
mm_key:
|
| 67 |
+
consecutive_placeholder_ranges(num_items=num_images,
|
| 68 |
+
item_size=image_feature_size)
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def dummy_image_for_clip(
|
| 73 |
+
hf_config: CLIPVisionConfig,
|
| 74 |
+
num_images: int,
|
| 75 |
+
*,
|
| 76 |
+
image_width_override: Optional[int] = None,
|
| 77 |
+
image_height_override: Optional[int] = None,
|
| 78 |
+
):
|
| 79 |
+
width = height = hf_config.image_size
|
| 80 |
+
if image_width_override is not None:
|
| 81 |
+
width = image_width_override
|
| 82 |
+
if image_height_override is not None:
|
| 83 |
+
height = image_height_override
|
| 84 |
+
|
| 85 |
+
image = Image.new("RGB", (width, height), color=0)
|
| 86 |
+
return {"image": image if num_images == 1 else [image] * num_images}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def dummy_video_for_clip(
|
| 90 |
+
hf_config: CLIPVisionConfig,
|
| 91 |
+
num_frames: int,
|
| 92 |
+
num_videos: int = 1,
|
| 93 |
+
*,
|
| 94 |
+
image_width_override: Optional[int] = None,
|
| 95 |
+
image_height_override: Optional[int] = None,
|
| 96 |
+
):
|
| 97 |
+
pil_frame = dummy_image_for_clip(
|
| 98 |
+
hf_config,
|
| 99 |
+
num_images=1,
|
| 100 |
+
image_width_override=image_width_override,
|
| 101 |
+
image_height_override=image_height_override)
|
| 102 |
+
np_frame = np.array(pil_frame["image"])
|
| 103 |
+
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
|
| 104 |
+
video_data = [mm_data_per_video] * num_videos
|
| 105 |
+
mm_data = {"video": video_data}
|
| 106 |
+
return mm_data
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def input_processor_for_clip(
|
| 110 |
+
model_config: ModelConfig,
|
| 111 |
+
hf_config: CLIPVisionConfig,
|
| 112 |
+
inputs: DecoderOnlyInputs,
|
| 113 |
+
*,
|
| 114 |
+
image_token_id: int,
|
| 115 |
+
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
| 116 |
+
):
|
| 117 |
+
multi_modal_data = inputs.get("multi_modal_data")
|
| 118 |
+
if multi_modal_data is None or "image" not in multi_modal_data:
|
| 119 |
+
return inputs
|
| 120 |
+
|
| 121 |
+
if "multi_modal_placeholders" in inputs and "image" in inputs[
|
| 122 |
+
"multi_modal_placeholders"]:
|
| 123 |
+
# The inputs already have placeholders.
|
| 124 |
+
return inputs
|
| 125 |
+
|
| 126 |
+
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
| 127 |
+
|
| 128 |
+
if image_feature_size_override is None:
|
| 129 |
+
image_data = multi_modal_data["image"]
|
| 130 |
+
if isinstance(image_data, Image.Image):
|
| 131 |
+
image_feature_size = get_clip_image_feature_size(hf_config)
|
| 132 |
+
elif isinstance(image_data, torch.Tensor):
|
| 133 |
+
num_images, image_feature_size, hidden_size = image_data.shape
|
| 134 |
+
else:
|
| 135 |
+
raise TypeError(f"Invalid image type: {type(image_data)}")
|
| 136 |
+
else:
|
| 137 |
+
image_feature_size = image_feature_size_override
|
| 138 |
+
|
| 139 |
+
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
| 140 |
+
tokenizer,
|
| 141 |
+
inputs.get("prompt"),
|
| 142 |
+
inputs["prompt_token_ids"],
|
| 143 |
+
placeholder_token_id=image_token_id,
|
| 144 |
+
repeat_count=image_feature_size,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# NOTE: Create a defensive copy of the original inputs
|
| 148 |
+
return token_inputs(prompt_token_ids=new_token_ids,
|
| 149 |
+
prompt=new_prompt,
|
| 150 |
+
multi_modal_data=multi_modal_data,
|
| 151 |
+
multi_modal_placeholders={"image": ranges})
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
|
| 155 |
+
|
| 156 |
+
def get_num_image_tokens(
|
| 157 |
+
self,
|
| 158 |
+
*,
|
| 159 |
+
image_width: int,
|
| 160 |
+
image_height: int,
|
| 161 |
+
) -> int:
|
| 162 |
+
return get_clip_image_feature_size(self.vision_config)
|
| 163 |
+
|
| 164 |
+
def get_max_image_tokens(self) -> int:
|
| 165 |
+
return get_max_clip_image_tokens(self.vision_config)
|
| 166 |
+
|
| 167 |
+
def get_image_size(self) -> int:
|
| 168 |
+
return self.vision_config.image_size
|
| 169 |
+
|
| 170 |
+
def get_patch_size(self) -> int:
|
| 171 |
+
return self.vision_config.patch_size
|
| 172 |
+
|
| 173 |
+
def get_patch_grid_length(self) -> int:
|
| 174 |
+
return get_clip_patch_grid_length(
|
| 175 |
+
image_size=self.vision_config.image_size,
|
| 176 |
+
patch_size=self.vision_config.patch_size,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
| 181 |
+
class CLIPVisionEmbeddings(nn.Module):
|
| 182 |
+
|
| 183 |
+
def __init__(self, config: CLIPVisionConfig):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.config = config
|
| 186 |
+
self.embed_dim = config.hidden_size
|
| 187 |
+
self.image_size = config.image_size
|
| 188 |
+
self.patch_size = config.patch_size
|
| 189 |
+
|
| 190 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
| 191 |
+
|
| 192 |
+
self.patch_embedding = nn.Conv2d(
|
| 193 |
+
in_channels=config.num_channels,
|
| 194 |
+
out_channels=self.embed_dim,
|
| 195 |
+
kernel_size=self.patch_size,
|
| 196 |
+
stride=self.patch_size,
|
| 197 |
+
bias=False,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.num_patches = get_clip_num_patches(image_size=self.image_size,
|
| 201 |
+
patch_size=self.patch_size)
|
| 202 |
+
self.num_positions = self.num_patches + 1
|
| 203 |
+
self.position_embedding = nn.Embedding(self.num_positions,
|
| 204 |
+
self.embed_dim)
|
| 205 |
+
self.register_buffer("position_ids",
|
| 206 |
+
torch.arange(self.num_positions).expand((1, -1)),
|
| 207 |
+
persistent=False)
|
| 208 |
+
|
| 209 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 210 |
+
batch_size = pixel_values.shape[0]
|
| 211 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 212 |
+
patch_embeds = self.patch_embedding(pixel_values.to(
|
| 213 |
+
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
| 214 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 215 |
+
|
| 216 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
| 217 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 218 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
| 219 |
+
|
| 220 |
+
return embeddings
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class CLIPAttention(nn.Module):
|
| 224 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 225 |
+
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
config: CLIPVisionConfig,
|
| 229 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 230 |
+
prefix: str = "",
|
| 231 |
+
):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.config = config
|
| 234 |
+
self.embed_dim = config.hidden_size
|
| 235 |
+
self.num_heads = config.num_attention_heads
|
| 236 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 237 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
"embed_dim must be divisible by num_heads "
|
| 240 |
+
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 241 |
+
f" {self.num_heads}).")
|
| 242 |
+
self.scale = self.head_dim**-0.5
|
| 243 |
+
self.dropout = config.attention_dropout
|
| 244 |
+
|
| 245 |
+
self.qkv_proj = QKVParallelLinear(
|
| 246 |
+
hidden_size=self.embed_dim,
|
| 247 |
+
head_size=self.head_dim,
|
| 248 |
+
total_num_heads=self.num_heads,
|
| 249 |
+
quant_config=quant_config,
|
| 250 |
+
prefix=f"{prefix}.qkv_proj",
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.out_proj = RowParallelLinear(
|
| 254 |
+
input_size=self.embed_dim,
|
| 255 |
+
output_size=self.embed_dim,
|
| 256 |
+
quant_config=quant_config,
|
| 257 |
+
prefix=f"{prefix}.out_proj",
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 261 |
+
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
| 262 |
+
|
| 263 |
+
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
| 264 |
+
self.head_dim, self.scale)
|
| 265 |
+
|
| 266 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 267 |
+
return tensor.view(bsz, seq_len, self.num_heads,
|
| 268 |
+
self.head_dim).transpose(1, 2).contiguous()
|
| 269 |
+
|
| 270 |
+
def forward(
|
| 271 |
+
self,
|
| 272 |
+
hidden_states: torch.Tensor,
|
| 273 |
+
):
|
| 274 |
+
"""Input shape: Batch x Time x Channel"""
|
| 275 |
+
|
| 276 |
+
qkv_states, _ = self.qkv_proj(hidden_states)
|
| 277 |
+
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
| 278 |
+
out = self.attn(query_states, key_states, value_states)
|
| 279 |
+
attn_output, _ = self.out_proj(out)
|
| 280 |
+
|
| 281 |
+
return attn_output, None
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class CLIPMLP(nn.Module):
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
config: CLIPVisionConfig,
|
| 289 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 290 |
+
prefix: str = "",
|
| 291 |
+
) -> None:
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.config = config
|
| 294 |
+
self.activation_fn = get_act_fn(config.hidden_act)
|
| 295 |
+
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
| 296 |
+
config.intermediate_size,
|
| 297 |
+
bias=True,
|
| 298 |
+
quant_config=quant_config,
|
| 299 |
+
prefix=f"{prefix}.fc1")
|
| 300 |
+
self.fc2 = RowParallelLinear(config.intermediate_size,
|
| 301 |
+
config.hidden_size,
|
| 302 |
+
bias=True,
|
| 303 |
+
quant_config=quant_config,
|
| 304 |
+
prefix=f"{prefix}.fc2")
|
| 305 |
+
|
| 306 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
hidden_states, _ = self.fc1(hidden_states)
|
| 308 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 309 |
+
hidden_states, _ = self.fc2(hidden_states)
|
| 310 |
+
|
| 311 |
+
return hidden_states
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class CLIPEncoderLayer(nn.Module):
|
| 315 |
+
|
| 316 |
+
def __init__(
|
| 317 |
+
self,
|
| 318 |
+
config: CLIPVisionConfig,
|
| 319 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 320 |
+
prefix: str = "",
|
| 321 |
+
) -> None:
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.self_attn = CLIPAttention(
|
| 324 |
+
config,
|
| 325 |
+
quant_config=quant_config,
|
| 326 |
+
prefix=f"{prefix}.self_attn",
|
| 327 |
+
)
|
| 328 |
+
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
| 329 |
+
eps=config.layer_norm_eps)
|
| 330 |
+
self.mlp = CLIPMLP(config,
|
| 331 |
+
quant_config=quant_config,
|
| 332 |
+
prefix=f"{prefix}.mlp")
|
| 333 |
+
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
| 334 |
+
eps=config.layer_norm_eps)
|
| 335 |
+
|
| 336 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 337 |
+
|
| 338 |
+
residual = hidden_states
|
| 339 |
+
|
| 340 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 341 |
+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
| 342 |
+
hidden_states = residual + hidden_states
|
| 343 |
+
|
| 344 |
+
residual = hidden_states
|
| 345 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 346 |
+
hidden_states = self.mlp(hidden_states)
|
| 347 |
+
hidden_states = residual + hidden_states
|
| 348 |
+
|
| 349 |
+
return hidden_states
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class CLIPEncoder(nn.Module):
|
| 353 |
+
"""
|
| 354 |
+
Transformer encoder consisting of `config.num_hidden_layers` self
|
| 355 |
+
attention layers. Each layer is a [`CLIPEncoderLayer`].
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
config: CLIPConfig
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
config: CLIPVisionConfig,
|
| 364 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 365 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 366 |
+
prefix: str = "",
|
| 367 |
+
) -> None:
|
| 368 |
+
super().__init__()
|
| 369 |
+
|
| 370 |
+
self.config = config
|
| 371 |
+
|
| 372 |
+
if num_hidden_layers_override is None:
|
| 373 |
+
num_hidden_layers = config.num_hidden_layers
|
| 374 |
+
else:
|
| 375 |
+
num_hidden_layers = num_hidden_layers_override
|
| 376 |
+
self.layers = nn.ModuleList([
|
| 377 |
+
CLIPEncoderLayer(config=config,
|
| 378 |
+
quant_config=quant_config,
|
| 379 |
+
prefix=f"{prefix}.layers.{layer_idx}")
|
| 380 |
+
for layer_idx in range(num_hidden_layers)
|
| 381 |
+
])
|
| 382 |
+
|
| 383 |
+
def forward(
|
| 384 |
+
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
|
| 385 |
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
| 386 |
+
hidden_states_pool = []
|
| 387 |
+
hidden_states = inputs_embeds
|
| 388 |
+
|
| 389 |
+
for encoder_layer in self.layers:
|
| 390 |
+
hidden_states = encoder_layer(hidden_states)
|
| 391 |
+
if return_all_hidden_states:
|
| 392 |
+
hidden_states_pool.append(hidden_states)
|
| 393 |
+
# If we have multiple feature sample layers, we return all hidden
|
| 394 |
+
# states in order and grab the ones we need by index.
|
| 395 |
+
if return_all_hidden_states:
|
| 396 |
+
return hidden_states_pool
|
| 397 |
+
return hidden_states
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class CLIPVisionTransformer(nn.Module):
|
| 401 |
+
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
config: CLIPVisionConfig,
|
| 405 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 406 |
+
*,
|
| 407 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 408 |
+
require_post_norm: Optional[bool] = None,
|
| 409 |
+
prefix: str = "",
|
| 410 |
+
) -> None:
|
| 411 |
+
super().__init__()
|
| 412 |
+
|
| 413 |
+
self.config = config
|
| 414 |
+
embed_dim = config.hidden_size
|
| 415 |
+
|
| 416 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
| 417 |
+
|
| 418 |
+
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
| 419 |
+
# the original transformers code and name of the model weights.
|
| 420 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 421 |
+
|
| 422 |
+
self.encoder = CLIPEncoder(
|
| 423 |
+
config=config,
|
| 424 |
+
quant_config=quant_config,
|
| 425 |
+
num_hidden_layers_override=num_hidden_layers_override,
|
| 426 |
+
prefix=f"{prefix}.encoder",
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
num_hidden_layers = config.num_hidden_layers
|
| 430 |
+
if len(self.encoder.layers) > config.num_hidden_layers:
|
| 431 |
+
raise ValueError(
|
| 432 |
+
f"The original encoder only has {num_hidden_layers} "
|
| 433 |
+
f"layers, but you requested {len(self.encoder.layers)} layers."
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# If possible, skip post_layernorm to conserve memory
|
| 437 |
+
if require_post_norm is None:
|
| 438 |
+
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
| 439 |
+
|
| 440 |
+
if require_post_norm:
|
| 441 |
+
self.post_layernorm = nn.LayerNorm(embed_dim,
|
| 442 |
+
eps=config.layer_norm_eps)
|
| 443 |
+
else:
|
| 444 |
+
self.post_layernorm = None
|
| 445 |
+
|
| 446 |
+
def forward(
|
| 447 |
+
self,
|
| 448 |
+
pixel_values: torch.Tensor,
|
| 449 |
+
feature_sample_layers: Optional[list[int]] = None,
|
| 450 |
+
) -> torch.Tensor:
|
| 451 |
+
|
| 452 |
+
hidden_states = self.embeddings(pixel_values)
|
| 453 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
| 454 |
+
|
| 455 |
+
return_all_hidden_states = feature_sample_layers is not None
|
| 456 |
+
|
| 457 |
+
# Produces either the last layer output or all of the hidden states,
|
| 458 |
+
# depending on if we have feature_sample_layers or not
|
| 459 |
+
encoder_outputs = self.encoder(
|
| 460 |
+
inputs_embeds=hidden_states,
|
| 461 |
+
return_all_hidden_states=return_all_hidden_states)
|
| 462 |
+
|
| 463 |
+
# Handle post-norm (if applicable) and stacks feature layers if needed
|
| 464 |
+
encoder_outputs = resolve_visual_encoder_outputs(
|
| 465 |
+
encoder_outputs, feature_sample_layers, self.post_layernorm,
|
| 466 |
+
self.config.num_hidden_layers)
|
| 467 |
+
|
| 468 |
+
return encoder_outputs
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class CLIPVisionModel(nn.Module):
|
| 472 |
+
|
| 473 |
+
config_class = CLIPVisionConfig
|
| 474 |
+
main_input_name = "pixel_values"
|
| 475 |
+
|
| 476 |
+
def __init__(
|
| 477 |
+
self,
|
| 478 |
+
config: CLIPVisionConfig,
|
| 479 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 480 |
+
*,
|
| 481 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 482 |
+
require_post_norm: Optional[bool] = None,
|
| 483 |
+
prefix: str = "",
|
| 484 |
+
) -> None:
|
| 485 |
+
super().__init__()
|
| 486 |
+
self.vision_model = CLIPVisionTransformer(
|
| 487 |
+
config=config,
|
| 488 |
+
quant_config=quant_config,
|
| 489 |
+
num_hidden_layers_override=num_hidden_layers_override,
|
| 490 |
+
require_post_norm=require_post_norm,
|
| 491 |
+
prefix=f"{prefix}.vision_model")
|
| 492 |
+
|
| 493 |
+
def forward(
|
| 494 |
+
self,
|
| 495 |
+
pixel_values: torch.Tensor,
|
| 496 |
+
feature_sample_layers: Optional[list[int]] = None,
|
| 497 |
+
) -> torch.Tensor:
|
| 498 |
+
return self.vision_model(pixel_values, feature_sample_layers)
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def device(self):
|
| 502 |
+
return next(self.parameters()).device
|
| 503 |
+
|
| 504 |
+
# (TODO) Add prefix argument for filtering out weights to be loaded
|
| 505 |
+
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
| 506 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 507 |
+
torch.Tensor]]) -> Set[str]:
|
| 508 |
+
stacked_params_mapping = [
|
| 509 |
+
# (param_name, shard_name, shard_id)
|
| 510 |
+
("qkv_proj", "q_proj", "q"),
|
| 511 |
+
("qkv_proj", "k_proj", "k"),
|
| 512 |
+
("qkv_proj", "v_proj", "v"),
|
| 513 |
+
]
|
| 514 |
+
params_dict = dict(self.named_parameters())
|
| 515 |
+
loaded_params: Set[str] = set()
|
| 516 |
+
layer_count = len(self.vision_model.encoder.layers)
|
| 517 |
+
|
| 518 |
+
for name, loaded_weight in weights:
|
| 519 |
+
# post_layernorm is not needed in CLIPVisionModel
|
| 520 |
+
if (name.startswith("vision_model.post_layernorm")
|
| 521 |
+
and self.vision_model.post_layernorm is None):
|
| 522 |
+
continue
|
| 523 |
+
|
| 524 |
+
# omit layers when num_hidden_layers_override is set
|
| 525 |
+
if name.startswith("vision_model.encoder.layers"):
|
| 526 |
+
layer_idx = int(name.split(".")[3])
|
| 527 |
+
if layer_idx >= layer_count:
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 531 |
+
if weight_name not in name:
|
| 532 |
+
continue
|
| 533 |
+
name = name.replace(weight_name, param_name)
|
| 534 |
+
|
| 535 |
+
param = params_dict[name]
|
| 536 |
+
weight_loader = param.weight_loader
|
| 537 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 538 |
+
break
|
| 539 |
+
else:
|
| 540 |
+
param = params_dict[name]
|
| 541 |
+
weight_loader = getattr(param, "weight_loader",
|
| 542 |
+
default_weight_loader)
|
| 543 |
+
weight_loader(param, loaded_weight)
|
| 544 |
+
loaded_params.add(name)
|
| 545 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/commandr.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
# This file is based on the LLama model definition file in transformers
|
| 23 |
+
"""PyTorch Cohere model."""
|
| 24 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.utils.checkpoint
|
| 28 |
+
from torch import nn
|
| 29 |
+
from transformers import CohereConfig
|
| 30 |
+
|
| 31 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 32 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 33 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 34 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 35 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 36 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 37 |
+
QKVParallelLinear,
|
| 38 |
+
RowParallelLinear)
|
| 39 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 40 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 41 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 42 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 43 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 44 |
+
VocabParallelEmbedding)
|
| 45 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 46 |
+
default_weight_loader, maybe_remap_kv_scale_name,
|
| 47 |
+
row_parallel_weight_loader)
|
| 48 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 49 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 50 |
+
from vllm.platforms import current_platform
|
| 51 |
+
from vllm.sequence import IntermediateTensors
|
| 52 |
+
|
| 53 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 54 |
+
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
| 55 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 56 |
+
maybe_prefix)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.compile(backend=current_platform.simple_compile_backend)
|
| 60 |
+
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
| 61 |
+
input_dtype = hidden_states.dtype
|
| 62 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 63 |
+
mean = hidden_states.mean(-1, keepdim=True)
|
| 64 |
+
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
| 65 |
+
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
|
| 66 |
+
variance_epsilon)
|
| 67 |
+
hidden_states = weight.to(torch.float32) * hidden_states
|
| 68 |
+
return hidden_states.to(input_dtype)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class LayerNorm(nn.Module):
|
| 72 |
+
|
| 73 |
+
def __init__(self, param_shape=None, eps=1e-5):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.weight = nn.Parameter(torch.ones(param_shape))
|
| 76 |
+
self.variance_epsilon = eps
|
| 77 |
+
set_weight_attrs(self.weight,
|
| 78 |
+
{"weight_loader": row_parallel_weight_loader})
|
| 79 |
+
|
| 80 |
+
def forward(self, hidden_states, residuals=None):
|
| 81 |
+
hidden_states = layer_norm_func(hidden_states, self.weight,
|
| 82 |
+
self.variance_epsilon)
|
| 83 |
+
return hidden_states, residuals
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
| 87 |
+
class CohereMLP(nn.Module):
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
config: CohereConfig,
|
| 92 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.config = config
|
| 96 |
+
self.hidden_size = config.hidden_size
|
| 97 |
+
self.intermediate_size = config.intermediate_size
|
| 98 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 99 |
+
self.hidden_size,
|
| 100 |
+
[self.intermediate_size] * 2,
|
| 101 |
+
bias=False,
|
| 102 |
+
quant_config=quant_config,
|
| 103 |
+
)
|
| 104 |
+
self.down_proj = RowParallelLinear(
|
| 105 |
+
self.intermediate_size,
|
| 106 |
+
self.hidden_size,
|
| 107 |
+
bias=False,
|
| 108 |
+
quant_config=quant_config,
|
| 109 |
+
)
|
| 110 |
+
self.act_fn = SiluAndMul()
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 114 |
+
x = self.act_fn(gate_up)
|
| 115 |
+
x, _ = self.down_proj(x)
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class CohereAttention(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
config: CohereConfig,
|
| 124 |
+
cache_config: Optional[CacheConfig] = None,
|
| 125 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 126 |
+
prefix: str = "",
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 130 |
+
self.config = config
|
| 131 |
+
self.attention_dropout = config.attention_dropout
|
| 132 |
+
self.hidden_size = config.hidden_size
|
| 133 |
+
self.total_num_heads = config.num_attention_heads
|
| 134 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 135 |
+
self.head_dim = self.hidden_size // self.total_num_heads
|
| 136 |
+
self.total_num_kv_heads = config.num_key_value_heads
|
| 137 |
+
if self.total_num_kv_heads >= tp_size:
|
| 138 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 139 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 140 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 141 |
+
else:
|
| 142 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 143 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 144 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 145 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 146 |
+
self.q_size = self.num_heads * self.head_dim
|
| 147 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 148 |
+
self.scaling = self.head_dim**-0.5
|
| 149 |
+
self.max_position_embeddings = getattr(
|
| 150 |
+
config, "model_max_length", None) or getattr(
|
| 151 |
+
config, "max_position_embeddings", 8192)
|
| 152 |
+
self.rope_theta = config.rope_theta
|
| 153 |
+
self.rope_scaling = getattr(config, "rope_scaling", None)
|
| 154 |
+
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
| 155 |
+
self.qkv_proj = QKVParallelLinear(
|
| 156 |
+
self.hidden_size,
|
| 157 |
+
self.head_dim,
|
| 158 |
+
self.total_num_heads,
|
| 159 |
+
self.total_num_kv_heads,
|
| 160 |
+
bias=False,
|
| 161 |
+
quant_config=quant_config,
|
| 162 |
+
)
|
| 163 |
+
self.o_proj = RowParallelLinear(
|
| 164 |
+
self.total_num_heads * self.head_dim,
|
| 165 |
+
self.hidden_size,
|
| 166 |
+
bias=False,
|
| 167 |
+
quant_config=quant_config,
|
| 168 |
+
)
|
| 169 |
+
self.rotary_emb = get_rope(
|
| 170 |
+
self.head_dim,
|
| 171 |
+
rotary_dim=self.head_dim,
|
| 172 |
+
max_position=self.max_position_embeddings,
|
| 173 |
+
base=self.rope_theta,
|
| 174 |
+
rope_scaling=self.rope_scaling,
|
| 175 |
+
is_neox_style=False,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Model v2 has interleaved sliding windows, v1 does not
|
| 179 |
+
interleaved_sliding_window = getattr(config,
|
| 180 |
+
"interleaved_sliding_window",
|
| 181 |
+
None)
|
| 182 |
+
self.v1 = interleaved_sliding_window is None
|
| 183 |
+
|
| 184 |
+
layer_idx = extract_layer_index(prefix)
|
| 185 |
+
layer_has_sliding_window = (
|
| 186 |
+
getattr(config, "sliding_window_pattern", False)
|
| 187 |
+
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
|
| 188 |
+
|
| 189 |
+
self.sliding_window = (interleaved_sliding_window
|
| 190 |
+
if layer_has_sliding_window else None)
|
| 191 |
+
|
| 192 |
+
self.attn = Attention(self.num_heads,
|
| 193 |
+
self.head_dim,
|
| 194 |
+
self.scaling,
|
| 195 |
+
num_kv_heads=self.num_kv_heads,
|
| 196 |
+
cache_config=cache_config,
|
| 197 |
+
quant_config=quant_config,
|
| 198 |
+
per_layer_sliding_window=self.sliding_window,
|
| 199 |
+
prefix=f"{prefix}.attn")
|
| 200 |
+
if self.use_qk_norm:
|
| 201 |
+
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
| 202 |
+
self.head_dim),
|
| 203 |
+
eps=config.layer_norm_eps)
|
| 204 |
+
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
|
| 205 |
+
self.head_dim),
|
| 206 |
+
eps=config.layer_norm_eps)
|
| 207 |
+
|
| 208 |
+
def _apply_qk_norm(self, q, k):
|
| 209 |
+
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
| 210 |
+
k = k.view(*k.shape[:-1], -1, self.head_dim)
|
| 211 |
+
q, _ = self.q_norm(q)
|
| 212 |
+
k, _ = self.k_norm(k)
|
| 213 |
+
q = q.view(*q.shape[:-2], -1)
|
| 214 |
+
k = k.view(*k.shape[:-2], -1)
|
| 215 |
+
return q, k
|
| 216 |
+
|
| 217 |
+
def forward(
|
| 218 |
+
self,
|
| 219 |
+
positions: torch.Tensor,
|
| 220 |
+
hidden_states: torch.Tensor,
|
| 221 |
+
kv_cache: torch.Tensor,
|
| 222 |
+
attn_metadata: AttentionMetadata,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 225 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 226 |
+
if self.use_qk_norm:
|
| 227 |
+
q, k = self._apply_qk_norm(q, k)
|
| 228 |
+
if self.v1 or self.sliding_window:
|
| 229 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 230 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 231 |
+
output, _ = self.o_proj(attn_output)
|
| 232 |
+
return output
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CohereDecoderLayer(nn.Module):
|
| 236 |
+
|
| 237 |
+
def __init__(self,
|
| 238 |
+
config: CohereConfig,
|
| 239 |
+
cache_config: Optional[CacheConfig] = None,
|
| 240 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 241 |
+
prefix: str = ""):
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.hidden_size = config.hidden_size
|
| 244 |
+
|
| 245 |
+
self.self_attn = CohereAttention(config,
|
| 246 |
+
cache_config,
|
| 247 |
+
quant_config=quant_config,
|
| 248 |
+
prefix=f"{prefix}.self_attn")
|
| 249 |
+
|
| 250 |
+
self.mlp = CohereMLP(config, quant_config=quant_config)
|
| 251 |
+
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
| 252 |
+
eps=config.layer_norm_eps)
|
| 253 |
+
|
| 254 |
+
def forward(
|
| 255 |
+
self,
|
| 256 |
+
positions: torch.Tensor,
|
| 257 |
+
hidden_states: torch.Tensor,
|
| 258 |
+
kv_cache: torch.Tensor,
|
| 259 |
+
attn_metadata: AttentionMetadata,
|
| 260 |
+
residual: Optional[torch.Tensor],
|
| 261 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 262 |
+
# Self Attention
|
| 263 |
+
residual = hidden_states
|
| 264 |
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
| 265 |
+
hidden_states_attention = self.self_attn(
|
| 266 |
+
positions=positions,
|
| 267 |
+
hidden_states=hidden_states,
|
| 268 |
+
kv_cache=kv_cache,
|
| 269 |
+
attn_metadata=attn_metadata,
|
| 270 |
+
)
|
| 271 |
+
hidden_states_mlp = self.mlp(hidden_states)
|
| 272 |
+
# Add everything together
|
| 273 |
+
hidden_states = residual + hidden_states_attention + hidden_states_mlp
|
| 274 |
+
|
| 275 |
+
return hidden_states, residual
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@support_torch_compile
|
| 279 |
+
class CohereModel(nn.Module):
|
| 280 |
+
|
| 281 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 282 |
+
super().__init__()
|
| 283 |
+
|
| 284 |
+
config = vllm_config.model_config.hf_config
|
| 285 |
+
cache_config = vllm_config.cache_config
|
| 286 |
+
quant_config = vllm_config.quant_config
|
| 287 |
+
lora_config = vllm_config.lora_config
|
| 288 |
+
|
| 289 |
+
self.config = config
|
| 290 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 291 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 292 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 293 |
+
self.org_vocab_size = config.vocab_size
|
| 294 |
+
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
| 295 |
+
config.hidden_size)
|
| 296 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 297 |
+
config.num_hidden_layers,
|
| 298 |
+
lambda prefix: CohereDecoderLayer(
|
| 299 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 300 |
+
prefix=f"{prefix}.layers")
|
| 301 |
+
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
| 302 |
+
eps=config.layer_norm_eps)
|
| 303 |
+
self.make_empty_intermediate_tensors = (
|
| 304 |
+
make_empty_intermediate_tensors_factory(
|
| 305 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 306 |
+
|
| 307 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 308 |
+
return self.embed_tokens(input_ids)
|
| 309 |
+
|
| 310 |
+
def forward(
|
| 311 |
+
self,
|
| 312 |
+
input_ids: torch.Tensor,
|
| 313 |
+
positions: torch.Tensor,
|
| 314 |
+
kv_caches: List[torch.Tensor],
|
| 315 |
+
attn_metadata: AttentionMetadata,
|
| 316 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 317 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 318 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 319 |
+
if get_pp_group().is_first_rank:
|
| 320 |
+
if inputs_embeds is not None:
|
| 321 |
+
hidden_states = inputs_embeds
|
| 322 |
+
else:
|
| 323 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 324 |
+
residual = None
|
| 325 |
+
else:
|
| 326 |
+
assert intermediate_tensors is not None
|
| 327 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 328 |
+
residual = intermediate_tensors["residual"]
|
| 329 |
+
for i in range(self.start_layer, self.end_layer):
|
| 330 |
+
layer = self.layers[i]
|
| 331 |
+
hidden_states, residual = layer(
|
| 332 |
+
positions,
|
| 333 |
+
hidden_states,
|
| 334 |
+
kv_caches[i - self.start_layer],
|
| 335 |
+
attn_metadata,
|
| 336 |
+
residual,
|
| 337 |
+
)
|
| 338 |
+
if not get_pp_group().is_last_rank:
|
| 339 |
+
return IntermediateTensors({
|
| 340 |
+
"hidden_states": hidden_states,
|
| 341 |
+
"residual": residual
|
| 342 |
+
})
|
| 343 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 344 |
+
return hidden_states
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 348 |
+
packed_modules_mapping = {
|
| 349 |
+
"qkv_proj": [
|
| 350 |
+
"q_proj",
|
| 351 |
+
"k_proj",
|
| 352 |
+
"v_proj",
|
| 353 |
+
],
|
| 354 |
+
"gate_up_proj": [
|
| 355 |
+
"gate_proj",
|
| 356 |
+
"up_proj",
|
| 357 |
+
],
|
| 358 |
+
}
|
| 359 |
+
# LoRA specific attributes
|
| 360 |
+
supported_lora_modules = [
|
| 361 |
+
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
| 362 |
+
]
|
| 363 |
+
embedding_modules = {"embed_tokens": "input_embeddings"}
|
| 364 |
+
embedding_padding_modules = []
|
| 365 |
+
|
| 366 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 367 |
+
super().__init__()
|
| 368 |
+
config = vllm_config.model_config.hf_config
|
| 369 |
+
quant_config = vllm_config.quant_config
|
| 370 |
+
lora_config = vllm_config.lora_config
|
| 371 |
+
self.config = config
|
| 372 |
+
# currently all existing command R models have `tie_word_embeddings`
|
| 373 |
+
# enabled
|
| 374 |
+
assert config.tie_word_embeddings
|
| 375 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 376 |
+
if lora_config:
|
| 377 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 378 |
+
self.quant_config = quant_config
|
| 379 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 380 |
+
config.vocab_size,
|
| 381 |
+
scale=config.logit_scale)
|
| 382 |
+
self.model = CohereModel(vllm_config=vllm_config,
|
| 383 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 384 |
+
self.sampler = get_sampler()
|
| 385 |
+
self.make_empty_intermediate_tensors = (
|
| 386 |
+
self.model.make_empty_intermediate_tensors)
|
| 387 |
+
|
| 388 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 389 |
+
return self.model.get_input_embeddings(input_ids)
|
| 390 |
+
|
| 391 |
+
@torch.no_grad()
|
| 392 |
+
def forward(
|
| 393 |
+
self,
|
| 394 |
+
input_ids: torch.Tensor,
|
| 395 |
+
positions: torch.Tensor,
|
| 396 |
+
kv_caches: List[torch.Tensor],
|
| 397 |
+
attn_metadata: AttentionMetadata,
|
| 398 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 399 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 400 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 401 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 402 |
+
attn_metadata, intermediate_tensors,
|
| 403 |
+
inputs_embeds)
|
| 404 |
+
return hidden_states
|
| 405 |
+
|
| 406 |
+
def compute_logits(
|
| 407 |
+
self,
|
| 408 |
+
hidden_states: torch.Tensor,
|
| 409 |
+
sampling_metadata: SamplingMetadata,
|
| 410 |
+
) -> Optional[torch.Tensor]:
|
| 411 |
+
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
| 412 |
+
if is_not_lora:
|
| 413 |
+
logits = self.logits_processor(self.model.embed_tokens,
|
| 414 |
+
hidden_states, sampling_metadata)
|
| 415 |
+
else:
|
| 416 |
+
logits = self.logits_processor(self.model.embed_tokens.base_layer,
|
| 417 |
+
hidden_states, sampling_metadata)
|
| 418 |
+
|
| 419 |
+
return logits
|
| 420 |
+
|
| 421 |
+
def sample(
|
| 422 |
+
self,
|
| 423 |
+
logits: torch.Tensor,
|
| 424 |
+
sampling_metadata: SamplingMetadata,
|
| 425 |
+
) -> Optional[SamplerOutput]:
|
| 426 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 427 |
+
return next_tokens
|
| 428 |
+
|
| 429 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 430 |
+
torch.Tensor]]) -> Set[str]:
|
| 431 |
+
stacked_params_mapping = [
|
| 432 |
+
# (param_name, shard_name, shard_id)
|
| 433 |
+
("qkv_proj", "q_proj", "q"),
|
| 434 |
+
("qkv_proj", "k_proj", "k"),
|
| 435 |
+
("qkv_proj", "v_proj", "v"),
|
| 436 |
+
("gate_up_proj", "gate_proj", 0),
|
| 437 |
+
("gate_up_proj", "up_proj", 1),
|
| 438 |
+
]
|
| 439 |
+
params_dict = dict(self.named_parameters())
|
| 440 |
+
loaded_params: Set[str] = set()
|
| 441 |
+
for name, loaded_weight in weights:
|
| 442 |
+
|
| 443 |
+
if (self.quant_config is not None and
|
| 444 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 445 |
+
# Loading kv cache quantization scales
|
| 446 |
+
param = params_dict[scale_name]
|
| 447 |
+
weight_loader = getattr(param, "weight_loader",
|
| 448 |
+
default_weight_loader)
|
| 449 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 450 |
+
loaded_weight[0])
|
| 451 |
+
weight_loader(param, loaded_weight)
|
| 452 |
+
loaded_params.add(scale_name)
|
| 453 |
+
continue
|
| 454 |
+
|
| 455 |
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
| 456 |
+
if shard_name not in name:
|
| 457 |
+
continue
|
| 458 |
+
name = name.replace(shard_name, param_name)
|
| 459 |
+
# Skip loading extra bias for GPTQ models.
|
| 460 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 461 |
+
continue
|
| 462 |
+
if is_pp_missing_parameter(name, self):
|
| 463 |
+
continue
|
| 464 |
+
param = params_dict[name]
|
| 465 |
+
weight_loader = param.weight_loader
|
| 466 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 467 |
+
break
|
| 468 |
+
else:
|
| 469 |
+
# lm_head is not used in vllm as it is tied with embed_token.
|
| 470 |
+
# To prevent errors, skip loading lm_head.weight.
|
| 471 |
+
if "lm_head.weight" in name:
|
| 472 |
+
continue
|
| 473 |
+
# Skip loading extra bias for GPTQ models.
|
| 474 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 475 |
+
continue
|
| 476 |
+
# Remapping the name of FP8 kv-scale.
|
| 477 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 478 |
+
if name is None:
|
| 479 |
+
continue
|
| 480 |
+
|
| 481 |
+
if is_pp_missing_parameter(name, self):
|
| 482 |
+
continue
|
| 483 |
+
param = params_dict[name]
|
| 484 |
+
weight_loader = getattr(param, "weight_loader",
|
| 485 |
+
default_weight_loader)
|
| 486 |
+
weight_loader(param, loaded_weight)
|
| 487 |
+
loaded_params.add(name)
|
| 488 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/dbrx.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 9 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 10 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 11 |
+
get_tensor_model_parallel_world_size)
|
| 12 |
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
| 13 |
+
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
| 14 |
+
ReplicatedLinear,
|
| 15 |
+
RowParallelLinear)
|
| 16 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 17 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 18 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 19 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 20 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 21 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 22 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 23 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 24 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 25 |
+
from vllm.sequence import IntermediateTensors
|
| 26 |
+
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
| 27 |
+
|
| 28 |
+
from .interfaces import SupportsPP
|
| 29 |
+
from .utils import (is_pp_missing_parameter,
|
| 30 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 31 |
+
maybe_prefix)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DbrxRouter(nn.Module):
|
| 35 |
+
"""A Router implementation for DBRX that returns logits for each expert
|
| 36 |
+
per token.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
config: DbrxConfig,
|
| 42 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 46 |
+
self.num_total_experts = config.ffn_config.moe_num_experts
|
| 47 |
+
self.d_model = config.d_model
|
| 48 |
+
self.layer = ReplicatedLinear(
|
| 49 |
+
self.d_model,
|
| 50 |
+
self.num_total_experts,
|
| 51 |
+
bias=False,
|
| 52 |
+
params_dtype=params_dtype,
|
| 53 |
+
quant_config=None,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
router_logits, _ = self.layer(hidden_states)
|
| 58 |
+
return router_logits
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class DbrxExperts(FusedMoE):
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
config: DbrxConfig,
|
| 66 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 67 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 68 |
+
):
|
| 69 |
+
super().__init__(
|
| 70 |
+
num_experts=config.ffn_config.moe_num_experts,
|
| 71 |
+
top_k=config.ffn_config.moe_top_k,
|
| 72 |
+
hidden_size=config.d_model,
|
| 73 |
+
intermediate_size=config.ffn_config.ffn_hidden_size,
|
| 74 |
+
params_dtype=params_dtype,
|
| 75 |
+
reduce_results=True,
|
| 76 |
+
renormalize=True,
|
| 77 |
+
quant_config=quant_config,
|
| 78 |
+
tp_size=get_tensor_model_parallel_world_size(),
|
| 79 |
+
)
|
| 80 |
+
self.config = config
|
| 81 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 82 |
+
self.d_model = config.d_model
|
| 83 |
+
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
|
| 84 |
+
self.tp_size)
|
| 85 |
+
|
| 86 |
+
# Define custom weight loader for dbrx model
|
| 87 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
| 88 |
+
weight_name: str, param_name: str):
|
| 89 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 90 |
+
param_data = param.data
|
| 91 |
+
shard_size = self.intermediate_size
|
| 92 |
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
| 93 |
+
# DBRX uses GLU for each experts.
|
| 94 |
+
# GLU has 3 linear layers: w1, v1 and w2.
|
| 95 |
+
if weight_name.endswith("w1"):
|
| 96 |
+
if param_name.endswith("weight"):
|
| 97 |
+
loaded_weight = torch.reshape(
|
| 98 |
+
loaded_weight,
|
| 99 |
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
| 100 |
+
)
|
| 101 |
+
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
| 102 |
+
elif param_name.endswith("weight_scale"):
|
| 103 |
+
param_data[:, 0] = loaded_weight
|
| 104 |
+
else:
|
| 105 |
+
param_data = loaded_weight
|
| 106 |
+
if weight_name.endswith("v1"):
|
| 107 |
+
if param_name.endswith("weight"):
|
| 108 |
+
loaded_weight = torch.reshape(
|
| 109 |
+
loaded_weight,
|
| 110 |
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
| 111 |
+
)
|
| 112 |
+
param_data[:, shard_size:2 *
|
| 113 |
+
shard_size, :] = loaded_weight[:, shard, :]
|
| 114 |
+
elif param_name.endswith("weight_scale"):
|
| 115 |
+
param_data[:, 1] = loaded_weight
|
| 116 |
+
else:
|
| 117 |
+
param_data[:] = loaded_weight
|
| 118 |
+
if weight_name.endswith("w2"):
|
| 119 |
+
if param_name.endswith("weight"):
|
| 120 |
+
loaded_weight = torch.reshape(
|
| 121 |
+
loaded_weight,
|
| 122 |
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
| 123 |
+
).transpose(1, 2)
|
| 124 |
+
param_data[:] = loaded_weight[:, :, shard]
|
| 125 |
+
else:
|
| 126 |
+
param_data[:] = loaded_weight
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class DbrxMoE(nn.Module):
|
| 130 |
+
"""A tensor-parallel MoE implementation for DBRX.
|
| 131 |
+
|
| 132 |
+
Each expert's weights are sharded across all ranks and a fused MoE
|
| 133 |
+
kernel is used for the forward pass, and finally we reduce the outputs
|
| 134 |
+
across ranks.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
config: DbrxConfig,
|
| 140 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 141 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.d_model = config.d_model
|
| 145 |
+
if params_dtype is None:
|
| 146 |
+
params_dtype = torch.get_default_dtype()
|
| 147 |
+
self.params_dtype = params_dtype
|
| 148 |
+
|
| 149 |
+
self.router = DbrxRouter(config, self.params_dtype)
|
| 150 |
+
|
| 151 |
+
self.experts = DbrxExperts(config=config,
|
| 152 |
+
quant_config=quant_config,
|
| 153 |
+
params_dtype=self.params_dtype)
|
| 154 |
+
|
| 155 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
orig_shape = hidden_states.shape
|
| 157 |
+
hidden_states = hidden_states.view(-1, self.d_model)
|
| 158 |
+
# router_logits: (num_tokens, n_experts)
|
| 159 |
+
router_logits = self.router(hidden_states)
|
| 160 |
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
| 161 |
+
return final_hidden_states.view(orig_shape)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class DbrxAttention(nn.Module):
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
config: DbrxConfig,
|
| 169 |
+
cache_config: Optional[CacheConfig] = None,
|
| 170 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 171 |
+
prefix: str = "",
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.d_model = config.d_model
|
| 175 |
+
self.total_num_heads = config.n_heads
|
| 176 |
+
self.head_dim = self.d_model // self.total_num_heads
|
| 177 |
+
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
| 178 |
+
self.clip_qkv = config.attn_config.clip_qkv
|
| 179 |
+
self.rope_theta = config.attn_config.rope_theta
|
| 180 |
+
self.max_position = config.max_seq_len
|
| 181 |
+
|
| 182 |
+
# pylint: disable=invalid-name
|
| 183 |
+
self.Wqkv = QKVParallelLinear(
|
| 184 |
+
self.d_model,
|
| 185 |
+
self.head_dim,
|
| 186 |
+
self.total_num_heads,
|
| 187 |
+
self.total_num_kv_heads,
|
| 188 |
+
bias=False,
|
| 189 |
+
quant_config=quant_config,
|
| 190 |
+
)
|
| 191 |
+
self.out_proj = RowParallelLinear(
|
| 192 |
+
self.d_model,
|
| 193 |
+
self.d_model,
|
| 194 |
+
bias=False,
|
| 195 |
+
quant_config=quant_config,
|
| 196 |
+
)
|
| 197 |
+
self.rotary_emb = get_rope(
|
| 198 |
+
self.head_dim,
|
| 199 |
+
rotary_dim=self.head_dim,
|
| 200 |
+
max_position=self.max_position,
|
| 201 |
+
base=int(self.rope_theta),
|
| 202 |
+
is_neox_style=True,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
tp_world_size = get_tensor_model_parallel_world_size()
|
| 206 |
+
self.tp_size = tp_world_size
|
| 207 |
+
assert self.total_num_heads % tp_world_size == 0
|
| 208 |
+
self.num_heads = self.total_num_heads // tp_world_size
|
| 209 |
+
if self.total_num_kv_heads >= tp_world_size:
|
| 210 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 211 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 212 |
+
assert self.total_num_kv_heads % tp_world_size == 0
|
| 213 |
+
else:
|
| 214 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 215 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 216 |
+
assert tp_world_size % self.total_num_kv_heads == 0
|
| 217 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
| 218 |
+
self.q_size = self.num_heads * self.head_dim
|
| 219 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 220 |
+
self.scaling = self.head_dim**-0.5
|
| 221 |
+
self.attn = Attention(self.num_heads,
|
| 222 |
+
self.head_dim,
|
| 223 |
+
self.scaling,
|
| 224 |
+
num_kv_heads=self.num_kv_heads,
|
| 225 |
+
cache_config=cache_config,
|
| 226 |
+
quant_config=quant_config,
|
| 227 |
+
prefix=f"{prefix}.attn")
|
| 228 |
+
|
| 229 |
+
def forward(
|
| 230 |
+
self,
|
| 231 |
+
position_ids: torch.Tensor,
|
| 232 |
+
hidden_states: torch.Tensor,
|
| 233 |
+
kv_cache: torch.Tensor,
|
| 234 |
+
attn_metadata: AttentionMetadata,
|
| 235 |
+
) -> torch.Tensor:
|
| 236 |
+
qkv, _ = self.Wqkv(hidden_states)
|
| 237 |
+
if self.clip_qkv is not None:
|
| 238 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 239 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 240 |
+
q, k = self.rotary_emb(position_ids, q, k)
|
| 241 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 242 |
+
hidden_states, _ = self.out_proj(attn_output)
|
| 243 |
+
return hidden_states
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class DbrxFusedNormAttention(nn.Module):
|
| 247 |
+
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
config: DbrxConfig,
|
| 251 |
+
cache_config: Optional[CacheConfig] = None,
|
| 252 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 253 |
+
prefix: str = "",
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.d_model = config.d_model
|
| 257 |
+
self.attn = DbrxAttention(config,
|
| 258 |
+
cache_config,
|
| 259 |
+
quant_config,
|
| 260 |
+
prefix=f"{prefix}.attn")
|
| 261 |
+
self.norm_1 = nn.LayerNorm(self.d_model)
|
| 262 |
+
self.norm_2 = nn.LayerNorm(self.d_model)
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
position_ids: torch.Tensor,
|
| 267 |
+
hidden_states: torch.Tensor,
|
| 268 |
+
kv_cache: torch.Tensor,
|
| 269 |
+
attn_metadata: AttentionMetadata,
|
| 270 |
+
) -> torch.Tensor:
|
| 271 |
+
residual = hidden_states
|
| 272 |
+
hidden_states = self.norm_1(hidden_states)
|
| 273 |
+
x = self.attn(
|
| 274 |
+
position_ids=position_ids,
|
| 275 |
+
hidden_states=hidden_states,
|
| 276 |
+
kv_cache=kv_cache,
|
| 277 |
+
attn_metadata=attn_metadata,
|
| 278 |
+
)
|
| 279 |
+
hidden_states = residual + x
|
| 280 |
+
residual = hidden_states
|
| 281 |
+
hidden_states = self.norm_2(hidden_states)
|
| 282 |
+
return hidden_states, residual
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class DbrxBlock(nn.Module):
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
config: DbrxConfig,
|
| 290 |
+
cache_config: Optional[CacheConfig] = None,
|
| 291 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 292 |
+
prefix: str = "",
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.norm_attn_norm = DbrxFusedNormAttention(
|
| 296 |
+
config,
|
| 297 |
+
cache_config,
|
| 298 |
+
quant_config,
|
| 299 |
+
prefix=f"{prefix}.norm_attn_norm")
|
| 300 |
+
self.ffn = DbrxMoE(config, quant_config)
|
| 301 |
+
|
| 302 |
+
def forward(
|
| 303 |
+
self,
|
| 304 |
+
position_ids: torch.Tensor,
|
| 305 |
+
hidden_states: torch.Tensor,
|
| 306 |
+
kv_cache: torch.Tensor,
|
| 307 |
+
attn_metadata: AttentionMetadata,
|
| 308 |
+
) -> torch.Tensor:
|
| 309 |
+
hidden_states, residual = self.norm_attn_norm(
|
| 310 |
+
position_ids=position_ids,
|
| 311 |
+
hidden_states=hidden_states,
|
| 312 |
+
kv_cache=kv_cache,
|
| 313 |
+
attn_metadata=attn_metadata,
|
| 314 |
+
)
|
| 315 |
+
hidden_states = self.ffn(hidden_states)
|
| 316 |
+
hidden_states = hidden_states + residual
|
| 317 |
+
return hidden_states
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class DbrxModel(nn.Module):
|
| 321 |
+
|
| 322 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 323 |
+
super().__init__()
|
| 324 |
+
|
| 325 |
+
config = vllm_config.model_config.hf_config
|
| 326 |
+
cache_config = vllm_config.cache_config
|
| 327 |
+
quant_config = vllm_config.quant_config
|
| 328 |
+
|
| 329 |
+
self.wte = VocabParallelEmbedding(
|
| 330 |
+
config.vocab_size,
|
| 331 |
+
config.d_model,
|
| 332 |
+
)
|
| 333 |
+
self.start_layer, self.end_layer, self.blocks = make_layers(
|
| 334 |
+
config.n_layers,
|
| 335 |
+
lambda prefix: DbrxBlock(
|
| 336 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 337 |
+
prefix=f"{prefix}.blocks",
|
| 338 |
+
)
|
| 339 |
+
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
| 340 |
+
for module in self.modules():
|
| 341 |
+
if hasattr(module, "bias") and isinstance(module.bias,
|
| 342 |
+
nn.Parameter):
|
| 343 |
+
# Remove the bias term in Linear and LayerNorm.
|
| 344 |
+
module.register_parameter("bias", None)
|
| 345 |
+
self.make_empty_intermediate_tensors = (
|
| 346 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 347 |
+
config.d_model))
|
| 348 |
+
|
| 349 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 350 |
+
return self.wte(input_ids)
|
| 351 |
+
|
| 352 |
+
def forward(
|
| 353 |
+
self,
|
| 354 |
+
input_ids: torch.Tensor,
|
| 355 |
+
position_ids: torch.Tensor,
|
| 356 |
+
kv_caches: List[torch.Tensor],
|
| 357 |
+
attn_metadata: AttentionMetadata,
|
| 358 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 359 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 360 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 361 |
+
if get_pp_group().is_first_rank:
|
| 362 |
+
if inputs_embeds is not None:
|
| 363 |
+
hidden_states = inputs_embeds
|
| 364 |
+
else:
|
| 365 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 366 |
+
else:
|
| 367 |
+
assert intermediate_tensors
|
| 368 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 369 |
+
for i in range(self.start_layer, self.end_layer):
|
| 370 |
+
block = self.blocks[i]
|
| 371 |
+
hidden_states = block(
|
| 372 |
+
position_ids,
|
| 373 |
+
hidden_states,
|
| 374 |
+
kv_caches[i - self.start_layer],
|
| 375 |
+
attn_metadata,
|
| 376 |
+
)
|
| 377 |
+
if not get_pp_group().is_last_rank:
|
| 378 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 379 |
+
hidden_states = self.norm_f(hidden_states)
|
| 380 |
+
return hidden_states
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class DbrxForCausalLM(nn.Module, SupportsPP):
|
| 384 |
+
|
| 385 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 386 |
+
super().__init__()
|
| 387 |
+
config = vllm_config.model_config.hf_config
|
| 388 |
+
quant_config = vllm_config.quant_config
|
| 389 |
+
self.config = config
|
| 390 |
+
if config.tie_word_embeddings:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
"tie_word_embeddings is not supported for Dbrx models.")
|
| 393 |
+
self.quant_config = quant_config
|
| 394 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 395 |
+
self.transformer = DbrxModel(vllm_config=vllm_config,
|
| 396 |
+
prefix=maybe_prefix(
|
| 397 |
+
prefix, "transformer"))
|
| 398 |
+
self.lm_head = ParallelLMHead(
|
| 399 |
+
config.vocab_size,
|
| 400 |
+
config.d_model,
|
| 401 |
+
org_num_embeddings=config.vocab_size,
|
| 402 |
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
| 403 |
+
quant_config=quant_config,
|
| 404 |
+
)
|
| 405 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 406 |
+
config.vocab_size)
|
| 407 |
+
self.sampler = get_sampler()
|
| 408 |
+
self.make_empty_intermediate_tensors = (
|
| 409 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 410 |
+
|
| 411 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 412 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 413 |
+
|
| 414 |
+
def forward(
|
| 415 |
+
self,
|
| 416 |
+
input_ids: torch.Tensor,
|
| 417 |
+
positions: torch.Tensor,
|
| 418 |
+
kv_caches: List[torch.Tensor],
|
| 419 |
+
attn_metadata: AttentionMetadata,
|
| 420 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 421 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 422 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 423 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 424 |
+
attn_metadata, intermediate_tensors,
|
| 425 |
+
inputs_embeds)
|
| 426 |
+
return hidden_states
|
| 427 |
+
|
| 428 |
+
def compute_logits(
|
| 429 |
+
self,
|
| 430 |
+
hidden_states: torch.Tensor,
|
| 431 |
+
sampling_metadata: SamplingMetadata,
|
| 432 |
+
) -> Optional[torch.Tensor]:
|
| 433 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 434 |
+
sampling_metadata)
|
| 435 |
+
return logits
|
| 436 |
+
|
| 437 |
+
def sample(
|
| 438 |
+
self,
|
| 439 |
+
logits: Optional[torch.Tensor],
|
| 440 |
+
sampling_metadata: SamplingMetadata,
|
| 441 |
+
) -> Optional[SamplerOutput]:
|
| 442 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 443 |
+
return next_tokens
|
| 444 |
+
|
| 445 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 446 |
+
torch.Tensor]]) -> Set[str]:
|
| 447 |
+
expert_params_mapping = [(
|
| 448 |
+
"w13" if weight_name in ["w1", "v1"] else "w2",
|
| 449 |
+
f"mlp.{weight_name}",
|
| 450 |
+
) for weight_name in ["w1", "v1", "w2"]]
|
| 451 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 452 |
+
loaded_params: Set[str] = set()
|
| 453 |
+
|
| 454 |
+
for name, loaded_weight in weights:
|
| 455 |
+
if (self.quant_config is not None and
|
| 456 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 457 |
+
# Loading kv cache quantization scales
|
| 458 |
+
param = params_dict[scale_name]
|
| 459 |
+
weight_loader = getattr(param, "weight_loader",
|
| 460 |
+
default_weight_loader)
|
| 461 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 462 |
+
loaded_weight[0])
|
| 463 |
+
weight_loader(param, loaded_weight)
|
| 464 |
+
loaded_params.add(scale_name)
|
| 465 |
+
continue
|
| 466 |
+
|
| 467 |
+
if name.endswith(("w1", "w2", "v1")):
|
| 468 |
+
name = name + "_weight"
|
| 469 |
+
for param_name, weight_name in expert_params_mapping:
|
| 470 |
+
if weight_name not in name:
|
| 471 |
+
continue
|
| 472 |
+
name = name.replace(weight_name, param_name)
|
| 473 |
+
if is_pp_missing_parameter(name, self):
|
| 474 |
+
continue
|
| 475 |
+
param = params_dict[name]
|
| 476 |
+
weight_loader = param.weight_loader
|
| 477 |
+
weight_loader(param, loaded_weight, weight_name, name)
|
| 478 |
+
break
|
| 479 |
+
|
| 480 |
+
else:
|
| 481 |
+
# Remapping the name of FP8 kv-scale.
|
| 482 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 483 |
+
if name is None:
|
| 484 |
+
continue
|
| 485 |
+
|
| 486 |
+
if is_pp_missing_parameter(name, self):
|
| 487 |
+
continue
|
| 488 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 489 |
+
if name is None:
|
| 490 |
+
continue
|
| 491 |
+
param = params_dict[name]
|
| 492 |
+
weight_loader = getattr(param, "weight_loader",
|
| 493 |
+
default_weight_loader)
|
| 494 |
+
weight_loader(param, loaded_weight)
|
| 495 |
+
loaded_params.add(name)
|
| 496 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/decilm.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 DeciAI Research Team. All rights reserved.
|
| 6 |
+
# Copyright 2023 The vLLM team.
|
| 7 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
"""Inference-only DeciLM model compatible with HuggingFace weights."""
|
| 26 |
+
|
| 27 |
+
from typing import Iterable, Set, Tuple
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
from vllm.config import VllmConfig
|
| 32 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 33 |
+
from vllm.model_executor.models.llama import LlamaForCausalLM
|
| 34 |
+
|
| 35 |
+
from .utils import is_pp_missing_parameter
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DeciLMForCausalLM(LlamaForCausalLM):
|
| 39 |
+
"""
|
| 40 |
+
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
|
| 41 |
+
Based on the llama executor.
|
| 42 |
+
|
| 43 |
+
The main difference is that DeciLM uses Variable Grouped Query Attention.
|
| 44 |
+
The constant number of GQA heads in the decoder is overridden with a value
|
| 45 |
+
per layer.
|
| 46 |
+
|
| 47 |
+
Usually, in the HuggingFace implementation, instead of
|
| 48 |
+
"config.num_key_value_heads", we use
|
| 49 |
+
"config.num_key_value_heads_per_layer[i]" which varies.
|
| 50 |
+
|
| 51 |
+
Currently, PagedAttention does not work well with variable GQA, so we
|
| 52 |
+
normalize the weights upon loading, and use uniform GQA with the max value
|
| 53 |
+
instead.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 57 |
+
config = vllm_config.model_config.hf_config
|
| 58 |
+
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
| 59 |
+
delattr(config, "num_key_value_heads_per_layer")
|
| 60 |
+
super().__init__(vllm_config=vllm_config)
|
| 61 |
+
|
| 62 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 63 |
+
torch.Tensor]]) -> Set[str]:
|
| 64 |
+
stacked_params_mapping = [
|
| 65 |
+
# (param_name, shard_name, shard_id)
|
| 66 |
+
("qkv_proj", "q_proj", "q"),
|
| 67 |
+
("qkv_proj", "k_proj", "k"),
|
| 68 |
+
("qkv_proj", "v_proj", "v"),
|
| 69 |
+
("gate_up_proj", "gate_proj", 0),
|
| 70 |
+
("gate_up_proj", "up_proj", 1),
|
| 71 |
+
]
|
| 72 |
+
params_dict = dict(self.named_parameters())
|
| 73 |
+
loaded_params: Set[str] = set()
|
| 74 |
+
for name, loaded_weight in weights:
|
| 75 |
+
if "rotary_emb.inv_freq" in name:
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
if "k_proj" in name or "v_proj" in name:
|
| 79 |
+
loaded_weight = self._degroup_weight(loaded_weight)
|
| 80 |
+
|
| 81 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 82 |
+
if weight_name not in name:
|
| 83 |
+
continue
|
| 84 |
+
name = name.replace(weight_name, param_name)
|
| 85 |
+
# Skip loading extra bias for GPTQ models.
|
| 86 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 87 |
+
continue
|
| 88 |
+
if is_pp_missing_parameter(name, self):
|
| 89 |
+
continue
|
| 90 |
+
param = params_dict[name]
|
| 91 |
+
weight_loader = param.weight_loader
|
| 92 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 93 |
+
break
|
| 94 |
+
else:
|
| 95 |
+
# Skip loading extra bias for GPTQ models.
|
| 96 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 97 |
+
continue
|
| 98 |
+
if is_pp_missing_parameter(name, self):
|
| 99 |
+
continue
|
| 100 |
+
param = params_dict[name]
|
| 101 |
+
weight_loader = getattr(param, "weight_loader",
|
| 102 |
+
default_weight_loader)
|
| 103 |
+
weight_loader(param, loaded_weight)
|
| 104 |
+
loaded_params.add(name)
|
| 105 |
+
return loaded_params
|
| 106 |
+
|
| 107 |
+
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
hidden_size = self.config.hidden_size
|
| 109 |
+
head_size = self.config.hidden_size // self.config.num_attention_heads
|
| 110 |
+
target_num_kv_heads = self.config.num_key_value_heads
|
| 111 |
+
num_kv_heads = loaded_weight.shape[0] // head_size
|
| 112 |
+
n_repeats = target_num_kv_heads / num_kv_heads
|
| 113 |
+
assert n_repeats == int(n_repeats)
|
| 114 |
+
|
| 115 |
+
n_repeats = int(n_repeats)
|
| 116 |
+
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
|
| 117 |
+
hidden_size)
|
| 118 |
+
loaded_weight = torch.repeat_interleave(loaded_weight,
|
| 119 |
+
repeats=n_repeats,
|
| 120 |
+
dim=0)
|
| 121 |
+
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
|
| 122 |
+
hidden_size)
|
| 123 |
+
|
| 124 |
+
return loaded_weight
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek_v2.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 9 |
+
# and OPT implementations in this library. It has been modified from its
|
| 10 |
+
# original forms to accommodate minor architectural differences compared
|
| 11 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
"""Inference-only DeepseekV2/DeepseekV3 model."""
|
| 25 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import nn
|
| 29 |
+
from transformers import PretrainedConfig
|
| 30 |
+
|
| 31 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 32 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 33 |
+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
| 34 |
+
from vllm.distributed import (get_pp_group,
|
| 35 |
+
get_tensor_model_parallel_world_size,
|
| 36 |
+
tensor_model_parallel_all_reduce)
|
| 37 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 38 |
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
| 39 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 40 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 41 |
+
MergedColumnParallelLinear,
|
| 42 |
+
ReplicatedLinear,
|
| 43 |
+
RowParallelLinear)
|
| 44 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 45 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 46 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 47 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 48 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 49 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 50 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 51 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 52 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 53 |
+
from vllm.sequence import IntermediateTensors
|
| 54 |
+
|
| 55 |
+
from .interfaces import SupportsPP
|
| 56 |
+
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
| 57 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 58 |
+
maybe_prefix)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class DeepseekV2MLP(nn.Module):
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
hidden_size: int,
|
| 66 |
+
intermediate_size: int,
|
| 67 |
+
hidden_act: str,
|
| 68 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 69 |
+
reduce_results: bool = True,
|
| 70 |
+
prefix: str = "",
|
| 71 |
+
) -> None:
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 74 |
+
hidden_size, [intermediate_size] * 2,
|
| 75 |
+
bias=False,
|
| 76 |
+
quant_config=quant_config,
|
| 77 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 78 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 79 |
+
hidden_size,
|
| 80 |
+
bias=False,
|
| 81 |
+
quant_config=quant_config,
|
| 82 |
+
reduce_results=reduce_results,
|
| 83 |
+
prefix=f"{prefix}.down_proj")
|
| 84 |
+
if hidden_act != "silu":
|
| 85 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 86 |
+
"Only silu is supported for now.")
|
| 87 |
+
self.act_fn = SiluAndMul()
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 91 |
+
x = self.act_fn(gate_up)
|
| 92 |
+
x, _ = self.down_proj(x)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class DeepseekV2MoE(nn.Module):
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
config: PretrainedConfig,
|
| 101 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 102 |
+
prefix: str = "",
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 106 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 107 |
+
self.n_shared_experts = config.n_shared_experts
|
| 108 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 109 |
+
if self.tp_size > config.n_routed_experts:
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"Tensor parallel size {self.tp_size} is greater than "
|
| 112 |
+
f"the number of experts {config.n_routed_experts}.")
|
| 113 |
+
|
| 114 |
+
if config.hidden_act != "silu":
|
| 115 |
+
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
| 116 |
+
"Only silu is supported for now.")
|
| 117 |
+
|
| 118 |
+
self.gate = ReplicatedLinear(config.hidden_size,
|
| 119 |
+
config.n_routed_experts,
|
| 120 |
+
bias=False,
|
| 121 |
+
quant_config=None,
|
| 122 |
+
prefix=f"{prefix}.gate")
|
| 123 |
+
if config.topk_method == "noaux_tc":
|
| 124 |
+
self.gate.e_score_correction_bias = nn.Parameter(
|
| 125 |
+
torch.empty(config.n_routed_experts))
|
| 126 |
+
else:
|
| 127 |
+
self.gate.e_score_correction_bias = None
|
| 128 |
+
|
| 129 |
+
self.experts = FusedMoE(
|
| 130 |
+
num_experts=config.n_routed_experts,
|
| 131 |
+
top_k=config.num_experts_per_tok,
|
| 132 |
+
hidden_size=config.hidden_size,
|
| 133 |
+
intermediate_size=config.moe_intermediate_size,
|
| 134 |
+
reduce_results=False,
|
| 135 |
+
renormalize=config.norm_topk_prob,
|
| 136 |
+
quant_config=quant_config,
|
| 137 |
+
use_grouped_topk=True,
|
| 138 |
+
num_expert_group=config.n_group,
|
| 139 |
+
topk_group=config.topk_group,
|
| 140 |
+
prefix=f"{prefix}.experts",
|
| 141 |
+
scoring_func=config.scoring_func,
|
| 142 |
+
e_score_correction_bias=self.gate.e_score_correction_bias)
|
| 143 |
+
|
| 144 |
+
if config.n_shared_experts is not None:
|
| 145 |
+
intermediate_size = (config.moe_intermediate_size *
|
| 146 |
+
config.n_shared_experts)
|
| 147 |
+
self.shared_experts = DeepseekV2MLP(
|
| 148 |
+
hidden_size=config.hidden_size,
|
| 149 |
+
intermediate_size=intermediate_size,
|
| 150 |
+
hidden_act=config.hidden_act,
|
| 151 |
+
quant_config=quant_config,
|
| 152 |
+
reduce_results=False,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
num_tokens, hidden_dim = hidden_states.shape
|
| 157 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 158 |
+
if self.n_shared_experts is not None:
|
| 159 |
+
shared_output = self.shared_experts(hidden_states)
|
| 160 |
+
# router_logits: (num_tokens, n_experts)
|
| 161 |
+
router_logits, _ = self.gate(hidden_states)
|
| 162 |
+
final_hidden_states = self.experts(
|
| 163 |
+
hidden_states=hidden_states,
|
| 164 |
+
router_logits=router_logits) * self.routed_scaling_factor
|
| 165 |
+
if shared_output is not None:
|
| 166 |
+
final_hidden_states = final_hidden_states + shared_output
|
| 167 |
+
if self.tp_size > 1:
|
| 168 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 169 |
+
final_hidden_states)
|
| 170 |
+
|
| 171 |
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
| 175 |
+
import math
|
| 176 |
+
if scale <= 1:
|
| 177 |
+
return 1.0
|
| 178 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DeepseekV2Attention(nn.Module):
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
config: PretrainedConfig,
|
| 186 |
+
hidden_size: int,
|
| 187 |
+
num_heads: int,
|
| 188 |
+
qk_nope_head_dim: int,
|
| 189 |
+
qk_rope_head_dim: int,
|
| 190 |
+
v_head_dim: int,
|
| 191 |
+
q_lora_rank: int,
|
| 192 |
+
kv_lora_rank: int,
|
| 193 |
+
rope_theta: float = 10000,
|
| 194 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 195 |
+
max_position_embeddings: int = 8192,
|
| 196 |
+
cache_config: Optional[CacheConfig] = None,
|
| 197 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 198 |
+
prefix: str = "",
|
| 199 |
+
) -> None:
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.hidden_size = hidden_size
|
| 202 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 203 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 204 |
+
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
| 205 |
+
self.v_head_dim = v_head_dim
|
| 206 |
+
self.q_lora_rank = q_lora_rank
|
| 207 |
+
self.kv_lora_rank = kv_lora_rank
|
| 208 |
+
self.num_heads = num_heads
|
| 209 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 210 |
+
assert num_heads % tp_size == 0
|
| 211 |
+
self.num_local_heads = num_heads // tp_size
|
| 212 |
+
self.scaling = self.qk_head_dim**-0.5
|
| 213 |
+
self.rope_theta = rope_theta
|
| 214 |
+
self.max_position_embeddings = max_position_embeddings
|
| 215 |
+
|
| 216 |
+
if self.q_lora_rank is not None:
|
| 217 |
+
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
| 218 |
+
self.q_lora_rank,
|
| 219 |
+
bias=False,
|
| 220 |
+
quant_config=quant_config,
|
| 221 |
+
prefix=f"{prefix}.q_a_proj")
|
| 222 |
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
| 223 |
+
eps=config.rms_norm_eps)
|
| 224 |
+
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
| 225 |
+
self.num_heads *
|
| 226 |
+
self.qk_head_dim,
|
| 227 |
+
bias=False,
|
| 228 |
+
quant_config=quant_config,
|
| 229 |
+
prefix=f"{prefix}.q_b_proj")
|
| 230 |
+
else:
|
| 231 |
+
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
| 232 |
+
self.num_heads *
|
| 233 |
+
self.qk_head_dim,
|
| 234 |
+
bias=False,
|
| 235 |
+
quant_config=quant_config,
|
| 236 |
+
prefix=f"{prefix}.q_proj")
|
| 237 |
+
|
| 238 |
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
| 239 |
+
self.hidden_size,
|
| 240 |
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
| 241 |
+
bias=False,
|
| 242 |
+
quant_config=quant_config,
|
| 243 |
+
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
| 244 |
+
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
| 245 |
+
eps=config.rms_norm_eps)
|
| 246 |
+
self.kv_b_proj = ColumnParallelLinear(
|
| 247 |
+
self.kv_lora_rank,
|
| 248 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
| 249 |
+
bias=False,
|
| 250 |
+
quant_config=quant_config,
|
| 251 |
+
prefix=f"{prefix}.kv_b_proj")
|
| 252 |
+
# O projection.
|
| 253 |
+
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
| 254 |
+
self.hidden_size,
|
| 255 |
+
bias=False,
|
| 256 |
+
quant_config=quant_config,
|
| 257 |
+
prefix=f"{prefix}.o_proj")
|
| 258 |
+
if rope_scaling:
|
| 259 |
+
rope_scaling["rope_type"] = 'deepseek_yarn'
|
| 260 |
+
self.use_normal_rope = False
|
| 261 |
+
else:
|
| 262 |
+
self.use_normal_rope = True
|
| 263 |
+
self.rotary_emb = get_rope(qk_rope_head_dim,
|
| 264 |
+
rotary_dim=qk_rope_head_dim,
|
| 265 |
+
max_position=max_position_embeddings,
|
| 266 |
+
base=rope_theta,
|
| 267 |
+
rope_scaling=rope_scaling,
|
| 268 |
+
is_neox_style=False)
|
| 269 |
+
|
| 270 |
+
if rope_scaling:
|
| 271 |
+
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
| 272 |
+
scaling_factor = rope_scaling["factor"]
|
| 273 |
+
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
| 274 |
+
self.scaling = self.scaling * mscale * mscale
|
| 275 |
+
|
| 276 |
+
self.attn = Attention(self.num_local_heads,
|
| 277 |
+
self.qk_head_dim,
|
| 278 |
+
self.scaling,
|
| 279 |
+
num_kv_heads=self.num_local_heads,
|
| 280 |
+
cache_config=cache_config,
|
| 281 |
+
quant_config=quant_config,
|
| 282 |
+
prefix=f"{prefix}.attn")
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
positions: torch.Tensor,
|
| 287 |
+
hidden_states: torch.Tensor,
|
| 288 |
+
kv_cache: torch.Tensor,
|
| 289 |
+
attn_metadata: AttentionMetadata,
|
| 290 |
+
) -> torch.Tensor:
|
| 291 |
+
if self.q_lora_rank is not None:
|
| 292 |
+
q = self.q_a_proj(hidden_states)[0]
|
| 293 |
+
q = self.q_a_layernorm(q)
|
| 294 |
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
|
| 295 |
+
self.qk_head_dim)
|
| 296 |
+
else:
|
| 297 |
+
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
|
| 298 |
+
self.qk_head_dim)
|
| 299 |
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
| 300 |
+
dim=-1)
|
| 301 |
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 302 |
+
kv_a, _ = latent_cache.split(
|
| 303 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 304 |
+
latent_cache = latent_cache.unsqueeze(1)
|
| 305 |
+
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
| 306 |
+
kv = self.kv_b_proj(kv_a)[0]
|
| 307 |
+
kv = kv.view(-1, self.num_local_heads,
|
| 308 |
+
self.qk_nope_head_dim + self.v_head_dim)
|
| 309 |
+
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 310 |
+
k_pe = latent_cache[:, :, self.kv_lora_rank:]
|
| 311 |
+
|
| 312 |
+
if self.use_normal_rope:
|
| 313 |
+
seq_len = positions.size(0)
|
| 314 |
+
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
| 315 |
+
q_pe = q_pe.reshape(seq_len, -1)
|
| 316 |
+
k_pe = k_pe.reshape(seq_len, -1)
|
| 317 |
+
|
| 318 |
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
| 319 |
+
|
| 320 |
+
if self.use_normal_rope:
|
| 321 |
+
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
| 322 |
+
|
| 323 |
+
q[..., self.qk_nope_head_dim:] = q_pe
|
| 324 |
+
k = torch.empty_like(q)
|
| 325 |
+
k[..., :self.qk_nope_head_dim] = k_nope
|
| 326 |
+
k[..., self.qk_nope_head_dim:] = k_pe
|
| 327 |
+
# padding value to qk_head_dim for alignment
|
| 328 |
+
v = torch.nn.functional.pad(
|
| 329 |
+
v, [0, self.qk_head_dim - self.v_head_dim],
|
| 330 |
+
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
| 331 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 332 |
+
attn_output = attn_output.view(
|
| 333 |
+
-1, self.num_local_heads,
|
| 334 |
+
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
| 335 |
+
-1, self.num_local_heads * self.v_head_dim)
|
| 336 |
+
output, _ = self.o_proj(attn_output)
|
| 337 |
+
return output
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class DeepseekV2MLAAttention(nn.Module):
|
| 341 |
+
"""
|
| 342 |
+
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
| 343 |
+
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
| 344 |
+
|
| 345 |
+
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
config: PretrainedConfig,
|
| 351 |
+
hidden_size: int,
|
| 352 |
+
num_heads: int,
|
| 353 |
+
qk_nope_head_dim: int,
|
| 354 |
+
qk_rope_head_dim: int,
|
| 355 |
+
v_head_dim: int,
|
| 356 |
+
q_lora_rank: Optional[int],
|
| 357 |
+
kv_lora_rank: int,
|
| 358 |
+
rope_theta: float = 10000,
|
| 359 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 360 |
+
max_position_embeddings: int = 8192,
|
| 361 |
+
cache_config: Optional[CacheConfig] = None,
|
| 362 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 363 |
+
prefix: str = "",
|
| 364 |
+
) -> None:
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.hidden_size = hidden_size
|
| 367 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 368 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 369 |
+
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
| 370 |
+
self.v_head_dim = v_head_dim
|
| 371 |
+
|
| 372 |
+
self.q_lora_rank = q_lora_rank
|
| 373 |
+
self.kv_lora_rank = kv_lora_rank
|
| 374 |
+
|
| 375 |
+
self.num_heads = num_heads
|
| 376 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 377 |
+
assert num_heads % tp_size == 0
|
| 378 |
+
self.num_local_heads = num_heads // tp_size
|
| 379 |
+
|
| 380 |
+
self.scaling = self.qk_head_dim**-0.5
|
| 381 |
+
self.rope_theta = rope_theta
|
| 382 |
+
self.max_position_embeddings = max_position_embeddings
|
| 383 |
+
|
| 384 |
+
if self.q_lora_rank is not None:
|
| 385 |
+
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
| 386 |
+
self.q_lora_rank,
|
| 387 |
+
bias=False,
|
| 388 |
+
quant_config=quant_config,
|
| 389 |
+
prefix=f"{prefix}.q_a_proj")
|
| 390 |
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
| 391 |
+
eps=config.rms_norm_eps)
|
| 392 |
+
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
| 393 |
+
self.num_heads *
|
| 394 |
+
self.qk_head_dim,
|
| 395 |
+
bias=False,
|
| 396 |
+
quant_config=quant_config,
|
| 397 |
+
prefix=f"{prefix}.q_b_proj")
|
| 398 |
+
else:
|
| 399 |
+
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
| 400 |
+
self.num_heads *
|
| 401 |
+
self.qk_head_dim,
|
| 402 |
+
bias=False,
|
| 403 |
+
quant_config=quant_config,
|
| 404 |
+
prefix=f"{prefix}.q_proj")
|
| 405 |
+
|
| 406 |
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
| 407 |
+
self.hidden_size,
|
| 408 |
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
| 409 |
+
bias=False,
|
| 410 |
+
quant_config=quant_config,
|
| 411 |
+
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
| 412 |
+
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
| 413 |
+
eps=config.rms_norm_eps)
|
| 414 |
+
self.kv_b_proj = ColumnParallelLinear(
|
| 415 |
+
self.kv_lora_rank,
|
| 416 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
| 417 |
+
bias=False,
|
| 418 |
+
quant_config=quant_config,
|
| 419 |
+
prefix=f"{prefix}.kv_b_proj")
|
| 420 |
+
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
| 421 |
+
self.hidden_size,
|
| 422 |
+
bias=False,
|
| 423 |
+
quant_config=quant_config,
|
| 424 |
+
prefix=f"{prefix}.o_proj")
|
| 425 |
+
|
| 426 |
+
if rope_scaling:
|
| 427 |
+
rope_scaling["rope_type"] = 'deepseek_yarn'
|
| 428 |
+
self.rotary_emb = get_rope(qk_rope_head_dim,
|
| 429 |
+
rotary_dim=qk_rope_head_dim,
|
| 430 |
+
max_position=max_position_embeddings,
|
| 431 |
+
base=rope_theta,
|
| 432 |
+
rope_scaling=rope_scaling,
|
| 433 |
+
is_neox_style=False)
|
| 434 |
+
if rope_scaling:
|
| 435 |
+
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
| 436 |
+
scaling_factor = rope_scaling["factor"]
|
| 437 |
+
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
| 438 |
+
self.scaling = self.scaling * mscale * mscale
|
| 439 |
+
|
| 440 |
+
self.mla_attn = Attention(
|
| 441 |
+
num_heads=self.num_local_heads,
|
| 442 |
+
head_size=self.kv_lora_rank,
|
| 443 |
+
scale=self.scaling,
|
| 444 |
+
num_kv_heads=1,
|
| 445 |
+
cache_config=cache_config,
|
| 446 |
+
quant_config=quant_config,
|
| 447 |
+
prefix=f"{prefix}.attn",
|
| 448 |
+
use_mla=True,
|
| 449 |
+
# MLA Args
|
| 450 |
+
q_lora_rank=self.q_lora_rank,
|
| 451 |
+
kv_lora_rank=self.kv_lora_rank,
|
| 452 |
+
qk_nope_head_dim=self.qk_nope_head_dim,
|
| 453 |
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
| 454 |
+
qk_head_dim=self.qk_head_dim,
|
| 455 |
+
v_head_dim=self.v_head_dim,
|
| 456 |
+
rotary_emb=self.rotary_emb,
|
| 457 |
+
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
| 458 |
+
kv_b_proj=self.kv_b_proj,
|
| 459 |
+
o_proj=self.o_proj,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self.prefix = prefix
|
| 463 |
+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
| 464 |
+
|
| 465 |
+
def forward(
|
| 466 |
+
self,
|
| 467 |
+
positions: torch.Tensor,
|
| 468 |
+
hidden_states: torch.Tensor,
|
| 469 |
+
kv_cache: torch.Tensor,
|
| 470 |
+
attn_metadata: AttentionMetadata,
|
| 471 |
+
) -> torch.Tensor:
|
| 472 |
+
if self.q_lora_rank is not None:
|
| 473 |
+
ckq = self.q_a_proj(hidden_states)[0]
|
| 474 |
+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
| 475 |
+
else:
|
| 476 |
+
hidden_states_or_q_c = hidden_states
|
| 477 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
| 478 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 479 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 480 |
+
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
| 481 |
+
attn_metadata)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class DeepseekV2DecoderLayer(nn.Module):
|
| 485 |
+
|
| 486 |
+
def __init__(
|
| 487 |
+
self,
|
| 488 |
+
config: PretrainedConfig,
|
| 489 |
+
prefix: str,
|
| 490 |
+
model_config: ModelConfig,
|
| 491 |
+
cache_config: Optional[CacheConfig] = None,
|
| 492 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 493 |
+
) -> None:
|
| 494 |
+
super().__init__()
|
| 495 |
+
self.hidden_size = config.hidden_size
|
| 496 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 497 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 498 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 499 |
+
8192)
|
| 500 |
+
# DecoderLayers are created with `make_layers` which passes the prefix
|
| 501 |
+
# with the layer's index.
|
| 502 |
+
layer_idx = int(prefix.split(sep='.')[-1])
|
| 503 |
+
if model_config.use_mla:
|
| 504 |
+
attn_cls = DeepseekV2MLAAttention
|
| 505 |
+
else:
|
| 506 |
+
attn_cls = DeepseekV2Attention
|
| 507 |
+
self.self_attn = attn_cls(
|
| 508 |
+
config=config,
|
| 509 |
+
hidden_size=self.hidden_size,
|
| 510 |
+
num_heads=config.num_attention_heads,
|
| 511 |
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
| 512 |
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
| 513 |
+
v_head_dim=config.v_head_dim,
|
| 514 |
+
q_lora_rank=config.q_lora_rank
|
| 515 |
+
if hasattr(config, "q_lora_rank") else None,
|
| 516 |
+
kv_lora_rank=config.kv_lora_rank,
|
| 517 |
+
rope_theta=rope_theta,
|
| 518 |
+
rope_scaling=rope_scaling,
|
| 519 |
+
max_position_embeddings=max_position_embeddings,
|
| 520 |
+
cache_config=cache_config,
|
| 521 |
+
quant_config=quant_config,
|
| 522 |
+
prefix=f"{prefix}.self_attn",
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if (config.n_routed_experts is not None
|
| 526 |
+
and layer_idx >= config.first_k_dense_replace
|
| 527 |
+
and layer_idx % config.moe_layer_freq == 0):
|
| 528 |
+
self.mlp = DeepseekV2MoE(
|
| 529 |
+
config=config,
|
| 530 |
+
quant_config=quant_config,
|
| 531 |
+
prefix=f"{prefix}.mlp",
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
self.mlp = DeepseekV2MLP(
|
| 535 |
+
hidden_size=config.hidden_size,
|
| 536 |
+
intermediate_size=config.intermediate_size,
|
| 537 |
+
hidden_act=config.hidden_act,
|
| 538 |
+
quant_config=quant_config,
|
| 539 |
+
prefix=f"{prefix}.mlp",
|
| 540 |
+
)
|
| 541 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 542 |
+
eps=config.rms_norm_eps)
|
| 543 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 544 |
+
eps=config.rms_norm_eps)
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
positions: torch.Tensor,
|
| 549 |
+
hidden_states: torch.Tensor,
|
| 550 |
+
kv_cache: torch.Tensor,
|
| 551 |
+
attn_metadata: AttentionMetadata,
|
| 552 |
+
residual: Optional[torch.Tensor],
|
| 553 |
+
) -> torch.Tensor:
|
| 554 |
+
# Self Attention
|
| 555 |
+
if residual is None:
|
| 556 |
+
residual = hidden_states
|
| 557 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 558 |
+
else:
|
| 559 |
+
hidden_states, residual = self.input_layernorm(
|
| 560 |
+
hidden_states, residual)
|
| 561 |
+
hidden_states = self.self_attn(
|
| 562 |
+
positions=positions,
|
| 563 |
+
hidden_states=hidden_states,
|
| 564 |
+
kv_cache=kv_cache,
|
| 565 |
+
attn_metadata=attn_metadata,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Fully Connected
|
| 569 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 570 |
+
hidden_states, residual)
|
| 571 |
+
hidden_states = self.mlp(hidden_states)
|
| 572 |
+
return hidden_states, residual
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
@support_torch_compile
|
| 576 |
+
class DeepseekV2Model(nn.Module):
|
| 577 |
+
|
| 578 |
+
fall_back_to_pt_during_load = False
|
| 579 |
+
|
| 580 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 581 |
+
super().__init__()
|
| 582 |
+
|
| 583 |
+
config = vllm_config.model_config.hf_config
|
| 584 |
+
model_config = vllm_config.model_config
|
| 585 |
+
cache_config = vllm_config.cache_config
|
| 586 |
+
quant_config = vllm_config.quant_config
|
| 587 |
+
|
| 588 |
+
self.padding_idx = config.pad_token_id
|
| 589 |
+
self.vocab_size = config.vocab_size
|
| 590 |
+
|
| 591 |
+
if get_pp_group().is_first_rank:
|
| 592 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 593 |
+
config.vocab_size,
|
| 594 |
+
config.hidden_size,
|
| 595 |
+
)
|
| 596 |
+
else:
|
| 597 |
+
self.embed_tokens = PPMissingLayer()
|
| 598 |
+
|
| 599 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 600 |
+
config.num_hidden_layers,
|
| 601 |
+
lambda prefix: DeepseekV2DecoderLayer(
|
| 602 |
+
config,
|
| 603 |
+
prefix,
|
| 604 |
+
model_config=model_config,
|
| 605 |
+
cache_config=cache_config,
|
| 606 |
+
quant_config=quant_config,
|
| 607 |
+
),
|
| 608 |
+
prefix=f"{prefix}.layers")
|
| 609 |
+
|
| 610 |
+
if get_pp_group().is_last_rank:
|
| 611 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 612 |
+
else:
|
| 613 |
+
self.norm = PPMissingLayer()
|
| 614 |
+
self.make_empty_intermediate_tensors = (
|
| 615 |
+
make_empty_intermediate_tensors_factory(
|
| 616 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 617 |
+
|
| 618 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 619 |
+
return self.embed_tokens(input_ids)
|
| 620 |
+
|
| 621 |
+
def forward(
|
| 622 |
+
self,
|
| 623 |
+
input_ids: torch.Tensor,
|
| 624 |
+
positions: torch.Tensor,
|
| 625 |
+
kv_caches: List[torch.Tensor],
|
| 626 |
+
attn_metadata: AttentionMetadata,
|
| 627 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 628 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 629 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 630 |
+
if get_pp_group().is_first_rank:
|
| 631 |
+
if inputs_embeds is not None:
|
| 632 |
+
hidden_states = inputs_embeds
|
| 633 |
+
else:
|
| 634 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 635 |
+
residual = None
|
| 636 |
+
else:
|
| 637 |
+
assert intermediate_tensors is not None
|
| 638 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 639 |
+
residual = intermediate_tensors["residual"]
|
| 640 |
+
|
| 641 |
+
for i in range(self.start_layer, self.end_layer):
|
| 642 |
+
layer = self.layers[i]
|
| 643 |
+
hidden_states, residual = layer(positions, hidden_states,
|
| 644 |
+
kv_caches[i - self.start_layer],
|
| 645 |
+
attn_metadata, residual)
|
| 646 |
+
|
| 647 |
+
if not get_pp_group().is_last_rank:
|
| 648 |
+
return IntermediateTensors({
|
| 649 |
+
"hidden_states": hidden_states,
|
| 650 |
+
"residual": residual
|
| 651 |
+
})
|
| 652 |
+
|
| 653 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 654 |
+
return hidden_states
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
| 658 |
+
|
| 659 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 660 |
+
super().__init__()
|
| 661 |
+
config = vllm_config.model_config.hf_config
|
| 662 |
+
quant_config = vllm_config.quant_config
|
| 663 |
+
self.config = config
|
| 664 |
+
self.quant_config = quant_config
|
| 665 |
+
self.model = DeepseekV2Model(vllm_config=vllm_config,
|
| 666 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 667 |
+
self.lm_head = ParallelLMHead(config.vocab_size,
|
| 668 |
+
config.hidden_size,
|
| 669 |
+
quant_config=quant_config)
|
| 670 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 671 |
+
self.sampler = get_sampler()
|
| 672 |
+
self.make_empty_intermediate_tensors = (
|
| 673 |
+
self.model.make_empty_intermediate_tensors)
|
| 674 |
+
|
| 675 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 676 |
+
return self.model.get_input_embeddings(input_ids)
|
| 677 |
+
|
| 678 |
+
def forward(
|
| 679 |
+
self,
|
| 680 |
+
input_ids: torch.Tensor,
|
| 681 |
+
positions: torch.Tensor,
|
| 682 |
+
kv_caches: List[torch.Tensor],
|
| 683 |
+
attn_metadata: AttentionMetadata,
|
| 684 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 685 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 686 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 687 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 688 |
+
attn_metadata, intermediate_tensors,
|
| 689 |
+
inputs_embeds)
|
| 690 |
+
return hidden_states
|
| 691 |
+
|
| 692 |
+
def compute_logits(
|
| 693 |
+
self,
|
| 694 |
+
hidden_states: torch.Tensor,
|
| 695 |
+
sampling_metadata: SamplingMetadata,
|
| 696 |
+
) -> Optional[torch.Tensor]:
|
| 697 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 698 |
+
sampling_metadata)
|
| 699 |
+
return logits
|
| 700 |
+
|
| 701 |
+
def sample(
|
| 702 |
+
self,
|
| 703 |
+
logits: Optional[torch.Tensor],
|
| 704 |
+
sampling_metadata: SamplingMetadata,
|
| 705 |
+
) -> Optional[SamplerOutput]:
|
| 706 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 707 |
+
return next_tokens
|
| 708 |
+
|
| 709 |
+
def make_empty_intermediate_tensors(
|
| 710 |
+
self, batch_size: int, dtype: torch.dtype,
|
| 711 |
+
device: torch.device) -> IntermediateTensors:
|
| 712 |
+
return IntermediateTensors({
|
| 713 |
+
"hidden_states":
|
| 714 |
+
torch.zeros((batch_size, self.config.hidden_size),
|
| 715 |
+
dtype=dtype,
|
| 716 |
+
device=device),
|
| 717 |
+
"residual":
|
| 718 |
+
torch.zeros((batch_size, self.config.hidden_size),
|
| 719 |
+
dtype=dtype,
|
| 720 |
+
device=device),
|
| 721 |
+
})
|
| 722 |
+
|
| 723 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 724 |
+
torch.Tensor]]) -> Set[str]:
|
| 725 |
+
stacked_params_mapping = [
|
| 726 |
+
# (param_name, shard_name, shard_id)
|
| 727 |
+
("gate_up_proj", "gate_proj", 0),
|
| 728 |
+
("gate_up_proj", "up_proj", 1),
|
| 729 |
+
]
|
| 730 |
+
|
| 731 |
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
| 732 |
+
# (param_name, weight_name, expert_id, shard_id)
|
| 733 |
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
| 734 |
+
ckpt_gate_proj_name="gate_proj",
|
| 735 |
+
ckpt_down_proj_name="down_proj",
|
| 736 |
+
ckpt_up_proj_name="up_proj",
|
| 737 |
+
num_experts=self.config.n_routed_experts)
|
| 738 |
+
|
| 739 |
+
params_dict = dict(self.named_parameters())
|
| 740 |
+
loaded_params: Set[str] = set()
|
| 741 |
+
for name, loaded_weight in weights:
|
| 742 |
+
if "rotary_emb.inv_freq" in name:
|
| 743 |
+
continue
|
| 744 |
+
|
| 745 |
+
# TODO(simon): support nextn predict layers
|
| 746 |
+
if hasattr(self.config, "num_nextn_predict_layers"
|
| 747 |
+
) and self.config.num_nextn_predict_layers > 0:
|
| 748 |
+
assert self.config.num_nextn_predict_layers == 1
|
| 749 |
+
layer_idx = self.config.num_hidden_layers
|
| 750 |
+
if name.startswith(f"model.layers.{layer_idx}"):
|
| 751 |
+
continue
|
| 752 |
+
|
| 753 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 754 |
+
# Skip non-stacked layers and experts (experts handled below).
|
| 755 |
+
if weight_name not in name:
|
| 756 |
+
continue
|
| 757 |
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
| 758 |
+
# Since we handle the experts below in expert_params_mapping,
|
| 759 |
+
# we need to skip here BEFORE we update the name, otherwise
|
| 760 |
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
| 761 |
+
# will then be updated below in expert_params_mapping
|
| 762 |
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
| 763 |
+
if (("mlp.experts." in name) and name not in params_dict):
|
| 764 |
+
continue
|
| 765 |
+
name = name.replace(weight_name, param_name)
|
| 766 |
+
# Skip loading extra bias for GPTQ models.
|
| 767 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 768 |
+
continue
|
| 769 |
+
|
| 770 |
+
if is_pp_missing_parameter(name, self):
|
| 771 |
+
continue
|
| 772 |
+
|
| 773 |
+
param = params_dict[name]
|
| 774 |
+
weight_loader = param.weight_loader
|
| 775 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 776 |
+
break
|
| 777 |
+
else:
|
| 778 |
+
for mapping in expert_params_mapping:
|
| 779 |
+
param_name, weight_name, expert_id, shard_id = mapping
|
| 780 |
+
if weight_name not in name:
|
| 781 |
+
continue
|
| 782 |
+
name = name.replace(weight_name, param_name)
|
| 783 |
+
|
| 784 |
+
if is_pp_missing_parameter(name, self):
|
| 785 |
+
continue
|
| 786 |
+
|
| 787 |
+
param = params_dict[name]
|
| 788 |
+
weight_loader = param.weight_loader
|
| 789 |
+
weight_loader(param,
|
| 790 |
+
loaded_weight,
|
| 791 |
+
name,
|
| 792 |
+
shard_id=shard_id,
|
| 793 |
+
expert_id=expert_id)
|
| 794 |
+
break
|
| 795 |
+
else:
|
| 796 |
+
# Skip loading extra bias for GPTQ models.
|
| 797 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 798 |
+
continue
|
| 799 |
+
|
| 800 |
+
# Remapping the name of FP8 kv-scale.
|
| 801 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 802 |
+
if name is None:
|
| 803 |
+
continue
|
| 804 |
+
|
| 805 |
+
if is_pp_missing_parameter(name, self):
|
| 806 |
+
continue
|
| 807 |
+
|
| 808 |
+
param = params_dict[name]
|
| 809 |
+
weight_loader = getattr(param, "weight_loader",
|
| 810 |
+
default_weight_loader)
|
| 811 |
+
weight_loader(param, loaded_weight)
|
| 812 |
+
loaded_params.add(name)
|
| 813 |
+
return loaded_params
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
| 817 |
+
pass
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/deepseek_vl2.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
|
| 4 |
+
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
|
| 5 |
+
import math
|
| 6 |
+
from functools import cached_property
|
| 7 |
+
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
| 8 |
+
TypedDict, Union)
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
from transformers import BatchFeature
|
| 15 |
+
|
| 16 |
+
from vllm.attention import AttentionMetadata
|
| 17 |
+
from vllm.config import VllmConfig
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.model_executor import SamplingMetadata
|
| 20 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 21 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 22 |
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
| 23 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 24 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 25 |
+
NestedTensors)
|
| 26 |
+
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
| 27 |
+
ImageSize, MultiModalDataItems)
|
| 28 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 29 |
+
BaseProcessingInfo, PromptReplacement)
|
| 30 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 31 |
+
from vllm.multimodal.utils import cached_get_tokenizer
|
| 32 |
+
from vllm.sequence import IntermediateTensors
|
| 33 |
+
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
|
| 34 |
+
MlpProjectorConfig,
|
| 35 |
+
VisionEncoderConfig)
|
| 36 |
+
from vllm.transformers_utils.processors.deepseek_vl2 import (
|
| 37 |
+
DeepseekVLV2Processor)
|
| 38 |
+
from vllm.utils import is_list_of
|
| 39 |
+
|
| 40 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 41 |
+
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
| 42 |
+
init_vllm_registered_model, maybe_prefix,
|
| 43 |
+
merge_multimodal_embeddings)
|
| 44 |
+
|
| 45 |
+
logger = init_logger(__name__)
|
| 46 |
+
|
| 47 |
+
# The image token id may be various
|
| 48 |
+
_IMAGE_TOKEN = "<image>"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DeepseekVL2ImagePixelInputs(TypedDict):
|
| 52 |
+
type: Literal["pixel_values"]
|
| 53 |
+
data: Union[torch.Tensor, List[torch.Tensor]]
|
| 54 |
+
"""
|
| 55 |
+
Shape: `(batch_size * num_images, num_channels, height, width)`
|
| 56 |
+
"""
|
| 57 |
+
images_spatial_crop: torch.Tensor
|
| 58 |
+
"""
|
| 59 |
+
Shape: `(batch_size * num_images, 2)`
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class DeepseekVL2VImageEmbeddingInputs(TypedDict):
|
| 64 |
+
type: Literal["image_embeds"]
|
| 65 |
+
data: Union[torch.Tensor, List[torch.Tensor]]
|
| 66 |
+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
| 67 |
+
|
| 68 |
+
`hidden_size` must match the hidden size of language model backbone.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs,
|
| 73 |
+
DeepseekVL2VImageEmbeddingInputs]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MlpProjector(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(self, cfg: MlpProjectorConfig):
|
| 79 |
+
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
self.cfg = cfg
|
| 83 |
+
assert not cfg.token_pooling, (
|
| 84 |
+
"Token pooling is not supported currently.")
|
| 85 |
+
|
| 86 |
+
if cfg.projector_type == "downsample_mlp_gelu":
|
| 87 |
+
mlp_depth = cfg.depth
|
| 88 |
+
mlp_ratio = cfg.mlp_ratio
|
| 89 |
+
modules = [
|
| 90 |
+
nn.Linear(
|
| 91 |
+
cfg.input_dim * cfg.downsample_ratio *
|
| 92 |
+
cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
|
| 93 |
+
]
|
| 94 |
+
for _ in range(1, mlp_depth - 1):
|
| 95 |
+
modules.append(nn.GELU())
|
| 96 |
+
modules.append(
|
| 97 |
+
nn.Linear(cfg.n_embed * mlp_ratio,
|
| 98 |
+
cfg.n_embed * mlp_ratio))
|
| 99 |
+
modules.append(nn.GELU())
|
| 100 |
+
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
| 101 |
+
modules = nn.Sequential(*modules)
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError(
|
| 105 |
+
f"Unsupported projector type: {cfg.projector_type}")
|
| 106 |
+
|
| 107 |
+
self.layers = modules
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
bs, hw, input_dim = x.shape
|
| 111 |
+
h = w = int((hw)**0.5)
|
| 112 |
+
"""compute padding"""
|
| 113 |
+
if h % self.cfg.downsample_ratio:
|
| 114 |
+
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
| 115 |
+
else:
|
| 116 |
+
pad = 0
|
| 117 |
+
x = x.reshape(bs, h, w, input_dim)
|
| 118 |
+
if pad > 0:
|
| 119 |
+
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
| 120 |
+
"""4 to 1 concat"""
|
| 121 |
+
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
| 122 |
+
x = F.unfold(x,
|
| 123 |
+
kernel_size=self.cfg.downsample_ratio,
|
| 124 |
+
stride=self.cfg.downsample_ratio,
|
| 125 |
+
padding=0) # B, C*4, HW // 4
|
| 126 |
+
x = x.permute(0, 2, 1)
|
| 127 |
+
|
| 128 |
+
return self.layers(x)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
|
| 132 |
+
|
| 133 |
+
def get_hf_config(self):
|
| 134 |
+
return self.ctx.get_hf_config(DeepseekVLV2Config)
|
| 135 |
+
|
| 136 |
+
def get_hf_processor(self) -> DeepseekVLV2Processor:
|
| 137 |
+
return self.ctx.get_hf_processor(DeepseekVLV2Processor)
|
| 138 |
+
|
| 139 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 140 |
+
return {"image": None}
|
| 141 |
+
|
| 142 |
+
def get_num_image_tokens(self, *, image_width: int,
|
| 143 |
+
image_height: int) -> int:
|
| 144 |
+
hf_processor = self.get_hf_processor()
|
| 145 |
+
image_size = hf_processor.image_size
|
| 146 |
+
patch_size = hf_processor.patch_size
|
| 147 |
+
downsample_ratio = hf_processor.downsample_ratio
|
| 148 |
+
|
| 149 |
+
best_width, best_height = hf_processor.select_best_resolution(
|
| 150 |
+
(image_width, image_height))
|
| 151 |
+
|
| 152 |
+
num_width_tiles, num_height_tiles = (best_width // image_size,
|
| 153 |
+
best_height // image_size)
|
| 154 |
+
h = w = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 155 |
+
|
| 156 |
+
global_views_tokens = h * (w + 1)
|
| 157 |
+
local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1)
|
| 158 |
+
return global_views_tokens + local_views_tokens + 1
|
| 159 |
+
|
| 160 |
+
def get_image_size_with_most_features(self) -> ImageSize:
|
| 161 |
+
hf_config = self.get_hf_config()
|
| 162 |
+
candidate_resolutions = hf_config.candidate_resolutions
|
| 163 |
+
height, width = max(candidate_resolutions,
|
| 164 |
+
key=lambda x: self.get_num_image_tokens(
|
| 165 |
+
image_width=x[1], image_height=x[0]))
|
| 166 |
+
return ImageSize(width=width, height=height)
|
| 167 |
+
|
| 168 |
+
def get_mm_max_tokens_per_item(
|
| 169 |
+
self,
|
| 170 |
+
seq_len: int,
|
| 171 |
+
mm_counts: Mapping[str, int],
|
| 172 |
+
) -> Mapping[str, int]:
|
| 173 |
+
max_image_size = self.get_image_size_with_most_features()
|
| 174 |
+
max_image_tokens = self.get_num_image_tokens(
|
| 175 |
+
image_height=max_image_size.height,
|
| 176 |
+
image_width=max_image_size.width)
|
| 177 |
+
|
| 178 |
+
return {"image": max_image_tokens}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DeepseekVL2DummyInputsBuilder(
|
| 182 |
+
BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]):
|
| 183 |
+
|
| 184 |
+
def get_dummy_processor_inputs(
|
| 185 |
+
self,
|
| 186 |
+
seq_len: int,
|
| 187 |
+
mm_counts: Mapping[str, int],
|
| 188 |
+
) -> ProcessorInputs:
|
| 189 |
+
num_images = mm_counts.get("image", 0)
|
| 190 |
+
hf_processor = self.info.get_hf_processor()
|
| 191 |
+
image_token: str = hf_processor.image_token
|
| 192 |
+
|
| 193 |
+
max_image_size = self.info.get_image_size_with_most_features()
|
| 194 |
+
|
| 195 |
+
mm_data = {
|
| 196 |
+
"image":
|
| 197 |
+
self._get_dummy_images(width=max_image_size.width,
|
| 198 |
+
height=max_image_size.height,
|
| 199 |
+
num_images=num_images)
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
return ProcessorInputs(
|
| 203 |
+
prompt_text=image_token * num_images,
|
| 204 |
+
mm_data=mm_data,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class DeepseekVL2MultiModalProcessor(
|
| 209 |
+
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
|
| 210 |
+
|
| 211 |
+
def _call_hf_processor(
|
| 212 |
+
self,
|
| 213 |
+
prompt: str,
|
| 214 |
+
mm_data: Mapping[str, object],
|
| 215 |
+
mm_kwargs: Mapping[str, object],
|
| 216 |
+
) -> BatchFeature:
|
| 217 |
+
if mm_data:
|
| 218 |
+
processed_outputs = self.info.ctx.call_hf_processor(
|
| 219 |
+
self.info.get_hf_processor(**mm_kwargs),
|
| 220 |
+
dict(prompt=prompt, **mm_data),
|
| 221 |
+
mm_kwargs,
|
| 222 |
+
)
|
| 223 |
+
target_dtype = self.info.ctx.model_config.dtype
|
| 224 |
+
pixel_values = processed_outputs.pop("pixel_values").to(
|
| 225 |
+
target_dtype)
|
| 226 |
+
# split pixel values into patches corresponding to each image
|
| 227 |
+
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
| 228 |
+
patches_per_image = [
|
| 229 |
+
x.prod().item() + 1 for x in images_spatial_crop
|
| 230 |
+
]
|
| 231 |
+
pixel_values = pixel_values.split(patches_per_image)
|
| 232 |
+
processed_outputs["pixel_values"] = pixel_values
|
| 233 |
+
else:
|
| 234 |
+
tokenizer = self.info.get_tokenizer()
|
| 235 |
+
processed_outputs = tokenizer(prompt,
|
| 236 |
+
add_special_tokens=True,
|
| 237 |
+
return_tensors="pt")
|
| 238 |
+
|
| 239 |
+
return processed_outputs
|
| 240 |
+
|
| 241 |
+
def _get_mm_fields_config(
|
| 242 |
+
self,
|
| 243 |
+
hf_inputs: BatchFeature,
|
| 244 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 245 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 246 |
+
return dict(
|
| 247 |
+
pixel_values=MultiModalFieldConfig.batched("image"),
|
| 248 |
+
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
| 249 |
+
image_embeds=MultiModalFieldConfig.batched("image"),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def _get_prompt_replacements(
|
| 253 |
+
self,
|
| 254 |
+
mm_items: MultiModalDataItems,
|
| 255 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 256 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 257 |
+
) -> list[PromptReplacement]:
|
| 258 |
+
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
| 259 |
+
|
| 260 |
+
image_token_id = hf_processor.image_token_id
|
| 261 |
+
assert isinstance(image_token_id, int)
|
| 262 |
+
|
| 263 |
+
def get_replacement_deepseek_vl2(item_idx: int):
|
| 264 |
+
images = mm_items.get_items(
|
| 265 |
+
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
| 266 |
+
|
| 267 |
+
if isinstance(images, ImageEmbeddingItems):
|
| 268 |
+
num_image_tokens = images.get_feature_size(item_idx)
|
| 269 |
+
else:
|
| 270 |
+
image_size = images.get_image_size(item_idx)
|
| 271 |
+
|
| 272 |
+
num_image_tokens = self.info.get_num_image_tokens(
|
| 273 |
+
image_width=image_size.width,
|
| 274 |
+
image_height=image_size.height,
|
| 275 |
+
)
|
| 276 |
+
return [image_token_id] * num_image_tokens
|
| 277 |
+
|
| 278 |
+
return [
|
| 279 |
+
PromptReplacement(
|
| 280 |
+
modality="image",
|
| 281 |
+
target=[image_token_id],
|
| 282 |
+
replacement=get_replacement_deepseek_vl2,
|
| 283 |
+
)
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@MULTIMODAL_REGISTRY.register_processor(
|
| 288 |
+
DeepseekVL2MultiModalProcessor,
|
| 289 |
+
info=DeepseekVL2ProcessingInfo,
|
| 290 |
+
dummy_inputs=DeepseekVL2DummyInputsBuilder)
|
| 291 |
+
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
| 292 |
+
|
| 293 |
+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
| 294 |
+
"language.": "language_model.",
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 298 |
+
super().__init__()
|
| 299 |
+
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
|
| 300 |
+
quant_config = vllm_config.quant_config
|
| 301 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 302 |
+
|
| 303 |
+
self.config = config
|
| 304 |
+
self.multimodal_config = multimodal_config
|
| 305 |
+
|
| 306 |
+
self.vision_config = config.vision_config
|
| 307 |
+
self.projector_config = config.projector_config
|
| 308 |
+
self.text_config = config.text_config
|
| 309 |
+
|
| 310 |
+
model_config = vllm_config.model_config
|
| 311 |
+
tokenizer = cached_get_tokenizer(
|
| 312 |
+
model_config.tokenizer,
|
| 313 |
+
tokenizer_mode=model_config.tokenizer_mode,
|
| 314 |
+
tokenizer_revision=model_config.tokenizer_revision,
|
| 315 |
+
trust_remote_code=model_config.trust_remote_code,
|
| 316 |
+
)
|
| 317 |
+
self.image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN)
|
| 318 |
+
|
| 319 |
+
self.vision = self._init_vision_module(self.vision_config,
|
| 320 |
+
quant_config,
|
| 321 |
+
maybe_prefix(prefix, "vision"))
|
| 322 |
+
|
| 323 |
+
self.projector = MlpProjector(self.projector_config)
|
| 324 |
+
self.tile_tag = config.tile_tag
|
| 325 |
+
self.global_view_pos = config.global_view_pos
|
| 326 |
+
|
| 327 |
+
# special token for image token sequence format
|
| 328 |
+
embed_std = 1 / torch.sqrt(
|
| 329 |
+
torch.tensor(self.projector_config.n_embed, dtype=torch.float32))
|
| 330 |
+
if self.tile_tag == "2D":
|
| 331 |
+
# <|view_separator|>, <|\n|>
|
| 332 |
+
self.image_newline = nn.Parameter(
|
| 333 |
+
torch.randn(self.projector_config.n_embed) * embed_std)
|
| 334 |
+
# This is a typo in original implementation
|
| 335 |
+
self.view_seperator = nn.Parameter(
|
| 336 |
+
torch.randn(self.projector_config.n_embed) * embed_std)
|
| 337 |
+
else:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if self.text_config.topk_method == "noaux_tc":
|
| 343 |
+
architectures = ["DeepseekV3ForCausalLM"]
|
| 344 |
+
elif not self.text_config.use_mla:
|
| 345 |
+
architectures = ["DeepseekForCausalLM"]
|
| 346 |
+
else:
|
| 347 |
+
architectures = ["DeepseekV2ForCausalLM"]
|
| 348 |
+
|
| 349 |
+
self.language_model = init_vllm_registered_model(
|
| 350 |
+
vllm_config=vllm_config,
|
| 351 |
+
hf_config=self.text_config,
|
| 352 |
+
prefix=maybe_prefix(prefix, "language"),
|
| 353 |
+
architectures=architectures,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
self.make_empty_intermediate_tensors = (
|
| 357 |
+
self.language_model.make_empty_intermediate_tensors)
|
| 358 |
+
|
| 359 |
+
def _init_vision_module(
|
| 360 |
+
self,
|
| 361 |
+
vision_config: VisionEncoderConfig,
|
| 362 |
+
quant_config: Optional[QuantizationConfig],
|
| 363 |
+
prefix: str = "",
|
| 364 |
+
) -> nn.Module:
|
| 365 |
+
# TODO: refactor vision model through timm wrapper from transformers
|
| 366 |
+
try:
|
| 367 |
+
import timm
|
| 368 |
+
except ImportError:
|
| 369 |
+
raise ImportError("Please install timm") from ImportError
|
| 370 |
+
|
| 371 |
+
with set_default_torch_dtype(torch.float16):
|
| 372 |
+
model = timm.create_model(
|
| 373 |
+
"vit_so400m_patch14_siglip_384.webli",
|
| 374 |
+
pretrained=False,
|
| 375 |
+
num_classes=0,
|
| 376 |
+
dynamic_img_size=True,
|
| 377 |
+
dynamic_img_pad=True,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
model = model.to(dtype=torch.get_default_dtype())
|
| 381 |
+
return model
|
| 382 |
+
|
| 383 |
+
@cached_property
|
| 384 |
+
def sampler(self):
|
| 385 |
+
if hasattr(self.language_model, "sampler"):
|
| 386 |
+
return self.language_model.sampler
|
| 387 |
+
|
| 388 |
+
return get_sampler()
|
| 389 |
+
|
| 390 |
+
def _validate_pixel_values(
|
| 391 |
+
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
| 392 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 393 |
+
|
| 394 |
+
h = w = self.vision_config.image_size
|
| 395 |
+
expected_dims = (3, h, w)
|
| 396 |
+
|
| 397 |
+
def _validate_shape(d: torch.Tensor):
|
| 398 |
+
actual_dims = tuple(d.shape[1:])
|
| 399 |
+
|
| 400 |
+
if actual_dims != expected_dims:
|
| 401 |
+
expected_expr = ("num_patches", *map(str, expected_dims))
|
| 402 |
+
raise ValueError(
|
| 403 |
+
"The expected shape of pixel values per image per batch "
|
| 404 |
+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
| 405 |
+
|
| 406 |
+
for d in data:
|
| 407 |
+
_validate_shape(d)
|
| 408 |
+
|
| 409 |
+
return data
|
| 410 |
+
|
| 411 |
+
def _validate_images_spatial_crop(
|
| 412 |
+
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
| 413 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 414 |
+
expected_dims = 2
|
| 415 |
+
|
| 416 |
+
def _validate_shape(d: torch.Tensor):
|
| 417 |
+
actual_dims = d.size(-1)
|
| 418 |
+
|
| 419 |
+
if actual_dims != expected_dims:
|
| 420 |
+
expected_expr = str(expected_dims)
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"The expected shape of image sizes per image per batch "
|
| 423 |
+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
| 424 |
+
|
| 425 |
+
for d in data:
|
| 426 |
+
_validate_shape(d)
|
| 427 |
+
|
| 428 |
+
return data
|
| 429 |
+
|
| 430 |
+
def _parse_and_validate_image_input(
|
| 431 |
+
self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]:
|
| 432 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 433 |
+
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
| 434 |
+
image_embeds = kwargs.pop("image_embeds", None)
|
| 435 |
+
|
| 436 |
+
if pixel_values is None and image_embeds is None:
|
| 437 |
+
return None
|
| 438 |
+
|
| 439 |
+
if pixel_values is not None:
|
| 440 |
+
if not isinstance(pixel_values, (torch.Tensor, list)):
|
| 441 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 442 |
+
f"Got type: {type(pixel_values)}")
|
| 443 |
+
|
| 444 |
+
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
| 445 |
+
raise ValueError("Incorrect type of image sizes. "
|
| 446 |
+
f"Got type: {type(images_spatial_crop)}")
|
| 447 |
+
|
| 448 |
+
return DeepseekVL2ImagePixelInputs(
|
| 449 |
+
type="pixel_values",
|
| 450 |
+
data=self._validate_pixel_values(flatten_bn(pixel_values)),
|
| 451 |
+
images_spatial_crop=self._validate_images_spatial_crop(
|
| 452 |
+
flatten_bn(images_spatial_crop, concat=True)))
|
| 453 |
+
|
| 454 |
+
if image_embeds is not None:
|
| 455 |
+
if not isinstance(image_embeds, torch.Tensor):
|
| 456 |
+
raise ValueError("Incorrect type of image embeddings. "
|
| 457 |
+
f"Got type: {type(image_embeds)}")
|
| 458 |
+
|
| 459 |
+
return DeepseekVL2VImageEmbeddingInputs(
|
| 460 |
+
type="image_embeds",
|
| 461 |
+
data=flatten_bn(image_embeds),
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
raise AssertionError("This line should be unreachable.")
|
| 465 |
+
|
| 466 |
+
def _pixel_values_to_embedding(
|
| 467 |
+
self,
|
| 468 |
+
pixel_values: NestedTensors,
|
| 469 |
+
images_spatial_crop: torch.Tensor,
|
| 470 |
+
) -> NestedTensors:
|
| 471 |
+
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
|
| 472 |
+
total_tiles = [x for x in pixel_values]
|
| 473 |
+
|
| 474 |
+
# [batch_all_tiles, 3, height, width]
|
| 475 |
+
total_tiles = torch.cat(total_tiles, dim=0)
|
| 476 |
+
|
| 477 |
+
# [batch_all_tiles, vit_seq_len, c]
|
| 478 |
+
images_feature = self.vision.forward_features(total_tiles)
|
| 479 |
+
|
| 480 |
+
# [batch_all_tiles, hw, D]
|
| 481 |
+
images_embeds = self.projector(images_feature)
|
| 482 |
+
|
| 483 |
+
_, hw, n_dim = images_embeds.shape
|
| 484 |
+
h = w = int(hw**0.5)
|
| 485 |
+
|
| 486 |
+
# 根据self.tile_tag & self.global_view_pos填充image token sequence
|
| 487 |
+
tile_index = 0
|
| 488 |
+
vision_embeddings = []
|
| 489 |
+
for jdx in range(images_spatial_crop.size(0)):
|
| 490 |
+
# extra global & local features
|
| 491 |
+
num_width_tiles, num_height_tiles = images_spatial_crop[jdx]
|
| 492 |
+
if num_width_tiles == 0 or num_height_tiles == 0:
|
| 493 |
+
break
|
| 494 |
+
num_tiles_in_image = num_width_tiles * num_height_tiles
|
| 495 |
+
|
| 496 |
+
# [hw, D]
|
| 497 |
+
global_features = images_embeds[tile_index]
|
| 498 |
+
|
| 499 |
+
# [num_height_tiles * num_width_tiles, hw, D]
|
| 500 |
+
local_features = images_embeds[tile_index + 1:tile_index + 1 +
|
| 501 |
+
num_tiles_in_image]
|
| 502 |
+
tile_index += num_tiles_in_image + 1
|
| 503 |
+
|
| 504 |
+
# format global and local features
|
| 505 |
+
# ----------------- global view add newline -----------------
|
| 506 |
+
# [hw, D] -> [h, w, D]
|
| 507 |
+
global_features = global_features.view(h, w, n_dim)
|
| 508 |
+
|
| 509 |
+
# [D] -> [h, 1, D]
|
| 510 |
+
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
| 511 |
+
|
| 512 |
+
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
| 513 |
+
global_features = torch.cat([global_features, new_lines_in_global],
|
| 514 |
+
dim=1)
|
| 515 |
+
|
| 516 |
+
# [h, w + 1, D] -> [h * (w + 1), D]
|
| 517 |
+
global_features = global_features.view(-1, n_dim)
|
| 518 |
+
|
| 519 |
+
# ----------------- local view add newline -----------------
|
| 520 |
+
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
| 521 |
+
# [num_height_tiles * h, num_width_tiles * w, D]
|
| 522 |
+
local_features = rearrange(local_features,
|
| 523 |
+
"(th tw) (h w) d -> (th h) (tw w) d",
|
| 524 |
+
th=num_height_tiles,
|
| 525 |
+
tw=num_width_tiles,
|
| 526 |
+
h=h,
|
| 527 |
+
w=w)
|
| 528 |
+
|
| 529 |
+
# [D] -> [num_height_tiles * h, 1, D]
|
| 530 |
+
new_lines_in_local = repeat(self.image_newline,
|
| 531 |
+
"d -> (th h) 1 d",
|
| 532 |
+
th=num_height_tiles,
|
| 533 |
+
h=h)
|
| 534 |
+
|
| 535 |
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
| 536 |
+
local_features = torch.cat([local_features, new_lines_in_local],
|
| 537 |
+
dim=1)
|
| 538 |
+
|
| 539 |
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
| 540 |
+
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
| 541 |
+
local_features = local_features.view(-1, n_dim)
|
| 542 |
+
|
| 543 |
+
# merge global and local tiles
|
| 544 |
+
if self.global_view_pos == "head":
|
| 545 |
+
global_local_features = torch.cat([
|
| 546 |
+
global_features,
|
| 547 |
+
self.view_seperator[None, :],
|
| 548 |
+
local_features,
|
| 549 |
+
])
|
| 550 |
+
else:
|
| 551 |
+
global_local_features = torch.cat([
|
| 552 |
+
local_features,
|
| 553 |
+
self.view_seperator[None, :],
|
| 554 |
+
global_features,
|
| 555 |
+
])
|
| 556 |
+
|
| 557 |
+
vision_embeddings.append(global_local_features)
|
| 558 |
+
return vision_embeddings
|
| 559 |
+
|
| 560 |
+
def _process_image_input(
|
| 561 |
+
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
|
| 562 |
+
if image_input["type"] == "image_embeds":
|
| 563 |
+
image_data = image_input["data"]
|
| 564 |
+
if is_list_of(image_data, torch.Tensor):
|
| 565 |
+
# it's already a list of tensors
|
| 566 |
+
return image_data
|
| 567 |
+
if len(image_data.shape) == 3:
|
| 568 |
+
# 3D tensor
|
| 569 |
+
return list(torch.unbind(image_data, dim=0))
|
| 570 |
+
raise ValueError(
|
| 571 |
+
"We expect batched 2D tensors;"
|
| 572 |
+
"this can be either a list of 2D tensors or a single 3D tensor."
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
pixel_values = image_input["data"]
|
| 576 |
+
images_spatial_crop = image_input["images_spatial_crop"]
|
| 577 |
+
|
| 578 |
+
return self._pixel_values_to_embedding(
|
| 579 |
+
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
|
| 580 |
+
|
| 581 |
+
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
|
| 582 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 583 |
+
if image_input is None:
|
| 584 |
+
return None
|
| 585 |
+
vision_embeddings = self._process_image_input(image_input)
|
| 586 |
+
return vision_embeddings
|
| 587 |
+
|
| 588 |
+
def get_input_embeddings(
|
| 589 |
+
self,
|
| 590 |
+
input_ids: torch.Tensor,
|
| 591 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 592 |
+
) -> torch.Tensor:
|
| 593 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 594 |
+
if multimodal_embeddings is not None:
|
| 595 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 596 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 597 |
+
self.image_token_id)
|
| 598 |
+
return inputs_embeds
|
| 599 |
+
|
| 600 |
+
def forward(self,
|
| 601 |
+
input_ids: torch.Tensor,
|
| 602 |
+
positions: torch.Tensor,
|
| 603 |
+
kv_caches: List[torch.Tensor],
|
| 604 |
+
attn_metadata: AttentionMetadata,
|
| 605 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 606 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 607 |
+
**kwargs: object):
|
| 608 |
+
|
| 609 |
+
if intermediate_tensors is not None:
|
| 610 |
+
inputs_embeds = None
|
| 611 |
+
|
| 612 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 613 |
+
# condition is for v0 compatibility
|
| 614 |
+
elif inputs_embeds is None:
|
| 615 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 616 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 617 |
+
vision_embeddings)
|
| 618 |
+
input_ids = None
|
| 619 |
+
|
| 620 |
+
hidden_states = self.language_model(input_ids,
|
| 621 |
+
positions,
|
| 622 |
+
kv_caches,
|
| 623 |
+
attn_metadata,
|
| 624 |
+
intermediate_tensors,
|
| 625 |
+
inputs_embeds=inputs_embeds)
|
| 626 |
+
|
| 627 |
+
return hidden_states
|
| 628 |
+
|
| 629 |
+
def compute_logits(
|
| 630 |
+
self,
|
| 631 |
+
hidden_states: torch.Tensor,
|
| 632 |
+
sampling_metadata: SamplingMetadata,
|
| 633 |
+
) -> Optional[torch.Tensor]:
|
| 634 |
+
return self.language_model.compute_logits(hidden_states,
|
| 635 |
+
sampling_metadata)
|
| 636 |
+
|
| 637 |
+
def sample(
|
| 638 |
+
self,
|
| 639 |
+
logits: torch.Tensor,
|
| 640 |
+
sampling_metadata: SamplingMetadata,
|
| 641 |
+
) -> Optional[SamplerOutput]:
|
| 642 |
+
return self.language_model.sample(logits, sampling_metadata)
|
| 643 |
+
|
| 644 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 645 |
+
torch.Tensor]]) -> Set[str]:
|
| 646 |
+
|
| 647 |
+
loader = AutoWeightsLoader(self)
|
| 648 |
+
autoloaded_weights = loader.load_weights(weights,
|
| 649 |
+
mapper=self.hf_to_vllm_mapper)
|
| 650 |
+
return autoloaded_weights
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/exaone.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py
|
| 5 |
+
# Copyright 2024 The LG U+ CTO AI Tech Lab.
|
| 6 |
+
# Copyright 2021 The LG AI Research EXAONE Lab
|
| 7 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
| 26 |
+
|
| 27 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch import nn
|
| 31 |
+
|
| 32 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 33 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 34 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 35 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 36 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 37 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 38 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 39 |
+
QKVParallelLinear,
|
| 40 |
+
RowParallelLinear)
|
| 41 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 42 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 43 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 44 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 45 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 46 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 47 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 48 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 49 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 50 |
+
from vllm.sequence import IntermediateTensors
|
| 51 |
+
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
| 52 |
+
|
| 53 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 54 |
+
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
| 55 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 56 |
+
maybe_prefix)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ExaoneGatedMLP(nn.Module):
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
hidden_size: int,
|
| 64 |
+
intermediate_size: int,
|
| 65 |
+
hidden_act: str,
|
| 66 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 67 |
+
bias: bool = False,
|
| 68 |
+
prefix: str = "",
|
| 69 |
+
) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 72 |
+
input_size=hidden_size,
|
| 73 |
+
output_sizes=[intermediate_size] * 2,
|
| 74 |
+
bias=bias,
|
| 75 |
+
quant_config=quant_config,
|
| 76 |
+
prefix=f"{prefix}.gate_up_proj",
|
| 77 |
+
)
|
| 78 |
+
self.c_proj = RowParallelLinear(
|
| 79 |
+
input_size=intermediate_size,
|
| 80 |
+
output_size=hidden_size,
|
| 81 |
+
bias=bias,
|
| 82 |
+
quant_config=quant_config,
|
| 83 |
+
prefix=f"{prefix}.c_proj",
|
| 84 |
+
)
|
| 85 |
+
if hidden_act != "silu":
|
| 86 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 87 |
+
"Only silu is supported for now.")
|
| 88 |
+
self.act_fn = SiluAndMul()
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 92 |
+
x = self.act_fn(gate_up)
|
| 93 |
+
x, _ = self.c_proj(x)
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ExaoneAttention(nn.Module):
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
config: ExaoneConfig,
|
| 102 |
+
hidden_size: int,
|
| 103 |
+
num_heads: int,
|
| 104 |
+
num_kv_heads: int,
|
| 105 |
+
rope_theta: float = 10000,
|
| 106 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 107 |
+
max_position_embeddings: int = 8192,
|
| 108 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 109 |
+
bias: bool = False,
|
| 110 |
+
cache_config: Optional[CacheConfig] = None,
|
| 111 |
+
prefix: str = "",
|
| 112 |
+
) -> None:
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.hidden_size = hidden_size
|
| 115 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 116 |
+
self.total_num_heads = num_heads
|
| 117 |
+
assert self.total_num_heads % tp_size == 0
|
| 118 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 119 |
+
self.total_num_kv_heads = num_kv_heads
|
| 120 |
+
if self.total_num_kv_heads >= tp_size:
|
| 121 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 122 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 123 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 124 |
+
else:
|
| 125 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 126 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 127 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 128 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 129 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 130 |
+
self.head_dim = getattr(config, "head_dim",
|
| 131 |
+
self.hidden_size // self.total_num_heads)
|
| 132 |
+
self.q_size = self.num_heads * self.head_dim
|
| 133 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 134 |
+
self.scaling = self.head_dim**-0.5
|
| 135 |
+
self.rope_theta = rope_theta
|
| 136 |
+
self.max_position_embeddings = max_position_embeddings
|
| 137 |
+
|
| 138 |
+
self.qkv_proj = QKVParallelLinear(
|
| 139 |
+
hidden_size=hidden_size,
|
| 140 |
+
head_size=self.head_dim,
|
| 141 |
+
total_num_heads=self.total_num_heads,
|
| 142 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 143 |
+
bias=bias,
|
| 144 |
+
quant_config=quant_config,
|
| 145 |
+
prefix=f"{prefix}.qkv_proj",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.out_proj = RowParallelLinear(
|
| 149 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 150 |
+
output_size=hidden_size,
|
| 151 |
+
bias=bias,
|
| 152 |
+
quant_config=quant_config,
|
| 153 |
+
prefix=f"{prefix}.out_proj",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
is_neox_style = True
|
| 157 |
+
if quant_config is not None and quant_config.get_name() == "gguf":
|
| 158 |
+
is_neox_style = False
|
| 159 |
+
|
| 160 |
+
self.rotary_emb = get_rope(
|
| 161 |
+
self.head_dim,
|
| 162 |
+
rotary_dim=self.head_dim,
|
| 163 |
+
max_position=max_position_embeddings,
|
| 164 |
+
base=rope_theta,
|
| 165 |
+
rope_scaling=rope_scaling,
|
| 166 |
+
is_neox_style=is_neox_style,
|
| 167 |
+
)
|
| 168 |
+
self.attn = Attention(
|
| 169 |
+
self.num_heads,
|
| 170 |
+
self.head_dim,
|
| 171 |
+
self.scaling,
|
| 172 |
+
num_kv_heads=self.num_kv_heads,
|
| 173 |
+
cache_config=cache_config,
|
| 174 |
+
quant_config=quant_config,
|
| 175 |
+
prefix=f"{prefix}.attn",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def forward(
|
| 179 |
+
self,
|
| 180 |
+
positions: torch.Tensor,
|
| 181 |
+
hidden_states: torch.Tensor,
|
| 182 |
+
kv_cache: torch.Tensor,
|
| 183 |
+
attn_metadata: AttentionMetadata,
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 186 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 187 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 188 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 189 |
+
output, _ = self.out_proj(attn_output)
|
| 190 |
+
return output
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ExaoneBlockAttention(nn.Module):
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
config: ExaoneConfig,
|
| 198 |
+
hidden_size: int,
|
| 199 |
+
num_heads: int,
|
| 200 |
+
num_kv_heads: int,
|
| 201 |
+
rope_theta: float = 10000,
|
| 202 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 203 |
+
max_position_embeddings: int = 8192,
|
| 204 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 205 |
+
bias: bool = False,
|
| 206 |
+
cache_config: Optional[CacheConfig] = None,
|
| 207 |
+
prefix: str = "",
|
| 208 |
+
) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.attention = ExaoneAttention(
|
| 211 |
+
config=config,
|
| 212 |
+
hidden_size=hidden_size,
|
| 213 |
+
num_heads=num_heads,
|
| 214 |
+
num_kv_heads=num_kv_heads,
|
| 215 |
+
rope_theta=rope_theta,
|
| 216 |
+
rope_scaling=rope_scaling,
|
| 217 |
+
max_position_embeddings=max_position_embeddings,
|
| 218 |
+
quant_config=quant_config,
|
| 219 |
+
bias=bias,
|
| 220 |
+
cache_config=cache_config,
|
| 221 |
+
prefix=f"{prefix}.attention",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(
|
| 225 |
+
self,
|
| 226 |
+
positions: torch.Tensor,
|
| 227 |
+
hidden_states: torch.Tensor,
|
| 228 |
+
kv_cache: torch.Tensor,
|
| 229 |
+
attn_metadata: AttentionMetadata,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
return self.attention(
|
| 232 |
+
positions=positions,
|
| 233 |
+
hidden_states=hidden_states,
|
| 234 |
+
kv_cache=kv_cache,
|
| 235 |
+
attn_metadata=attn_metadata,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ExaoneDecoderLayer(nn.Module):
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
config: ExaoneConfig,
|
| 244 |
+
cache_config: Optional[CacheConfig] = None,
|
| 245 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 246 |
+
prefix: str = "",
|
| 247 |
+
) -> None:
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.hidden_size = config.hidden_size
|
| 250 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 251 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 252 |
+
if rope_scaling is not None and getattr(
|
| 253 |
+
config, "original_max_position_embeddings", None):
|
| 254 |
+
rope_scaling["original_max_position_embeddings"] = (
|
| 255 |
+
config.original_max_position_embeddings)
|
| 256 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 257 |
+
8192)
|
| 258 |
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
| 259 |
+
# Support internlm/internlm-7b with bias
|
| 260 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 261 |
+
config, "bias", False)
|
| 262 |
+
self.attn = ExaoneBlockAttention(
|
| 263 |
+
config=config,
|
| 264 |
+
hidden_size=self.hidden_size,
|
| 265 |
+
num_heads=config.num_attention_heads,
|
| 266 |
+
num_kv_heads=getattr(config, "num_key_value_heads",
|
| 267 |
+
config.num_attention_heads),
|
| 268 |
+
rope_theta=rope_theta,
|
| 269 |
+
rope_scaling=rope_scaling,
|
| 270 |
+
max_position_embeddings=max_position_embeddings,
|
| 271 |
+
quant_config=quant_config,
|
| 272 |
+
bias=attention_bias,
|
| 273 |
+
cache_config=cache_config,
|
| 274 |
+
prefix=f"{prefix}.attn",
|
| 275 |
+
)
|
| 276 |
+
self.mlp = ExaoneGatedMLP(
|
| 277 |
+
hidden_size=self.hidden_size,
|
| 278 |
+
intermediate_size=config.intermediate_size,
|
| 279 |
+
hidden_act=config.activation_function,
|
| 280 |
+
quant_config=quant_config,
|
| 281 |
+
bias=getattr(config, "mlp_bias", False),
|
| 282 |
+
prefix=f"{prefix}.mlp",
|
| 283 |
+
)
|
| 284 |
+
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 285 |
+
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 286 |
+
|
| 287 |
+
def forward(
|
| 288 |
+
self,
|
| 289 |
+
positions: torch.Tensor,
|
| 290 |
+
hidden_states: torch.Tensor,
|
| 291 |
+
kv_cache: torch.Tensor,
|
| 292 |
+
attn_metadata: AttentionMetadata,
|
| 293 |
+
residual: Optional[torch.Tensor],
|
| 294 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 295 |
+
# Self Attention
|
| 296 |
+
if residual is None:
|
| 297 |
+
residual = hidden_states
|
| 298 |
+
hidden_states = self.ln_1(hidden_states)
|
| 299 |
+
else:
|
| 300 |
+
hidden_states, residual = self.ln_1(hidden_states, residual)
|
| 301 |
+
hidden_states = self.attn(
|
| 302 |
+
positions=positions,
|
| 303 |
+
hidden_states=hidden_states,
|
| 304 |
+
kv_cache=kv_cache,
|
| 305 |
+
attn_metadata=attn_metadata,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Fully Connected
|
| 309 |
+
hidden_states, residual = self.ln_2(hidden_states, residual)
|
| 310 |
+
hidden_states = self.mlp(hidden_states)
|
| 311 |
+
return hidden_states, residual
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@support_torch_compile
|
| 315 |
+
class ExaoneModel(nn.Module):
|
| 316 |
+
|
| 317 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 318 |
+
super().__init__()
|
| 319 |
+
|
| 320 |
+
config = vllm_config.model_config.hf_config
|
| 321 |
+
cache_config = vllm_config.cache_config
|
| 322 |
+
quant_config = vllm_config.quant_config
|
| 323 |
+
lora_config = vllm_config.lora_config
|
| 324 |
+
|
| 325 |
+
self.config = config
|
| 326 |
+
self.padding_idx = config.pad_token_id
|
| 327 |
+
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
| 328 |
+
(lora_config.max_loras or 1)) if lora_config else 0)
|
| 329 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 330 |
+
self.wte = config.vocab_size
|
| 331 |
+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
| 332 |
+
and get_pp_group().is_last_rank):
|
| 333 |
+
self.wte = VocabParallelEmbedding(
|
| 334 |
+
self.vocab_size,
|
| 335 |
+
config.hidden_size,
|
| 336 |
+
org_num_embeddings=config.vocab_size,
|
| 337 |
+
quant_config=quant_config,
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
self.wte = PPMissingLayer()
|
| 341 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 342 |
+
config.num_hidden_layers,
|
| 343 |
+
lambda prefix: ExaoneDecoderLayer(
|
| 344 |
+
config=config,
|
| 345 |
+
cache_config=cache_config,
|
| 346 |
+
quant_config=quant_config,
|
| 347 |
+
prefix=prefix,
|
| 348 |
+
),
|
| 349 |
+
prefix=f"{prefix}.h",
|
| 350 |
+
)
|
| 351 |
+
if get_pp_group().is_last_rank:
|
| 352 |
+
self.ln_f = RMSNorm(config.hidden_size,
|
| 353 |
+
eps=config.layer_norm_epsilon)
|
| 354 |
+
else:
|
| 355 |
+
self.ln_f = PPMissingLayer()
|
| 356 |
+
|
| 357 |
+
self.make_empty_intermediate_tensors = (
|
| 358 |
+
make_empty_intermediate_tensors_factory(
|
| 359 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 360 |
+
|
| 361 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 362 |
+
return self.wte(input_ids)
|
| 363 |
+
|
| 364 |
+
def forward(
|
| 365 |
+
self,
|
| 366 |
+
input_ids: Optional[torch.Tensor],
|
| 367 |
+
positions: torch.Tensor,
|
| 368 |
+
kv_caches: List[torch.Tensor],
|
| 369 |
+
attn_metadata: AttentionMetadata,
|
| 370 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 371 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 372 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 373 |
+
if get_pp_group().is_first_rank:
|
| 374 |
+
if inputs_embeds is not None:
|
| 375 |
+
hidden_states = inputs_embeds
|
| 376 |
+
else:
|
| 377 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 378 |
+
residual = None
|
| 379 |
+
else:
|
| 380 |
+
assert intermediate_tensors is not None
|
| 381 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 382 |
+
residual = intermediate_tensors["residual"]
|
| 383 |
+
|
| 384 |
+
for i in range(self.start_layer, self.end_layer):
|
| 385 |
+
layer = self.h[i]
|
| 386 |
+
hidden_states, residual = layer(
|
| 387 |
+
positions,
|
| 388 |
+
hidden_states,
|
| 389 |
+
kv_caches[i - self.start_layer],
|
| 390 |
+
attn_metadata,
|
| 391 |
+
residual,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if not get_pp_group().is_last_rank:
|
| 395 |
+
return IntermediateTensors({
|
| 396 |
+
"hidden_states": hidden_states,
|
| 397 |
+
"residual": residual
|
| 398 |
+
})
|
| 399 |
+
|
| 400 |
+
hidden_states, _ = self.ln_f(hidden_states, residual)
|
| 401 |
+
return hidden_states
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 405 |
+
packed_modules_mapping = {
|
| 406 |
+
"qkv_proj": [
|
| 407 |
+
"q_proj",
|
| 408 |
+
"k_proj",
|
| 409 |
+
"v_proj",
|
| 410 |
+
],
|
| 411 |
+
"gate_up_proj": [
|
| 412 |
+
"c_fc_0",
|
| 413 |
+
"c_fc_1",
|
| 414 |
+
],
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# LoRA specific attributes
|
| 418 |
+
supported_lora_modules = [
|
| 419 |
+
"qkv_proj",
|
| 420 |
+
"out_proj",
|
| 421 |
+
"gate_up_proj",
|
| 422 |
+
"c_proj",
|
| 423 |
+
"wte",
|
| 424 |
+
"lm_head",
|
| 425 |
+
]
|
| 426 |
+
embedding_modules = {
|
| 427 |
+
"wte": "input_embeddings",
|
| 428 |
+
"lm_head": "output_embeddings",
|
| 429 |
+
}
|
| 430 |
+
embedding_padding_modules = ["lm_head"]
|
| 431 |
+
|
| 432 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 433 |
+
super().__init__()
|
| 434 |
+
config = vllm_config.model_config.hf_config
|
| 435 |
+
quant_config = vllm_config.quant_config
|
| 436 |
+
lora_config = vllm_config.lora_config
|
| 437 |
+
|
| 438 |
+
self.config = config
|
| 439 |
+
self.lora_config = lora_config
|
| 440 |
+
self.quant_config = quant_config
|
| 441 |
+
|
| 442 |
+
self.transformer = ExaoneModel(
|
| 443 |
+
vllm_config=vllm_config,
|
| 444 |
+
prefix=maybe_prefix(prefix, "model"),
|
| 445 |
+
)
|
| 446 |
+
if get_pp_group().is_last_rank:
|
| 447 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 448 |
+
if lora_config:
|
| 449 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 450 |
+
self.lm_head = ParallelLMHead(
|
| 451 |
+
self.unpadded_vocab_size,
|
| 452 |
+
config.hidden_size,
|
| 453 |
+
org_num_embeddings=config.vocab_size,
|
| 454 |
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
| 455 |
+
# We need bigger padding if using lora for kernel
|
| 456 |
+
# compatibility
|
| 457 |
+
if not lora_config else lora_config.lora_vocab_padding_size,
|
| 458 |
+
quant_config=quant_config,
|
| 459 |
+
)
|
| 460 |
+
if config.tie_word_embeddings:
|
| 461 |
+
self.lm_head.weight = self.transformer.wte.weight
|
| 462 |
+
|
| 463 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 464 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 465 |
+
config.vocab_size,
|
| 466 |
+
logit_scale)
|
| 467 |
+
else:
|
| 468 |
+
self.lm_head = PPMissingLayer()
|
| 469 |
+
|
| 470 |
+
self.sampler = get_sampler()
|
| 471 |
+
|
| 472 |
+
self.make_empty_intermediate_tensors = (
|
| 473 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 474 |
+
|
| 475 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 476 |
+
return self.model.get_input_embeddings(input_ids)
|
| 477 |
+
|
| 478 |
+
def forward(
|
| 479 |
+
self,
|
| 480 |
+
input_ids: torch.Tensor,
|
| 481 |
+
positions: torch.Tensor,
|
| 482 |
+
kv_caches: List[torch.Tensor],
|
| 483 |
+
attn_metadata: AttentionMetadata,
|
| 484 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 485 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 486 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 487 |
+
model_output = self.transformer(input_ids, positions, kv_caches,
|
| 488 |
+
attn_metadata, intermediate_tensors,
|
| 489 |
+
inputs_embeds)
|
| 490 |
+
return model_output
|
| 491 |
+
|
| 492 |
+
def compute_logits(
|
| 493 |
+
self,
|
| 494 |
+
hidden_states: torch.Tensor,
|
| 495 |
+
sampling_metadata: SamplingMetadata,
|
| 496 |
+
) -> Optional[torch.Tensor]:
|
| 497 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 498 |
+
sampling_metadata)
|
| 499 |
+
return logits
|
| 500 |
+
|
| 501 |
+
def sample(
|
| 502 |
+
self,
|
| 503 |
+
logits: torch.Tensor,
|
| 504 |
+
sampling_metadata: SamplingMetadata,
|
| 505 |
+
) -> Optional[SamplerOutput]:
|
| 506 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 507 |
+
return next_tokens
|
| 508 |
+
|
| 509 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 510 |
+
torch.Tensor]]) -> Set[str]:
|
| 511 |
+
stacked_params_mapping = [
|
| 512 |
+
# (param_name, shard_name, shard_id)
|
| 513 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 514 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 515 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 516 |
+
(".gate_up_proj", ".c_fc_0", 0),
|
| 517 |
+
(".gate_up_proj", ".c_fc_1", 1),
|
| 518 |
+
]
|
| 519 |
+
params_dict = dict(self.named_parameters())
|
| 520 |
+
loaded_params: Set[str] = set()
|
| 521 |
+
for name, loaded_weight in weights:
|
| 522 |
+
if "rotary_emb.inv_freq" in name:
|
| 523 |
+
continue
|
| 524 |
+
if ("rotary_emb.cos_cached" in name
|
| 525 |
+
or "rotary_emb.sin_cached" in name):
|
| 526 |
+
# Models trained using ColossalAI may include these tensors in
|
| 527 |
+
# the checkpoint. Skip them.
|
| 528 |
+
continue
|
| 529 |
+
# With tie_word_embeddings, we can skip lm_head.weight
|
| 530 |
+
# The weight might appear unnecessarily in the files if the model is
|
| 531 |
+
# processed with quantization, LoRA, fine-tuning, etc.
|
| 532 |
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
| 533 |
+
continue
|
| 534 |
+
if (self.quant_config is not None and
|
| 535 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 536 |
+
# Loading kv cache quantization scales
|
| 537 |
+
param = params_dict[scale_name]
|
| 538 |
+
weight_loader = getattr(param, "weight_loader",
|
| 539 |
+
default_weight_loader)
|
| 540 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 541 |
+
loaded_weight[0])
|
| 542 |
+
weight_loader(param, loaded_weight)
|
| 543 |
+
loaded_params.add(scale_name)
|
| 544 |
+
continue
|
| 545 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 546 |
+
if weight_name not in name:
|
| 547 |
+
continue
|
| 548 |
+
name = name.replace(weight_name, param_name)
|
| 549 |
+
# Skip loading extra bias for GPTQ models.
|
| 550 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 551 |
+
continue
|
| 552 |
+
|
| 553 |
+
if is_pp_missing_parameter(name, self):
|
| 554 |
+
continue
|
| 555 |
+
|
| 556 |
+
param = params_dict[name]
|
| 557 |
+
weight_loader = param.weight_loader
|
| 558 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 559 |
+
|
| 560 |
+
break
|
| 561 |
+
else:
|
| 562 |
+
# Skip loading extra bias for GPTQ models.
|
| 563 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 564 |
+
continue
|
| 565 |
+
# Remapping the name of FP8 kv-scale.
|
| 566 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 567 |
+
if name is None:
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
if is_pp_missing_parameter(name, self):
|
| 571 |
+
continue
|
| 572 |
+
|
| 573 |
+
param = params_dict[name]
|
| 574 |
+
weight_loader = getattr(param, "weight_loader",
|
| 575 |
+
default_weight_loader)
|
| 576 |
+
weight_loader(param, loaded_weight)
|
| 577 |
+
loaded_params.add(name)
|
| 578 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/fairseq2_llama.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The vLLM team.
|
| 4 |
+
# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""Llama model for fairseq2 weights."""
|
| 18 |
+
|
| 19 |
+
from typing import Iterable, Set, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch.nn import Parameter
|
| 23 |
+
|
| 24 |
+
from vllm.config import VllmConfig
|
| 25 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 26 |
+
get_tensor_model_parallel_world_size)
|
| 27 |
+
from vllm.model_executor.layers.linear import set_weight_attrs
|
| 28 |
+
from vllm.model_executor.models.llama import LlamaForCausalLM
|
| 29 |
+
|
| 30 |
+
from .utils import AutoWeightsLoader, WeightsMapper
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
| 34 |
+
|
| 35 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 36 |
+
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
| 37 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 38 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 39 |
+
# For the model loader to read only the relevant checkpoint files
|
| 40 |
+
self.allow_patterns_overrides = [
|
| 41 |
+
# either the full checkpoint
|
| 42 |
+
"model.pt",
|
| 43 |
+
# or the tp-sharded checkpoint of the current rank
|
| 44 |
+
f"model.{self.tp_rank}.pt",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 48 |
+
torch.Tensor]]) -> Set[str]:
|
| 49 |
+
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
|
| 50 |
+
# { "model_key": my_model_name, "my_model_name": state_dict }
|
| 51 |
+
# which we first need to unpack
|
| 52 |
+
weights_wrapped = dict(weights)
|
| 53 |
+
weights = weights_wrapped[
|
| 54 |
+
weights_wrapped["model_key"]].items() # type: ignore
|
| 55 |
+
|
| 56 |
+
# remap keys
|
| 57 |
+
fs2_to_vllm_mapper = WeightsMapper(
|
| 58 |
+
orig_to_new_prefix={
|
| 59 |
+
"decoder_frontend.embed.": "model.embed_tokens.",
|
| 60 |
+
"decoder.": "model.",
|
| 61 |
+
"final_proj.": "lm_head.",
|
| 62 |
+
},
|
| 63 |
+
orig_to_new_substr={
|
| 64 |
+
".self_attn_layer_norm.": ".input_layernorm.",
|
| 65 |
+
".ffn_layer_norm.": ".post_attention_layernorm.",
|
| 66 |
+
".self_attn.output_proj.": ".self_attn.o_proj.",
|
| 67 |
+
".ffn.gate_proj.": ".mlp.gate_proj.",
|
| 68 |
+
".ffn.inner_proj.": ".mlp.up_proj.",
|
| 69 |
+
".ffn.output_proj.": ".mlp.down_proj.",
|
| 70 |
+
".layer_norm.": ".norm.",
|
| 71 |
+
},
|
| 72 |
+
)
|
| 73 |
+
weights = fs2_to_vllm_mapper.apply(weights)
|
| 74 |
+
|
| 75 |
+
params = dict(self.named_parameters())
|
| 76 |
+
|
| 77 |
+
loader = AutoWeightsLoader(
|
| 78 |
+
self,
|
| 79 |
+
skip_prefixes=(["lm_head."]
|
| 80 |
+
if self.config.tie_word_embeddings else None),
|
| 81 |
+
)
|
| 82 |
+
return loader.load_weights(
|
| 83 |
+
(self.reshape_fairseq2_weights(name, loaded_weight, params)
|
| 84 |
+
for name, loaded_weight in weights))
|
| 85 |
+
|
| 86 |
+
def flag_sharded_weights(self, params: dict[str, Parameter]):
|
| 87 |
+
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
|
| 88 |
+
for name, param in params.items():
|
| 89 |
+
modules = name.split(".")
|
| 90 |
+
if "norm" in name and len(param.size()) < 2:
|
| 91 |
+
# layer norms are not sharded
|
| 92 |
+
continue
|
| 93 |
+
elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
| 94 |
+
# for now we repeat embedding layers for compatibility
|
| 95 |
+
continue
|
| 96 |
+
else:
|
| 97 |
+
# all other layers are sharded
|
| 98 |
+
set_weight_attrs(param, {"is_sharded_weight": True})
|
| 99 |
+
|
| 100 |
+
def reshape_fairseq2_weights(
|
| 101 |
+
self,
|
| 102 |
+
name: str,
|
| 103 |
+
loaded_weight: torch.Tensor,
|
| 104 |
+
params: dict[str, Parameter],
|
| 105 |
+
) -> Tuple[str, torch.Tensor]:
|
| 106 |
+
"""Reshape fairseq2's weights."""
|
| 107 |
+
|
| 108 |
+
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
| 109 |
+
attn_in = self.config.head_dim * n_heads
|
| 110 |
+
# check for a sharded weight on dim 0
|
| 111 |
+
if attn_in // self.tp_size == w.size()[0]:
|
| 112 |
+
attn_in //= self.tp_size
|
| 113 |
+
n_heads //= self.tp_size
|
| 114 |
+
attn_out = self.config.hidden_size
|
| 115 |
+
return (w.view(n_heads, attn_in // n_heads // 2, 2,
|
| 116 |
+
attn_out).transpose(1,
|
| 117 |
+
2).reshape(attn_in, attn_out))
|
| 118 |
+
|
| 119 |
+
modules = name.split(".")
|
| 120 |
+
|
| 121 |
+
# rotary embeds should be sliced
|
| 122 |
+
if "k_proj" in modules:
|
| 123 |
+
loaded_weight = permute(loaded_weight,
|
| 124 |
+
self.config.num_key_value_heads)
|
| 125 |
+
|
| 126 |
+
elif "q_proj" in modules:
|
| 127 |
+
loaded_weight = permute(loaded_weight,
|
| 128 |
+
self.config.num_attention_heads)
|
| 129 |
+
|
| 130 |
+
# We make the loaded weights compatible with both
|
| 131 |
+
# full checkpoints and tp sharded checkpoints.
|
| 132 |
+
# Embeddings are repeated to fit the vocab size.
|
| 133 |
+
# Other weights are flagged for the weight_loader calls.
|
| 134 |
+
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
| 135 |
+
# Embeddings are sharded on dim 0
|
| 136 |
+
dim = 0
|
| 137 |
+
# In fairseq2, vocab size has to be divisible by tp_size
|
| 138 |
+
# so we don't worry about padding
|
| 139 |
+
if self.tp_size > 1 and loaded_weight.shape[
|
| 140 |
+
dim] < self.config.vocab_size:
|
| 141 |
+
assert loaded_weight.shape[
|
| 142 |
+
dim] * self.tp_size == self.config.vocab_size, \
|
| 143 |
+
"vocab_size should be divisible by tp_size."
|
| 144 |
+
repeats = [1] * len(loaded_weight.size())
|
| 145 |
+
repeats[dim] = self.tp_size
|
| 146 |
+
# repeat to match vocab size and to be easily 'narrow'able
|
| 147 |
+
loaded_weight = loaded_weight.repeat(repeats)
|
| 148 |
+
set_weight_attrs(params[name], {"is_sharded_weight": False})
|
| 149 |
+
# if embeddings are sharded, the rest is too
|
| 150 |
+
if "embed_tokens" in modules:
|
| 151 |
+
self.flag_sharded_weights(params)
|
| 152 |
+
|
| 153 |
+
return name, loaded_weight
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma2.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The vLLM team.
|
| 4 |
+
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
from transformers import Gemma2Config
|
| 23 |
+
|
| 24 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 25 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 26 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 27 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 28 |
+
from vllm.logger import init_logger
|
| 29 |
+
from vllm.model_executor.layers.activation import GeluAndMul
|
| 30 |
+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
| 31 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 32 |
+
QKVParallelLinear,
|
| 33 |
+
RowParallelLinear)
|
| 34 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 35 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 36 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 37 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 38 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 39 |
+
VocabParallelEmbedding)
|
| 40 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 41 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 42 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 46 |
+
from .utils import (AutoWeightsLoader, extract_layer_index,
|
| 47 |
+
is_pp_missing_parameter,
|
| 48 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 49 |
+
maybe_prefix)
|
| 50 |
+
|
| 51 |
+
logger = init_logger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Gemma2MLP(nn.Module):
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
hidden_size: int,
|
| 59 |
+
intermediate_size: int,
|
| 60 |
+
hidden_act: str,
|
| 61 |
+
hidden_activation: str,
|
| 62 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 66 |
+
hidden_size, [intermediate_size] * 2,
|
| 67 |
+
bias=False,
|
| 68 |
+
quant_config=quant_config)
|
| 69 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 70 |
+
hidden_size,
|
| 71 |
+
bias=False,
|
| 72 |
+
quant_config=quant_config)
|
| 73 |
+
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
| 74 |
+
raise ValueError(
|
| 75 |
+
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
|
| 76 |
+
"function. Please set `hidden_act` and `hidden_activation` to "
|
| 77 |
+
"`gelu_pytorch_tanh`.")
|
| 78 |
+
self.act_fn = GeluAndMul(approximate="tanh")
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 82 |
+
x = self.act_fn(gate_up)
|
| 83 |
+
x, _ = self.down_proj(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Gemma2Attention(nn.Module):
|
| 88 |
+
|
| 89 |
+
def __init__(self,
|
| 90 |
+
config: Gemma2Config,
|
| 91 |
+
hidden_size: int,
|
| 92 |
+
num_heads: int,
|
| 93 |
+
num_kv_heads: int,
|
| 94 |
+
head_dim: int,
|
| 95 |
+
max_position_embeddings: int,
|
| 96 |
+
rope_theta: float,
|
| 97 |
+
cache_config: Optional[CacheConfig] = None,
|
| 98 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 99 |
+
attn_logits_soft_cap: Optional[float] = None,
|
| 100 |
+
prefix: str = "") -> None:
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.config = config
|
| 103 |
+
self.hidden_size = hidden_size
|
| 104 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 105 |
+
self.total_num_heads = num_heads
|
| 106 |
+
assert self.total_num_heads % tp_size == 0
|
| 107 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 108 |
+
self.total_num_kv_heads = num_kv_heads
|
| 109 |
+
if self.total_num_kv_heads >= tp_size:
|
| 110 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 111 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 112 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 113 |
+
else:
|
| 114 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 115 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 116 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 117 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 118 |
+
self.head_dim = head_dim
|
| 119 |
+
self.q_size = self.num_heads * self.head_dim
|
| 120 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 121 |
+
self.scaling = config.query_pre_attn_scalar**-0.5
|
| 122 |
+
self.rope_theta = rope_theta
|
| 123 |
+
|
| 124 |
+
self.qkv_proj = QKVParallelLinear(
|
| 125 |
+
hidden_size,
|
| 126 |
+
self.head_dim,
|
| 127 |
+
self.total_num_heads,
|
| 128 |
+
self.total_num_kv_heads,
|
| 129 |
+
bias=config.attention_bias,
|
| 130 |
+
quant_config=quant_config,
|
| 131 |
+
)
|
| 132 |
+
self.o_proj = RowParallelLinear(
|
| 133 |
+
self.total_num_heads * self.head_dim,
|
| 134 |
+
hidden_size,
|
| 135 |
+
bias=config.attention_bias,
|
| 136 |
+
quant_config=quant_config,
|
| 137 |
+
)
|
| 138 |
+
self.rotary_emb = get_rope(
|
| 139 |
+
self.head_dim,
|
| 140 |
+
rotary_dim=self.head_dim,
|
| 141 |
+
max_position=max_position_embeddings,
|
| 142 |
+
base=self.rope_theta,
|
| 143 |
+
is_neox_style=True,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# reference:
|
| 147 |
+
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
| 148 |
+
layer_idx = extract_layer_index(prefix)
|
| 149 |
+
use_sliding_window = (layer_idx % 2 == 0 and
|
| 150 |
+
config.interleaved_sliding_window is not None)
|
| 151 |
+
sliding_window = config.interleaved_sliding_window if \
|
| 152 |
+
use_sliding_window else None
|
| 153 |
+
self.attn = Attention(self.num_heads,
|
| 154 |
+
self.head_dim,
|
| 155 |
+
self.scaling,
|
| 156 |
+
num_kv_heads=self.num_kv_heads,
|
| 157 |
+
cache_config=cache_config,
|
| 158 |
+
quant_config=quant_config,
|
| 159 |
+
logits_soft_cap=attn_logits_soft_cap,
|
| 160 |
+
per_layer_sliding_window=sliding_window,
|
| 161 |
+
prefix=f"{prefix}.attn")
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
positions: torch.Tensor,
|
| 166 |
+
hidden_states: torch.Tensor,
|
| 167 |
+
kv_cache: torch.Tensor,
|
| 168 |
+
attn_metadata: AttentionMetadata,
|
| 169 |
+
) -> torch.Tensor:
|
| 170 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 171 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 172 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 173 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 174 |
+
output, _ = self.o_proj(attn_output)
|
| 175 |
+
return output
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class Gemma2DecoderLayer(nn.Module):
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
config: Gemma2Config,
|
| 183 |
+
cache_config: Optional[CacheConfig] = None,
|
| 184 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 185 |
+
prefix: str = "",
|
| 186 |
+
) -> None:
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.hidden_size = config.hidden_size
|
| 189 |
+
self.self_attn = Gemma2Attention(
|
| 190 |
+
config=config,
|
| 191 |
+
hidden_size=self.hidden_size,
|
| 192 |
+
num_heads=config.num_attention_heads,
|
| 193 |
+
num_kv_heads=config.num_key_value_heads,
|
| 194 |
+
head_dim=config.head_dim,
|
| 195 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 196 |
+
rope_theta=config.rope_theta,
|
| 197 |
+
cache_config=cache_config,
|
| 198 |
+
quant_config=quant_config,
|
| 199 |
+
attn_logits_soft_cap=config.attn_logit_softcapping,
|
| 200 |
+
prefix=f"{prefix}.self_attn",
|
| 201 |
+
)
|
| 202 |
+
self.hidden_size = config.hidden_size
|
| 203 |
+
self.mlp = Gemma2MLP(
|
| 204 |
+
hidden_size=self.hidden_size,
|
| 205 |
+
intermediate_size=config.intermediate_size,
|
| 206 |
+
hidden_act=config.hidden_act,
|
| 207 |
+
hidden_activation=config.hidden_activation,
|
| 208 |
+
quant_config=quant_config,
|
| 209 |
+
)
|
| 210 |
+
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
| 211 |
+
eps=config.rms_norm_eps)
|
| 212 |
+
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
| 213 |
+
eps=config.rms_norm_eps)
|
| 214 |
+
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
| 215 |
+
eps=config.rms_norm_eps)
|
| 216 |
+
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
| 217 |
+
eps=config.rms_norm_eps)
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self,
|
| 221 |
+
positions: torch.Tensor,
|
| 222 |
+
hidden_states: torch.Tensor,
|
| 223 |
+
kv_cache: torch.Tensor,
|
| 224 |
+
attn_metadata: AttentionMetadata,
|
| 225 |
+
residual: Optional[torch.Tensor],
|
| 226 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 227 |
+
if residual is None:
|
| 228 |
+
residual = hidden_states
|
| 229 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 230 |
+
else:
|
| 231 |
+
hidden_states, residual = self.input_layernorm(
|
| 232 |
+
hidden_states, residual)
|
| 233 |
+
hidden_states = self.self_attn(
|
| 234 |
+
positions=positions,
|
| 235 |
+
hidden_states=hidden_states,
|
| 236 |
+
kv_cache=kv_cache,
|
| 237 |
+
attn_metadata=attn_metadata,
|
| 238 |
+
)
|
| 239 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 240 |
+
|
| 241 |
+
hidden_states, residual = self.pre_feedforward_layernorm(
|
| 242 |
+
hidden_states, residual)
|
| 243 |
+
hidden_states = self.mlp(hidden_states)
|
| 244 |
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
| 245 |
+
return hidden_states, residual
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@support_torch_compile
|
| 249 |
+
class Gemma2Model(nn.Module):
|
| 250 |
+
|
| 251 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 252 |
+
super().__init__()
|
| 253 |
+
config = vllm_config.model_config.hf_config
|
| 254 |
+
cache_config = vllm_config.cache_config
|
| 255 |
+
quant_config = vllm_config.quant_config
|
| 256 |
+
self.config = config
|
| 257 |
+
self.quant_config = quant_config
|
| 258 |
+
|
| 259 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 260 |
+
config.vocab_size,
|
| 261 |
+
config.hidden_size,
|
| 262 |
+
)
|
| 263 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 264 |
+
config.num_hidden_layers,
|
| 265 |
+
lambda prefix: Gemma2DecoderLayer(
|
| 266 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 267 |
+
prefix=f"{prefix}.layers")
|
| 268 |
+
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 269 |
+
|
| 270 |
+
# Normalize the embedding by sqrt(hidden_size)
|
| 271 |
+
# The normalizer's data type should be downcasted to the model's
|
| 272 |
+
# data type such as bfloat16, not float32.
|
| 273 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
| 274 |
+
normalizer = self.config.hidden_size**0.5
|
| 275 |
+
self.register_buffer("normalizer", torch.tensor(normalizer))
|
| 276 |
+
self.make_empty_intermediate_tensors = (
|
| 277 |
+
make_empty_intermediate_tensors_factory(
|
| 278 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 279 |
+
|
| 280 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 281 |
+
return self.embed_tokens(input_ids)
|
| 282 |
+
|
| 283 |
+
def forward(
|
| 284 |
+
self,
|
| 285 |
+
input_ids: Optional[torch.Tensor],
|
| 286 |
+
positions: torch.Tensor,
|
| 287 |
+
kv_caches: List[torch.Tensor],
|
| 288 |
+
attn_metadata: AttentionMetadata,
|
| 289 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 290 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 291 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 292 |
+
if get_pp_group().is_first_rank:
|
| 293 |
+
if inputs_embeds is not None:
|
| 294 |
+
hidden_states = inputs_embeds
|
| 295 |
+
else:
|
| 296 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 297 |
+
hidden_states *= self.normalizer
|
| 298 |
+
residual = None
|
| 299 |
+
else:
|
| 300 |
+
assert intermediate_tensors is not None
|
| 301 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 302 |
+
residual = intermediate_tensors["residual"]
|
| 303 |
+
for i in range(self.start_layer, self.end_layer):
|
| 304 |
+
layer = self.layers[i]
|
| 305 |
+
hidden_states, residual = layer(
|
| 306 |
+
positions,
|
| 307 |
+
hidden_states,
|
| 308 |
+
kv_caches[i - self.start_layer],
|
| 309 |
+
attn_metadata,
|
| 310 |
+
residual,
|
| 311 |
+
)
|
| 312 |
+
if not get_pp_group().is_last_rank:
|
| 313 |
+
return IntermediateTensors({
|
| 314 |
+
"hidden_states": hidden_states,
|
| 315 |
+
"residual": residual
|
| 316 |
+
})
|
| 317 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 318 |
+
return hidden_states
|
| 319 |
+
|
| 320 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 321 |
+
torch.Tensor]]) -> Set[str]:
|
| 322 |
+
stacked_params_mapping = [
|
| 323 |
+
# (param_name, shard_name, shard_id)
|
| 324 |
+
("qkv_proj", "q_proj", "q"),
|
| 325 |
+
("qkv_proj", "k_proj", "k"),
|
| 326 |
+
("qkv_proj", "v_proj", "v"),
|
| 327 |
+
("gate_up_proj", "gate_proj", 0),
|
| 328 |
+
("gate_up_proj", "up_proj", 1),
|
| 329 |
+
]
|
| 330 |
+
params_dict = dict(self.named_parameters())
|
| 331 |
+
loaded_params: Set[str] = set()
|
| 332 |
+
for name, loaded_weight in weights:
|
| 333 |
+
if (self.quant_config is not None and
|
| 334 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 335 |
+
# Loading kv cache scales for compressed-tensors quantization
|
| 336 |
+
param = params_dict[scale_name]
|
| 337 |
+
weight_loader = getattr(param, "weight_loader",
|
| 338 |
+
default_weight_loader)
|
| 339 |
+
loaded_weight = loaded_weight[0]
|
| 340 |
+
weight_loader(param, loaded_weight)
|
| 341 |
+
loaded_params.add(scale_name)
|
| 342 |
+
continue
|
| 343 |
+
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
| 344 |
+
if shard_name not in name:
|
| 345 |
+
continue
|
| 346 |
+
name = name.replace(shard_name, param_name)
|
| 347 |
+
# Skip loading extra bias for GPTQ models.
|
| 348 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 349 |
+
continue
|
| 350 |
+
if is_pp_missing_parameter(name, self):
|
| 351 |
+
continue
|
| 352 |
+
param = params_dict[name]
|
| 353 |
+
weight_loader = param.weight_loader
|
| 354 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 355 |
+
break
|
| 356 |
+
else:
|
| 357 |
+
# Skip loading extra bias for GPTQ models.
|
| 358 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 359 |
+
continue
|
| 360 |
+
# Remapping the name of FP8 kv-scale.
|
| 361 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 362 |
+
if name is None:
|
| 363 |
+
continue
|
| 364 |
+
if is_pp_missing_parameter(name, self):
|
| 365 |
+
continue
|
| 366 |
+
param = params_dict[name]
|
| 367 |
+
weight_loader = getattr(param, "weight_loader",
|
| 368 |
+
default_weight_loader)
|
| 369 |
+
weight_loader(param, loaded_weight)
|
| 370 |
+
loaded_params.add(name)
|
| 371 |
+
|
| 372 |
+
unloaded_params = params_dict.keys() - loaded_params
|
| 373 |
+
if unloaded_params:
|
| 374 |
+
logger.warning(
|
| 375 |
+
"Some weights are not initialized from checkpoints: %s",
|
| 376 |
+
unloaded_params)
|
| 377 |
+
return loaded_params
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 381 |
+
packed_modules_mapping = {
|
| 382 |
+
"qkv_proj": [
|
| 383 |
+
"q_proj",
|
| 384 |
+
"k_proj",
|
| 385 |
+
"v_proj",
|
| 386 |
+
],
|
| 387 |
+
"gate_up_proj": [
|
| 388 |
+
"gate_proj",
|
| 389 |
+
"up_proj",
|
| 390 |
+
],
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
# LoRA specific attributes
|
| 394 |
+
supported_lora_modules = [
|
| 395 |
+
"qkv_proj",
|
| 396 |
+
"o_proj",
|
| 397 |
+
"gate_up_proj",
|
| 398 |
+
"down_proj",
|
| 399 |
+
]
|
| 400 |
+
# Gemma does not apply LoRA to the embedding layer.
|
| 401 |
+
embedding_modules = {}
|
| 402 |
+
embedding_padding_modules = []
|
| 403 |
+
|
| 404 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 405 |
+
config = vllm_config.model_config.hf_config
|
| 406 |
+
quant_config = vllm_config.quant_config
|
| 407 |
+
lora_config = vllm_config.lora_config
|
| 408 |
+
del lora_config # Unused.
|
| 409 |
+
super().__init__()
|
| 410 |
+
self.config = config
|
| 411 |
+
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
| 412 |
+
assert config.tie_word_embeddings
|
| 413 |
+
self.quant_config = quant_config
|
| 414 |
+
self.model = Gemma2Model(vllm_config=vllm_config,
|
| 415 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 416 |
+
self.logits_processor = LogitsProcessor(
|
| 417 |
+
config.vocab_size, soft_cap=config.final_logit_softcapping)
|
| 418 |
+
self.sampler = get_sampler()
|
| 419 |
+
self.make_empty_intermediate_tensors = (
|
| 420 |
+
self.model.make_empty_intermediate_tensors)
|
| 421 |
+
|
| 422 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 423 |
+
return self.model.get_input_embeddings(input_ids)
|
| 424 |
+
|
| 425 |
+
def forward(
|
| 426 |
+
self,
|
| 427 |
+
input_ids: torch.Tensor,
|
| 428 |
+
positions: torch.Tensor,
|
| 429 |
+
kv_caches: List[torch.Tensor],
|
| 430 |
+
attn_metadata: AttentionMetadata,
|
| 431 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 432 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 433 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 434 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 435 |
+
attn_metadata, intermediate_tensors,
|
| 436 |
+
inputs_embeds)
|
| 437 |
+
return hidden_states
|
| 438 |
+
|
| 439 |
+
def compute_logits(
|
| 440 |
+
self,
|
| 441 |
+
hidden_states: torch.Tensor,
|
| 442 |
+
sampling_metadata: SamplingMetadata,
|
| 443 |
+
) -> Optional[torch.Tensor]:
|
| 444 |
+
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
| 445 |
+
sampling_metadata)
|
| 446 |
+
return logits
|
| 447 |
+
|
| 448 |
+
def sample(
|
| 449 |
+
self,
|
| 450 |
+
logits: torch.Tensor,
|
| 451 |
+
sampling_metadata: SamplingMetadata,
|
| 452 |
+
) -> Optional[SamplerOutput]:
|
| 453 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 454 |
+
return next_tokens
|
| 455 |
+
|
| 456 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 457 |
+
torch.Tensor]]) -> Set[str]:
|
| 458 |
+
loader = AutoWeightsLoader(
|
| 459 |
+
self,
|
| 460 |
+
skip_prefixes=(["lm_head."]
|
| 461 |
+
if self.config.tie_word_embeddings else None),
|
| 462 |
+
)
|
| 463 |
+
return loader.load_weights(weights)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/glm.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Inference-only HF format GLM-4 model compatible with THUDM weights."""
|
| 3 |
+
from vllm.config import VllmConfig
|
| 4 |
+
from vllm.model_executor.models.llama import LlamaForCausalLM
|
| 5 |
+
|
| 6 |
+
from .utils import PPMissingLayer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GlmForCausalLM(LlamaForCausalLM):
|
| 10 |
+
|
| 11 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 12 |
+
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
| 13 |
+
# Hack Llama model to fit HF format GLM implementation
|
| 14 |
+
# Attention difference between GLM and Llama:
|
| 15 |
+
# 1. Half partial rotary_dim and no Neox style.
|
| 16 |
+
# 2. There is no bias for o_proj in attention
|
| 17 |
+
for layer in self.model.layers:
|
| 18 |
+
if not isinstance(layer, PPMissingLayer):
|
| 19 |
+
layer.self_attn.rotary_emb.rotary_dim //= 2
|
| 20 |
+
layer.self_attn.rotary_emb.is_neox_style = False
|
| 21 |
+
layer.self_attn.o_proj.bias = None
|
| 22 |
+
layer.self_attn.o_proj.skip_bias_add = True
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_j.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
| 20 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from transformers import GPTJConfig
|
| 25 |
+
|
| 26 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 27 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 28 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 29 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 30 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 31 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 32 |
+
QKVParallelLinear,
|
| 33 |
+
RowParallelLinear)
|
| 34 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 35 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 36 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 37 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 38 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 39 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 40 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 41 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 42 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 43 |
+
from vllm.sequence import IntermediateTensors
|
| 44 |
+
|
| 45 |
+
from .interfaces import SupportsPP
|
| 46 |
+
from .utils import (is_pp_missing_parameter,
|
| 47 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 48 |
+
maybe_prefix)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class GPTJAttention(nn.Module):
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
config: GPTJConfig,
|
| 56 |
+
cache_config: Optional[CacheConfig] = None,
|
| 57 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 58 |
+
prefix: str = "",
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.total_num_heads = config.num_attention_heads
|
| 62 |
+
self.hidden_size = config.hidden_size
|
| 63 |
+
self.head_size = self.hidden_size // self.total_num_heads
|
| 64 |
+
|
| 65 |
+
self.qkv_proj = QKVParallelLinear(
|
| 66 |
+
config.hidden_size,
|
| 67 |
+
self.head_size,
|
| 68 |
+
self.total_num_heads,
|
| 69 |
+
bias=False,
|
| 70 |
+
quant_config=quant_config,
|
| 71 |
+
)
|
| 72 |
+
self.out_proj = RowParallelLinear(
|
| 73 |
+
config.hidden_size,
|
| 74 |
+
config.hidden_size,
|
| 75 |
+
bias=False,
|
| 76 |
+
quant_config=quant_config,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
tp_world_size = get_tensor_model_parallel_world_size()
|
| 80 |
+
assert self.total_num_heads % tp_world_size == 0
|
| 81 |
+
self.num_heads = self.total_num_heads // tp_world_size
|
| 82 |
+
|
| 83 |
+
scaling = self.head_size**-0.5
|
| 84 |
+
assert getattr(config, "rotary", True)
|
| 85 |
+
assert config.rotary_dim % 2 == 0
|
| 86 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 87 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 88 |
+
8192)
|
| 89 |
+
self.rotary_emb = get_rope(
|
| 90 |
+
self.head_size,
|
| 91 |
+
rotary_dim=config.rotary_dim,
|
| 92 |
+
max_position=max_position_embeddings,
|
| 93 |
+
base=rope_theta,
|
| 94 |
+
is_neox_style=False,
|
| 95 |
+
)
|
| 96 |
+
self.attn = Attention(self.num_heads,
|
| 97 |
+
self.head_size,
|
| 98 |
+
scaling,
|
| 99 |
+
cache_config=cache_config,
|
| 100 |
+
quant_config=quant_config,
|
| 101 |
+
prefix=f"{prefix}.attn")
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
position_ids: torch.Tensor,
|
| 106 |
+
hidden_states: torch.Tensor,
|
| 107 |
+
kv_cache: torch.Tensor,
|
| 108 |
+
attn_metadata: AttentionMetadata,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 111 |
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
| 112 |
+
q, k = self.rotary_emb(position_ids, q, k)
|
| 113 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 114 |
+
attn_output, _ = self.out_proj(attn_output)
|
| 115 |
+
return attn_output
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class GPTJMLP(nn.Module):
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
intermediate_size: int,
|
| 123 |
+
config: GPTJConfig,
|
| 124 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
hidden_size = config.n_embd
|
| 128 |
+
self.fc_in = ColumnParallelLinear(
|
| 129 |
+
hidden_size,
|
| 130 |
+
intermediate_size,
|
| 131 |
+
quant_config=quant_config,
|
| 132 |
+
)
|
| 133 |
+
self.fc_out = RowParallelLinear(
|
| 134 |
+
intermediate_size,
|
| 135 |
+
hidden_size,
|
| 136 |
+
quant_config=quant_config,
|
| 137 |
+
)
|
| 138 |
+
self.act = get_act_fn(config.activation_function)
|
| 139 |
+
|
| 140 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
hidden_states, _ = self.fc_in(hidden_states)
|
| 142 |
+
hidden_states = self.act(hidden_states)
|
| 143 |
+
hidden_states, _ = self.fc_out(hidden_states)
|
| 144 |
+
return hidden_states
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class GPTJBlock(nn.Module):
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
config: GPTJConfig,
|
| 152 |
+
cache_config: Optional[CacheConfig] = None,
|
| 153 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 154 |
+
prefix: str = "",
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
inner_dim = (4 * config.n_embd
|
| 158 |
+
if config.n_inner is None else config.n_inner)
|
| 159 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 160 |
+
self.attn = GPTJAttention(config,
|
| 161 |
+
cache_config,
|
| 162 |
+
quant_config,
|
| 163 |
+
prefix=f"{prefix}.attn")
|
| 164 |
+
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
position_ids: torch.Tensor,
|
| 169 |
+
hidden_states: torch.Tensor,
|
| 170 |
+
kv_cache: torch.Tensor,
|
| 171 |
+
attn_metadata: AttentionMetadata,
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
residual = hidden_states
|
| 174 |
+
hidden_states = self.ln_1(hidden_states)
|
| 175 |
+
attn_output = self.attn(
|
| 176 |
+
position_ids=position_ids,
|
| 177 |
+
hidden_states=hidden_states,
|
| 178 |
+
kv_cache=kv_cache,
|
| 179 |
+
attn_metadata=attn_metadata,
|
| 180 |
+
)
|
| 181 |
+
mlp_output = self.mlp(hidden_states)
|
| 182 |
+
hidden_states = attn_output + mlp_output + residual
|
| 183 |
+
return hidden_states
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@support_torch_compile
|
| 187 |
+
class GPTJModel(nn.Module):
|
| 188 |
+
|
| 189 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
config = vllm_config.model_config.hf_config
|
| 193 |
+
cache_config = vllm_config.cache_config
|
| 194 |
+
quant_config = vllm_config.quant_config
|
| 195 |
+
|
| 196 |
+
self.config = config
|
| 197 |
+
self.embed_dim = config.n_embd
|
| 198 |
+
self.wte = VocabParallelEmbedding(
|
| 199 |
+
config.vocab_size,
|
| 200 |
+
self.embed_dim,
|
| 201 |
+
)
|
| 202 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 203 |
+
config.n_layer,
|
| 204 |
+
lambda prefix: GPTJBlock(
|
| 205 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 206 |
+
prefix=f"{prefix}.h",
|
| 207 |
+
)
|
| 208 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 209 |
+
self.make_empty_intermediate_tensors = (
|
| 210 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 211 |
+
config.n_embd))
|
| 212 |
+
|
| 213 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 214 |
+
return self.wte(input_ids)
|
| 215 |
+
|
| 216 |
+
def forward(
|
| 217 |
+
self,
|
| 218 |
+
input_ids: torch.Tensor,
|
| 219 |
+
position_ids: torch.Tensor,
|
| 220 |
+
kv_caches: List[torch.Tensor],
|
| 221 |
+
attn_metadata: AttentionMetadata,
|
| 222 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 223 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 224 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 225 |
+
if get_pp_group().is_first_rank:
|
| 226 |
+
if inputs_embeds is not None:
|
| 227 |
+
hidden_states = inputs_embeds
|
| 228 |
+
else:
|
| 229 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 230 |
+
else:
|
| 231 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 232 |
+
for i in range(self.start_layer, self.end_layer):
|
| 233 |
+
layer = self.h[i]
|
| 234 |
+
hidden_states = layer(
|
| 235 |
+
position_ids,
|
| 236 |
+
hidden_states,
|
| 237 |
+
kv_caches[i - self.start_layer],
|
| 238 |
+
attn_metadata,
|
| 239 |
+
)
|
| 240 |
+
if not get_pp_group().is_last_rank:
|
| 241 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 242 |
+
hidden_states = self.ln_f(hidden_states)
|
| 243 |
+
return hidden_states
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class GPTJForCausalLM(nn.Module, SupportsPP):
|
| 247 |
+
|
| 248 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 249 |
+
super().__init__()
|
| 250 |
+
config = vllm_config.model_config.hf_config
|
| 251 |
+
quant_config = vllm_config.quant_config
|
| 252 |
+
self.config = config
|
| 253 |
+
self.quant_config = quant_config
|
| 254 |
+
assert not config.tie_word_embeddings
|
| 255 |
+
self.transformer = GPTJModel(vllm_config=vllm_config,
|
| 256 |
+
prefix=maybe_prefix(
|
| 257 |
+
prefix, "transformer"))
|
| 258 |
+
self.lm_head = ParallelLMHead(
|
| 259 |
+
config.vocab_size,
|
| 260 |
+
config.n_embd,
|
| 261 |
+
bias=True,
|
| 262 |
+
quant_config=quant_config,
|
| 263 |
+
)
|
| 264 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 265 |
+
self.sampler = get_sampler()
|
| 266 |
+
self.make_empty_intermediate_tensors = (
|
| 267 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 268 |
+
|
| 269 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 270 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 271 |
+
|
| 272 |
+
def forward(
|
| 273 |
+
self,
|
| 274 |
+
input_ids: torch.Tensor,
|
| 275 |
+
positions: torch.Tensor,
|
| 276 |
+
kv_caches: List[torch.Tensor],
|
| 277 |
+
attn_metadata: AttentionMetadata,
|
| 278 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 279 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 280 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 281 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 282 |
+
attn_metadata, intermediate_tensors,
|
| 283 |
+
inputs_embeds)
|
| 284 |
+
return hidden_states
|
| 285 |
+
|
| 286 |
+
def compute_logits(
|
| 287 |
+
self,
|
| 288 |
+
hidden_states: torch.Tensor,
|
| 289 |
+
sampling_metadata: SamplingMetadata,
|
| 290 |
+
) -> Optional[torch.Tensor]:
|
| 291 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 292 |
+
sampling_metadata, self.lm_head.bias)
|
| 293 |
+
return logits
|
| 294 |
+
|
| 295 |
+
def sample(
|
| 296 |
+
self,
|
| 297 |
+
logits: torch.Tensor,
|
| 298 |
+
sampling_metadata: SamplingMetadata,
|
| 299 |
+
) -> Optional[SamplerOutput]:
|
| 300 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 301 |
+
return next_tokens
|
| 302 |
+
|
| 303 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 304 |
+
torch.Tensor]]) -> Set[str]:
|
| 305 |
+
stacked_params_mapping = [
|
| 306 |
+
# (param_name, shard_name, shard_id)
|
| 307 |
+
("qkv_proj", "q_proj", "q"),
|
| 308 |
+
("qkv_proj", "k_proj", "k"),
|
| 309 |
+
("qkv_proj", "v_proj", "v"),
|
| 310 |
+
("gate_up_proj", "gate_proj", 0),
|
| 311 |
+
("gate_up_proj", "up_proj", 1),
|
| 312 |
+
]
|
| 313 |
+
params_dict = dict(self.named_parameters())
|
| 314 |
+
loaded_params: Set[str] = set()
|
| 315 |
+
for name, loaded_weight in weights:
|
| 316 |
+
if "attn.bias" in name or "attn.masked_bias" in name:
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
if (self.quant_config is not None and
|
| 320 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 321 |
+
# Loading kv cache quantization scales
|
| 322 |
+
param = params_dict[scale_name]
|
| 323 |
+
weight_loader = getattr(param, "weight_loader",
|
| 324 |
+
default_weight_loader)
|
| 325 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 326 |
+
loaded_weight[0])
|
| 327 |
+
weight_loader(param, loaded_weight)
|
| 328 |
+
loaded_params.add(scale_name)
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 332 |
+
if weight_name not in name:
|
| 333 |
+
continue
|
| 334 |
+
name = name.replace(weight_name, param_name)
|
| 335 |
+
# Skip loading extra bias for GPTQ models.
|
| 336 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 337 |
+
continue
|
| 338 |
+
if is_pp_missing_parameter(name, self):
|
| 339 |
+
continue
|
| 340 |
+
param = params_dict[name]
|
| 341 |
+
weight_loader = param.weight_loader
|
| 342 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 343 |
+
break
|
| 344 |
+
else:
|
| 345 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 346 |
+
if name is None:
|
| 347 |
+
continue
|
| 348 |
+
# Skip loading extra bias for GPTQ models.
|
| 349 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 350 |
+
continue
|
| 351 |
+
if is_pp_missing_parameter(name, self):
|
| 352 |
+
continue
|
| 353 |
+
param = params_dict[name]
|
| 354 |
+
weight_loader = getattr(param, "weight_loader",
|
| 355 |
+
default_weight_loader)
|
| 356 |
+
weight_loader(param, loaded_weight)
|
| 357 |
+
loaded_params.add(name)
|
| 358 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gpt_neox.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
| 20 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from transformers import GPTNeoXConfig
|
| 25 |
+
|
| 26 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 27 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 28 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 29 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 30 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 31 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 32 |
+
QKVParallelLinear,
|
| 33 |
+
RowParallelLinear)
|
| 34 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 35 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 36 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 37 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 38 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 39 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 40 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 41 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 42 |
+
from vllm.sequence import IntermediateTensors
|
| 43 |
+
|
| 44 |
+
from .interfaces import SupportsPP
|
| 45 |
+
from .utils import (is_pp_missing_parameter,
|
| 46 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 47 |
+
maybe_prefix)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class GPTNeoXAttention(nn.Module):
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
config: GPTNeoXConfig,
|
| 55 |
+
cache_config: Optional[CacheConfig] = None,
|
| 56 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 57 |
+
prefix: str = "",
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.total_num_heads = config.num_attention_heads
|
| 61 |
+
self.hidden_size = config.hidden_size
|
| 62 |
+
self.head_size = self.hidden_size // self.total_num_heads
|
| 63 |
+
self.bias = getattr(config, "attention_bias", True)
|
| 64 |
+
|
| 65 |
+
tensor_model_parallel_world_size = (
|
| 66 |
+
get_tensor_model_parallel_world_size())
|
| 67 |
+
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
| 68 |
+
self.num_heads = (self.total_num_heads //
|
| 69 |
+
tensor_model_parallel_world_size)
|
| 70 |
+
|
| 71 |
+
self.query_key_value = QKVParallelLinear(
|
| 72 |
+
config.hidden_size,
|
| 73 |
+
self.head_size,
|
| 74 |
+
self.total_num_heads,
|
| 75 |
+
bias=self.bias,
|
| 76 |
+
quant_config=quant_config,
|
| 77 |
+
)
|
| 78 |
+
self.dense = RowParallelLinear(
|
| 79 |
+
config.hidden_size,
|
| 80 |
+
config.hidden_size,
|
| 81 |
+
bias=self.bias,
|
| 82 |
+
quant_config=quant_config,
|
| 83 |
+
)
|
| 84 |
+
scaling = self.head_size**-0.5
|
| 85 |
+
rotary_dim = int(self.head_size * config.rotary_pct)
|
| 86 |
+
assert rotary_dim % 2 == 0
|
| 87 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 88 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 89 |
+
8192)
|
| 90 |
+
self.rotary_emb = get_rope(
|
| 91 |
+
self.head_size,
|
| 92 |
+
rotary_dim=rotary_dim,
|
| 93 |
+
max_position=max_position_embeddings,
|
| 94 |
+
base=rope_theta,
|
| 95 |
+
)
|
| 96 |
+
self.attn = Attention(self.num_heads,
|
| 97 |
+
self.head_size,
|
| 98 |
+
scaling,
|
| 99 |
+
cache_config=cache_config,
|
| 100 |
+
quant_config=quant_config,
|
| 101 |
+
prefix=f"{prefix}.attn")
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
position_ids: torch.Tensor,
|
| 106 |
+
hidden_states: torch.Tensor,
|
| 107 |
+
kv_cache: torch.Tensor,
|
| 108 |
+
attn_metadata: AttentionMetadata,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
qkv, _ = self.query_key_value(hidden_states)
|
| 111 |
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
| 112 |
+
q, k = self.rotary_emb(position_ids, q, k)
|
| 113 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 114 |
+
output, _ = self.dense(attn_output)
|
| 115 |
+
return output
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class GPTNeoXMLP(nn.Module):
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
config: GPTNeoXConfig,
|
| 123 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dense_h_to_4h = ColumnParallelLinear(
|
| 127 |
+
config.hidden_size,
|
| 128 |
+
config.intermediate_size,
|
| 129 |
+
quant_config=quant_config,
|
| 130 |
+
)
|
| 131 |
+
self.dense_4h_to_h = RowParallelLinear(
|
| 132 |
+
config.intermediate_size,
|
| 133 |
+
config.hidden_size,
|
| 134 |
+
quant_config=quant_config,
|
| 135 |
+
)
|
| 136 |
+
self.act = get_act_fn(config.hidden_act)
|
| 137 |
+
|
| 138 |
+
def forward(self, hidden_states):
|
| 139 |
+
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
| 140 |
+
hidden_states = self.act(hidden_states)
|
| 141 |
+
hidden_states, _ = self.dense_4h_to_h(hidden_states)
|
| 142 |
+
return hidden_states
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class GPTNeoXLayer(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
config: GPTNeoXConfig,
|
| 150 |
+
cache_config: Optional[CacheConfig] = None,
|
| 151 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 152 |
+
prefix: str = "",
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.use_parallel_residual = config.use_parallel_residual
|
| 156 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
| 157 |
+
eps=config.layer_norm_eps)
|
| 158 |
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
| 159 |
+
eps=config.layer_norm_eps)
|
| 160 |
+
self.attention = GPTNeoXAttention(config,
|
| 161 |
+
cache_config,
|
| 162 |
+
quant_config,
|
| 163 |
+
prefix=f"{prefix}.attention")
|
| 164 |
+
self.mlp = GPTNeoXMLP(config, quant_config)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
position_ids: torch.Tensor,
|
| 169 |
+
hidden_states: torch.Tensor,
|
| 170 |
+
kv_cache: torch.Tensor,
|
| 171 |
+
attn_metadata: AttentionMetadata,
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
attn_input = self.input_layernorm(hidden_states)
|
| 174 |
+
attn_output = self.attention(
|
| 175 |
+
position_ids=position_ids,
|
| 176 |
+
hidden_states=attn_input,
|
| 177 |
+
kv_cache=kv_cache,
|
| 178 |
+
attn_metadata=attn_metadata,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if self.use_parallel_residual:
|
| 182 |
+
# pseudocode:
|
| 183 |
+
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
| 184 |
+
mlp_input = self.post_attention_layernorm(hidden_states)
|
| 185 |
+
mlp_output = self.mlp(mlp_input)
|
| 186 |
+
hidden_states = mlp_output + attn_output + hidden_states
|
| 187 |
+
else:
|
| 188 |
+
# pseudocode:
|
| 189 |
+
# x = x + attn(ln1(x))
|
| 190 |
+
# x = x + mlp(ln2(x))
|
| 191 |
+
attn_output = attn_output + hidden_states
|
| 192 |
+
mlp_input = self.post_attention_layernorm(attn_output)
|
| 193 |
+
mlp_output = self.mlp(mlp_input)
|
| 194 |
+
hidden_states = mlp_output + attn_output
|
| 195 |
+
return hidden_states
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@support_torch_compile
|
| 199 |
+
class GPTNeoXModel(nn.Module):
|
| 200 |
+
|
| 201 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 202 |
+
super().__init__()
|
| 203 |
+
|
| 204 |
+
config = vllm_config.model_config.hf_config
|
| 205 |
+
cache_config = vllm_config.cache_config
|
| 206 |
+
quant_config = vllm_config.quant_config
|
| 207 |
+
|
| 208 |
+
self.config = config
|
| 209 |
+
|
| 210 |
+
self.embed_in = VocabParallelEmbedding(
|
| 211 |
+
config.vocab_size,
|
| 212 |
+
config.hidden_size,
|
| 213 |
+
)
|
| 214 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 215 |
+
config.num_hidden_layers,
|
| 216 |
+
lambda prefix: GPTNeoXLayer(
|
| 217 |
+
config, cache_config, quant_config, prefix=prefix),
|
| 218 |
+
prefix=f"{prefix}.layers",
|
| 219 |
+
)
|
| 220 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
| 221 |
+
eps=config.layer_norm_eps)
|
| 222 |
+
self.make_empty_intermediate_tensors = (
|
| 223 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 224 |
+
config.hidden_size))
|
| 225 |
+
|
| 226 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 227 |
+
return self.embed_in(input_ids)
|
| 228 |
+
|
| 229 |
+
def forward(
|
| 230 |
+
self,
|
| 231 |
+
input_ids: torch.Tensor,
|
| 232 |
+
position_ids: torch.Tensor,
|
| 233 |
+
kv_caches: List[torch.Tensor],
|
| 234 |
+
attn_metadata: AttentionMetadata,
|
| 235 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 236 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 237 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 238 |
+
if get_pp_group().is_first_rank:
|
| 239 |
+
if inputs_embeds is not None:
|
| 240 |
+
hidden_states = inputs_embeds
|
| 241 |
+
else:
|
| 242 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 243 |
+
else:
|
| 244 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 245 |
+
for i in range(self.start_layer, self.end_layer):
|
| 246 |
+
layer = self.layers[i]
|
| 247 |
+
hidden_states = layer(
|
| 248 |
+
position_ids,
|
| 249 |
+
hidden_states,
|
| 250 |
+
kv_caches[i - self.start_layer],
|
| 251 |
+
attn_metadata,
|
| 252 |
+
)
|
| 253 |
+
if not get_pp_group().is_last_rank:
|
| 254 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 255 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 256 |
+
return hidden_states
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
| 260 |
+
|
| 261 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 262 |
+
super().__init__()
|
| 263 |
+
config = vllm_config.model_config.hf_config
|
| 264 |
+
quant_config = vllm_config.quant_config
|
| 265 |
+
self.config = config
|
| 266 |
+
self.quant_config = quant_config
|
| 267 |
+
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
|
| 268 |
+
prefix=maybe_prefix(prefix, "gpt_neox"))
|
| 269 |
+
self.embed_out = ParallelLMHead(
|
| 270 |
+
config.vocab_size,
|
| 271 |
+
config.hidden_size,
|
| 272 |
+
quant_config=quant_config,
|
| 273 |
+
)
|
| 274 |
+
if self.config.tie_word_embeddings:
|
| 275 |
+
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
| 276 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 277 |
+
self.sampler = get_sampler()
|
| 278 |
+
self.make_empty_intermediate_tensors = (
|
| 279 |
+
self.gpt_neox.make_empty_intermediate_tensors)
|
| 280 |
+
|
| 281 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
return self.gpt_neox.get_input_embeddings(input_ids)
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
input_ids: torch.Tensor,
|
| 287 |
+
positions: torch.Tensor,
|
| 288 |
+
kv_caches: List[torch.Tensor],
|
| 289 |
+
attn_metadata: AttentionMetadata,
|
| 290 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 291 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 292 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 293 |
+
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
| 294 |
+
attn_metadata, intermediate_tensors,
|
| 295 |
+
inputs_embeds)
|
| 296 |
+
return hidden_states
|
| 297 |
+
|
| 298 |
+
def compute_logits(
|
| 299 |
+
self,
|
| 300 |
+
hidden_states: torch.Tensor,
|
| 301 |
+
sampling_metadata: SamplingMetadata,
|
| 302 |
+
) -> Optional[torch.Tensor]:
|
| 303 |
+
logits = self.logits_processor(self.embed_out, hidden_states,
|
| 304 |
+
sampling_metadata)
|
| 305 |
+
return logits
|
| 306 |
+
|
| 307 |
+
def sample(
|
| 308 |
+
self,
|
| 309 |
+
logits: torch.Tensor,
|
| 310 |
+
sampling_metadata: SamplingMetadata,
|
| 311 |
+
) -> Optional[SamplerOutput]:
|
| 312 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 313 |
+
return next_tokens
|
| 314 |
+
|
| 315 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 316 |
+
torch.Tensor]]) -> Set[str]:
|
| 317 |
+
params_dict = dict(self.named_parameters())
|
| 318 |
+
loaded_params: Set[str] = set()
|
| 319 |
+
for name, loaded_weight in weights:
|
| 320 |
+
if ("attention.bias" in name or "attention.masked_bias" in name
|
| 321 |
+
or "rotary_emb.inv_freq" in name):
|
| 322 |
+
continue
|
| 323 |
+
if ("rotary_emb.cos_cached" in name
|
| 324 |
+
or "rotary_emb.sin_cached" in name):
|
| 325 |
+
# Models trained using OpenRLHF may include
|
| 326 |
+
# these tensors in the checkpoint. Skip them.
|
| 327 |
+
continue
|
| 328 |
+
if is_pp_missing_parameter(name, self):
|
| 329 |
+
continue
|
| 330 |
+
param = params_dict[name]
|
| 331 |
+
|
| 332 |
+
if "query_key_value" in name:
|
| 333 |
+
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
|
| 334 |
+
# (num_heads * 3 * head_size), while the
|
| 335 |
+
# required shape is (3 * num_heads * head_size).
|
| 336 |
+
# Thus, we need weight conversion.
|
| 337 |
+
output_dim = getattr(param, "output_dim", None)
|
| 338 |
+
num_heads = self.config.num_attention_heads
|
| 339 |
+
if output_dim is not None:
|
| 340 |
+
loaded_weight_shape = loaded_weight.shape
|
| 341 |
+
loaded_weight = loaded_weight.view(
|
| 342 |
+
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
| 343 |
+
loaded_weight_shape[output_dim + 1:])
|
| 344 |
+
loaded_weight = loaded_weight.transpose(
|
| 345 |
+
output_dim, output_dim + 1)
|
| 346 |
+
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
| 347 |
+
|
| 348 |
+
weight_loader = getattr(param, "weight_loader",
|
| 349 |
+
default_weight_loader)
|
| 350 |
+
weight_loader(param, loaded_weight)
|
| 351 |
+
loaded_params.add(name)
|
| 352 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/granite.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 9 |
+
# and OPT implementations in this library. It has been modified from its
|
| 10 |
+
# original forms to accommodate minor architectural differences compared
|
| 11 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
|
| 25 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import nn
|
| 29 |
+
from transformers import GraniteConfig
|
| 30 |
+
|
| 31 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 32 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 33 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 34 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 35 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 36 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 37 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 38 |
+
QKVParallelLinear,
|
| 39 |
+
RowParallelLinear)
|
| 40 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 41 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 42 |
+
QuantizationConfig)
|
| 43 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 44 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 45 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 46 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 47 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 48 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 49 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 50 |
+
from vllm.sequence import IntermediateTensors
|
| 51 |
+
|
| 52 |
+
from .interfaces import SupportsLoRA, SupportsPP
|
| 53 |
+
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
|
| 54 |
+
maybe_prefix)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GraniteMLP(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
hidden_size: int,
|
| 62 |
+
intermediate_size: int,
|
| 63 |
+
hidden_act: str,
|
| 64 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 65 |
+
bias: bool = False,
|
| 66 |
+
prefix: str = "",
|
| 67 |
+
) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 70 |
+
input_size=hidden_size,
|
| 71 |
+
output_sizes=[intermediate_size] * 2,
|
| 72 |
+
bias=bias,
|
| 73 |
+
quant_config=quant_config,
|
| 74 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 75 |
+
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
| 76 |
+
output_size=hidden_size,
|
| 77 |
+
bias=bias,
|
| 78 |
+
quant_config=quant_config,
|
| 79 |
+
prefix=f"{prefix}.down_proj")
|
| 80 |
+
if hidden_act != "silu":
|
| 81 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 82 |
+
"Only silu is supported for now.")
|
| 83 |
+
self.act_fn = SiluAndMul()
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
gate_up, _ = self.gate_up_proj(x)
|
| 87 |
+
x = self.act_fn(gate_up)
|
| 88 |
+
x, _ = self.down_proj(x)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class GraniteAttention(nn.Module):
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
config: GraniteConfig,
|
| 97 |
+
hidden_size: int,
|
| 98 |
+
num_heads: int,
|
| 99 |
+
num_kv_heads: int,
|
| 100 |
+
rope_theta: float = 10000,
|
| 101 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 102 |
+
max_position_embeddings: int = 8192,
|
| 103 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 104 |
+
bias: bool = False,
|
| 105 |
+
cache_config: Optional[CacheConfig] = None,
|
| 106 |
+
prefix: str = "",
|
| 107 |
+
) -> None:
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.hidden_size = hidden_size
|
| 110 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 111 |
+
self.total_num_heads = num_heads
|
| 112 |
+
assert self.total_num_heads % tp_size == 0
|
| 113 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 114 |
+
self.total_num_kv_heads = num_kv_heads
|
| 115 |
+
if self.total_num_kv_heads >= tp_size:
|
| 116 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 117 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 118 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 119 |
+
else:
|
| 120 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 121 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 122 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 123 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 124 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 125 |
+
self.head_dim = getattr(config, "head_dim",
|
| 126 |
+
self.hidden_size // self.total_num_heads)
|
| 127 |
+
self.q_size = self.num_heads * self.head_dim
|
| 128 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 129 |
+
self.scaling = config.attention_multiplier
|
| 130 |
+
self.rope_theta = rope_theta
|
| 131 |
+
self.max_position_embeddings = max_position_embeddings
|
| 132 |
+
|
| 133 |
+
self.qkv_proj = QKVParallelLinear(
|
| 134 |
+
hidden_size=hidden_size,
|
| 135 |
+
head_size=self.head_dim,
|
| 136 |
+
total_num_heads=self.total_num_heads,
|
| 137 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 138 |
+
bias=bias,
|
| 139 |
+
quant_config=quant_config,
|
| 140 |
+
prefix=f"{prefix}.qkv_proj",
|
| 141 |
+
)
|
| 142 |
+
self.o_proj = RowParallelLinear(
|
| 143 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 144 |
+
output_size=hidden_size,
|
| 145 |
+
bias=bias,
|
| 146 |
+
quant_config=quant_config,
|
| 147 |
+
prefix=f"{prefix}.o_proj",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.rotary_emb = get_rope(
|
| 151 |
+
self.head_dim,
|
| 152 |
+
rotary_dim=self.head_dim,
|
| 153 |
+
max_position=max_position_embeddings,
|
| 154 |
+
base=rope_theta,
|
| 155 |
+
rope_scaling=rope_scaling,
|
| 156 |
+
)
|
| 157 |
+
self.attn = Attention(self.num_heads,
|
| 158 |
+
self.head_dim,
|
| 159 |
+
self.scaling,
|
| 160 |
+
num_kv_heads=self.num_kv_heads,
|
| 161 |
+
cache_config=cache_config,
|
| 162 |
+
quant_config=quant_config,
|
| 163 |
+
prefix=f"{prefix}.attn")
|
| 164 |
+
|
| 165 |
+
def forward(
|
| 166 |
+
self,
|
| 167 |
+
positions: torch.Tensor,
|
| 168 |
+
hidden_states: torch.Tensor,
|
| 169 |
+
kv_cache: torch.Tensor,
|
| 170 |
+
attn_metadata: AttentionMetadata,
|
| 171 |
+
) -> torch.Tensor:
|
| 172 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 173 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 174 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 175 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 176 |
+
output, _ = self.o_proj(attn_output)
|
| 177 |
+
return output
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class GraniteDecoderLayer(nn.Module):
|
| 181 |
+
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
config: GraniteConfig,
|
| 185 |
+
cache_config: Optional[CacheConfig] = None,
|
| 186 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 187 |
+
prefix: str = "",
|
| 188 |
+
) -> None:
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.hidden_size = config.hidden_size
|
| 191 |
+
self.residual_multiplier = config.residual_multiplier
|
| 192 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 193 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 194 |
+
if rope_scaling is not None and getattr(
|
| 195 |
+
config, "original_max_position_embeddings", None):
|
| 196 |
+
rope_scaling["original_max_position_embeddings"] = (
|
| 197 |
+
config.original_max_position_embeddings)
|
| 198 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 199 |
+
8192)
|
| 200 |
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
| 201 |
+
# Support internlm/internlm-7b with bias
|
| 202 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 203 |
+
config, "bias", False)
|
| 204 |
+
self.self_attn = GraniteAttention(
|
| 205 |
+
config=config,
|
| 206 |
+
hidden_size=self.hidden_size,
|
| 207 |
+
num_heads=config.num_attention_heads,
|
| 208 |
+
num_kv_heads=getattr(config, "num_key_value_heads",
|
| 209 |
+
config.num_attention_heads),
|
| 210 |
+
rope_theta=rope_theta,
|
| 211 |
+
rope_scaling=rope_scaling,
|
| 212 |
+
max_position_embeddings=max_position_embeddings,
|
| 213 |
+
quant_config=quant_config,
|
| 214 |
+
bias=attention_bias,
|
| 215 |
+
cache_config=cache_config,
|
| 216 |
+
prefix=f"{prefix}.self_attn",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.mlp = GraniteMLP(
|
| 220 |
+
hidden_size=self.hidden_size,
|
| 221 |
+
intermediate_size=config.intermediate_size,
|
| 222 |
+
hidden_act=config.hidden_act,
|
| 223 |
+
quant_config=quant_config,
|
| 224 |
+
bias=getattr(config, "mlp_bias", False),
|
| 225 |
+
prefix=f"{prefix}.mlp",
|
| 226 |
+
)
|
| 227 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 228 |
+
eps=config.rms_norm_eps)
|
| 229 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 230 |
+
eps=config.rms_norm_eps)
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
positions: torch.Tensor,
|
| 235 |
+
hidden_states: torch.Tensor,
|
| 236 |
+
kv_cache: torch.Tensor,
|
| 237 |
+
attn_metadata: AttentionMetadata,
|
| 238 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 239 |
+
# Self Attention
|
| 240 |
+
residual = hidden_states
|
| 241 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 242 |
+
hidden_states = self.self_attn(
|
| 243 |
+
positions=positions,
|
| 244 |
+
hidden_states=hidden_states,
|
| 245 |
+
kv_cache=kv_cache,
|
| 246 |
+
attn_metadata=attn_metadata,
|
| 247 |
+
)
|
| 248 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 249 |
+
# Fully Connected
|
| 250 |
+
residual = hidden_states
|
| 251 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 252 |
+
hidden_states = self.mlp(hidden_states)
|
| 253 |
+
hidden_states = residual + hidden_states * self.residual_multiplier
|
| 254 |
+
return hidden_states
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@support_torch_compile
|
| 258 |
+
class GraniteModel(nn.Module):
|
| 259 |
+
|
| 260 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
config = vllm_config.model_config.hf_config
|
| 264 |
+
cache_config = vllm_config.cache_config
|
| 265 |
+
quant_config = vllm_config.quant_config
|
| 266 |
+
lora_config = vllm_config.lora_config
|
| 267 |
+
|
| 268 |
+
self.config = config
|
| 269 |
+
self.padding_idx = config.pad_token_id
|
| 270 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 271 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 272 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 273 |
+
self.org_vocab_size = config.vocab_size
|
| 274 |
+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
| 275 |
+
and get_pp_group().is_last_rank):
|
| 276 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 277 |
+
self.vocab_size,
|
| 278 |
+
config.hidden_size,
|
| 279 |
+
org_num_embeddings=config.vocab_size,
|
| 280 |
+
quant_config=quant_config,
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
self.embed_tokens = PPMissingLayer()
|
| 284 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 285 |
+
config.num_hidden_layers,
|
| 286 |
+
lambda prefix: GraniteDecoderLayer(config=config,
|
| 287 |
+
cache_config=cache_config,
|
| 288 |
+
quant_config=quant_config,
|
| 289 |
+
prefix=prefix),
|
| 290 |
+
prefix=f"{prefix}.layers")
|
| 291 |
+
if get_pp_group().is_last_rank:
|
| 292 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 293 |
+
else:
|
| 294 |
+
self.norm = PPMissingLayer()
|
| 295 |
+
|
| 296 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 297 |
+
return self.embed_tokens(input_ids)
|
| 298 |
+
|
| 299 |
+
def forward(
|
| 300 |
+
self,
|
| 301 |
+
input_ids: Optional[torch.Tensor],
|
| 302 |
+
positions: torch.Tensor,
|
| 303 |
+
kv_caches: List[torch.Tensor],
|
| 304 |
+
attn_metadata: AttentionMetadata,
|
| 305 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 306 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 307 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 308 |
+
if get_pp_group().is_first_rank:
|
| 309 |
+
if inputs_embeds is not None:
|
| 310 |
+
hidden_states = inputs_embeds
|
| 311 |
+
else:
|
| 312 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 313 |
+
residual = None
|
| 314 |
+
|
| 315 |
+
hidden_states *= self.config.embedding_multiplier
|
| 316 |
+
else:
|
| 317 |
+
assert intermediate_tensors is not None
|
| 318 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 319 |
+
residual = intermediate_tensors["residual"]
|
| 320 |
+
|
| 321 |
+
for i in range(self.start_layer, self.end_layer):
|
| 322 |
+
layer = self.layers[i]
|
| 323 |
+
hidden_states = layer(
|
| 324 |
+
positions,
|
| 325 |
+
hidden_states,
|
| 326 |
+
kv_caches[i - self.start_layer],
|
| 327 |
+
attn_metadata,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if not get_pp_group().is_last_rank:
|
| 331 |
+
return IntermediateTensors({
|
| 332 |
+
"hidden_states": hidden_states,
|
| 333 |
+
"residual": residual
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
hidden_states = self.norm(hidden_states)
|
| 337 |
+
return hidden_states
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
| 341 |
+
packed_modules_mapping = {
|
| 342 |
+
"qkv_proj": [
|
| 343 |
+
"q_proj",
|
| 344 |
+
"k_proj",
|
| 345 |
+
"v_proj",
|
| 346 |
+
],
|
| 347 |
+
"gate_up_proj": [
|
| 348 |
+
"gate_proj",
|
| 349 |
+
"up_proj",
|
| 350 |
+
],
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
# LoRA specific attributes
|
| 354 |
+
supported_lora_modules = [
|
| 355 |
+
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
| 356 |
+
"lm_head"
|
| 357 |
+
]
|
| 358 |
+
embedding_modules = {
|
| 359 |
+
"embed_tokens": "input_embeddings",
|
| 360 |
+
"lm_head": "output_embeddings",
|
| 361 |
+
}
|
| 362 |
+
embedding_padding_modules = ["lm_head"]
|
| 363 |
+
|
| 364 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 365 |
+
super().__init__()
|
| 366 |
+
config = vllm_config.model_config.hf_config
|
| 367 |
+
quant_config = vllm_config.quant_config
|
| 368 |
+
lora_config = vllm_config.lora_config
|
| 369 |
+
|
| 370 |
+
self.config = config
|
| 371 |
+
self.lora_config = lora_config
|
| 372 |
+
self.quant_config = quant_config
|
| 373 |
+
|
| 374 |
+
self.model = GraniteModel(vllm_config=vllm_config,
|
| 375 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 376 |
+
if get_pp_group().is_last_rank:
|
| 377 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 378 |
+
if lora_config:
|
| 379 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 380 |
+
self.lm_head = ParallelLMHead(
|
| 381 |
+
self.unpadded_vocab_size,
|
| 382 |
+
config.hidden_size,
|
| 383 |
+
org_num_embeddings=config.vocab_size,
|
| 384 |
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
| 385 |
+
# We need bigger padding if using lora for kernel
|
| 386 |
+
# compatibility
|
| 387 |
+
if not lora_config else lora_config.lora_vocab_padding_size,
|
| 388 |
+
quant_config=quant_config,
|
| 389 |
+
)
|
| 390 |
+
if config.tie_word_embeddings:
|
| 391 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 392 |
+
|
| 393 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 394 |
+
if hasattr(config, "logits_scaling"):
|
| 395 |
+
logit_scale /= config.logits_scaling
|
| 396 |
+
|
| 397 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 398 |
+
config.vocab_size,
|
| 399 |
+
scale=logit_scale)
|
| 400 |
+
else:
|
| 401 |
+
self.lm_head = PPMissingLayer()
|
| 402 |
+
|
| 403 |
+
self.sampler = get_sampler()
|
| 404 |
+
|
| 405 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 406 |
+
return self.model.get_input_embeddings(input_ids)
|
| 407 |
+
|
| 408 |
+
def forward(
|
| 409 |
+
self,
|
| 410 |
+
input_ids: torch.Tensor,
|
| 411 |
+
positions: torch.Tensor,
|
| 412 |
+
kv_caches: List[torch.Tensor],
|
| 413 |
+
attn_metadata: AttentionMetadata,
|
| 414 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 415 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 416 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 417 |
+
model_output = self.model(input_ids, positions, kv_caches,
|
| 418 |
+
attn_metadata, intermediate_tensors,
|
| 419 |
+
inputs_embeds)
|
| 420 |
+
return model_output
|
| 421 |
+
|
| 422 |
+
def compute_logits(
|
| 423 |
+
self, hidden_states: torch.Tensor,
|
| 424 |
+
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
|
| 425 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 426 |
+
sampling_metadata)
|
| 427 |
+
return logits
|
| 428 |
+
|
| 429 |
+
def sample(
|
| 430 |
+
self,
|
| 431 |
+
logits: torch.Tensor,
|
| 432 |
+
sampling_metadata: SamplingMetadata,
|
| 433 |
+
) -> Optional[SamplerOutput]:
|
| 434 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 435 |
+
return next_tokens
|
| 436 |
+
|
| 437 |
+
def make_empty_intermediate_tensors(
|
| 438 |
+
self, batch_size: int, dtype: torch.dtype,
|
| 439 |
+
device: torch.device) -> IntermediateTensors:
|
| 440 |
+
return IntermediateTensors({
|
| 441 |
+
"hidden_states":
|
| 442 |
+
torch.zeros((batch_size, self.config.hidden_size),
|
| 443 |
+
dtype=dtype,
|
| 444 |
+
device=device),
|
| 445 |
+
"residual":
|
| 446 |
+
torch.zeros((batch_size, self.config.hidden_size),
|
| 447 |
+
dtype=dtype,
|
| 448 |
+
device=device),
|
| 449 |
+
})
|
| 450 |
+
|
| 451 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 452 |
+
torch.Tensor]]) -> Set[str]:
|
| 453 |
+
stacked_params_mapping = [
|
| 454 |
+
# (param_name, shard_name, shard_id)
|
| 455 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 456 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 457 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 458 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 459 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 460 |
+
]
|
| 461 |
+
params_dict = dict(self.named_parameters())
|
| 462 |
+
loaded_params: Set[str] = set()
|
| 463 |
+
for name, loaded_weight in weights:
|
| 464 |
+
if "rotary_emb.inv_freq" in name:
|
| 465 |
+
continue
|
| 466 |
+
if ("rotary_emb.cos_cached" in name
|
| 467 |
+
or "rotary_emb.sin_cached" in name):
|
| 468 |
+
# Models trained using ColossalAI may include these tensors in
|
| 469 |
+
# the checkpoint. Skip them.
|
| 470 |
+
continue
|
| 471 |
+
# With tie_word_embeddings, we can skip lm_head.weight
|
| 472 |
+
# The weight might appear unnecessarily in the files if the model is
|
| 473 |
+
# processed with quantization, LoRA, fine-tuning, etc.
|
| 474 |
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
| 475 |
+
continue
|
| 476 |
+
if (self.quant_config is not None and
|
| 477 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 478 |
+
# Loading kv cache quantization scales
|
| 479 |
+
param = params_dict[scale_name]
|
| 480 |
+
weight_loader = getattr(param, "weight_loader",
|
| 481 |
+
default_weight_loader)
|
| 482 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 483 |
+
loaded_weight[0])
|
| 484 |
+
weight_loader(param, loaded_weight)
|
| 485 |
+
loaded_params.add(scale_name)
|
| 486 |
+
continue
|
| 487 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
| 488 |
+
if weight_name not in name:
|
| 489 |
+
continue
|
| 490 |
+
name = name.replace(weight_name, param_name)
|
| 491 |
+
# Skip loading extra bias for GPTQ models.
|
| 492 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 493 |
+
continue
|
| 494 |
+
|
| 495 |
+
if is_pp_missing_parameter(name, self):
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
param = params_dict[name]
|
| 499 |
+
weight_loader = param.weight_loader
|
| 500 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 501 |
+
|
| 502 |
+
break
|
| 503 |
+
else:
|
| 504 |
+
# Skip loading extra bias for GPTQ models.
|
| 505 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 506 |
+
continue
|
| 507 |
+
# Remapping the name of FP8 kv-scale.
|
| 508 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 509 |
+
if name is None:
|
| 510 |
+
continue
|
| 511 |
+
|
| 512 |
+
if is_pp_missing_parameter(name, self):
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
param = params_dict[name]
|
| 516 |
+
weight_loader = getattr(param, "weight_loader",
|
| 517 |
+
default_weight_loader)
|
| 518 |
+
weight_loader(param, loaded_weight)
|
| 519 |
+
loaded_params.add(name)
|
| 520 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/gritlm.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from array import array
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
| 9 |
+
|
| 10 |
+
from vllm.attention import AttentionMetadata
|
| 11 |
+
from vllm.attention.backends.xformers import XFormersImpl
|
| 12 |
+
from vllm.config import ModelConfig, VllmConfig
|
| 13 |
+
from vllm.logger import init_logger
|
| 14 |
+
from vllm.model_executor.layers.pooler import PoolerHead
|
| 15 |
+
from vllm.model_executor.models.llama import LlamaForCausalLM
|
| 16 |
+
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
| 17 |
+
PoolingTensors)
|
| 18 |
+
from vllm.multimodal.utils import cached_get_tokenizer
|
| 19 |
+
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
| 20 |
+
PoolingSequenceGroupOutput)
|
| 21 |
+
|
| 22 |
+
logger = init_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GritLMPooler(nn.Module):
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_config: ModelConfig):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.model_config = model_config
|
| 31 |
+
|
| 32 |
+
tokenizer = cached_get_tokenizer(
|
| 33 |
+
self.model_config.tokenizer,
|
| 34 |
+
tokenizer_mode=self.model_config.tokenizer_mode,
|
| 35 |
+
tokenizer_revision=self.model_config.tokenizer_revision,
|
| 36 |
+
trust_remote_code=self.model_config.trust_remote_code,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Collect the tokens needed for pattern matching.
|
| 40 |
+
# "▁<" is different from "_<". The former uses "▁" to indicate that
|
| 41 |
+
# the next token is the start of a word.
|
| 42 |
+
# "<0x0A>" is the newline token (i.e. "\n")."
|
| 43 |
+
self.token_ids = {
|
| 44 |
+
tok: tokenizer.convert_tokens_to_ids([tok])[0]
|
| 45 |
+
for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def tokens_to_ids(tokens: list[str]) -> array:
|
| 49 |
+
return array("i", [self.token_ids[token] for token in tokens])
|
| 50 |
+
|
| 51 |
+
self.user_pattern_ids = tokens_to_ids(
|
| 52 |
+
["▁<", "|", "user", "|", ">", "<0x0A>"])
|
| 53 |
+
self.embed_newline_pattern_ids = tokens_to_ids(
|
| 54 |
+
["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"])
|
| 55 |
+
self.embed_pattern_ids = tokens_to_ids(
|
| 56 |
+
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
| 57 |
+
|
| 58 |
+
self.head = PoolerHead(normalize=True, softmax=False)
|
| 59 |
+
|
| 60 |
+
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
| 61 |
+
"""
|
| 62 |
+
Find the first occurrence of target in arr starting from start_idx.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
arr: The array to search within
|
| 66 |
+
target: The consecutive subsequence to find
|
| 67 |
+
start_idx: The starting index to search from
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
int: The index of the first occurrence of target in arr.
|
| 71 |
+
"""
|
| 72 |
+
if start_idx < 0:
|
| 73 |
+
raise ValueError("start_idx must be non-negative")
|
| 74 |
+
if not target or not arr:
|
| 75 |
+
raise ValueError("Empty arr or target not allowed")
|
| 76 |
+
|
| 77 |
+
target_len = len(target)
|
| 78 |
+
for i in range(start_idx, len(arr) - target_len + 1):
|
| 79 |
+
if arr[i:i + target_len] == target:
|
| 80 |
+
return i
|
| 81 |
+
return -1
|
| 82 |
+
|
| 83 |
+
def _get_instruction_len(self, prompt_token_ids: array) -> int:
|
| 84 |
+
"""
|
| 85 |
+
Get the length of the instruction in the prompt.
|
| 86 |
+
|
| 87 |
+
We do a pattern matching to find the instruction in the prompt,
|
| 88 |
+
and then return the length of the instruction.
|
| 89 |
+
|
| 90 |
+
The pattern matching is done using integers instead of strings
|
| 91 |
+
because the prompt is given as a list of token IDs.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
instruction_len = 0
|
| 95 |
+
|
| 96 |
+
# Return no instruction in case of missing BOS token.
|
| 97 |
+
if prompt_token_ids[0] != self.token_ids["<s>"]:
|
| 98 |
+
logger.warning("BOS token not found in prompt,"
|
| 99 |
+
"thus using empty string for instruction."
|
| 100 |
+
"GritLM requires BOS token in prompt.")
|
| 101 |
+
return instruction_len
|
| 102 |
+
|
| 103 |
+
# If user pattern is found in the prompt, that means there should be
|
| 104 |
+
# a newline token before the embed pattern.
|
| 105 |
+
embed_pattern_ids = self.embed_pattern_ids
|
| 106 |
+
if self._find_array(prompt_token_ids,
|
| 107 |
+
self.user_pattern_ids,
|
| 108 |
+
start_idx=1) == 1:
|
| 109 |
+
embed_pattern_ids = self.embed_newline_pattern_ids
|
| 110 |
+
|
| 111 |
+
# Find the embed pattern in the prompt.
|
| 112 |
+
found_embed_pattern_idx = self._find_array(prompt_token_ids,
|
| 113 |
+
embed_pattern_ids,
|
| 114 |
+
start_idx=1)
|
| 115 |
+
|
| 116 |
+
if found_embed_pattern_idx != -1:
|
| 117 |
+
instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
|
| 118 |
+
else:
|
| 119 |
+
logger.warning("Query instruction not found in prompt,"
|
| 120 |
+
"thus using BOS token as instruction instead."
|
| 121 |
+
"GritLM requires query instruction in prompt.")
|
| 122 |
+
instruction_len = 1
|
| 123 |
+
|
| 124 |
+
return instruction_len
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
hidden_states: torch.Tensor,
|
| 129 |
+
pooling_metadata: PoolingMetadata,
|
| 130 |
+
) -> PoolerOutput:
|
| 131 |
+
"""
|
| 132 |
+
Pool the hidden states by summing the embeddings of
|
| 133 |
+
non-instruction tokens.
|
| 134 |
+
"""
|
| 135 |
+
prompts_token_ids = [
|
| 136 |
+
token_ids.prompt_token_ids_array
|
| 137 |
+
for _, token_ids in pooling_metadata.seq_data.items()
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
instruction_lens = torch.tensor(
|
| 141 |
+
[
|
| 142 |
+
self._get_instruction_len(prompt_token_ids)
|
| 143 |
+
for prompt_token_ids in prompts_token_ids
|
| 144 |
+
],
|
| 145 |
+
device=hidden_states.device,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
prompt_lens = PoolingTensors.from_pooling_metadata(
|
| 149 |
+
pooling_metadata, hidden_states.device).prompt_lens
|
| 150 |
+
|
| 151 |
+
mask = torch.zeros_like(hidden_states, dtype=torch.bool)
|
| 152 |
+
|
| 153 |
+
start_idx = 0
|
| 154 |
+
for prompt_len, instruction_len in zip(prompt_lens, instruction_lens):
|
| 155 |
+
end_idx = start_idx + prompt_len
|
| 156 |
+
mask[start_idx + instruction_len:end_idx] = True
|
| 157 |
+
start_idx = end_idx
|
| 158 |
+
|
| 159 |
+
masked_hidden_states = hidden_states.masked_fill(~mask, 0.0)
|
| 160 |
+
|
| 161 |
+
sum_embeddings = torch.zeros(len(prompt_lens),
|
| 162 |
+
hidden_states.size(1),
|
| 163 |
+
device=hidden_states.device)
|
| 164 |
+
|
| 165 |
+
start_idx = 0
|
| 166 |
+
for i, prompt_len in enumerate(prompt_lens):
|
| 167 |
+
end_idx = start_idx + prompt_len
|
| 168 |
+
sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum(
|
| 169 |
+
dim=0)
|
| 170 |
+
start_idx = end_idx
|
| 171 |
+
|
| 172 |
+
num_non_instruction_tokens = prompt_lens - instruction_lens
|
| 173 |
+
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
| 174 |
+
1)
|
| 175 |
+
|
| 176 |
+
pooled_data = self.head(mean_embeddings)
|
| 177 |
+
|
| 178 |
+
pooled_outputs = [
|
| 179 |
+
PoolingSequenceGroupOutput(data) for data in pooled_data
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
return PoolerOutput(outputs=pooled_outputs)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class GritLM(LlamaForCausalLM):
|
| 186 |
+
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
|
| 187 |
+
|
| 188 |
+
The class inherits from LlamaForCausalLM and provides a custom pooling
|
| 189 |
+
layer.
|
| 190 |
+
|
| 191 |
+
The main difference between the pooling layer in GritLM and the one in
|
| 192 |
+
LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
|
| 193 |
+
when pooling the hidden states.
|
| 194 |
+
|
| 195 |
+
Embedding prompts should be in the following format:
|
| 196 |
+
- With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT".
|
| 197 |
+
- Without instruction: "<|embed|>\nPROMPT".
|
| 198 |
+
|
| 199 |
+
Generation prompts should be in the following format:
|
| 200 |
+
- "<|user|>\nPROMPT\n<|assistant|>\n"
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
vllm_config: VllmConfig,
|
| 206 |
+
prefix: str = "",
|
| 207 |
+
**kwargs,
|
| 208 |
+
) -> None:
|
| 209 |
+
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
| 210 |
+
|
| 211 |
+
self.runner_type = vllm_config.model_config.runner_type
|
| 212 |
+
|
| 213 |
+
self._pooler = GritLMPooler(vllm_config.model_config)
|
| 214 |
+
|
| 215 |
+
for layer in self.model.layers:
|
| 216 |
+
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
|
| 217 |
+
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
|
| 218 |
+
"GritLM embedding is only supported by XFormers backend, "
|
| 219 |
+
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
input_ids: torch.Tensor,
|
| 224 |
+
positions: torch.Tensor,
|
| 225 |
+
kv_caches: List[torch.Tensor],
|
| 226 |
+
attn_metadata: AttentionMetadata,
|
| 227 |
+
**kwargs,
|
| 228 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 229 |
+
|
| 230 |
+
# Change attention to non-causal for pooling tasks.
|
| 231 |
+
if self.runner_type == "pooling":
|
| 232 |
+
assert attn_metadata.prefill_metadata.attn_bias is None
|
| 233 |
+
attn_metadata.prefill_metadata.attn_bias = [
|
| 234 |
+
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
return super().forward(
|
| 238 |
+
input_ids=input_ids,
|
| 239 |
+
positions=positions,
|
| 240 |
+
kv_caches=kv_caches,
|
| 241 |
+
attn_metadata=attn_metadata,
|
| 242 |
+
**kwargs,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def pooler(
|
| 246 |
+
self,
|
| 247 |
+
hidden_states: torch.Tensor,
|
| 248 |
+
pooling_metadata: PoolingMetadata,
|
| 249 |
+
) -> Optional[PoolerOutput]:
|
| 250 |
+
return self._pooler(hidden_states, pooling_metadata)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/idefics2_vision_model.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py
|
| 4 |
+
# Copyright 2024 The vLLM team.
|
| 5 |
+
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""PyTorch Idefics2 model."""
|
| 19 |
+
|
| 20 |
+
from typing import Iterable, Optional, Set, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from transformers.models.idefics2.configuration_idefics2 import (
|
| 25 |
+
Idefics2Config, Idefics2VisionConfig)
|
| 26 |
+
|
| 27 |
+
from vllm.attention.layer import MultiHeadAttention
|
| 28 |
+
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
| 29 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 30 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 31 |
+
QKVParallelLinear,
|
| 32 |
+
RowParallelLinear)
|
| 33 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 34 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Idefics2VisionEmbeddings(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
|
| 40 |
+
` to enable images of variable
|
| 41 |
+
resolution.
|
| 42 |
+
|
| 43 |
+
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
|
| 44 |
+
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
| 45 |
+
which allows treating images in their native aspect ratio and without the
|
| 46 |
+
need to resize them to the same fixed size. In particular, we start from the
|
| 47 |
+
original pre-trained SigLIP model(which uses images of fixed-size square
|
| 48 |
+
images) and adapt it by training on images of variable resolutions.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, config: Idefics2VisionConfig):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.embed_dim = config.hidden_size
|
| 54 |
+
self.image_size = config.image_size
|
| 55 |
+
self.patch_size = config.patch_size
|
| 56 |
+
self.patch_embedding = nn.Conv2d(
|
| 57 |
+
in_channels=config.num_channels,
|
| 58 |
+
out_channels=self.embed_dim,
|
| 59 |
+
kernel_size=self.patch_size,
|
| 60 |
+
stride=self.patch_size,
|
| 61 |
+
padding="valid",
|
| 62 |
+
)
|
| 63 |
+
self.num_patches_per_side = self.image_size // self.patch_size
|
| 64 |
+
self.num_patches = self.num_patches_per_side**2
|
| 65 |
+
self.num_positions = self.num_patches
|
| 66 |
+
self.position_embedding = nn.Embedding(self.num_positions,
|
| 67 |
+
self.embed_dim)
|
| 68 |
+
|
| 69 |
+
def forward(self,
|
| 70 |
+
pixel_values: torch.FloatTensor,
|
| 71 |
+
patch_attention_mask: torch.BoolTensor,
|
| 72 |
+
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
|
| 73 |
+
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
| 74 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 75 |
+
patch_embeds = self.patch_embedding(pixel_values.to(target_dtype))
|
| 76 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
| 77 |
+
max_nb_patches_h, max_nb_patches_w = (
|
| 78 |
+
max_im_h // self.patch_size,
|
| 79 |
+
max_im_w // self.patch_size,
|
| 80 |
+
)
|
| 81 |
+
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
|
| 82 |
+
1 / self.num_patches_per_side)
|
| 83 |
+
position_ids = torch.full(size=(batch_size,
|
| 84 |
+
max_nb_patches_h * max_nb_patches_w),
|
| 85 |
+
fill_value=0)
|
| 86 |
+
|
| 87 |
+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
| 88 |
+
|
| 89 |
+
if tgt_sizes is not None:
|
| 90 |
+
nb_patches_h = tgt_sizes[batch_idx][0]
|
| 91 |
+
nb_patches_w = tgt_sizes[batch_idx][1]
|
| 92 |
+
else:
|
| 93 |
+
nb_patches_h = p_attn_mask[:, 0].sum()
|
| 94 |
+
nb_patches_w = p_attn_mask[0].sum()
|
| 95 |
+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
| 96 |
+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
| 97 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h,
|
| 98 |
+
boundaries,
|
| 99 |
+
right=True)
|
| 100 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w,
|
| 101 |
+
boundaries,
|
| 102 |
+
right=True)
|
| 103 |
+
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
|
| 104 |
+
bucket_coords_w).flatten()
|
| 105 |
+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
| 106 |
+
position_ids = position_ids.to(self.position_embedding.weight.device)
|
| 107 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
| 108 |
+
return embeddings
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Idefics2VisionAttention(nn.Module):
|
| 112 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
config: Idefics2Config,
|
| 117 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 118 |
+
prefix: str = "",
|
| 119 |
+
) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.config = config
|
| 122 |
+
self.embed_dim = config.hidden_size
|
| 123 |
+
self.num_heads = config.num_attention_heads
|
| 124 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 125 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501
|
| 128 |
+
f" {self.num_heads}).")
|
| 129 |
+
self.scale = self.head_dim**-0.5
|
| 130 |
+
self.dropout = config.attention_dropout
|
| 131 |
+
self.qkv_proj = QKVParallelLinear(
|
| 132 |
+
self.embed_dim,
|
| 133 |
+
self.head_dim,
|
| 134 |
+
self.num_heads,
|
| 135 |
+
quant_config=quant_config,
|
| 136 |
+
prefix=f"{prefix}.qkv_proj",
|
| 137 |
+
)
|
| 138 |
+
self.out_proj = RowParallelLinear(
|
| 139 |
+
self.embed_dim,
|
| 140 |
+
self.embed_dim,
|
| 141 |
+
bias=True,
|
| 142 |
+
quant_config=quant_config,
|
| 143 |
+
prefix=f"{prefix}.out_proj",
|
| 144 |
+
)
|
| 145 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 146 |
+
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
| 147 |
+
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
| 148 |
+
self.head_dim, self.scale)
|
| 149 |
+
|
| 150 |
+
def forward(
|
| 151 |
+
self,
|
| 152 |
+
hidden_states: torch.Tensor,
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
qkv, _ = self.qkv_proj(
|
| 155 |
+
hidden_states
|
| 156 |
+
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
| 157 |
+
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
| 158 |
+
out = self.attn(query_states, key_states, value_states)
|
| 159 |
+
attn_output, _ = self.out_proj(out)
|
| 160 |
+
return attn_output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Idefics2VisionMLP(nn.Module):
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
config: Idefics2Config,
|
| 168 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 169 |
+
prefix: str = "",
|
| 170 |
+
) -> None:
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.config = config
|
| 173 |
+
self.activation_fn = get_act_fn(config.hidden_act)
|
| 174 |
+
self.fc1 = ColumnParallelLinear(
|
| 175 |
+
config.hidden_size,
|
| 176 |
+
config.intermediate_size,
|
| 177 |
+
bias=True,
|
| 178 |
+
quant_config=quant_config,
|
| 179 |
+
prefix=f"{prefix}.fc1",
|
| 180 |
+
)
|
| 181 |
+
self.fc2 = RowParallelLinear(
|
| 182 |
+
config.intermediate_size,
|
| 183 |
+
config.hidden_size,
|
| 184 |
+
bias=True,
|
| 185 |
+
quant_config=quant_config,
|
| 186 |
+
prefix=f"{prefix}.fc2",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
hidden_states, _ = self.fc1(hidden_states)
|
| 191 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 192 |
+
hidden_states, _ = self.fc2(hidden_states)
|
| 193 |
+
return hidden_states
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Idefics2EncoderLayer(nn.Module):
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
config: Idefics2Config,
|
| 201 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 202 |
+
prefix: str = "",
|
| 203 |
+
) -> None:
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.embed_dim = config.hidden_size
|
| 206 |
+
self.self_attn = Idefics2VisionAttention(config,
|
| 207 |
+
quant_config=quant_config,
|
| 208 |
+
prefix=f"{prefix}.self_attn")
|
| 209 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
| 210 |
+
eps=config.layer_norm_eps)
|
| 211 |
+
self.mlp = Idefics2VisionMLP(config,
|
| 212 |
+
quant_config=quant_config,
|
| 213 |
+
prefix=f"{prefix}.mlp")
|
| 214 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
| 215 |
+
eps=config.layer_norm_eps)
|
| 216 |
+
|
| 217 |
+
def forward(
|
| 218 |
+
self,
|
| 219 |
+
hidden_states: torch.Tensor,
|
| 220 |
+
) -> torch.Tensor:
|
| 221 |
+
"""
|
| 222 |
+
Args:
|
| 223 |
+
hidden_states (`torch.FloatTensor`):
|
| 224 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
| 225 |
+
|
| 226 |
+
"""
|
| 227 |
+
residual = hidden_states
|
| 228 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 229 |
+
hidden_states = self.self_attn(hidden_states)
|
| 230 |
+
hidden_states = residual + hidden_states
|
| 231 |
+
residual = hidden_states
|
| 232 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 233 |
+
hidden_states = self.mlp(hidden_states)
|
| 234 |
+
hidden_states = residual + hidden_states
|
| 235 |
+
return hidden_states
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class Idefics2Encoder(nn.Module):
|
| 239 |
+
"""
|
| 240 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention
|
| 241 |
+
layers. Each layer is a
|
| 242 |
+
[`Idefics2EncoderLayer`].
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
config: Idefics2Config
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
config: Idefics2Config,
|
| 251 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 252 |
+
prefix: str = "",
|
| 253 |
+
) -> None:
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
self.config = config
|
| 257 |
+
self.layers = nn.ModuleList([
|
| 258 |
+
Idefics2EncoderLayer(config,
|
| 259 |
+
quant_config=quant_config,
|
| 260 |
+
prefix=f"{prefix}.layers.{layer_idx}")
|
| 261 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 262 |
+
])
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
inputs_embeds: torch.Tensor,
|
| 267 |
+
) -> torch.Tensor:
|
| 268 |
+
r"""
|
| 269 |
+
Args:
|
| 270 |
+
inputs_embeds (torch.Tensor):
|
| 271 |
+
Optionally, instead of passing `input_ids` you can choose to
|
| 272 |
+
directly pass an embedded representation.
|
| 273 |
+
This is useful if you want more control over how to convert
|
| 274 |
+
`input_ids` indices into associated vectorsthan the model's
|
| 275 |
+
internal embedding lookup matrix.
|
| 276 |
+
"""
|
| 277 |
+
hidden_states = inputs_embeds
|
| 278 |
+
for encoder_layer in self.layers:
|
| 279 |
+
layer_outputs = encoder_layer(hidden_states)
|
| 280 |
+
hidden_states = layer_outputs
|
| 281 |
+
return hidden_states
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class Idefics2VisionTransformer(nn.Module):
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
config: Idefics2VisionConfig,
|
| 289 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 290 |
+
prefix: str = "",
|
| 291 |
+
) -> None:
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
embed_dim = config.hidden_size
|
| 295 |
+
self.config = config
|
| 296 |
+
self.embeddings = Idefics2VisionEmbeddings(config)
|
| 297 |
+
self.encoder = Idefics2Encoder(config,
|
| 298 |
+
quant_config=quant_config,
|
| 299 |
+
prefix=f"{prefix}.encoder")
|
| 300 |
+
self.post_layernorm = nn.LayerNorm(embed_dim,
|
| 301 |
+
eps=config.layer_norm_eps)
|
| 302 |
+
|
| 303 |
+
def get_input_embeddings(self):
|
| 304 |
+
return self.embeddings
|
| 305 |
+
|
| 306 |
+
def forward(
|
| 307 |
+
self,
|
| 308 |
+
pixel_values,
|
| 309 |
+
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
| 310 |
+
tgt_sizes: Optional[torch.IntTensor] = None,
|
| 311 |
+
) -> torch.Tensor:
|
| 312 |
+
hidden_states = self.embeddings(
|
| 313 |
+
pixel_values=pixel_values,
|
| 314 |
+
patch_attention_mask=patch_attention_mask,
|
| 315 |
+
tgt_sizes=tgt_sizes,
|
| 316 |
+
)
|
| 317 |
+
encoder_outputs = self.encoder(hidden_states)
|
| 318 |
+
last_hidden_state = self.post_layernorm(encoder_outputs)
|
| 319 |
+
return last_hidden_state
|
| 320 |
+
|
| 321 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 322 |
+
torch.Tensor]]) -> Set[str]:
|
| 323 |
+
stacked_params_mapping = [
|
| 324 |
+
# (param_name, shard_name, shard_id)
|
| 325 |
+
("qkv_proj", "q_proj", "q"),
|
| 326 |
+
("qkv_proj", "k_proj", "k"),
|
| 327 |
+
("qkv_proj", "v_proj", "v"),
|
| 328 |
+
]
|
| 329 |
+
params_dict = dict(self.named_parameters())
|
| 330 |
+
loaded_params: Set[str] = set()
|
| 331 |
+
for name, loaded_weight in weights:
|
| 332 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 333 |
+
if weight_name not in name:
|
| 334 |
+
continue
|
| 335 |
+
name = name.replace(weight_name, param_name)
|
| 336 |
+
param = params_dict[name]
|
| 337 |
+
weight_loader = param.weight_loader
|
| 338 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 339 |
+
break
|
| 340 |
+
else:
|
| 341 |
+
param = params_dict[name]
|
| 342 |
+
weight_loader = getattr(param, "weight_loader",
|
| 343 |
+
default_weight_loader)
|
| 344 |
+
weight_loader(param, loaded_weight)
|
| 345 |
+
loaded_params.add(name)
|
| 346 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/interfaces.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
| 4 |
+
Protocol, Type, Union, overload, runtime_checkable)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing_extensions import TypeIs, TypeVar
|
| 8 |
+
|
| 9 |
+
from vllm.logger import init_logger
|
| 10 |
+
from vllm.utils import supports_kw
|
| 11 |
+
|
| 12 |
+
from .interfaces_base import is_pooling_model
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from vllm.attention import AttentionMetadata
|
| 16 |
+
from vllm.multimodal.inputs import NestedTensors # noqa: F401
|
| 17 |
+
from vllm.sequence import IntermediateTensors
|
| 18 |
+
|
| 19 |
+
logger = init_logger(__name__)
|
| 20 |
+
|
| 21 |
+
T = TypeVar("T", default="NestedTensors")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@runtime_checkable
|
| 25 |
+
class SupportsMultiModal(Protocol):
|
| 26 |
+
"""The interface required for all multi-modal models."""
|
| 27 |
+
|
| 28 |
+
supports_multimodal: ClassVar[Literal[True]] = True
|
| 29 |
+
"""
|
| 30 |
+
A flag that indicates this model supports multi-modal inputs.
|
| 31 |
+
|
| 32 |
+
Note:
|
| 33 |
+
There is no need to redefine this flag if this class is in the
|
| 34 |
+
MRO of your model class.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
|
| 38 |
+
"""
|
| 39 |
+
Returns multimodal embeddings generated from multimodal kwargs
|
| 40 |
+
to be merged with text embeddings.
|
| 41 |
+
|
| 42 |
+
The output embeddings must be one of the following formats:
|
| 43 |
+
|
| 44 |
+
- A list or tuple of 2D tensors, where each tensor corresponds to
|
| 45 |
+
each input multimodal data item (e.g, image).
|
| 46 |
+
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
|
| 47 |
+
|
| 48 |
+
Note:
|
| 49 |
+
The returned multimodal embeddings must be in the same order as
|
| 50 |
+
the appearances of their corresponding multimodal data item in the
|
| 51 |
+
input prompt.
|
| 52 |
+
"""
|
| 53 |
+
...
|
| 54 |
+
|
| 55 |
+
# Only for models that support v0 chunked prefill
|
| 56 |
+
# TODO(ywang96): Remove this overload once v0 is deprecated
|
| 57 |
+
@overload
|
| 58 |
+
def get_input_embeddings(
|
| 59 |
+
self,
|
| 60 |
+
input_ids: torch.Tensor,
|
| 61 |
+
multimodal_embeddings: Optional[T] = None,
|
| 62 |
+
attn_metadata: Optional["AttentionMetadata"] = None,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
...
|
| 65 |
+
|
| 66 |
+
@overload
|
| 67 |
+
def get_input_embeddings(
|
| 68 |
+
self,
|
| 69 |
+
input_ids: torch.Tensor,
|
| 70 |
+
multimodal_embeddings: Optional[T] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
"""
|
| 73 |
+
Returns the input embeddings merged from the text embeddings from
|
| 74 |
+
input_ids and the multimodal embeddings generated from multimodal
|
| 75 |
+
kwargs.
|
| 76 |
+
"""
|
| 77 |
+
...
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# We can't use runtime_checkable with ClassVar for issubclass checks
|
| 81 |
+
# so we need to treat the class as an instance and use isinstance instead
|
| 82 |
+
@runtime_checkable
|
| 83 |
+
class _SupportsMultiModalType(Protocol):
|
| 84 |
+
supports_multimodal: Literal[True]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@overload
|
| 88 |
+
def supports_multimodal(
|
| 89 |
+
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@overload
|
| 94 |
+
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
|
| 95 |
+
...
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def supports_multimodal(
|
| 99 |
+
model: Union[Type[object], object],
|
| 100 |
+
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
|
| 101 |
+
if isinstance(model, type):
|
| 102 |
+
return isinstance(model, _SupportsMultiModalType)
|
| 103 |
+
|
| 104 |
+
return isinstance(model, SupportsMultiModal)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@runtime_checkable
|
| 108 |
+
class SupportsLoRA(Protocol):
|
| 109 |
+
"""The interface required for all models that support LoRA."""
|
| 110 |
+
|
| 111 |
+
supports_lora: ClassVar[Literal[True]] = True
|
| 112 |
+
"""
|
| 113 |
+
A flag that indicates this model supports LoRA.
|
| 114 |
+
|
| 115 |
+
Note:
|
| 116 |
+
There is no need to redefine this flag if this class is in the
|
| 117 |
+
MRO of your model class.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
|
| 121 |
+
supported_lora_modules: ClassVar[List[str]]
|
| 122 |
+
embedding_modules: ClassVar[Dict[str, str]]
|
| 123 |
+
embedding_padding_modules: ClassVar[List[str]]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# We can't use runtime_checkable with ClassVar for issubclass checks
|
| 127 |
+
# so we need to treat the class as an instance and use isinstance instead
|
| 128 |
+
@runtime_checkable
|
| 129 |
+
class _SupportsLoRAType(Protocol):
|
| 130 |
+
supports_lora: Literal[True]
|
| 131 |
+
|
| 132 |
+
packed_modules_mapping: Dict[str, List[str]]
|
| 133 |
+
supported_lora_modules: List[str]
|
| 134 |
+
embedding_modules: Dict[str, str]
|
| 135 |
+
embedding_padding_modules: List[str]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@overload
|
| 139 |
+
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
|
| 140 |
+
...
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@overload
|
| 144 |
+
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
|
| 145 |
+
...
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def supports_lora(
|
| 149 |
+
model: Union[Type[object], object],
|
| 150 |
+
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
|
| 151 |
+
result = _supports_lora(model)
|
| 152 |
+
|
| 153 |
+
if not result:
|
| 154 |
+
lora_attrs = (
|
| 155 |
+
"packed_modules_mapping",
|
| 156 |
+
"supported_lora_modules",
|
| 157 |
+
"embedding_modules",
|
| 158 |
+
"embedding_padding_modules",
|
| 159 |
+
)
|
| 160 |
+
missing_attrs = tuple(attr for attr in lora_attrs
|
| 161 |
+
if not hasattr(model, attr))
|
| 162 |
+
|
| 163 |
+
if getattr(model, "supports_lora", False):
|
| 164 |
+
if missing_attrs:
|
| 165 |
+
logger.warning(
|
| 166 |
+
"The model (%s) sets `supports_lora=True`, "
|
| 167 |
+
"but is missing LoRA-specific attributes: %s",
|
| 168 |
+
model,
|
| 169 |
+
missing_attrs,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
if not missing_attrs:
|
| 173 |
+
logger.warning(
|
| 174 |
+
"The model (%s) contains all LoRA-specific attributes, "
|
| 175 |
+
"but does not set `supports_lora=True`.", model)
|
| 176 |
+
|
| 177 |
+
return result
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _supports_lora(model: Union[Type[object], object]) -> bool:
|
| 181 |
+
if isinstance(model, type):
|
| 182 |
+
return isinstance(model, _SupportsLoRAType)
|
| 183 |
+
|
| 184 |
+
return isinstance(model, SupportsLoRA)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@runtime_checkable
|
| 188 |
+
class SupportsPP(Protocol):
|
| 189 |
+
"""The interface required for all models that support pipeline parallel."""
|
| 190 |
+
|
| 191 |
+
supports_pp: ClassVar[Literal[True]] = True
|
| 192 |
+
"""
|
| 193 |
+
A flag that indicates this model supports pipeline parallel.
|
| 194 |
+
|
| 195 |
+
Note:
|
| 196 |
+
There is no need to redefine this flag if this class is in the
|
| 197 |
+
MRO of your model class.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def make_empty_intermediate_tensors(
|
| 201 |
+
self,
|
| 202 |
+
batch_size: int,
|
| 203 |
+
dtype: torch.dtype,
|
| 204 |
+
device: torch.device,
|
| 205 |
+
) -> "IntermediateTensors":
|
| 206 |
+
"""Called when PP rank > 0 for profiling purposes."""
|
| 207 |
+
...
|
| 208 |
+
|
| 209 |
+
def forward(
|
| 210 |
+
self,
|
| 211 |
+
*,
|
| 212 |
+
intermediate_tensors: Optional["IntermediateTensors"],
|
| 213 |
+
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
| 214 |
+
"""
|
| 215 |
+
Accept :class:`IntermediateTensors` when PP rank > 0.
|
| 216 |
+
|
| 217 |
+
Return :class:`IntermediateTensors` only for the last PP rank.
|
| 218 |
+
"""
|
| 219 |
+
...
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# We can't use runtime_checkable with ClassVar for issubclass checks
|
| 223 |
+
# so we need to treat the class as an instance and use isinstance instead
|
| 224 |
+
@runtime_checkable
|
| 225 |
+
class _SupportsPPType(Protocol):
|
| 226 |
+
supports_pp: Literal[True]
|
| 227 |
+
|
| 228 |
+
def make_empty_intermediate_tensors(
|
| 229 |
+
self,
|
| 230 |
+
batch_size: int,
|
| 231 |
+
dtype: torch.dtype,
|
| 232 |
+
device: torch.device,
|
| 233 |
+
) -> "IntermediateTensors":
|
| 234 |
+
...
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
*,
|
| 239 |
+
intermediate_tensors: Optional["IntermediateTensors"],
|
| 240 |
+
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
| 241 |
+
...
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@overload
|
| 245 |
+
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
|
| 246 |
+
...
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@overload
|
| 250 |
+
def supports_pp(model: object) -> TypeIs[SupportsPP]:
|
| 251 |
+
...
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def supports_pp(
|
| 255 |
+
model: Union[Type[object], object],
|
| 256 |
+
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
|
| 257 |
+
supports_attributes = _supports_pp_attributes(model)
|
| 258 |
+
supports_inspect = _supports_pp_inspect(model)
|
| 259 |
+
|
| 260 |
+
if supports_attributes and not supports_inspect:
|
| 261 |
+
logger.warning(
|
| 262 |
+
"The model (%s) sets `supports_pp=True`, but does not accept "
|
| 263 |
+
"`intermediate_tensors` in its `forward` method", model)
|
| 264 |
+
|
| 265 |
+
if not supports_attributes:
|
| 266 |
+
pp_attrs = ("make_empty_intermediate_tensors", )
|
| 267 |
+
missing_attrs = tuple(attr for attr in pp_attrs
|
| 268 |
+
if not hasattr(model, attr))
|
| 269 |
+
|
| 270 |
+
if getattr(model, "supports_pp", False):
|
| 271 |
+
if missing_attrs:
|
| 272 |
+
logger.warning(
|
| 273 |
+
"The model (%s) sets `supports_pp=True`, "
|
| 274 |
+
"but is missing PP-specific attributes: %s",
|
| 275 |
+
model,
|
| 276 |
+
missing_attrs,
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
if not missing_attrs:
|
| 280 |
+
logger.warning(
|
| 281 |
+
"The model (%s) contains all PP-specific attributes, "
|
| 282 |
+
"but does not set `supports_pp=True`.", model)
|
| 283 |
+
|
| 284 |
+
return supports_attributes and supports_inspect
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
|
| 288 |
+
if isinstance(model, type):
|
| 289 |
+
return isinstance(model, _SupportsPPType)
|
| 290 |
+
|
| 291 |
+
return isinstance(model, SupportsPP)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
|
| 295 |
+
model_forward = getattr(model, "forward", None)
|
| 296 |
+
if not callable(model_forward):
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
return supports_kw(model_forward, "intermediate_tensors")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@runtime_checkable
|
| 303 |
+
class HasInnerState(Protocol):
|
| 304 |
+
"""The interface required for all models that has inner state."""
|
| 305 |
+
|
| 306 |
+
has_inner_state: ClassVar[Literal[True]] = True
|
| 307 |
+
"""
|
| 308 |
+
A flag that indicates this model has inner state.
|
| 309 |
+
Models that has inner state usually need access to the scheduler_config
|
| 310 |
+
for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@runtime_checkable
|
| 315 |
+
class _HasInnerStateType(Protocol):
|
| 316 |
+
has_inner_state: ClassVar[Literal[True]]
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
@overload
|
| 320 |
+
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
|
| 321 |
+
...
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@overload
|
| 325 |
+
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
|
| 326 |
+
...
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def has_inner_state(
|
| 330 |
+
model: Union[Type[object], object]
|
| 331 |
+
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
|
| 332 |
+
if isinstance(model, type):
|
| 333 |
+
return isinstance(model, _HasInnerStateType)
|
| 334 |
+
|
| 335 |
+
return isinstance(model, HasInnerState)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@runtime_checkable
|
| 339 |
+
class IsAttentionFree(Protocol):
|
| 340 |
+
"""The interface required for all models like Mamba that lack attention,
|
| 341 |
+
but do have state whose size is constant wrt the number of tokens."""
|
| 342 |
+
|
| 343 |
+
is_attention_free: ClassVar[Literal[True]] = True
|
| 344 |
+
"""
|
| 345 |
+
A flag that indicates this model has no attention.
|
| 346 |
+
Used for block manager and attention backend selection.
|
| 347 |
+
True for Mamba but not Jamba.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@runtime_checkable
|
| 352 |
+
class _IsAttentionFreeType(Protocol):
|
| 353 |
+
is_attention_free: ClassVar[Literal[True]]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@overload
|
| 357 |
+
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
|
| 358 |
+
...
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@overload
|
| 362 |
+
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
|
| 363 |
+
...
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def is_attention_free(
|
| 367 |
+
model: Union[Type[object], object]
|
| 368 |
+
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
|
| 369 |
+
if isinstance(model, type):
|
| 370 |
+
return isinstance(model, _IsAttentionFreeType)
|
| 371 |
+
|
| 372 |
+
return isinstance(model, IsAttentionFree)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@runtime_checkable
|
| 376 |
+
class IsHybrid(Protocol):
|
| 377 |
+
"""The interface required for all models like Jamba that have both
|
| 378 |
+
attention and mamba blocks, indicates that
|
| 379 |
+
hf_config has 'layers_block_type'"""
|
| 380 |
+
|
| 381 |
+
is_hybrid: ClassVar[Literal[True]] = True
|
| 382 |
+
"""
|
| 383 |
+
A flag that indicates this model has both mamba and attention blocks
|
| 384 |
+
, also indicates that the model's hf_config has
|
| 385 |
+
'layers_block_type' """
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@runtime_checkable
|
| 389 |
+
class _IsHybridType(Protocol):
|
| 390 |
+
is_hybrid: ClassVar[Literal[True]]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@overload
|
| 394 |
+
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
|
| 395 |
+
...
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
@overload
|
| 399 |
+
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
|
| 400 |
+
...
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def is_hybrid(
|
| 404 |
+
model: Union[Type[object], object]
|
| 405 |
+
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
|
| 406 |
+
if isinstance(model, type):
|
| 407 |
+
return isinstance(model, _IsHybridType)
|
| 408 |
+
|
| 409 |
+
return isinstance(model, IsHybrid)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@runtime_checkable
|
| 413 |
+
class SupportsCrossEncoding(Protocol):
|
| 414 |
+
"""The interface required for all models that support cross encoding."""
|
| 415 |
+
|
| 416 |
+
supports_cross_encoding: ClassVar[Literal[True]] = True
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
@overload
|
| 420 |
+
def supports_cross_encoding(
|
| 421 |
+
model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
|
| 422 |
+
...
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@overload
|
| 426 |
+
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
|
| 427 |
+
...
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _supports_cross_encoding(
|
| 431 |
+
model: Union[Type[object], object],
|
| 432 |
+
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
| 433 |
+
|
| 434 |
+
if isinstance(model, type):
|
| 435 |
+
return isinstance(model, SupportsCrossEncoding)
|
| 436 |
+
|
| 437 |
+
return isinstance(model, SupportsCrossEncoding)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def supports_cross_encoding(
|
| 441 |
+
model: Union[Type[object], object],
|
| 442 |
+
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
| 443 |
+
return is_pooling_model(model) and _supports_cross_encoding(model)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/interfaces_base.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
|
| 4 |
+
overload, runtime_checkable)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing_extensions import TypeIs, TypeVar
|
| 9 |
+
|
| 10 |
+
from vllm.logger import init_logger
|
| 11 |
+
from vllm.utils import supports_kw
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from vllm.attention import AttentionMetadata
|
| 15 |
+
from vllm.config import VllmConfig
|
| 16 |
+
from vllm.model_executor.layers.pooler import PoolerOutput
|
| 17 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 18 |
+
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
| 19 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
# The type of hidden states
|
| 24 |
+
# Currently, T = torch.Tensor for all models except for Medusa
|
| 25 |
+
# which has T = List[torch.Tensor]
|
| 26 |
+
T = TypeVar("T", default=torch.Tensor)
|
| 27 |
+
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
|
| 28 |
+
|
| 29 |
+
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
|
| 30 |
+
# for the base interfaces to avoid breaking OOT registration for existing models
|
| 31 |
+
# that don't inherit from the base interface classes
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@runtime_checkable
|
| 35 |
+
class VllmModel(Protocol[T_co]):
|
| 36 |
+
"""The interface required for all models in vLLM."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
vllm_config: "VllmConfig",
|
| 41 |
+
prefix: str = "",
|
| 42 |
+
) -> None:
|
| 43 |
+
...
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
input_ids: torch.Tensor,
|
| 48 |
+
positions: torch.Tensor,
|
| 49 |
+
kv_caches: List[torch.Tensor],
|
| 50 |
+
attn_metadata: "AttentionMetadata",
|
| 51 |
+
) -> T_co:
|
| 52 |
+
...
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
|
| 56 |
+
model_init = model.__init__
|
| 57 |
+
return supports_kw(model_init, "vllm_config")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
| 61 |
+
model_forward = getattr(model, "forward", None)
|
| 62 |
+
if not callable(model_forward):
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
|
| 66 |
+
missing_kws = tuple(kw for kw in vllm_kws
|
| 67 |
+
if not supports_kw(model_forward, kw))
|
| 68 |
+
|
| 69 |
+
if missing_kws and (isinstance(model, type)
|
| 70 |
+
and issubclass(model, nn.Module)):
|
| 71 |
+
logger.warning(
|
| 72 |
+
"The model (%s) is missing "
|
| 73 |
+
"vLLM-specific keywords from its `forward` method: %s",
|
| 74 |
+
model,
|
| 75 |
+
missing_kws,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return len(missing_kws) == 0
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@overload
|
| 82 |
+
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
|
| 83 |
+
...
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@overload
|
| 87 |
+
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
| 88 |
+
...
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def is_vllm_model(
|
| 92 |
+
model: Union[Type[object], object],
|
| 93 |
+
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
|
| 94 |
+
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@runtime_checkable
|
| 98 |
+
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
|
| 99 |
+
"""The interface required for all generative models in vLLM."""
|
| 100 |
+
|
| 101 |
+
def compute_logits(
|
| 102 |
+
self,
|
| 103 |
+
hidden_states: T,
|
| 104 |
+
sampling_metadata: "SamplingMetadata",
|
| 105 |
+
) -> Optional[T]:
|
| 106 |
+
"""Return `None` if TP rank > 0."""
|
| 107 |
+
...
|
| 108 |
+
|
| 109 |
+
def sample(
|
| 110 |
+
self,
|
| 111 |
+
logits: T,
|
| 112 |
+
sampling_metadata: "SamplingMetadata",
|
| 113 |
+
) -> "SamplerOutput":
|
| 114 |
+
"""Only called on TP rank 0."""
|
| 115 |
+
...
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@overload
|
| 119 |
+
def is_text_generation_model(
|
| 120 |
+
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
|
| 121 |
+
...
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@overload
|
| 125 |
+
def is_text_generation_model(
|
| 126 |
+
model: object) -> TypeIs[VllmModelForTextGeneration]:
|
| 127 |
+
...
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def is_text_generation_model(
|
| 131 |
+
model: Union[Type[object], object],
|
| 132 |
+
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
|
| 133 |
+
TypeIs[VllmModelForTextGeneration]]:
|
| 134 |
+
if not is_vllm_model(model):
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
if isinstance(model, type):
|
| 138 |
+
return isinstance(model, VllmModelForTextGeneration)
|
| 139 |
+
|
| 140 |
+
return isinstance(model, VllmModelForTextGeneration)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@runtime_checkable
|
| 144 |
+
class VllmModelForPooling(VllmModel[T], Protocol[T]):
|
| 145 |
+
"""The interface required for all pooling models in vLLM."""
|
| 146 |
+
|
| 147 |
+
def pooler(
|
| 148 |
+
self,
|
| 149 |
+
hidden_states: T,
|
| 150 |
+
pooling_metadata: "PoolingMetadata",
|
| 151 |
+
) -> "PoolerOutput":
|
| 152 |
+
"""Only called on TP rank 0."""
|
| 153 |
+
...
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@overload
|
| 157 |
+
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
|
| 158 |
+
...
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@overload
|
| 162 |
+
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
|
| 163 |
+
...
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def is_pooling_model(
|
| 167 |
+
model: Union[Type[object], object],
|
| 168 |
+
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
|
| 169 |
+
if not is_vllm_model(model):
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
if isinstance(model, type):
|
| 173 |
+
return isinstance(model, VllmModelForPooling)
|
| 174 |
+
|
| 175 |
+
return isinstance(model, VllmModelForPooling)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/intern_vit.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# InternVL
|
| 6 |
+
# Copyright (c) 2023 OpenGVLab
|
| 7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Iterable, Optional, Set, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from transformers import PretrainedConfig
|
| 16 |
+
|
| 17 |
+
from vllm.attention.layer import MultiHeadAttention
|
| 18 |
+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
| 19 |
+
get_tensor_model_parallel_world_size,
|
| 20 |
+
split_tensor_along_last_dim,
|
| 21 |
+
tensor_model_parallel_all_gather)
|
| 22 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 23 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 24 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 25 |
+
QKVParallelLinear,
|
| 26 |
+
RowParallelLinear)
|
| 27 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 28 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 29 |
+
|
| 30 |
+
NORM2FN = {
|
| 31 |
+
'rms_norm': RMSNorm,
|
| 32 |
+
'layer_norm': nn.LayerNorm,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class InternVisionEmbeddings(nn.Module):
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: PretrainedConfig):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.config = config
|
| 41 |
+
self.embed_dim = config.hidden_size
|
| 42 |
+
self.image_size = config.image_size
|
| 43 |
+
self.patch_size = config.patch_size
|
| 44 |
+
|
| 45 |
+
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
| 46 |
+
|
| 47 |
+
self.patch_embedding = nn.Conv2d(in_channels=3,
|
| 48 |
+
out_channels=self.embed_dim,
|
| 49 |
+
kernel_size=self.patch_size,
|
| 50 |
+
stride=self.patch_size)
|
| 51 |
+
|
| 52 |
+
self.num_patches = (self.image_size // self.patch_size)**2
|
| 53 |
+
self.num_positions = self.num_patches + 1
|
| 54 |
+
|
| 55 |
+
self.position_embedding = nn.Parameter(
|
| 56 |
+
torch.randn(1, self.num_positions, self.embed_dim))
|
| 57 |
+
|
| 58 |
+
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
|
| 59 |
+
target_dtype = pos_embed.dtype
|
| 60 |
+
pos_embed = pos_embed.float().reshape(
|
| 61 |
+
1, self.image_size // self.patch_size,
|
| 62 |
+
self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
|
| 63 |
+
pos_embed = F.interpolate(pos_embed,
|
| 64 |
+
size=(H, W),
|
| 65 |
+
mode='bicubic',
|
| 66 |
+
align_corners=False)
|
| 67 |
+
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
|
| 68 |
+
1).to(target_dtype)
|
| 69 |
+
|
| 70 |
+
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
|
| 71 |
+
position_embedding = self.position_embedding
|
| 72 |
+
if self.num_patches == H * W:
|
| 73 |
+
return position_embedding
|
| 74 |
+
|
| 75 |
+
return torch.cat(
|
| 76 |
+
[
|
| 77 |
+
position_embedding[:, :1, :],
|
| 78 |
+
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
|
| 79 |
+
],
|
| 80 |
+
dim=1,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 84 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 85 |
+
patch_embeds = self.patch_embedding(pixel_values.to(
|
| 86 |
+
target_dtype)) # shape = [*, channel, width, height]
|
| 87 |
+
batch_size, _, height, width = patch_embeds.shape
|
| 88 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 89 |
+
class_embeds = self.class_embedding.expand(batch_size, 1,
|
| 90 |
+
-1).to(target_dtype)
|
| 91 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 92 |
+
position_embedding = self._get_position_embedding(height, width)
|
| 93 |
+
embeddings = embeddings + position_embedding.to(target_dtype)
|
| 94 |
+
return embeddings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class InternVisionPatchModel(nn.Module):
|
| 98 |
+
|
| 99 |
+
def __init__(self, config: PretrainedConfig):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.config = config
|
| 102 |
+
self.embeddings = InternVisionEmbeddings(config)
|
| 103 |
+
|
| 104 |
+
def get_input_embeddings(self):
|
| 105 |
+
return self.embeddings
|
| 106 |
+
|
| 107 |
+
def forward(
|
| 108 |
+
self,
|
| 109 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 110 |
+
pixel_embeds: Optional[torch.Tensor] = None,
|
| 111 |
+
) -> torch.FloatTensor:
|
| 112 |
+
if pixel_values is None and pixel_embeds is None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
'You have to specify pixel_values or pixel_embeds')
|
| 115 |
+
|
| 116 |
+
if pixel_embeds is not None:
|
| 117 |
+
hidden_states = pixel_embeds
|
| 118 |
+
elif pixel_values is not None:
|
| 119 |
+
if pixel_values.ndim == 4:
|
| 120 |
+
hidden_states = self.embeddings(pixel_values)
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f'wrong pixel_values size: {pixel_values.shape}')
|
| 124 |
+
|
| 125 |
+
return hidden_states
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class InternParallelAttention(nn.Module):
|
| 129 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
config: PretrainedConfig,
|
| 134 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 135 |
+
*,
|
| 136 |
+
num_dummy_heads: int = 0,
|
| 137 |
+
prefix: str = "",
|
| 138 |
+
) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
|
| 141 |
+
self.config = config
|
| 142 |
+
self.embed_dim = config.hidden_size
|
| 143 |
+
self.num_heads = config.num_attention_heads
|
| 144 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 145 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f'embed_dim must be divisible by num_heads '
|
| 148 |
+
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
| 149 |
+
f' {self.num_heads}).')
|
| 150 |
+
|
| 151 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 152 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 153 |
+
|
| 154 |
+
# Additional dummy heads are used to enable TP for common GPU counts.
|
| 155 |
+
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
| 156 |
+
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
|
| 157 |
+
self.tp_size)
|
| 158 |
+
|
| 159 |
+
self.scale = self.head_dim**-0.5
|
| 160 |
+
self.qkv = QKVParallelLinear(
|
| 161 |
+
self.embed_dim,
|
| 162 |
+
self.head_dim,
|
| 163 |
+
num_dummy_heads + self.num_heads,
|
| 164 |
+
bias=config.qkv_bias,
|
| 165 |
+
quant_config=quant_config,
|
| 166 |
+
prefix=f"{prefix}.qkv",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.qk_normalization = config.qk_normalization
|
| 170 |
+
|
| 171 |
+
if self.qk_normalization:
|
| 172 |
+
self.q_norm = RMSNorm(self.dummy_dim,
|
| 173 |
+
eps=config.layer_norm_eps,
|
| 174 |
+
var_hidden_size=self.embed_dim)
|
| 175 |
+
self.k_norm = RMSNorm(self.dummy_dim,
|
| 176 |
+
eps=config.layer_norm_eps,
|
| 177 |
+
var_hidden_size=self.embed_dim)
|
| 178 |
+
|
| 179 |
+
self.proj = RowParallelLinear(
|
| 180 |
+
self.dummy_dim,
|
| 181 |
+
self.embed_dim,
|
| 182 |
+
quant_config=quant_config,
|
| 183 |
+
prefix=f"{prefix}.proj",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
| 187 |
+
self.head_dim, self.scale)
|
| 188 |
+
|
| 189 |
+
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
| 190 |
+
if self.tp_size > 1:
|
| 191 |
+
q = tensor_model_parallel_all_gather(q.contiguous())
|
| 192 |
+
k = tensor_model_parallel_all_gather(k.contiguous())
|
| 193 |
+
q = self.q_norm.forward_native(q)
|
| 194 |
+
k = self.k_norm.forward_native(k)
|
| 195 |
+
if self.tp_size > 1:
|
| 196 |
+
splitter = partial(split_tensor_along_last_dim,
|
| 197 |
+
num_partitions=self.tp_size)
|
| 198 |
+
q = splitter(q)[self.tp_rank]
|
| 199 |
+
k = splitter(k)[self.tp_rank]
|
| 200 |
+
return q, k
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
B, N, _ = x.shape
|
| 204 |
+
qkv, _ = self.qkv(x)
|
| 205 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 206 |
+
|
| 207 |
+
if self.qk_normalization:
|
| 208 |
+
q, k = self._apply_qk_norm(q, k)
|
| 209 |
+
|
| 210 |
+
out = self.attn(q, k, v)
|
| 211 |
+
out, _ = self.proj(out)
|
| 212 |
+
return out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class InternSdpaAttention(nn.Module):
|
| 216 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 217 |
+
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
config: PretrainedConfig,
|
| 221 |
+
*,
|
| 222 |
+
num_dummy_heads: int = 0,
|
| 223 |
+
) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
|
| 226 |
+
self.config = config
|
| 227 |
+
self.embed_dim = config.hidden_size
|
| 228 |
+
self.num_heads = config.num_attention_heads
|
| 229 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 230 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 231 |
+
raise ValueError(
|
| 232 |
+
f'embed_dim must be divisible by num_heads '
|
| 233 |
+
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
| 234 |
+
f' {self.num_heads}).')
|
| 235 |
+
|
| 236 |
+
# Additional dummy heads are used to enable TP for common GPU counts.
|
| 237 |
+
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
| 238 |
+
|
| 239 |
+
self.scale = self.head_dim**-0.5
|
| 240 |
+
self.qkv = nn.Linear(self.embed_dim,
|
| 241 |
+
3 * self.dummy_dim,
|
| 242 |
+
bias=config.qkv_bias)
|
| 243 |
+
|
| 244 |
+
self.qk_normalization = config.qk_normalization
|
| 245 |
+
|
| 246 |
+
if self.qk_normalization:
|
| 247 |
+
self.q_norm = RMSNorm(self.dummy_dim,
|
| 248 |
+
eps=config.layer_norm_eps,
|
| 249 |
+
var_hidden_size=self.embed_dim)
|
| 250 |
+
self.k_norm = RMSNorm(self.dummy_dim,
|
| 251 |
+
eps=config.layer_norm_eps,
|
| 252 |
+
var_hidden_size=self.embed_dim)
|
| 253 |
+
|
| 254 |
+
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
|
| 255 |
+
|
| 256 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 257 |
+
B, N, C = x.shape
|
| 258 |
+
qkv = self.qkv(x)
|
| 259 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 260 |
+
|
| 261 |
+
q = q.view(B, N, self.num_heads, self.head_dim)
|
| 262 |
+
k = k.view(B, N, self.num_heads, self.head_dim)
|
| 263 |
+
v = v.view(B, N, self.num_heads, self.head_dim)
|
| 264 |
+
|
| 265 |
+
if self.qk_normalization:
|
| 266 |
+
B_, N_, H_, D_ = q.shape
|
| 267 |
+
q = self.q_norm.forward_native(q.flatten(-2,
|
| 268 |
+
-1)).view(B_, N_, H_, D_)
|
| 269 |
+
k = self.k_norm.forward_native(k.flatten(-2,
|
| 270 |
+
-1)).view(B_, N_, H_, D_)
|
| 271 |
+
q = q.transpose(1, 2)
|
| 272 |
+
k = k.transpose(1, 2)
|
| 273 |
+
v = v.transpose(1, 2)
|
| 274 |
+
|
| 275 |
+
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
| 276 |
+
x = x.transpose(1, 2).reshape(B, N, -1)
|
| 277 |
+
|
| 278 |
+
x = self.proj(x)
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class InternMLP(nn.Module):
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
config: PretrainedConfig,
|
| 287 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 288 |
+
prefix: str = "",
|
| 289 |
+
) -> None:
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.config = config
|
| 293 |
+
self.activation_fn = get_act_fn(config.hidden_act)
|
| 294 |
+
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
| 295 |
+
config.intermediate_size,
|
| 296 |
+
bias=True,
|
| 297 |
+
quant_config=quant_config,
|
| 298 |
+
prefix=f"{prefix}.fc1")
|
| 299 |
+
self.fc2 = RowParallelLinear(config.intermediate_size,
|
| 300 |
+
config.hidden_size,
|
| 301 |
+
bias=True,
|
| 302 |
+
quant_config=quant_config,
|
| 303 |
+
prefix=f"{prefix}.fc2")
|
| 304 |
+
|
| 305 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 306 |
+
hidden_states, _ = self.fc1(hidden_states)
|
| 307 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 308 |
+
hidden_states, _ = self.fc2(hidden_states)
|
| 309 |
+
|
| 310 |
+
return hidden_states
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class InternVisionEncoderLayer(nn.Module):
|
| 314 |
+
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
config: PretrainedConfig,
|
| 318 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 319 |
+
*,
|
| 320 |
+
num_dummy_heads: int = 0,
|
| 321 |
+
prefix: str = "",
|
| 322 |
+
) -> None:
|
| 323 |
+
super().__init__()
|
| 324 |
+
|
| 325 |
+
self.embed_dim = config.hidden_size
|
| 326 |
+
self.intermediate_size = config.intermediate_size
|
| 327 |
+
self.norm_type = config.norm_type
|
| 328 |
+
|
| 329 |
+
self.attn = self._init_attn(config,
|
| 330 |
+
quant_config,
|
| 331 |
+
num_dummy_heads=num_dummy_heads,
|
| 332 |
+
prefix=f"{prefix}.attn")
|
| 333 |
+
|
| 334 |
+
self.mlp = InternMLP(config,
|
| 335 |
+
quant_config=quant_config,
|
| 336 |
+
prefix=f"{prefix}.mlp")
|
| 337 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
| 338 |
+
eps=config.layer_norm_eps)
|
| 339 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
|
| 340 |
+
eps=config.layer_norm_eps)
|
| 341 |
+
|
| 342 |
+
self.ls1 = nn.Parameter(config.initializer_factor *
|
| 343 |
+
torch.ones(self.embed_dim))
|
| 344 |
+
self.ls2 = nn.Parameter(config.initializer_factor *
|
| 345 |
+
torch.ones(self.embed_dim))
|
| 346 |
+
|
| 347 |
+
def _init_attn(
|
| 348 |
+
self,
|
| 349 |
+
config: PretrainedConfig,
|
| 350 |
+
quant_config: Optional[QuantizationConfig],
|
| 351 |
+
*,
|
| 352 |
+
num_dummy_heads: int,
|
| 353 |
+
prefix: str = "",
|
| 354 |
+
):
|
| 355 |
+
# fallback to sdpa attention if tp unavailable
|
| 356 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 357 |
+
num_heads = config.num_attention_heads
|
| 358 |
+
|
| 359 |
+
if (num_heads + num_dummy_heads) % tp_size == 0:
|
| 360 |
+
return InternParallelAttention(config,
|
| 361 |
+
quant_config=quant_config,
|
| 362 |
+
num_dummy_heads=num_dummy_heads,
|
| 363 |
+
prefix=prefix)
|
| 364 |
+
|
| 365 |
+
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
hidden_states: torch.Tensor,
|
| 370 |
+
):
|
| 371 |
+
hidden_states = hidden_states + self.attn(
|
| 372 |
+
self.norm1(hidden_states)) * self.ls1
|
| 373 |
+
|
| 374 |
+
hidden_states = hidden_states + self.mlp(
|
| 375 |
+
self.norm2(hidden_states)) * self.ls2
|
| 376 |
+
|
| 377 |
+
return hidden_states
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class InternVisionEncoder(nn.Module):
|
| 381 |
+
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
config: PretrainedConfig,
|
| 385 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 386 |
+
*,
|
| 387 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 388 |
+
num_dummy_heads: int = 0,
|
| 389 |
+
prefix: str = "",
|
| 390 |
+
):
|
| 391 |
+
super().__init__()
|
| 392 |
+
|
| 393 |
+
self.config = config
|
| 394 |
+
|
| 395 |
+
if num_hidden_layers_override is None:
|
| 396 |
+
num_hidden_layers = config.num_hidden_layers
|
| 397 |
+
else:
|
| 398 |
+
num_hidden_layers = num_hidden_layers_override
|
| 399 |
+
|
| 400 |
+
self.layers = nn.ModuleList([
|
| 401 |
+
InternVisionEncoderLayer(config,
|
| 402 |
+
quant_config,
|
| 403 |
+
num_dummy_heads=num_dummy_heads,
|
| 404 |
+
prefix=f"{prefix}.layers.{layer_idx}")
|
| 405 |
+
for layer_idx in range(num_hidden_layers)
|
| 406 |
+
])
|
| 407 |
+
|
| 408 |
+
def forward(self, inputs_embeds: torch.Tensor):
|
| 409 |
+
|
| 410 |
+
hidden_states = inputs_embeds
|
| 411 |
+
for encoder_layer in self.layers:
|
| 412 |
+
hidden_states = encoder_layer(hidden_states)
|
| 413 |
+
|
| 414 |
+
return hidden_states
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class InternVisionModel(nn.Module):
|
| 418 |
+
|
| 419 |
+
def __init__(
|
| 420 |
+
self,
|
| 421 |
+
config: PretrainedConfig,
|
| 422 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 423 |
+
*,
|
| 424 |
+
num_hidden_layers_override: Optional[int] = None,
|
| 425 |
+
num_dummy_heads: int = 0,
|
| 426 |
+
prefix: str = "",
|
| 427 |
+
) -> None:
|
| 428 |
+
super().__init__()
|
| 429 |
+
|
| 430 |
+
self.config = config
|
| 431 |
+
|
| 432 |
+
self.embeddings = InternVisionEmbeddings(config)
|
| 433 |
+
self.encoder = InternVisionEncoder(
|
| 434 |
+
config=config,
|
| 435 |
+
quant_config=quant_config,
|
| 436 |
+
num_hidden_layers_override=num_hidden_layers_override,
|
| 437 |
+
num_dummy_heads=num_dummy_heads,
|
| 438 |
+
prefix=f"{prefix}.encoder",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def get_input_embeddings(self):
|
| 442 |
+
return self.embeddings
|
| 443 |
+
|
| 444 |
+
def forward(
|
| 445 |
+
self,
|
| 446 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 447 |
+
pixel_embeds: Optional[torch.Tensor] = None,
|
| 448 |
+
) -> torch.FloatTensor:
|
| 449 |
+
if pixel_values is None and pixel_embeds is None:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
'You have to specify pixel_values or pixel_embeds')
|
| 452 |
+
|
| 453 |
+
if pixel_embeds is not None:
|
| 454 |
+
hidden_states = pixel_embeds
|
| 455 |
+
elif pixel_values is not None:
|
| 456 |
+
if pixel_values.ndim == 4:
|
| 457 |
+
hidden_states = self.embeddings(pixel_values)
|
| 458 |
+
else:
|
| 459 |
+
raise ValueError(
|
| 460 |
+
f'wrong pixel_values size: {pixel_values.shape}')
|
| 461 |
+
|
| 462 |
+
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
| 463 |
+
|
| 464 |
+
return encoder_outputs
|
| 465 |
+
|
| 466 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 467 |
+
torch.Tensor]]) -> Set[str]:
|
| 468 |
+
params_dict = dict(self.named_parameters())
|
| 469 |
+
loaded_params: Set[str] = set()
|
| 470 |
+
for name, loaded_weight in weights:
|
| 471 |
+
param = params_dict[name]
|
| 472 |
+
weight_loader = getattr(param, "weight_loader",
|
| 473 |
+
default_weight_loader)
|
| 474 |
+
weight_loader(param, loaded_weight)
|
| 475 |
+
loaded_params.add(name)
|
| 476 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/internlm2_ve.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
|
| 9 |
+
from vllm.attention import AttentionMetadata
|
| 10 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 11 |
+
from vllm.distributed import get_pp_group
|
| 12 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 13 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 14 |
+
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
|
| 15 |
+
InternLM2ForCausalLM,
|
| 16 |
+
InternLM2MLP, InternLM2Model)
|
| 17 |
+
from vllm.sequence import IntermediateTensors
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class InternLM2VEDecoderLayer(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
config: PretrainedConfig,
|
| 25 |
+
cache_config: Optional[CacheConfig] = None,
|
| 26 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 27 |
+
prefix: str = "",
|
| 28 |
+
) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.hidden_size = config.hidden_size
|
| 31 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 32 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 33 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 34 |
+
8192)
|
| 35 |
+
self.attention = InternLM2Attention(
|
| 36 |
+
hidden_size=self.hidden_size,
|
| 37 |
+
num_heads=config.num_attention_heads,
|
| 38 |
+
num_kv_heads=config.num_key_value_heads,
|
| 39 |
+
rope_theta=rope_theta,
|
| 40 |
+
rope_scaling=rope_scaling,
|
| 41 |
+
max_position_embeddings=max_position_embeddings,
|
| 42 |
+
cache_config=cache_config,
|
| 43 |
+
quant_config=quant_config,
|
| 44 |
+
prefix=f"{prefix}.attention",
|
| 45 |
+
)
|
| 46 |
+
self.feed_forward = InternLM2MLP(
|
| 47 |
+
hidden_size=self.hidden_size,
|
| 48 |
+
intermediate_size=config.intermediate_size,
|
| 49 |
+
hidden_act=config.hidden_act,
|
| 50 |
+
quant_config=quant_config,
|
| 51 |
+
prefix=f"{prefix}.feed_forward",
|
| 52 |
+
)
|
| 53 |
+
self.feed_forward_ve = InternLM2MLP(
|
| 54 |
+
hidden_size=self.hidden_size,
|
| 55 |
+
intermediate_size=config.intermediate_size,
|
| 56 |
+
hidden_act=config.hidden_act,
|
| 57 |
+
quant_config=quant_config,
|
| 58 |
+
prefix=f"{prefix}.feed_forward_ve",
|
| 59 |
+
)
|
| 60 |
+
self.attention_norm = RMSNorm(config.hidden_size,
|
| 61 |
+
eps=config.rms_norm_eps)
|
| 62 |
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
positions: torch.Tensor,
|
| 67 |
+
hidden_states: torch.Tensor,
|
| 68 |
+
kv_cache: torch.Tensor,
|
| 69 |
+
attn_metadata: AttentionMetadata,
|
| 70 |
+
residual: Optional[torch.Tensor],
|
| 71 |
+
visual_token_mask: Optional[torch.Tensor] = None,
|
| 72 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 73 |
+
# Self Attention
|
| 74 |
+
if residual is None:
|
| 75 |
+
residual = hidden_states
|
| 76 |
+
hidden_states = self.attention_norm(hidden_states)
|
| 77 |
+
else:
|
| 78 |
+
hidden_states, residual = self.attention_norm(
|
| 79 |
+
hidden_states, residual)
|
| 80 |
+
hidden_states = self.attention(
|
| 81 |
+
positions=positions,
|
| 82 |
+
hidden_states=hidden_states,
|
| 83 |
+
kv_cache=kv_cache,
|
| 84 |
+
attn_metadata=attn_metadata,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Fully Connected
|
| 88 |
+
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
| 89 |
+
if visual_token_mask is not None and visual_token_mask.any():
|
| 90 |
+
visual_token_mask = visual_token_mask.repeat(
|
| 91 |
+
1, self.hidden_size).bool()
|
| 92 |
+
text_token_mask = ~visual_token_mask
|
| 93 |
+
hidden_states[visual_token_mask] = self.feed_forward_ve(
|
| 94 |
+
hidden_states[visual_token_mask].reshape(
|
| 95 |
+
-1, self.hidden_size)).flatten()
|
| 96 |
+
if text_token_mask.any():
|
| 97 |
+
hidden_states[text_token_mask] = self.feed_forward(
|
| 98 |
+
hidden_states[text_token_mask].reshape(
|
| 99 |
+
-1, self.hidden_size)).flatten()
|
| 100 |
+
else:
|
| 101 |
+
hidden_states = self.feed_forward(hidden_states)
|
| 102 |
+
return hidden_states, residual
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class InternLM2VEModel(InternLM2Model):
|
| 106 |
+
|
| 107 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 108 |
+
super().__init__(vllm_config=vllm_config,
|
| 109 |
+
prefix=prefix,
|
| 110 |
+
layer_type=InternLM2VEDecoderLayer)
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
input_ids: torch.Tensor,
|
| 115 |
+
positions: torch.Tensor,
|
| 116 |
+
kv_caches: List[torch.Tensor],
|
| 117 |
+
attn_metadata: AttentionMetadata,
|
| 118 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 119 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 120 |
+
visual_token_mask: Optional[torch.Tensor] = None,
|
| 121 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 122 |
+
if get_pp_group().is_first_rank:
|
| 123 |
+
if inputs_embeds is not None:
|
| 124 |
+
hidden_states = inputs_embeds
|
| 125 |
+
else:
|
| 126 |
+
hidden_states = self.tok_embeddings(input_ids)
|
| 127 |
+
residual = None
|
| 128 |
+
else:
|
| 129 |
+
assert intermediate_tensors is not None
|
| 130 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 131 |
+
residual = intermediate_tensors["residual"]
|
| 132 |
+
for i in range(self.start_layer, self.end_layer):
|
| 133 |
+
layer = self.layers[i]
|
| 134 |
+
hidden_states, residual = layer(
|
| 135 |
+
positions,
|
| 136 |
+
hidden_states,
|
| 137 |
+
kv_caches[i - self.start_layer],
|
| 138 |
+
attn_metadata,
|
| 139 |
+
residual,
|
| 140 |
+
visual_token_mask=visual_token_mask,
|
| 141 |
+
)
|
| 142 |
+
if not get_pp_group().is_last_rank:
|
| 143 |
+
return IntermediateTensors({
|
| 144 |
+
"hidden_states": hidden_states,
|
| 145 |
+
"residual": residual
|
| 146 |
+
})
|
| 147 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 148 |
+
return hidden_states
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
| 152 |
+
|
| 153 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 154 |
+
super().__init__(vllm_config=vllm_config,
|
| 155 |
+
prefix=prefix,
|
| 156 |
+
model_type=InternLM2VEModel)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/jais.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
|
| 7 |
+
# reserved.
|
| 8 |
+
# Copyright 2023 Cerebras Systems.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
"""Inference-only Jais model compatible with HuggingFace weights."""
|
| 22 |
+
|
| 23 |
+
import math
|
| 24 |
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from torch import nn
|
| 28 |
+
|
| 29 |
+
from vllm.attention import Attention, AttentionMetadata
|
| 30 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 31 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 32 |
+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
| 33 |
+
get_tensor_model_parallel_world_size)
|
| 34 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 35 |
+
QKVParallelLinear,
|
| 36 |
+
RowParallelLinear)
|
| 37 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 38 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 39 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 40 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 41 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 42 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 43 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 44 |
+
from vllm.sequence import IntermediateTensors
|
| 45 |
+
from vllm.transformers_utils.configs import JAISConfig
|
| 46 |
+
|
| 47 |
+
from .interfaces import SupportsPP
|
| 48 |
+
from .utils import (is_pp_missing_parameter,
|
| 49 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 50 |
+
maybe_prefix)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SwiGLUActivation(nn.Module):
|
| 54 |
+
|
| 55 |
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
return x1 * nn.functional.silu(x2)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_alibi_slopes(n):
|
| 60 |
+
|
| 61 |
+
def get_slopes_power_of_2(n):
|
| 62 |
+
start = 2**(-(2**-(math.log2(n) - 3)))
|
| 63 |
+
ratio = start
|
| 64 |
+
return [start * ratio**i for i in range(n)]
|
| 65 |
+
|
| 66 |
+
if math.log2(n).is_integer():
|
| 67 |
+
return get_slopes_power_of_2(n)
|
| 68 |
+
else:
|
| 69 |
+
closest_power_of_2 = 2**math.floor(math.log2(n))
|
| 70 |
+
return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
|
| 71 |
+
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class JAISAttention(nn.Module):
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
config: JAISConfig,
|
| 79 |
+
cache_config: Optional[CacheConfig] = None,
|
| 80 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 81 |
+
prefix: str = "",
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.hidden_size = config.hidden_size
|
| 85 |
+
total_num_heads = config.num_attention_heads
|
| 86 |
+
tensor_model_parallel_world_size = (
|
| 87 |
+
get_tensor_model_parallel_world_size())
|
| 88 |
+
assert total_num_heads % tensor_model_parallel_world_size == 0
|
| 89 |
+
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
| 90 |
+
self.head_dim = self.hidden_size // total_num_heads
|
| 91 |
+
if hasattr(config, "scale_qk_dot_by_d"):
|
| 92 |
+
config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
|
| 93 |
+
self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
|
| 94 |
+
self.scale = self.head_dim**-self.attn_scale_power
|
| 95 |
+
|
| 96 |
+
self.c_attn = QKVParallelLinear(
|
| 97 |
+
self.hidden_size,
|
| 98 |
+
self.head_dim,
|
| 99 |
+
total_num_heads,
|
| 100 |
+
bias=True,
|
| 101 |
+
quant_config=quant_config,
|
| 102 |
+
)
|
| 103 |
+
self.c_proj = RowParallelLinear(
|
| 104 |
+
self.hidden_size,
|
| 105 |
+
self.hidden_size,
|
| 106 |
+
bias=True,
|
| 107 |
+
quant_config=quant_config,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 111 |
+
head_start = tp_rank * self.num_heads
|
| 112 |
+
head_end = (tp_rank + 1) * self.num_heads
|
| 113 |
+
alibi_slopes = _get_alibi_slopes(total_num_heads)
|
| 114 |
+
alibi_slopes = alibi_slopes[head_start:head_end]
|
| 115 |
+
self.attn = Attention(self.num_heads,
|
| 116 |
+
self.head_dim,
|
| 117 |
+
scale=self.scale,
|
| 118 |
+
alibi_slopes=alibi_slopes,
|
| 119 |
+
cache_config=cache_config,
|
| 120 |
+
quant_config=quant_config,
|
| 121 |
+
prefix=f"{prefix}.attn")
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
hidden_states: torch.Tensor,
|
| 126 |
+
kv_cache: torch.Tensor,
|
| 127 |
+
attn_metadata: AttentionMetadata,
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
qkv, _ = self.c_attn(hidden_states)
|
| 130 |
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
| 131 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
| 132 |
+
attn_output, _ = self.c_proj(attn_output)
|
| 133 |
+
return attn_output
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class JAISMLP(nn.Module):
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
intermediate_size: int,
|
| 141 |
+
config: JAISConfig,
|
| 142 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
hidden_size = config.hidden_size
|
| 146 |
+
self.swiglu = config.activation_function == "swiglu"
|
| 147 |
+
self.c_fc = ColumnParallelLinear(
|
| 148 |
+
hidden_size,
|
| 149 |
+
intermediate_size,
|
| 150 |
+
bias=True,
|
| 151 |
+
quant_config=quant_config,
|
| 152 |
+
)
|
| 153 |
+
self.c_fc2 = (ColumnParallelLinear(
|
| 154 |
+
hidden_size,
|
| 155 |
+
intermediate_size,
|
| 156 |
+
bias=True,
|
| 157 |
+
quant_config=quant_config,
|
| 158 |
+
) if self.swiglu else None)
|
| 159 |
+
self.c_proj = RowParallelLinear(
|
| 160 |
+
intermediate_size,
|
| 161 |
+
hidden_size,
|
| 162 |
+
bias=True,
|
| 163 |
+
quant_config=quant_config,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
self.act = SwiGLUActivation()
|
| 167 |
+
|
| 168 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
if self.swiglu:
|
| 170 |
+
hidden_states2, _ = self.c_fc2(hidden_states)
|
| 171 |
+
hidden_states, _ = self.c_fc(hidden_states)
|
| 172 |
+
hidden_states = (self.act(hidden_states, hidden_states2)
|
| 173 |
+
if self.swiglu else self.act(hidden_states))
|
| 174 |
+
hidden_states, _ = self.c_proj(hidden_states)
|
| 175 |
+
return hidden_states
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class JAISBlock(nn.Module):
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
config: JAISConfig,
|
| 183 |
+
cache_config: Optional[CacheConfig] = None,
|
| 184 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 185 |
+
prefix: str = "",
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
hidden_size = config.hidden_size
|
| 189 |
+
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
| 190 |
+
hidden_size)
|
| 191 |
+
|
| 192 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 193 |
+
self.attn = JAISAttention(config,
|
| 194 |
+
cache_config,
|
| 195 |
+
quant_config,
|
| 196 |
+
prefix=f"{prefix}.attn")
|
| 197 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 198 |
+
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
| 199 |
+
|
| 200 |
+
def forward(
|
| 201 |
+
self,
|
| 202 |
+
hidden_states: torch.Tensor,
|
| 203 |
+
kv_cache: torch.Tensor,
|
| 204 |
+
attn_metadata: AttentionMetadata,
|
| 205 |
+
) -> torch.Tensor:
|
| 206 |
+
residual = hidden_states
|
| 207 |
+
hidden_states = self.ln_1(hidden_states)
|
| 208 |
+
attn_output = self.attn(
|
| 209 |
+
hidden_states=hidden_states,
|
| 210 |
+
kv_cache=kv_cache,
|
| 211 |
+
attn_metadata=attn_metadata,
|
| 212 |
+
)
|
| 213 |
+
# residual connection
|
| 214 |
+
hidden_states = attn_output + residual
|
| 215 |
+
|
| 216 |
+
residual = hidden_states
|
| 217 |
+
hidden_states = self.ln_2(hidden_states)
|
| 218 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
| 219 |
+
# residual connection
|
| 220 |
+
hidden_states = residual + feed_forward_hidden_states
|
| 221 |
+
return hidden_states
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@support_torch_compile
|
| 225 |
+
class JAISModel(nn.Module):
|
| 226 |
+
|
| 227 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 228 |
+
super().__init__()
|
| 229 |
+
|
| 230 |
+
config = vllm_config.model_config.hf_config
|
| 231 |
+
cache_config = vllm_config.cache_config
|
| 232 |
+
quant_config = vllm_config.quant_config
|
| 233 |
+
|
| 234 |
+
self.config = config
|
| 235 |
+
assert not config.add_cross_attention
|
| 236 |
+
assert not config.scale_attn_by_inverse_layer_idx
|
| 237 |
+
assert not config.reorder_and_upcast_attn
|
| 238 |
+
self.embed_dim = config.hidden_size
|
| 239 |
+
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
| 240 |
+
self.wpe = (nn.Embedding(config.max_position_embeddings,
|
| 241 |
+
self.embed_dim)
|
| 242 |
+
if config.position_embedding_type != "alibi" else None)
|
| 243 |
+
if hasattr(config, "embeddings_scale"):
|
| 244 |
+
self.embeddings_scale = config.embeddings_scale
|
| 245 |
+
else:
|
| 246 |
+
self.embeddings_scale = config.mup_embeddings_scale
|
| 247 |
+
|
| 248 |
+
self.start_layer, self.end_layer, self.h = make_layers(
|
| 249 |
+
config.num_hidden_layers,
|
| 250 |
+
lambda prefix: JAISBlock(config=config,
|
| 251 |
+
cache_config=cache_config,
|
| 252 |
+
quant_config=quant_config,
|
| 253 |
+
prefix=prefix),
|
| 254 |
+
prefix=f"{prefix}.h",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 258 |
+
self.make_empty_intermediate_tensors = (
|
| 259 |
+
make_empty_intermediate_tensors_factory(["hidden_states"],
|
| 260 |
+
config.n_embd))
|
| 261 |
+
|
| 262 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 263 |
+
return self.wte(input_ids)
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
input_ids: torch.Tensor,
|
| 268 |
+
position_ids: torch.Tensor,
|
| 269 |
+
kv_caches: List[torch.Tensor],
|
| 270 |
+
attn_metadata: AttentionMetadata,
|
| 271 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 272 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 273 |
+
) -> Union[IntermediateTensors, torch.Tensor]:
|
| 274 |
+
if get_pp_group().is_first_rank:
|
| 275 |
+
if inputs_embeds is None:
|
| 276 |
+
inputs_embeds = self.get_input_embeddings(input_ids)
|
| 277 |
+
if self.wpe is not None:
|
| 278 |
+
position_embeds = self.wpe(position_ids)
|
| 279 |
+
hidden_states = inputs_embeds + position_embeds
|
| 280 |
+
else:
|
| 281 |
+
hidden_states = inputs_embeds
|
| 282 |
+
hidden_states *= torch.tensor(float(self.embeddings_scale),
|
| 283 |
+
dtype=hidden_states.dtype)
|
| 284 |
+
else:
|
| 285 |
+
assert intermediate_tensors is not None
|
| 286 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 287 |
+
|
| 288 |
+
for i in range(self.start_layer, self.end_layer):
|
| 289 |
+
layer = self.h[i]
|
| 290 |
+
hidden_states = layer(hidden_states,
|
| 291 |
+
kv_caches[i - self.start_layer],
|
| 292 |
+
attn_metadata)
|
| 293 |
+
|
| 294 |
+
if not get_pp_group().is_last_rank:
|
| 295 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
| 296 |
+
|
| 297 |
+
hidden_states = self.ln_f(hidden_states)
|
| 298 |
+
return hidden_states
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class JAISLMHeadModel(nn.Module, SupportsPP):
|
| 302 |
+
|
| 303 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 304 |
+
super().__init__()
|
| 305 |
+
config = vllm_config.model_config.hf_config
|
| 306 |
+
quant_config = vllm_config.quant_config
|
| 307 |
+
self.config = config
|
| 308 |
+
self.quant_config = quant_config
|
| 309 |
+
self.transformer = JAISModel(vllm_config=vllm_config,
|
| 310 |
+
prefix=maybe_prefix(
|
| 311 |
+
prefix, "transformer"))
|
| 312 |
+
if self.config.tie_word_embeddings:
|
| 313 |
+
self.lm_head = self.transformer.wte
|
| 314 |
+
else:
|
| 315 |
+
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
| 316 |
+
self.config.hidden_size)
|
| 317 |
+
if hasattr(config, "width_scale"):
|
| 318 |
+
self.output_logits_scale = config.width_scale
|
| 319 |
+
else:
|
| 320 |
+
self.output_logits_scale = (config.mup_output_alpha *
|
| 321 |
+
config.mup_width_scale)
|
| 322 |
+
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
|
| 323 |
+
scale=self.output_logits_scale)
|
| 324 |
+
self.sampler = get_sampler()
|
| 325 |
+
self.make_empty_intermediate_tensors = (
|
| 326 |
+
self.transformer.make_empty_intermediate_tensors)
|
| 327 |
+
|
| 328 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 329 |
+
return self.transformer.get_input_embeddings(input_ids)
|
| 330 |
+
|
| 331 |
+
def forward(
|
| 332 |
+
self,
|
| 333 |
+
input_ids: torch.Tensor,
|
| 334 |
+
positions: torch.Tensor,
|
| 335 |
+
kv_caches: List[torch.Tensor],
|
| 336 |
+
attn_metadata: AttentionMetadata,
|
| 337 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 338 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 339 |
+
) -> Union[IntermediateTensors, torch.Tensor]:
|
| 340 |
+
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
| 341 |
+
attn_metadata, intermediate_tensors,
|
| 342 |
+
inputs_embeds)
|
| 343 |
+
return hidden_states
|
| 344 |
+
|
| 345 |
+
def compute_logits(
|
| 346 |
+
self,
|
| 347 |
+
hidden_states: torch.Tensor,
|
| 348 |
+
sampling_metadata: SamplingMetadata,
|
| 349 |
+
) -> Optional[torch.Tensor]:
|
| 350 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
| 351 |
+
sampling_metadata)
|
| 352 |
+
return logits
|
| 353 |
+
|
| 354 |
+
def sample(
|
| 355 |
+
self,
|
| 356 |
+
logits: torch.Tensor,
|
| 357 |
+
sampling_metadata: SamplingMetadata,
|
| 358 |
+
) -> Optional[SamplerOutput]:
|
| 359 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 360 |
+
return next_tokens
|
| 361 |
+
|
| 362 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 363 |
+
torch.Tensor]]) -> Set[str]:
|
| 364 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 365 |
+
loaded_params: Set[str] = set()
|
| 366 |
+
for name, loaded_weight in weights:
|
| 367 |
+
if "lm_head.weight" in name:
|
| 368 |
+
# GPT-2 ties the weights of the embedding layer and the final
|
| 369 |
+
# linear layer.
|
| 370 |
+
continue
|
| 371 |
+
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
| 372 |
+
# Skip attention mask.
|
| 373 |
+
# NOTE: "c_attn.bias" should not be skipped.
|
| 374 |
+
continue
|
| 375 |
+
if "relative_pe" in name:
|
| 376 |
+
continue
|
| 377 |
+
if not name.startswith("transformer."):
|
| 378 |
+
name = "transformer." + name
|
| 379 |
+
|
| 380 |
+
if is_pp_missing_parameter(name, self):
|
| 381 |
+
continue
|
| 382 |
+
|
| 383 |
+
param = params_dict[name]
|
| 384 |
+
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
| 385 |
+
# Because of this, we need to transpose the weights.
|
| 386 |
+
# Note(zhuohan): the logic below might break quantized models.
|
| 387 |
+
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
| 388 |
+
if conv1d_weight_name not in name:
|
| 389 |
+
continue
|
| 390 |
+
if not name.endswith(".weight"):
|
| 391 |
+
continue
|
| 392 |
+
loaded_weight = loaded_weight.t()
|
| 393 |
+
weight_loader = getattr(param, "weight_loader",
|
| 394 |
+
default_weight_loader)
|
| 395 |
+
weight_loader(param, loaded_weight)
|
| 396 |
+
loaded_params.add(name)
|
| 397 |
+
return loaded_params
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/llava_next_video.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import cached_property
|
| 5 |
+
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
| 6 |
+
TypedDict, Union)
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
| 11 |
+
LlavaNextVideoProcessor)
|
| 12 |
+
|
| 13 |
+
from vllm.attention import AttentionMetadata
|
| 14 |
+
from vllm.config import VllmConfig
|
| 15 |
+
from vllm.model_executor.layers.activation import get_act_fn
|
| 16 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
| 17 |
+
from vllm.model_executor.models.clip import CLIPVisionModel
|
| 18 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 19 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 20 |
+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
| 21 |
+
NestedTensors)
|
| 22 |
+
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
| 23 |
+
VideoEmbeddingItems, VideoProcessorItems)
|
| 24 |
+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
| 25 |
+
BaseProcessingInfo, PromptReplacement)
|
| 26 |
+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
| 27 |
+
from vllm.sequence import IntermediateTensors
|
| 28 |
+
from vllm.utils import is_list_of
|
| 29 |
+
|
| 30 |
+
from .interfaces import SupportsMultiModal, SupportsPP
|
| 31 |
+
from .llava import init_vision_tower_for_llava
|
| 32 |
+
from .siglip import SiglipVisionModel
|
| 33 |
+
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
| 34 |
+
maybe_prefix, merge_multimodal_embeddings)
|
| 35 |
+
from .vision import get_vision_encoder_info
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LlavaNextVideoPixelInputs(TypedDict):
|
| 39 |
+
type: Literal["pixel_values_videos"]
|
| 40 |
+
data: Union[torch.Tensor, List[torch.Tensor]]
|
| 41 |
+
"""
|
| 42 |
+
Shape: `(batch_size, num_frames, num_channels, height, width)`
|
| 43 |
+
|
| 44 |
+
Note that `num_frames` may be different for each batch, in which case
|
| 45 |
+
the data is passed as a list instead of a batched tensor.
|
| 46 |
+
|
| 47 |
+
Note that it only supports one video input for one batch.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
| 52 |
+
|
| 53 |
+
def get_hf_config(self):
|
| 54 |
+
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
| 55 |
+
|
| 56 |
+
def get_vision_encoder_info(self):
|
| 57 |
+
return get_vision_encoder_info(self.get_hf_config())
|
| 58 |
+
|
| 59 |
+
def get_hf_processor(self):
|
| 60 |
+
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
|
| 61 |
+
|
| 62 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 63 |
+
return {"video": 1}
|
| 64 |
+
|
| 65 |
+
def get_mm_max_tokens_per_item(
|
| 66 |
+
self,
|
| 67 |
+
seq_len: int,
|
| 68 |
+
mm_counts: Mapping[str, int],
|
| 69 |
+
) -> Mapping[str, int]:
|
| 70 |
+
target_width, target_height = self.get_image_size_with_most_features()
|
| 71 |
+
|
| 72 |
+
max_video_tokens = self.get_num_video_tokens(
|
| 73 |
+
image_width=target_width,
|
| 74 |
+
image_height=target_height,
|
| 75 |
+
num_frames=self.get_num_frames_with_most_features(seq_len),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return {"video": max_video_tokens}
|
| 79 |
+
|
| 80 |
+
def get_image_size_with_most_features(self) -> ImageSize:
|
| 81 |
+
vision_encoder_info = self.get_vision_encoder_info()
|
| 82 |
+
width = height = vision_encoder_info.get_image_size()
|
| 83 |
+
return ImageSize(width=width, height=height)
|
| 84 |
+
|
| 85 |
+
def _get_num_frame_tokens(
|
| 86 |
+
self,
|
| 87 |
+
*,
|
| 88 |
+
image_width: int,
|
| 89 |
+
image_height: int,
|
| 90 |
+
) -> int:
|
| 91 |
+
hf_config = self.get_hf_config()
|
| 92 |
+
spatial_pool_stride = hf_config.spatial_pool_stride
|
| 93 |
+
|
| 94 |
+
vision_encoder_info = self.get_vision_encoder_info()
|
| 95 |
+
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
| 96 |
+
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
| 97 |
+
|
| 98 |
+
return pooled_grid_length * pooled_grid_length
|
| 99 |
+
|
| 100 |
+
def get_num_video_tokens(
|
| 101 |
+
self,
|
| 102 |
+
*,
|
| 103 |
+
image_width: int,
|
| 104 |
+
image_height: int,
|
| 105 |
+
num_frames: int,
|
| 106 |
+
) -> int:
|
| 107 |
+
num_frame_tokens = self._get_num_frame_tokens(
|
| 108 |
+
image_width=image_width,
|
| 109 |
+
image_height=image_height,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return num_frame_tokens * num_frames
|
| 113 |
+
|
| 114 |
+
def _get_max_video_frames(self, max_tokens: int) -> int:
|
| 115 |
+
target_width, target_height = self.get_image_size_with_most_features()
|
| 116 |
+
|
| 117 |
+
num_frames = 0
|
| 118 |
+
|
| 119 |
+
while True:
|
| 120 |
+
next_num_frames = num_frames + 1
|
| 121 |
+
next_max_tokens = self.get_num_video_tokens(
|
| 122 |
+
image_width=target_width,
|
| 123 |
+
image_height=target_height,
|
| 124 |
+
num_frames=next_num_frames,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if next_max_tokens > max_tokens:
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
num_frames = next_num_frames
|
| 131 |
+
|
| 132 |
+
return num_frames
|
| 133 |
+
|
| 134 |
+
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
| 135 |
+
mm_config = self.ctx.get_mm_config()
|
| 136 |
+
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
| 137 |
+
|
| 138 |
+
max_total_frames = self._get_max_video_frames(seq_len)
|
| 139 |
+
|
| 140 |
+
return max(max_total_frames // max(max_videos, 1), 1)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class LlavaNextVideoDummyInputsBuilder(
|
| 144 |
+
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
|
| 145 |
+
|
| 146 |
+
def get_dummy_processor_inputs(
|
| 147 |
+
self,
|
| 148 |
+
seq_len: int,
|
| 149 |
+
mm_counts: Mapping[str, int],
|
| 150 |
+
) -> ProcessorInputs:
|
| 151 |
+
num_videos = mm_counts.get("video", 0)
|
| 152 |
+
|
| 153 |
+
processor = self.info.get_hf_processor()
|
| 154 |
+
video_token = processor.video_token
|
| 155 |
+
|
| 156 |
+
target_width, target_height = \
|
| 157 |
+
self.info.get_image_size_with_most_features()
|
| 158 |
+
target_num_frames = \
|
| 159 |
+
self.info.get_num_frames_with_most_features(seq_len)
|
| 160 |
+
|
| 161 |
+
mm_data = {
|
| 162 |
+
"video":
|
| 163 |
+
self._get_dummy_videos(
|
| 164 |
+
width=target_width,
|
| 165 |
+
height=target_height,
|
| 166 |
+
num_frames=target_num_frames,
|
| 167 |
+
num_videos=num_videos,
|
| 168 |
+
)
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
return ProcessorInputs(
|
| 172 |
+
prompt_text=video_token * num_videos,
|
| 173 |
+
mm_data=mm_data,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class LlavaNextVideoMultiModalProcessor(
|
| 178 |
+
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
|
| 179 |
+
|
| 180 |
+
def _get_mm_fields_config(
|
| 181 |
+
self,
|
| 182 |
+
hf_inputs: BatchFeature,
|
| 183 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 184 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 185 |
+
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
|
| 186 |
+
|
| 187 |
+
def _get_prompt_replacements(
|
| 188 |
+
self,
|
| 189 |
+
mm_items: MultiModalDataItems,
|
| 190 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 191 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 192 |
+
) -> list[PromptReplacement]:
|
| 193 |
+
hf_config = self.info.get_hf_config()
|
| 194 |
+
video_token_id = hf_config.video_token_index
|
| 195 |
+
|
| 196 |
+
def get_replacement(item_idx: int):
|
| 197 |
+
videos = mm_items.get_items(
|
| 198 |
+
"video", (VideoEmbeddingItems, VideoProcessorItems))
|
| 199 |
+
|
| 200 |
+
if isinstance(videos, VideoEmbeddingItems):
|
| 201 |
+
num_video_tokens = videos.get_feature_size(item_idx)
|
| 202 |
+
else:
|
| 203 |
+
image_size = videos.get_frame_size(item_idx)
|
| 204 |
+
num_video_tokens = self.info.get_num_video_tokens(
|
| 205 |
+
image_width=image_size.width,
|
| 206 |
+
image_height=image_size.height,
|
| 207 |
+
num_frames=videos.get_num_frames(item_idx),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return [video_token_id] * num_video_tokens
|
| 211 |
+
|
| 212 |
+
return [
|
| 213 |
+
PromptReplacement(
|
| 214 |
+
modality="video",
|
| 215 |
+
target=[video_token_id],
|
| 216 |
+
replacement=get_replacement,
|
| 217 |
+
),
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# adopted from transformers modeling_llava_next_video.py
|
| 222 |
+
class LlavaNextVideoPooler(nn.Module):
|
| 223 |
+
|
| 224 |
+
def __init__(self, config: LlavaNextVideoConfig):
|
| 225 |
+
super().__init__()
|
| 226 |
+
|
| 227 |
+
mode = config.spatial_pool_mode
|
| 228 |
+
stride = config.spatial_pool_stride
|
| 229 |
+
image_size = config.vision_config.image_size
|
| 230 |
+
patch_size = config.vision_config.patch_size
|
| 231 |
+
self.image_size = image_size // patch_size**2
|
| 232 |
+
|
| 233 |
+
if mode == "average":
|
| 234 |
+
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
| 235 |
+
elif mode == "max":
|
| 236 |
+
self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
| 237 |
+
else:
|
| 238 |
+
# TODO: Support Conv2d pooling layer, need to load weights
|
| 239 |
+
raise ValueError(
|
| 240 |
+
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
|
| 241 |
+
|
| 242 |
+
def forward(self, image_features: torch.Tensor):
|
| 243 |
+
ori_width = int(
|
| 244 |
+
math.sqrt(image_features.shape[1] * self.image_size //
|
| 245 |
+
self.image_size))
|
| 246 |
+
ori_height = int(ori_width * self.image_size // self.image_size)
|
| 247 |
+
|
| 248 |
+
batch_size, _, dim = image_features.shape
|
| 249 |
+
image_features_spatial = image_features \
|
| 250 |
+
.view(batch_size, ori_height, ori_height, dim) \
|
| 251 |
+
.permute(0, 3, 1, 2)
|
| 252 |
+
image_features_spatial = self.pool(image_features_spatial)
|
| 253 |
+
|
| 254 |
+
return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class LlavaNextMultiModalProjector(nn.Module):
|
| 258 |
+
|
| 259 |
+
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
|
| 260 |
+
projector_hidden_act: str, multimodal_projector_bias: bool):
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
self.linear_1 = nn.Linear(vision_hidden_size,
|
| 264 |
+
text_hidden_size,
|
| 265 |
+
bias=multimodal_projector_bias)
|
| 266 |
+
self.act = get_act_fn(projector_hidden_act)
|
| 267 |
+
self.linear_2 = nn.Linear(text_hidden_size,
|
| 268 |
+
text_hidden_size,
|
| 269 |
+
bias=multimodal_projector_bias)
|
| 270 |
+
|
| 271 |
+
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
| 272 |
+
hidden_states = self.linear_1(image_features)
|
| 273 |
+
hidden_states = self.act(hidden_states)
|
| 274 |
+
hidden_states = self.linear_2(hidden_states)
|
| 275 |
+
return hidden_states
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@MULTIMODAL_REGISTRY.register_processor(
|
| 279 |
+
LlavaNextVideoMultiModalProcessor,
|
| 280 |
+
info=LlavaNextVideoProcessingInfo,
|
| 281 |
+
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
|
| 282 |
+
)
|
| 283 |
+
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
| 284 |
+
SupportsPP):
|
| 285 |
+
|
| 286 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
| 287 |
+
super().__init__()
|
| 288 |
+
config = vllm_config.model_config.hf_config
|
| 289 |
+
quant_config = vllm_config.quant_config
|
| 290 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
| 291 |
+
|
| 292 |
+
self.config = config
|
| 293 |
+
self.multimodal_config = multimodal_config
|
| 294 |
+
|
| 295 |
+
# Initialize the vision tower only up to the required feature layer
|
| 296 |
+
self.vision_tower = init_vision_tower_for_llava(
|
| 297 |
+
config,
|
| 298 |
+
quant_config,
|
| 299 |
+
require_post_norm=False,
|
| 300 |
+
prefix=maybe_prefix(prefix, "vision_tower"))
|
| 301 |
+
self.vision_resampler = LlavaNextVideoPooler(config)
|
| 302 |
+
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
| 303 |
+
vision_hidden_size=config.vision_config.hidden_size,
|
| 304 |
+
text_hidden_size=config.text_config.hidden_size,
|
| 305 |
+
projector_hidden_act=config.projector_hidden_act,
|
| 306 |
+
multimodal_projector_bias=config.multimodal_projector_bias)
|
| 307 |
+
self.language_model = init_vllm_registered_model(
|
| 308 |
+
vllm_config=vllm_config,
|
| 309 |
+
hf_config=config.text_config,
|
| 310 |
+
prefix=maybe_prefix(prefix, "language_model"),
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.make_empty_intermediate_tensors = (
|
| 314 |
+
self.language_model.model.make_empty_intermediate_tensors)
|
| 315 |
+
|
| 316 |
+
@cached_property
|
| 317 |
+
def sampler(self):
|
| 318 |
+
if hasattr(self.language_model, "sampler"):
|
| 319 |
+
return self.language_model.sampler
|
| 320 |
+
|
| 321 |
+
return get_sampler()
|
| 322 |
+
|
| 323 |
+
def _validate_video_pixel_values(
|
| 324 |
+
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
| 325 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 326 |
+
|
| 327 |
+
h = w = self.config.vision_config.image_size
|
| 328 |
+
expected_dims = (3, h, w)
|
| 329 |
+
|
| 330 |
+
def _validate_shape(d: torch.Tensor):
|
| 331 |
+
actual_dims = tuple(d.shape[2:])
|
| 332 |
+
|
| 333 |
+
if actual_dims != expected_dims:
|
| 334 |
+
expected_expr = ("num_frames", *map(str, expected_dims))
|
| 335 |
+
raise ValueError(
|
| 336 |
+
"The expected shape of pixel values in each video frame "
|
| 337 |
+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
| 338 |
+
|
| 339 |
+
for d in data:
|
| 340 |
+
_validate_shape(d)
|
| 341 |
+
|
| 342 |
+
return data
|
| 343 |
+
|
| 344 |
+
def _parse_and_validate_video_input(
|
| 345 |
+
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
|
| 346 |
+
"""
|
| 347 |
+
A legal video input should have the following dimensions:
|
| 348 |
+
{
|
| 349 |
+
"pixel_values_videos" :
|
| 350 |
+
List[b, Tensor(nb_frames, nb_channels, height, width)]
|
| 351 |
+
}
|
| 352 |
+
"""
|
| 353 |
+
pixel_values = kwargs.pop("pixel_values_videos", None)
|
| 354 |
+
|
| 355 |
+
if pixel_values is None:
|
| 356 |
+
return None
|
| 357 |
+
|
| 358 |
+
if not (is_list_of(pixel_values,
|
| 359 |
+
(torch.Tensor)) # different shape videos
|
| 360 |
+
or isinstance(pixel_values,
|
| 361 |
+
torch.Tensor)): # same shape videos
|
| 362 |
+
raise ValueError("Incorrect type of pixel values. "
|
| 363 |
+
f"Got type: {type(pixel_values)}")
|
| 364 |
+
|
| 365 |
+
return LlavaNextVideoPixelInputs(
|
| 366 |
+
type="pixel_values_videos",
|
| 367 |
+
data=pixel_values,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def _select_image_features(self, image_features: torch.Tensor, *,
|
| 371 |
+
strategy: str) -> torch.Tensor:
|
| 372 |
+
if strategy == "default":
|
| 373 |
+
return image_features[:, 1:]
|
| 374 |
+
elif strategy == "full":
|
| 375 |
+
return image_features
|
| 376 |
+
|
| 377 |
+
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
| 378 |
+
|
| 379 |
+
def _video_pixels_to_features(
|
| 380 |
+
self,
|
| 381 |
+
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
| 382 |
+
pixel_values: torch.Tensor,
|
| 383 |
+
) -> torch.Tensor:
|
| 384 |
+
|
| 385 |
+
# NOTE: we skip the step to select the vision feature layer since
|
| 386 |
+
# this is already done inside the vision tower
|
| 387 |
+
image_features = vision_tower(pixel_values)
|
| 388 |
+
image_features = self._select_image_features(
|
| 389 |
+
image_features,
|
| 390 |
+
strategy=self.config.vision_feature_select_strategy,
|
| 391 |
+
)
|
| 392 |
+
image_features = self.vision_resampler(image_features)
|
| 393 |
+
image_features = self.multi_modal_projector(image_features)
|
| 394 |
+
return image_features
|
| 395 |
+
|
| 396 |
+
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
|
| 397 |
+
assert self.vision_tower is not None
|
| 398 |
+
|
| 399 |
+
video_pixels = inputs["data"]
|
| 400 |
+
|
| 401 |
+
if isinstance(video_pixels, torch.Tensor):
|
| 402 |
+
# TODO: support multiple videos per input
|
| 403 |
+
b, num_videos, num_frames, c, h, w = video_pixels.shape
|
| 404 |
+
assert (num_videos == 1)
|
| 405 |
+
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c,
|
| 406 |
+
h, w)
|
| 407 |
+
stacked_embeddings = self._video_pixels_to_features(
|
| 408 |
+
self.vision_tower, stacked_pixels)
|
| 409 |
+
return stacked_embeddings.view(b, num_frames,
|
| 410 |
+
*stacked_embeddings.shape[1:])
|
| 411 |
+
|
| 412 |
+
elif is_list_of(video_pixels, torch.Tensor):
|
| 413 |
+
frames_per_videos = [v.shape[0] for v in video_pixels]
|
| 414 |
+
stacked_pixels = torch.cat(video_pixels, dim=0)
|
| 415 |
+
stacked_embeddings = self._video_pixels_to_features(
|
| 416 |
+
self.vision_tower, stacked_pixels)
|
| 417 |
+
return torch.split(stacked_embeddings, frames_per_videos, dim=0)
|
| 418 |
+
|
| 419 |
+
else:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
f"Unsupported type of video input {type(video_pixels)}")
|
| 422 |
+
|
| 423 |
+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
| 424 |
+
video_input = self._parse_and_validate_video_input(**kwargs)
|
| 425 |
+
if video_input is None:
|
| 426 |
+
return None
|
| 427 |
+
vision_embeddings = self._process_video_pixels(video_input)
|
| 428 |
+
return vision_embeddings
|
| 429 |
+
|
| 430 |
+
def get_input_embeddings(
|
| 431 |
+
self,
|
| 432 |
+
input_ids: torch.Tensor,
|
| 433 |
+
multimodal_embeddings: Optional[NestedTensors] = None,
|
| 434 |
+
) -> torch.Tensor:
|
| 435 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 436 |
+
if multimodal_embeddings is not None:
|
| 437 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 438 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 439 |
+
self.config.video_token_index)
|
| 440 |
+
return inputs_embeds
|
| 441 |
+
|
| 442 |
+
def forward(
|
| 443 |
+
self,
|
| 444 |
+
input_ids: torch.Tensor,
|
| 445 |
+
positions: torch.Tensor,
|
| 446 |
+
kv_caches: List[torch.Tensor],
|
| 447 |
+
attn_metadata: AttentionMetadata,
|
| 448 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 449 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 450 |
+
**kwargs: object,
|
| 451 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 452 |
+
"""Run forward pass for LlaVA-NeXT-Video.
|
| 453 |
+
Args:
|
| 454 |
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
| 455 |
+
batch.
|
| 456 |
+
pixel_values_videos: Pixels in each frames for each input videos.
|
| 457 |
+
"""
|
| 458 |
+
if intermediate_tensors is not None:
|
| 459 |
+
inputs_embeds = None
|
| 460 |
+
|
| 461 |
+
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
| 462 |
+
# condition is for v0 compatibility.
|
| 463 |
+
elif inputs_embeds is None:
|
| 464 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 465 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 466 |
+
vision_embeddings)
|
| 467 |
+
input_ids = None
|
| 468 |
+
|
| 469 |
+
hidden_states = self.language_model.model(input_ids,
|
| 470 |
+
positions,
|
| 471 |
+
kv_caches,
|
| 472 |
+
attn_metadata,
|
| 473 |
+
intermediate_tensors,
|
| 474 |
+
inputs_embeds=inputs_embeds)
|
| 475 |
+
|
| 476 |
+
return hidden_states
|
| 477 |
+
|
| 478 |
+
def compute_logits(
|
| 479 |
+
self,
|
| 480 |
+
hidden_states: torch.Tensor,
|
| 481 |
+
sampling_metadata: SamplingMetadata,
|
| 482 |
+
) -> Optional[torch.Tensor]:
|
| 483 |
+
return self.language_model.compute_logits(hidden_states,
|
| 484 |
+
sampling_metadata)
|
| 485 |
+
|
| 486 |
+
def sample(
|
| 487 |
+
self,
|
| 488 |
+
logits: torch.Tensor,
|
| 489 |
+
sampling_metadata: SamplingMetadata,
|
| 490 |
+
) -> Optional[SamplerOutput]:
|
| 491 |
+
return self.language_model.sample(logits, sampling_metadata)
|
| 492 |
+
|
| 493 |
+
def load_weights(self, weights: Iterable[Tuple[str,
|
| 494 |
+
torch.Tensor]]) -> Set[str]:
|
| 495 |
+
loader = AutoWeightsLoader(
|
| 496 |
+
self,
|
| 497 |
+
# This model doesn't support images for now
|
| 498 |
+
ignore_unexpected_prefixes=["image_newline"],
|
| 499 |
+
)
|
| 500 |
+
return loader.load_weights(weights)
|