File size: 35,271 Bytes
e9f0a60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 |
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
LlavaNextForConditionalGeneration,
LlavaNextModel,
)
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextPreTrainedModel,
LlavaNextMultiModalProjector,
get_anyres_image_grid_shape,
image_size_to_num_patches,
unpad_image,
LlavaNextModelOutputWithPast
)
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple, logging
from accelerate import init_empty_weights
from transformers import Blip2QFormerConfig, Blip2QFormerModel
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
from .configuration import Granite4VisionConfig, Granite4VisionConfigNaflex
from .downsampling import BilinearDownsampler, QFormerDownsampler, WindowQFormerDownsampler
import math
import numpy as np
from fractions import Fraction
from transformers.modeling_utils import flash_attention_forward
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import HybridMambaAttentionDynamicCache
IGNORE_INDEX = -100
logger = logging.get_logger(__name__)
@dataclass
class Granite4VisionModelOutputWithPast(LlavaNextModelOutputWithPast):
r"""
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
balancing_loss: Optional[torch.FloatTensor] = None
@dataclass
class Granite4VisionCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
r"""
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
balancing_loss: Optional[torch.FloatTensor] = None
class ParamWrapper(nn.Module):
def __init__(self, param):
super().__init__()
self.param = param
class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
config_class = Granite4VisionConfig
def __init__(self, config: Granite4VisionConfig):
# Update config with pretrained models if specified
if config.pretrained_vision_tower:
config.vision_config = AutoConfig.from_pretrained(
config.pretrained_vision_tower, **config.vision_config.to_dict()
)
config.vision_config = (
config.vision_config.vision_config
if hasattr(config.vision_config, "vision_config")
else config.vision_config
)
if config.pretrained_language_model:
config.text_config = AutoConfig.from_pretrained(
config.pretrained_language_model, **config.text_config.to_dict()
)
# Initialize parent
LlavaNextPreTrainedModel.__init__(self, config)
# Create custom model instance
self.model = Granite4VisionModel(config)
# Create lm_head
self.lm_head = nn.Linear(
config.text_config.hidden_size, config.text_config.vocab_size, bias=False
)
# Load pretrained components if specified
if config.pretrained_vision_tower:
self._load_pretrained_vision_tower(config)
config.pretrained_vision_tower = ""
if config.pretrained_language_model:
self._load_pretrained_language_model(config)
config.pretrained_language_model = ""
self.post_init()
def _load_pretrained_vision_tower(self, config):
"""Load pretrained vision tower weights"""
print(f"Loading vision tower from: {config.pretrained_vision_tower}")
vision_tower = AutoModel.from_pretrained(
config.pretrained_vision_tower,
attn_implementation="flash_attention_2",
device_map="cpu",
dtype=torch.bfloat16,
)
self.model.vision_tower = self.model.vision_tower.to(torch.bfloat16)
print(self.model.vision_tower.load_state_dict(vision_tower.state_dict(), strict=False).missing_keys)
self.model.vision_tower.config._attn_implementation = "flash_attention_2"
# todo: (Avihu) would have done this but afraid - maybe something I'm missing
# self.model.vision_tower = vision_tower
self.config.vision_config = (
self.model.vision_tower.config.vision_config
if hasattr(self.model.vision_tower.config, "vision_config")
else self.model.vision_tower.config
)
def _load_pretrained_language_model(self, config):
"""Load pretrained language model weights"""
print(f"Loading language model from: {config.pretrained_language_model}")
language_model = AutoModelForCausalLM.from_pretrained(
config.pretrained_language_model,
device_map="cpu",
attn_implementation="flash_attention_2",
dtype=torch.bfloat16,
# use_kernels=True,
)
if self.config.image_token_index >= language_model.config.vocab_size:
language_model.resize_token_embeddings(self.config.image_token_index + 1)
# load weights in quantized mode with kernels
self.model.language_model = language_model.model
self.lm_head = language_model.lm_head
# Load weights into the language model inside self.model
self.config.text_config = self.model.language_model.config
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
spatial_shapes: Optional[torch.LongTensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Granite4VisionCausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
outputs = self.model(
input_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
spatial_shapes=spatial_shapes,
pixel_attention_mask=pixel_attention_mask,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
logits = logits / self.config.text_config.logits_scaling
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
# Avihu: removed the .float(), didn't make a huge difference and requires more memory
# logits = logits.float()
# Flatten the tokens
loss = self.loss_function(
logits,
labels,
vocab_size=self.config.text_config.vocab_size,
**kwargs,
)
return Granite4VisionCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
balancing_loss=outputs.balancing_loss
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
# Check if the model or its langauge model are moe type - requires special cache handling
if any(class_name in self.__class__.__name__.lower() or class_name in self.language_model.__class__.__name__.lower() for class_name in ["moe"]):
model_inputs = self.prepare_inputs_for_generation_granite_moe(**model_inputs)
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes
return model_inputs
# Avihu: would have used the GraniteMoeSharedForCausalLM method, but we don't store this object anymore (split the model / lm head)
def prepare_inputs_for_generation_granite_moe(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
# Note: (Avihu) in transformers v4, the past_key_values is already an empty DynamicCache object. Testing that too
empty_past_kv = past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values[0][0] is None)
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if not empty_past_kv:
if (
inputs_embeds is not None # Exception 1
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
elif use_cache:
past_key_values = HybridMambaAttentionDynamicCache(
self.model.language_model.config, input_ids.shape[0], self.dtype, device=self.device
)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if not empty_past_kv:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and empty_past_kv:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
)
# Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
return model_inputs
class Granite4VisionModel(LlavaNextPreTrainedModel):
config_class = Granite4VisionConfig
def __init__(self, config: Granite4VisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
self.downsampler = None
self.downsample_rate = config.downsample_rate
if config.downsample_rate is not None:
if config.downsample_method in ["interpolate", "bilinear"]:
self.downsampler = BilinearDownsampler(config)
elif config.downsample_method == "qformer":
self.downsampler = QFormerDownsampler(config)
elif config.downsample_method == "window_qformer":
self.downsampler = WindowQFormerDownsampler(config)
self.image_newline = None
if config.use_image_newline_parameter:
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.model_type = config.model_type
if self.model_type in ["gpt_vision", "granite4_vision"]:
# this hack allows to do lora training from scratch, so image_newline would be in modules_to_keep
self.image_newline = ParamWrapper(image_newline)
else:
self.image_newline = image_newline
self.vocab_size = config.text_config.vocab_size
# with init_empty_weights(): # Avihu: hack to load the model faster
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`list[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if self.downsampler is not None:
ds_rate = Fraction(self.downsample_rate)
height = int(height * ds_rate)
width = int(width * ds_rate)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
return new_image_features, feature_lens
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, list[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
image_features = self.multi_modal_projector(selected_image_feature)
if self.downsampler is not None:
# training this with peft+deepspeed had this issue+fix:
# https://github.com/deepspeedai/DeepSpeed/issues/7203#issuecomment-3007490737
image_features = self.downsampler(image_features)
if image_features.shape[0] != sum(image_num_patches):
print("about to crash on split", pixel_values.shape, image_sizes, image_num_patches)
image_features = torch.split(image_features, image_num_patches, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_newline = self.image_newline.param if self.model_type in ["gpt_vision", "granite4_vision"] else self.image_newline
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=image_newline,
)
return image_features
def get_image_features_naflex(
self,
pixel_values: torch.FloatTensor,
spatial_shapes,
pixel_attention_mask,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
):
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
# todo: (Avihu): Hack! siglip2 naflex now supports pad-free
# todo: This was done by manually editing the siglip modeling code
# todo: Until we have a better solution, to run this, consult with me
# Note! siglip gets a stacked tensor
image_features = self.vision_tower(pixel_values, spatial_shapes=spatial_shapes,
pixel_attention_mask=pixel_attention_mask, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature)
# Note (Avihu): downsampling would and newline is more complex at the moment
assert self.downsampler is None, "downsampler not supported for naflex yet"
assert self.image_newline is None, "newline not supported for naflex yet"
return image_features
def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
)
return special_image_mask
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, Granite4VisionModelOutputWithPast]:
r"""
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
print(input_ids, inputs_embeds, position_ids, pixel_values, image_sizes, kwargs, )
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
if spatial_shapes is not None and pixel_attention_mask is not None:
# naflex setup
image_features = self.get_image_features_naflex(
pixel_values,
spatial_shapes,
pixel_attention_mask,
vision_feature_layer=vision_feature_layer
)
else:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
elif torch.is_grad_enabled():
self.run_dummy_encoder_forward(inputs_embeds, vision_feature_layer, vision_feature_select_strategy)
try:
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
except Exception as e:
print(e)
print(attention_mask)
print(position_ids)
print(inputs_embeds)
print(input_ids)
print(kwargs)
raise e
return Granite4VisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def run_dummy_encoder_forward(self, inputs_embeds, vision_feature_layer, vision_feature_select_strategy):
if isinstance(self.config.vision_config, Siglip2VisionConfig):
print("no pixel values, using dummy data to get grads - naflex mode")
dummy_pixel_values = torch.zeros((1, 256, 768), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
dummy_spatial_shapes = torch.tensor([[16, 16]], device=inputs_embeds.device)
dummy_pixel_attention_mask = torch.ones((1,256), device=inputs_embeds.device)
other_embeds = self.get_image_features_naflex(
dummy_pixel_values,
dummy_spatial_shapes,
dummy_pixel_attention_mask,
vision_feature_layer=vision_feature_layer
)
other_embeds = other_embeds[0][:1] * 0 # adding zeros tensor
inputs_embeds[0, :1] = inputs_embeds[0, :1] + other_embeds
else:
print("no pixel values, using dummy data to get grads")
dummy_data = torch.zeros(
(3, 3, 384, 384), dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
dummy_sizes = torch.tensor([[768, 384]], device=inputs_embeds.device)
other_embeds = self.get_image_features(dummy_data, dummy_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy)
other_embeds = other_embeds[0][:1] * 0 # adding zeros tensor
inputs_embeds[0, :1] = inputs_embeds[0, :1] + other_embeds
class Granite4VisionForConditionalGenerationNaflex(Granite4VisionForConditionalGeneration):
config_class = Granite4VisionConfigNaflex |