|
|
from typing import List, Optional, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
LlavaNextForConditionalGeneration, |
|
|
LlavaNextModel, |
|
|
) |
|
|
from transformers.models.llava_next.modeling_llava_next import ( |
|
|
LlavaNextCausalLMOutputWithPast, |
|
|
LlavaNextPreTrainedModel, |
|
|
LlavaNextMultiModalProjector, |
|
|
get_anyres_image_grid_shape, |
|
|
image_size_to_num_patches, |
|
|
unpad_image, |
|
|
LlavaNextModelOutputWithPast |
|
|
) |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import TransformersKwargs, can_return_tuple, logging |
|
|
from accelerate import init_empty_weights |
|
|
from transformers import Blip2QFormerConfig, Blip2QFormerModel |
|
|
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig |
|
|
from .configuration import Granite4VisionConfig, Granite4VisionConfigNaflex |
|
|
from .downsampling import BilinearDownsampler, QFormerDownsampler, WindowQFormerDownsampler |
|
|
import math |
|
|
import numpy as np |
|
|
from fractions import Fraction |
|
|
from transformers.modeling_utils import flash_attention_forward |
|
|
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import HybridMambaAttentionDynamicCache |
|
|
IGNORE_INDEX = -100 |
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class Granite4VisionModelOutputWithPast(LlavaNextModelOutputWithPast): |
|
|
r""" |
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
|
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
|
`past_key_values` input) to speed up sequential decoding. |
|
|
image_hidden_states (`torch.FloatTensor`, *optional*): |
|
|
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
|
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. |
|
|
""" |
|
|
|
|
|
balancing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
@dataclass |
|
|
class Granite4VisionCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast): |
|
|
r""" |
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
|
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
|
`past_key_values` input) to speed up sequential decoding. |
|
|
image_hidden_states (`torch.FloatTensor`, *optional*): |
|
|
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
|
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. |
|
|
""" |
|
|
|
|
|
balancing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
class ParamWrapper(nn.Module): |
|
|
def __init__(self, param): |
|
|
super().__init__() |
|
|
self.param = param |
|
|
|
|
|
class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration): |
|
|
config_class = Granite4VisionConfig |
|
|
|
|
|
def __init__(self, config: Granite4VisionConfig): |
|
|
|
|
|
if config.pretrained_vision_tower: |
|
|
config.vision_config = AutoConfig.from_pretrained( |
|
|
config.pretrained_vision_tower, **config.vision_config.to_dict() |
|
|
) |
|
|
config.vision_config = ( |
|
|
config.vision_config.vision_config |
|
|
if hasattr(config.vision_config, "vision_config") |
|
|
else config.vision_config |
|
|
) |
|
|
if config.pretrained_language_model: |
|
|
config.text_config = AutoConfig.from_pretrained( |
|
|
config.pretrained_language_model, **config.text_config.to_dict() |
|
|
) |
|
|
|
|
|
|
|
|
LlavaNextPreTrainedModel.__init__(self, config) |
|
|
|
|
|
|
|
|
self.model = Granite4VisionModel(config) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear( |
|
|
config.text_config.hidden_size, config.text_config.vocab_size, bias=False |
|
|
) |
|
|
|
|
|
|
|
|
if config.pretrained_vision_tower: |
|
|
self._load_pretrained_vision_tower(config) |
|
|
config.pretrained_vision_tower = "" |
|
|
|
|
|
if config.pretrained_language_model: |
|
|
self._load_pretrained_language_model(config) |
|
|
config.pretrained_language_model = "" |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _load_pretrained_vision_tower(self, config): |
|
|
"""Load pretrained vision tower weights""" |
|
|
print(f"Loading vision tower from: {config.pretrained_vision_tower}") |
|
|
vision_tower = AutoModel.from_pretrained( |
|
|
config.pretrained_vision_tower, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="cpu", |
|
|
dtype=torch.bfloat16, |
|
|
|
|
|
) |
|
|
self.model.vision_tower = self.model.vision_tower.to(torch.bfloat16) |
|
|
print(self.model.vision_tower.load_state_dict(vision_tower.state_dict(), strict=False).missing_keys) |
|
|
self.model.vision_tower.config._attn_implementation = "flash_attention_2" |
|
|
|
|
|
|
|
|
self.config.vision_config = ( |
|
|
self.model.vision_tower.config.vision_config |
|
|
if hasattr(self.model.vision_tower.config, "vision_config") |
|
|
else self.model.vision_tower.config |
|
|
) |
|
|
|
|
|
def _load_pretrained_language_model(self, config): |
|
|
"""Load pretrained language model weights""" |
|
|
print(f"Loading language model from: {config.pretrained_language_model}") |
|
|
language_model = AutoModelForCausalLM.from_pretrained( |
|
|
config.pretrained_language_model, |
|
|
device_map="cpu", |
|
|
attn_implementation="flash_attention_2", |
|
|
dtype=torch.bfloat16, |
|
|
|
|
|
) |
|
|
if self.config.image_token_index >= language_model.config.vocab_size: |
|
|
language_model.resize_token_embeddings(self.config.image_token_index + 1) |
|
|
|
|
|
self.model.language_model = language_model.model |
|
|
self.lm_head = language_model.lm_head |
|
|
|
|
|
self.config.text_config = self.model.language_model.config |
|
|
|
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
image_sizes: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None, |
|
|
vision_feature_select_strategy: Optional[str] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
spatial_shapes: Optional[torch.LongTensor] = None, |
|
|
pixel_attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[tuple, Granite4VisionCausalLMOutputWithPast]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
vision_feature_layer = ( |
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
|
) |
|
|
vision_feature_select_strategy = ( |
|
|
vision_feature_select_strategy |
|
|
if vision_feature_select_strategy is not None |
|
|
else self.config.vision_feature_select_strategy |
|
|
) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids, |
|
|
pixel_values=pixel_values, |
|
|
image_sizes=image_sizes, |
|
|
vision_feature_layer=vision_feature_layer, |
|
|
vision_feature_select_strategy=vision_feature_select_strategy, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
spatial_shapes=spatial_shapes, |
|
|
pixel_attention_mask=pixel_attention_mask, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
logits = logits / self.config.text_config.logits_scaling |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = self.loss_function( |
|
|
logits, |
|
|
labels, |
|
|
vocab_size=self.config.text_config.vocab_size, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
return Granite4VisionCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
image_hidden_states=outputs.image_hidden_states, |
|
|
balancing_loss=outputs.balancing_loss |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
inputs_embeds=None, |
|
|
pixel_values=None, |
|
|
image_sizes=None, |
|
|
attention_mask=None, |
|
|
cache_position=None, |
|
|
logits_to_keep=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
logits_to_keep=logits_to_keep, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if any(class_name in self.__class__.__name__.lower() or class_name in self.language_model.__class__.__name__.lower() for class_name in ["moe"]): |
|
|
model_inputs = self.prepare_inputs_for_generation_granite_moe(**model_inputs) |
|
|
|
|
|
|
|
|
|
|
|
if cache_position[0] == 0: |
|
|
model_inputs["pixel_values"] = pixel_values |
|
|
model_inputs["image_sizes"] = image_sizes |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation_granite_moe( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
inputs_embeds=None, |
|
|
cache_position=None, |
|
|
position_ids=None, |
|
|
use_cache=True, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values[0][0] is None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not empty_past_kv: |
|
|
if ( |
|
|
inputs_embeds is not None |
|
|
or cache_position[-1] >= input_ids.shape[1] |
|
|
): |
|
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
|
input_ids = input_ids[:, cache_position] |
|
|
elif use_cache: |
|
|
past_key_values = HybridMambaAttentionDynamicCache( |
|
|
self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device |
|
|
) |
|
|
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
if not empty_past_kv: |
|
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
|
|
|
if inputs_embeds is not None and empty_past_kv: |
|
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
|
else: |
|
|
model_inputs = {"input_ids": input_ids.contiguous()} |
|
|
|
|
|
model_inputs.update( |
|
|
{ |
|
|
"position_ids": position_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": use_cache, |
|
|
"attention_mask": attention_mask, |
|
|
"cache_position": cache_position, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if key not in model_inputs: |
|
|
model_inputs[key] = value |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
|
|
|
class Granite4VisionModel(LlavaNextPreTrainedModel): |
|
|
config_class = Granite4VisionConfig |
|
|
|
|
|
def __init__(self, config: Granite4VisionConfig): |
|
|
super().__init__(config) |
|
|
self.vision_tower = AutoModel.from_config(config.vision_config) |
|
|
self.multi_modal_projector = LlavaNextMultiModalProjector(config) |
|
|
self.downsampler = None |
|
|
self.downsample_rate = config.downsample_rate |
|
|
if config.downsample_rate is not None: |
|
|
if config.downsample_method in ["interpolate", "bilinear"]: |
|
|
self.downsampler = BilinearDownsampler(config) |
|
|
elif config.downsample_method == "qformer": |
|
|
self.downsampler = QFormerDownsampler(config) |
|
|
elif config.downsample_method == "window_qformer": |
|
|
self.downsampler = WindowQFormerDownsampler(config) |
|
|
self.image_newline = None |
|
|
if config.use_image_newline_parameter: |
|
|
embed_std = 1 / math.sqrt(config.text_config.hidden_size) |
|
|
image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) |
|
|
self.model_type = config.model_type |
|
|
if self.model_type in ["gpt_vision", "granite4_vision"]: |
|
|
|
|
|
self.image_newline = ParamWrapper(image_newline) |
|
|
else: |
|
|
self.image_newline = image_newline |
|
|
self.vocab_size = config.text_config.vocab_size |
|
|
|
|
|
|
|
|
self.language_model = AutoModel.from_config(config.text_config) |
|
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.language_model.set_input_embeddings(value) |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.language_model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.language_model |
|
|
|
|
|
|
|
|
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): |
|
|
""" |
|
|
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. |
|
|
|
|
|
Args: |
|
|
image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) |
|
|
List of image feature tensor, each contains all the visual feature of all patches. |
|
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`) |
|
|
Actual image size of each images (H, W). |
|
|
vision_feature_select_strategy (`str`) |
|
|
The feature selection strategy used to select the vision feature from the vision backbone. |
|
|
image_newline (`torch.Tensor` of shape `(embed_dim)`) |
|
|
New line embedding vector. |
|
|
Returns: |
|
|
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) |
|
|
feature_lens (`list[int]`) |
|
|
token length of each image in image_features |
|
|
""" |
|
|
new_image_features = [] |
|
|
feature_lens = [] |
|
|
for image_idx, image_feature in enumerate(image_features): |
|
|
if image_feature.shape[0] > 1: |
|
|
base_image_feature = image_feature[0] |
|
|
image_feature = image_feature[1:] |
|
|
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size |
|
|
|
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape( |
|
|
image_sizes[image_idx], |
|
|
self.config.image_grid_pinpoints, |
|
|
self.config.vision_config.image_size, |
|
|
) |
|
|
if self.downsampler is not None: |
|
|
ds_rate = Fraction(self.downsample_rate) |
|
|
height = int(height * ds_rate) |
|
|
width = int(width * ds_rate) |
|
|
|
|
|
if ( |
|
|
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 |
|
|
and vision_feature_select_strategy == "default" |
|
|
): |
|
|
logger.warning_once( |
|
|
"Image feature shape does not line up with the provided patch size. " |
|
|
"You may be using the `default` vision_feature_select_strategy with a" |
|
|
" visual encoder that does not have CLS." |
|
|
) |
|
|
|
|
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
|
image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
|
|
if image_newline is not None: |
|
|
image_feature = torch.cat( |
|
|
( |
|
|
image_feature, |
|
|
image_newline[:, None, None] |
|
|
.expand(*image_feature.shape[:-1], 1) |
|
|
.to(image_feature.device, image_feature.dtype), |
|
|
), |
|
|
dim=-1, |
|
|
) |
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
|
|
else: |
|
|
image_feature = image_feature[0] |
|
|
if image_newline is not None: |
|
|
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) |
|
|
new_image_features.append(image_feature) |
|
|
feature_lens.append(image_feature.size(0)) |
|
|
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) |
|
|
return new_image_features, feature_lens |
|
|
|
|
|
def get_image_features( |
|
|
self, |
|
|
pixel_values: torch.FloatTensor, |
|
|
image_sizes: torch.Tensor, |
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None, |
|
|
vision_feature_select_strategy: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Obtains image last hidden states from the vision tower and apply multimodal projection. |
|
|
|
|
|
Args: |
|
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) |
|
|
The tensors corresponding to the input images. |
|
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`) |
|
|
Actual image size of each images (H, W). |
|
|
vision_feature_layer (`Union[int, list[int]]`, *optional*): |
|
|
The index of the layer to select the vision feature. If multiple indices are provided, |
|
|
the vision feature of the corresponding indices will be concatenated to form the |
|
|
vision features. |
|
|
vision_feature_select_strategy (`str`, *optional*): |
|
|
The feature selection strategy used to select the vision feature from the vision backbone. |
|
|
Can be one of `"default"` or `"full"` |
|
|
Returns: |
|
|
image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches |
|
|
and are of shape `(num_patches, image_length, embed_dim)`). |
|
|
""" |
|
|
vision_feature_layer = ( |
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
|
) |
|
|
vision_feature_select_strategy = ( |
|
|
vision_feature_select_strategy |
|
|
if vision_feature_select_strategy is not None |
|
|
else self.config.vision_feature_select_strategy |
|
|
) |
|
|
|
|
|
|
|
|
image_num_patches = [ |
|
|
image_size_to_num_patches( |
|
|
image_size=imsize, |
|
|
grid_pinpoints=self.config.image_grid_pinpoints, |
|
|
patch_size=self.config.vision_config.image_size, |
|
|
) |
|
|
for imsize in image_sizes |
|
|
] |
|
|
if pixel_values.dim() == 5: |
|
|
|
|
|
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] |
|
|
pixel_values = torch.cat(_pixel_values_list, dim=0) |
|
|
elif pixel_values.dim() != 4: |
|
|
|
|
|
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") |
|
|
|
|
|
image_features = self.vision_tower(pixel_values, output_hidden_states=True) |
|
|
|
|
|
|
|
|
if isinstance(vision_feature_layer, int): |
|
|
selected_image_feature = image_features.hidden_states[vision_feature_layer] |
|
|
else: |
|
|
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] |
|
|
selected_image_feature = torch.cat(hs_pool, dim=-1) |
|
|
|
|
|
if vision_feature_select_strategy == "default": |
|
|
selected_image_feature = selected_image_feature[:, 1:] |
|
|
|
|
|
image_features = self.multi_modal_projector(selected_image_feature) |
|
|
if self.downsampler is not None: |
|
|
|
|
|
|
|
|
image_features = self.downsampler(image_features) |
|
|
if image_features.shape[0] != sum(image_num_patches): |
|
|
print("about to crash on split", pixel_values.shape, image_sizes, image_num_patches) |
|
|
image_features = torch.split(image_features, image_num_patches, dim=0) |
|
|
|
|
|
|
|
|
image_newline = self.image_newline.param if self.model_type in ["gpt_vision", "granite4_vision"] else self.image_newline |
|
|
image_features, feature_lens = self.pack_image_features( |
|
|
image_features, |
|
|
image_sizes, |
|
|
vision_feature_select_strategy=vision_feature_select_strategy, |
|
|
image_newline=image_newline, |
|
|
) |
|
|
return image_features |
|
|
|
|
|
def get_image_features_naflex( |
|
|
self, |
|
|
pixel_values: torch.FloatTensor, |
|
|
spatial_shapes, |
|
|
pixel_attention_mask, |
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None, |
|
|
): |
|
|
vision_feature_layer = ( |
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_features = self.vision_tower(pixel_values, spatial_shapes=spatial_shapes, |
|
|
pixel_attention_mask=pixel_attention_mask, output_hidden_states=True) |
|
|
|
|
|
|
|
|
if isinstance(vision_feature_layer, int): |
|
|
selected_image_feature = image_features.hidden_states[vision_feature_layer] |
|
|
else: |
|
|
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] |
|
|
selected_image_feature = torch.cat(hs_pool, dim=-1) |
|
|
|
|
|
image_features = self.multi_modal_projector(selected_image_feature) |
|
|
|
|
|
assert self.downsampler is None, "downsampler not supported for naflex yet" |
|
|
assert self.image_newline is None, "newline not supported for naflex yet" |
|
|
return image_features |
|
|
|
|
|
|
|
|
|
|
|
def get_placeholder_mask( |
|
|
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor |
|
|
): |
|
|
""" |
|
|
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is |
|
|
equal to the length of multimodal features. If the lengths are different, an error is raised. |
|
|
""" |
|
|
if input_ids is None: |
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()( |
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) |
|
|
) |
|
|
special_image_mask = special_image_mask.all(-1) |
|
|
else: |
|
|
special_image_mask = input_ids == self.config.image_token_id |
|
|
|
|
|
n_image_tokens = special_image_mask.sum() |
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
if inputs_embeds[special_image_mask].numel() != image_features.numel(): |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" |
|
|
) |
|
|
return special_image_mask |
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
image_sizes: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None, |
|
|
vision_feature_select_strategy: Optional[str] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
spatial_shapes: Optional[torch.LongTensor] = None, |
|
|
pixel_attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Union[tuple, Granite4VisionModelOutputWithPast]: |
|
|
r""" |
|
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): |
|
|
The feature selection strategy used to select the vision feature from the vision backbone. |
|
|
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. |
|
|
If `"full"`, the full vision features are used. |
|
|
""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
vision_feature_layer = ( |
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
|
) |
|
|
vision_feature_select_strategy = ( |
|
|
vision_feature_select_strategy |
|
|
if vision_feature_select_strategy is not None |
|
|
else self.config.vision_feature_select_strategy |
|
|
) |
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
print(input_ids, inputs_embeds, position_ids, pixel_values, image_sizes, kwargs, ) |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
if pixel_values is not None and pixel_values.size(0) > 0: |
|
|
|
|
|
if spatial_shapes is not None and pixel_attention_mask is not None: |
|
|
|
|
|
image_features = self.get_image_features_naflex( |
|
|
pixel_values, |
|
|
spatial_shapes, |
|
|
pixel_attention_mask, |
|
|
vision_feature_layer=vision_feature_layer |
|
|
) |
|
|
else: |
|
|
image_features = self.get_image_features( |
|
|
pixel_values, |
|
|
image_sizes, |
|
|
vision_feature_layer=vision_feature_layer, |
|
|
vision_feature_select_strategy=vision_feature_select_strategy, |
|
|
) |
|
|
image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
special_image_mask = self.get_placeholder_mask( |
|
|
input_ids, inputs_embeds=inputs_embeds, image_features=image_features |
|
|
) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
|
|
elif torch.is_grad_enabled(): |
|
|
self.run_dummy_encoder_forward(inputs_embeds, vision_feature_layer, vision_feature_select_strategy) |
|
|
try: |
|
|
outputs = self.language_model( |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
print(attention_mask) |
|
|
print(position_ids) |
|
|
print(inputs_embeds) |
|
|
print(input_ids) |
|
|
print(kwargs) |
|
|
raise e |
|
|
return Granite4VisionModelOutputWithPast( |
|
|
last_hidden_state=outputs.last_hidden_state, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
image_hidden_states=image_features if pixel_values is not None else None, |
|
|
) |
|
|
|
|
|
def run_dummy_encoder_forward(self, inputs_embeds, vision_feature_layer, vision_feature_select_strategy): |
|
|
if isinstance(self.config.vision_config, Siglip2VisionConfig): |
|
|
print("no pixel values, using dummy data to get grads - naflex mode") |
|
|
dummy_pixel_values = torch.zeros((1, 256, 768), dtype=inputs_embeds.dtype, device=inputs_embeds.device) |
|
|
dummy_spatial_shapes = torch.tensor([[16, 16]], device=inputs_embeds.device) |
|
|
dummy_pixel_attention_mask = torch.ones((1,256), device=inputs_embeds.device) |
|
|
other_embeds = self.get_image_features_naflex( |
|
|
dummy_pixel_values, |
|
|
dummy_spatial_shapes, |
|
|
dummy_pixel_attention_mask, |
|
|
vision_feature_layer=vision_feature_layer |
|
|
) |
|
|
other_embeds = other_embeds[0][:1] * 0 |
|
|
inputs_embeds[0, :1] = inputs_embeds[0, :1] + other_embeds |
|
|
|
|
|
else: |
|
|
print("no pixel values, using dummy data to get grads") |
|
|
dummy_data = torch.zeros( |
|
|
(3, 3, 384, 384), dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
|
) |
|
|
dummy_sizes = torch.tensor([[768, 384]], device=inputs_embeds.device) |
|
|
other_embeds = self.get_image_features(dummy_data, dummy_sizes, |
|
|
vision_feature_layer=vision_feature_layer, |
|
|
vision_feature_select_strategy=vision_feature_select_strategy) |
|
|
other_embeds = other_embeds[0][:1] * 0 |
|
|
inputs_embeds[0, :1] = inputs_embeds[0, :1] + other_embeds |
|
|
|
|
|
class Granite4VisionForConditionalGenerationNaflex(Granite4VisionForConditionalGeneration): |
|
|
config_class = Granite4VisionConfigNaflex |