|
|
|
|
|
|
|
|
import math |
|
|
from collections.abc import Iterable, Mapping, Sequence |
|
|
from typing import Any, Literal, Optional, TypedDict |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import BatchFeature, Gemma3Config, Gemma3Processor |
|
|
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs |
|
|
|
|
|
import vllm.envs as envs |
|
|
from vllm.config import VllmConfig |
|
|
from vllm.logger import init_logger |
|
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm |
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys |
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY |
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, |
|
|
MultiModalKwargs) |
|
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, |
|
|
MultiModalDataItems) |
|
|
|
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor, |
|
|
BaseProcessingInfo, BoundPromptUpdate, |
|
|
PlaceholderFeaturesInfo, |
|
|
PromptReplacement, PromptTargetMatch, |
|
|
PromptUpdate, PromptUpdateDetails, |
|
|
find_mm_placeholders, |
|
|
replace_token_matches) |
|
|
|
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder |
|
|
from vllm.sequence import IntermediateTensors |
|
|
|
|
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, |
|
|
SupportsMultiModal, SupportsPP) |
|
|
from .siglip import SiglipVisionModel |
|
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, |
|
|
maybe_prefix, merge_multimodal_embeddings) |
|
|
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
|
|
|
class Gemma3ImagePixelInputs(TypedDict): |
|
|
type: Literal["pixel_values"] |
|
|
pixel_values: torch.Tensor |
|
|
""" |
|
|
Shape: `(num_patches_total, num_channels, height, width)` |
|
|
|
|
|
`num_patches_total` is the total number of patches |
|
|
over each image over each prompt in the batch. |
|
|
""" |
|
|
|
|
|
num_patches: torch.Tensor |
|
|
"""Shape: `(batch_size * num_images)`""" |
|
|
|
|
|
|
|
|
Gemma3ImageInputs = Gemma3ImagePixelInputs |
|
|
|
|
|
|
|
|
class Gemma3ProcessingInfo(BaseProcessingInfo): |
|
|
|
|
|
def get_hf_config(self): |
|
|
return self.ctx.get_hf_config(Gemma3Config) |
|
|
|
|
|
def get_hf_processor(self, **kwargs: object): |
|
|
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) |
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: |
|
|
return {"image": None} |
|
|
|
|
|
def _resolve_image_kwargs( |
|
|
self, |
|
|
processor: Gemma3Processor, |
|
|
keys: set[str], |
|
|
) -> dict[str, Any]: |
|
|
image_processor = processor.image_processor |
|
|
kwargs = processor._merge_kwargs( |
|
|
Gemma3ProcessorKwargs, |
|
|
tokenizer_init_kwargs=processor.tokenizer.init_kwargs, |
|
|
) |
|
|
|
|
|
images_kwargs = kwargs["images_kwargs"] |
|
|
|
|
|
def _resolve_kw(key: str): |
|
|
val = getattr(image_processor, key) |
|
|
if val is None: |
|
|
val = images_kwargs[key] |
|
|
|
|
|
return val |
|
|
|
|
|
return {k: _resolve_kw(k) for k in keys} |
|
|
|
|
|
def get_num_crops( |
|
|
self, |
|
|
*, |
|
|
image_width: int, |
|
|
image_height: int, |
|
|
processor: Optional[Gemma3Processor], |
|
|
) -> int: |
|
|
if processor is None: |
|
|
processor = self.get_hf_processor() |
|
|
|
|
|
images_kwargs = self._resolve_image_kwargs( |
|
|
processor, { |
|
|
"do_pan_and_scan", "pan_and_scan_min_crop_size", |
|
|
"pan_and_scan_max_num_crops", |
|
|
"pan_and_scan_min_ratio_to_activate" |
|
|
}) |
|
|
|
|
|
do_pan_and_scan = images_kwargs["do_pan_and_scan"] |
|
|
pan_and_scan_min_crop_size = images_kwargs[ |
|
|
"pan_and_scan_min_crop_size"] |
|
|
pan_and_scan_max_num_crops = images_kwargs[ |
|
|
"pan_and_scan_max_num_crops"] |
|
|
pan_and_scan_min_ratio_to_activate = images_kwargs[ |
|
|
"pan_and_scan_min_ratio_to_activate"] |
|
|
|
|
|
if not do_pan_and_scan: |
|
|
return 0 |
|
|
|
|
|
if envs.VLLM_USE_V1: |
|
|
logger.warning_once( |
|
|
"`do_pan_and_scan=True` has suboptimal results on V1 " |
|
|
"because of the simplified attention pattern being used.") |
|
|
|
|
|
|
|
|
if image_width >= image_height: |
|
|
if image_width / image_height < pan_and_scan_min_ratio_to_activate: |
|
|
return 0 |
|
|
|
|
|
num_crops_w = min( |
|
|
int(math.floor(image_width / pan_and_scan_min_crop_size)), |
|
|
int(math.floor(image_width / image_height + 0.5)), |
|
|
) |
|
|
|
|
|
num_crops_w = max(2, num_crops_w) |
|
|
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) |
|
|
num_crops_h = 1 |
|
|
else: |
|
|
if image_height / image_width < pan_and_scan_min_ratio_to_activate: |
|
|
return 0 |
|
|
|
|
|
num_crops_h = min( |
|
|
int(math.floor(image_height / pan_and_scan_min_crop_size)), |
|
|
int(math.floor(image_height / image_width + 0.5)), |
|
|
) |
|
|
|
|
|
num_crops_h = max(2, num_crops_h) |
|
|
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) |
|
|
num_crops_w = 1 |
|
|
|
|
|
crop_size_w = int(math.ceil(image_width / num_crops_w)) |
|
|
crop_size_h = int(math.ceil(image_height / num_crops_h)) |
|
|
|
|
|
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: |
|
|
return 0 |
|
|
|
|
|
return num_crops_w * num_crops_h |
|
|
|
|
|
def get_image_repl( |
|
|
self, |
|
|
*, |
|
|
image_width: int, |
|
|
image_height: int, |
|
|
processor: Optional[Gemma3Processor], |
|
|
) -> PromptUpdateDetails[str]: |
|
|
if processor is None: |
|
|
processor = self.get_hf_processor() |
|
|
|
|
|
boi_token = processor.boi_token |
|
|
|
|
|
num_crops = self.get_num_crops( |
|
|
image_width=image_width, |
|
|
image_height=image_height, |
|
|
processor=processor, |
|
|
) |
|
|
|
|
|
if num_crops == 0: |
|
|
image_text = boi_token |
|
|
else: |
|
|
crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) |
|
|
image_text = ( |
|
|
f"Here is the original image {boi_token} and here are some " |
|
|
f"crops to help you see better {crops_image_tokens}") |
|
|
|
|
|
repl_full = image_text.replace(boi_token, |
|
|
processor.full_image_sequence) |
|
|
|
|
|
tokenizer = processor.tokenizer |
|
|
vocab = tokenizer.get_vocab() |
|
|
image_token_id = vocab[tokenizer.image_token] |
|
|
|
|
|
return PromptUpdateDetails.select_token_id(repl_full, image_token_id) |
|
|
|
|
|
def get_num_image_tokens( |
|
|
self, |
|
|
*, |
|
|
image_width: int, |
|
|
image_height: int, |
|
|
processor: Optional[Gemma3Processor], |
|
|
) -> int: |
|
|
if processor is None: |
|
|
processor = self.get_hf_processor() |
|
|
|
|
|
num_crops = self.get_num_crops( |
|
|
image_width=image_width, |
|
|
image_height=image_height, |
|
|
processor=processor, |
|
|
) |
|
|
image_seq_len = processor.image_seq_length |
|
|
|
|
|
return (num_crops + 1) * image_seq_len |
|
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize: |
|
|
processor = self.get_hf_processor() |
|
|
|
|
|
images_kwargs = self._resolve_image_kwargs( |
|
|
processor, {"pan_and_scan_max_num_crops"}) |
|
|
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] |
|
|
|
|
|
|
|
|
return ImageSize(height=50 * max_num_crops, width=50) |
|
|
|
|
|
|
|
|
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): |
|
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: |
|
|
num_images = mm_counts.get("image", 0) |
|
|
|
|
|
processor = self.info.get_hf_processor() |
|
|
image_token = processor.boi_token |
|
|
|
|
|
return image_token * num_images |
|
|
|
|
|
def get_dummy_mm_data( |
|
|
self, |
|
|
seq_len: int, |
|
|
mm_counts: Mapping[str, int], |
|
|
) -> MultiModalDataDict: |
|
|
num_images = mm_counts.get("image", 0) |
|
|
|
|
|
target_width, target_height = \ |
|
|
self.info.get_image_size_with_most_features() |
|
|
|
|
|
return { |
|
|
"image": |
|
|
self._get_dummy_images(width=target_width, |
|
|
height=target_height, |
|
|
num_images=num_images) |
|
|
} |
|
|
|
|
|
|
|
|
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): |
|
|
|
|
|
def _call_hf_processor( |
|
|
self, |
|
|
prompt: str, |
|
|
mm_data: Mapping[str, object], |
|
|
mm_kwargs: Mapping[str, object], |
|
|
) -> BatchFeature: |
|
|
processed_outputs = super()._call_hf_processor( |
|
|
prompt, |
|
|
mm_data, |
|
|
mm_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if (images := mm_data.get("images")) is not None: |
|
|
parsed_images = (self._get_data_parser().parse_mm_data({ |
|
|
"image": |
|
|
images |
|
|
}).get_items("image", ImageProcessorItems)) |
|
|
image_sizes = [ |
|
|
parsed_images.get_image_size(i) |
|
|
for i in range(len(parsed_images)) |
|
|
] |
|
|
hf_processor = self.info.get_hf_processor(**mm_kwargs) |
|
|
|
|
|
num_crops = [ |
|
|
self.info.get_num_crops(image_width=size.width, |
|
|
image_height=size.height, |
|
|
processor=hf_processor) |
|
|
for size in image_sizes |
|
|
] |
|
|
processed_outputs["num_crops"] = torch.tensor(num_crops) |
|
|
|
|
|
return processed_outputs |
|
|
|
|
|
def _get_mm_fields_config( |
|
|
self, |
|
|
hf_inputs: BatchFeature, |
|
|
hf_processor_mm_kwargs: Mapping[str, object], |
|
|
) -> Mapping[str, MultiModalFieldConfig]: |
|
|
num_crops = hf_inputs.get("num_crops", torch.empty(0)) |
|
|
|
|
|
return dict( |
|
|
pixel_values=MultiModalFieldConfig.flat_from_sizes( |
|
|
"image", num_crops + 1), |
|
|
num_crops=MultiModalFieldConfig.batched("image"), |
|
|
) |
|
|
|
|
|
def _get_prompt_updates( |
|
|
self, |
|
|
mm_items: MultiModalDataItems, |
|
|
hf_processor_mm_kwargs: Mapping[str, Any], |
|
|
out_mm_kwargs: MultiModalKwargs, |
|
|
) -> Sequence[PromptUpdate]: |
|
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) |
|
|
image_token = hf_processor.boi_token |
|
|
|
|
|
def get_replacement_gemma3(item_idx: int): |
|
|
images = mm_items.get_items("image", ImageProcessorItems) |
|
|
|
|
|
image_size = images.get_image_size(item_idx) |
|
|
return self.info.get_image_repl( |
|
|
image_width=image_size.width, |
|
|
image_height=image_size.height, |
|
|
processor=hf_processor, |
|
|
) |
|
|
|
|
|
return [ |
|
|
PromptReplacement( |
|
|
modality="image", |
|
|
target=image_token, |
|
|
replacement=get_replacement_gemma3, |
|
|
) |
|
|
] |
|
|
|
|
|
def _apply_token_matches( |
|
|
self, |
|
|
prompt: list[int], |
|
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], |
|
|
mm_item_counts: Mapping[str, int], |
|
|
) -> list[int]: |
|
|
token_ids = super()._apply_token_matches( |
|
|
prompt, |
|
|
mm_matches, |
|
|
mm_item_counts, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = self.info.get_tokenizer() |
|
|
vocab = tokenizer.get_vocab() |
|
|
newline_1 = vocab["\n"] |
|
|
newline_2 = vocab["\n\n"] |
|
|
newline_3 = vocab["\n\n\n"] |
|
|
newline_4 = vocab["\n\n\n\n"] |
|
|
|
|
|
token_ids = replace_token_matches( |
|
|
token_ids, |
|
|
[newline_1, newline_2], |
|
|
[newline_3], |
|
|
) |
|
|
token_ids = replace_token_matches( |
|
|
token_ids, |
|
|
[newline_2, newline_1], |
|
|
[newline_3], |
|
|
) |
|
|
token_ids = replace_token_matches( |
|
|
token_ids, |
|
|
[newline_2, newline_2], |
|
|
[newline_4], |
|
|
) |
|
|
|
|
|
return token_ids |
|
|
|
|
|
def _find_mm_placeholders( |
|
|
self, |
|
|
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], |
|
|
new_token_ids: list[int], |
|
|
mm_item_counts: Mapping[str, int], |
|
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: |
|
|
|
|
|
tokenizer = self.info.get_tokenizer() |
|
|
vocab = tokenizer.get_vocab() |
|
|
newline_1 = vocab["\n"] |
|
|
newline_2 = vocab["\n\n"] |
|
|
newline_3 = vocab["\n\n\n"] |
|
|
newline_4 = vocab["\n\n\n\n"] |
|
|
|
|
|
def get_repl_toks(tok: int) -> list[int]: |
|
|
if tok == newline_3: |
|
|
return [newline_1, newline_2] |
|
|
if tok == newline_4: |
|
|
return [newline_2, newline_2] |
|
|
|
|
|
return [tok] |
|
|
|
|
|
repl_token_ids = list[int]() |
|
|
repl_orig_idxs = list[int]() |
|
|
for orig_idx, orig_tok in enumerate(new_token_ids): |
|
|
repl_toks = get_repl_toks(orig_tok) |
|
|
repl_token_ids.extend(repl_toks) |
|
|
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) |
|
|
|
|
|
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, |
|
|
mm_item_counts) |
|
|
|
|
|
return { |
|
|
modality: [ |
|
|
PlaceholderFeaturesInfo( |
|
|
modality=p.modality, |
|
|
item_idx=p.item_idx, |
|
|
start_idx=repl_orig_idxs[p.start_idx], |
|
|
tokens=p.tokens, |
|
|
is_embed=p.is_embed, |
|
|
) for p in placeholders |
|
|
] |
|
|
for modality, placeholders in repls.items() |
|
|
} |
|
|
|
|
|
|
|
|
class Gemma3MultiModalProjector(nn.Module): |
|
|
|
|
|
def __init__(self, config: Gemma3Config): |
|
|
super().__init__() |
|
|
|
|
|
self.mm_input_projection_weight = nn.Parameter( |
|
|
torch.zeros(config.vision_config.hidden_size, |
|
|
config.text_config.hidden_size)) |
|
|
|
|
|
self.mm_soft_emb_norm = GemmaRMSNorm( |
|
|
config.vision_config.hidden_size, |
|
|
eps=config.vision_config.layer_norm_eps) |
|
|
|
|
|
self.patches_per_image = int(config.vision_config.image_size // |
|
|
config.vision_config.patch_size) |
|
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5) |
|
|
self.kernel_size = self.patches_per_image // self.tokens_per_side |
|
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, |
|
|
stride=self.kernel_size) |
|
|
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
|
batch_size, _, seq_length = vision_outputs.shape |
|
|
|
|
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2) |
|
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape( |
|
|
batch_size, seq_length, self.patches_per_image, |
|
|
self.patches_per_image) |
|
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous() |
|
|
|
|
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) |
|
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2) |
|
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) |
|
|
|
|
|
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) |
|
|
|
|
|
projected_vision_outputs = torch.matmul( |
|
|
normed_vision_outputs, self.mm_input_projection_weight) |
|
|
return projected_vision_outputs.type_as(vision_outputs) |
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, |
|
|
info=Gemma3ProcessingInfo, |
|
|
dummy_inputs=Gemma3DummyInputsBuilder) |
|
|
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, |
|
|
SupportsLoRA): |
|
|
packed_modules_mapping = { |
|
|
"qkv_proj": [ |
|
|
"q_proj", |
|
|
"k_proj", |
|
|
"v_proj", |
|
|
], |
|
|
"gate_up_proj": [ |
|
|
"gate_proj", |
|
|
"up_proj", |
|
|
], |
|
|
} |
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
super().__init__() |
|
|
config = vllm_config.model_config.hf_config |
|
|
quant_config = vllm_config.quant_config |
|
|
multimodal_config = vllm_config.model_config.multimodal_config |
|
|
self.config = config |
|
|
self.quant_config = quant_config |
|
|
self.multimodal_config = multimodal_config |
|
|
self.sliding_window = getattr(config.text_config, |
|
|
"interleaved_sliding_window", None) |
|
|
|
|
|
self.vision_tower = SiglipVisionModel(config.vision_config, |
|
|
quant_config, |
|
|
prefix=maybe_prefix( |
|
|
prefix, "vision_tower")) |
|
|
self.multi_modal_projector = Gemma3MultiModalProjector(config) |
|
|
|
|
|
self.language_model = init_vllm_registered_model( |
|
|
vllm_config=vllm_config, |
|
|
hf_config=config.text_config, |
|
|
prefix=maybe_prefix(prefix, "language_model"), |
|
|
architectures=["Gemma3ForCausalLM"], |
|
|
) |
|
|
logit_scale = getattr(config, "logit_scale", 1.0) |
|
|
self.language_model.logits_processor.scale *= logit_scale |
|
|
|
|
|
self.make_empty_intermediate_tensors = ( |
|
|
self.language_model.make_empty_intermediate_tensors) |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return next(self.parameters()).dtype |
|
|
|
|
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: |
|
|
image_size = self.config.vision_config.image_size |
|
|
expected_dims = (3, image_size, image_size) |
|
|
if data.shape[1:] != expected_dims: |
|
|
raise ValueError( |
|
|
"The expected shape of pixel values per image per batch is " |
|
|
f"{expected_dims}. You supplied {tuple(data.shape)}.") |
|
|
return data |
|
|
|
|
|
def _parse_and_validate_image_input( |
|
|
self, **kwargs: object) -> Optional[Gemma3ImageInputs]: |
|
|
pixel_values = kwargs.pop("pixel_values", None) |
|
|
num_crops = kwargs.pop("num_crops", None) |
|
|
image_embeds = kwargs.pop("image_embeds", None) |
|
|
assert image_embeds is None, "Gemma3 does not support image_embeds." |
|
|
if pixel_values is None: |
|
|
return None |
|
|
|
|
|
if not isinstance(pixel_values, (torch.Tensor, list)): |
|
|
raise ValueError("Incorrect type of pixel values. " |
|
|
f"Got type: {type(pixel_values)}") |
|
|
|
|
|
if not isinstance(num_crops, (torch.Tensor, list)): |
|
|
raise ValueError("Incorrect type of num_crops. " |
|
|
f"Got type: {type(num_crops)}") |
|
|
|
|
|
pixel_values = flatten_bn(pixel_values, concat=True) |
|
|
num_crops = flatten_bn(num_crops, concat=True) |
|
|
|
|
|
return Gemma3ImagePixelInputs( |
|
|
type="pixel_values", |
|
|
pixel_values=self._validate_pixel_values(pixel_values), |
|
|
num_patches=num_crops + 1, |
|
|
) |
|
|
|
|
|
def _image_pixels_to_features( |
|
|
self, |
|
|
vision_tower: SiglipVisionModel, |
|
|
pixel_values: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
return vision_tower(pixel_values) |
|
|
|
|
|
def _process_image_input( |
|
|
self, |
|
|
image_input: Gemma3ImageInputs, |
|
|
) -> list[torch.Tensor]: |
|
|
assert self.vision_tower is not None |
|
|
|
|
|
pixel_values = image_input["pixel_values"] |
|
|
num_patches = image_input["num_patches"] |
|
|
|
|
|
image_features = self._image_pixels_to_features( |
|
|
self.vision_tower, |
|
|
pixel_values, |
|
|
) |
|
|
image_embeds = self.multi_modal_projector(image_features) |
|
|
|
|
|
return [ |
|
|
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) |
|
|
] |
|
|
|
|
|
def get_language_model(self) -> torch.nn.Module: |
|
|
return self.language_model |
|
|
|
|
|
def get_multimodal_embeddings( |
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: |
|
|
image_input = self._parse_and_validate_image_input(**kwargs) |
|
|
if image_input is None: |
|
|
return None |
|
|
|
|
|
return self._process_image_input(image_input) |
|
|
|
|
|
def get_input_embeddings( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, |
|
|
) -> torch.Tensor: |
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids) |
|
|
if multimodal_embeddings is not None: |
|
|
inputs_embeds = merge_multimodal_embeddings( |
|
|
input_ids, |
|
|
inputs_embeds, |
|
|
multimodal_embeddings, |
|
|
self.config.image_token_index, |
|
|
) |
|
|
return inputs_embeds |
|
|
|
|
|
def forward(self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
**kwargs: object) -> IntermediateTensors: |
|
|
if intermediate_tensors is not None: |
|
|
inputs_embeds = None |
|
|
|
|
|
|
|
|
|
|
|
elif inputs_embeds is None: |
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs) |
|
|
|
|
|
inputs_embeds = self.get_input_embeddings(input_ids, |
|
|
vision_embeddings) |
|
|
if vision_embeddings is not None: |
|
|
kwargs = self.prepare_attn_masks( |
|
|
input_ids, |
|
|
positions, |
|
|
mask_dtype=self.dtype, |
|
|
**kwargs, |
|
|
) |
|
|
input_ids = None |
|
|
|
|
|
hidden_states = self.language_model.model(input_ids, |
|
|
positions, |
|
|
intermediate_tensors, |
|
|
inputs_embeds=inputs_embeds, |
|
|
**kwargs) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
def prepare_attn_masks( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
mask_dtype: torch.dtype, |
|
|
**kwargs, |
|
|
): |
|
|
kwargs["has_images"] = True |
|
|
|
|
|
|
|
|
start_idices = (positions == 0).cpu().nonzero() |
|
|
num_seqs = len(start_idices) |
|
|
seq_lens = [] |
|
|
for i in range(num_seqs): |
|
|
start_idx = start_idices[i].item() |
|
|
if i < num_seqs - 1: |
|
|
end_idx = start_idices[i + 1].item() |
|
|
else: |
|
|
end_idx = len(input_ids) |
|
|
seq_lens.append(end_idx - start_idx) |
|
|
kwargs["seq_lens"] = seq_lens |
|
|
|
|
|
global_attn_masks = [] |
|
|
local_attn_masks = [] |
|
|
start_idx = 0 |
|
|
for seq_len in seq_lens: |
|
|
end_idx = start_idx + seq_len |
|
|
input_token_ids = input_ids[start_idx:end_idx] |
|
|
start_idx = end_idx |
|
|
|
|
|
global_attn_mask = torch.empty( |
|
|
1, |
|
|
1, |
|
|
seq_len, |
|
|
seq_len, |
|
|
dtype=mask_dtype, |
|
|
device=input_ids.device, |
|
|
) |
|
|
global_attn_mask.fill_(float("-inf")) |
|
|
|
|
|
global_attn_mask = global_attn_mask.triu(diagonal=1) |
|
|
|
|
|
|
|
|
img_mask = torch.zeros_like(global_attn_mask) |
|
|
img_pos = (input_token_ids == self.config.image_token_index) |
|
|
img_mask[:, :, :, img_pos] += 1 |
|
|
img_mask[:, :, img_pos, :] += 1 |
|
|
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) |
|
|
global_attn_masks.append(global_attn_mask) |
|
|
|
|
|
if self.sliding_window is not None: |
|
|
|
|
|
local_attn_mask = torch.ones_like(global_attn_mask) |
|
|
local_attn_mask = torch.tril(local_attn_mask, |
|
|
diagonal=-self.sliding_window) |
|
|
local_attn_mask = torch.where(local_attn_mask == 0, |
|
|
global_attn_mask, float("-inf")) |
|
|
local_attn_masks.append(local_attn_mask) |
|
|
kwargs["global_attn_masks"] = global_attn_masks |
|
|
kwargs["local_attn_masks"] = local_attn_masks |
|
|
return kwargs |
|
|
|
|
|
def compute_logits( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
sampling_metadata: SamplingMetadata, |
|
|
) -> Optional[torch.Tensor]: |
|
|
return self.language_model.compute_logits(hidden_states, |
|
|
sampling_metadata) |
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, |
|
|
torch.Tensor]]) -> set[str]: |
|
|
loader = AutoWeightsLoader(self) |
|
|
return loader.load_weights(weights) |
|
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys: |
|
|
""" |
|
|
Get the module prefix in multimodal models |
|
|
""" |
|
|
return MultiModelKeys.from_string_field( |
|
|
language_model="language_model", |
|
|
connector="multi_modal_projector", |
|
|
tower_model="vision_tower") |
|
|
|