llava-7b / modeling_vlm.py
Leonardo6's picture
Update modeling_vlm.py
7cd3bbb verified
from typing import Any, override
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor, Tensor
from transformers import AutoModel, LlamaForCausalLM, LlamaModel, PreTrainedModel
from .configuration_vlm import VLMConfig
from .connectors import Connector, connector_map
class VLM(LlamaModel):
config_class = VLMConfig
@override
def __init__(self, config):
super().__init__(config)
self.vision_model = self._build_vision_model(config)
self.connector = self._build_connector(config)
def _build_vision_model(self, config: Any) -> PreTrainedModel:
vision_config = config.vision_config
visual_encoder: PreTrainedModel = AutoModel.from_pretrained(
vision_config.hf_name,
trust_remote_code=True,
)
if getattr(visual_encoder, "vision_model", None):
visual_encoder = visual_encoder.vision_model # pyright: ignore
return visual_encoder
def _build_connector(self, config: Any) -> Connector:
connector_config = config.connector_config
connector_class = connector_map.get(connector_config.type)
if not connector_class:
raise ValueError(f"Unsupported connector type: {connector_config.type}")
return connector_class(
connector_config,
self.config.vision_config.hidden_size,
config.hidden_size,
)
class VLMForCausalLM(LlamaForCausalLM):
config_class = VLMConfig
@override
def __init__(self, config):
super().__init__(config)
self.model = VLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
@override
def forward(
self: Any,
input_ids: Tensor | None = None,
inputs_embeds: Tensor | None = None,
attention_mask: Tensor | None = None,
position_ids: LongTensor | None = None,
past_key_values: list[FloatTensor] | None = None,
labels: LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
images: FloatTensor | None = None,
image_sizes: list[list[int]] | None = None,
return_dict: bool | None = None,
) -> torch.Tensor:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
)
)
return super().forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@override
def generate(
self: Any,
inputs: Tensor | None = None,
images: FloatTensor | None = None,
image_sizes: list[list[int]] | None = None,
**kwargs: Any,
) -> Any:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(_, position_ids, attention_mask, _, inputs_embeds, _) = (
self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
)
)
else:
inputs_embeds = self.get_input_embeddings()(inputs)
return super().generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
@override
def prepare_inputs_for_generation(
self: Any,
input_ids: Tensor,
past_key_values: list[FloatTensor] | None = None,
inputs_embeds: Tensor | None = None,
**kwargs: Any,
):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
inputs.pop("cache_position")
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
def encode_images(self: Any, images: list[Tensor] | Tensor) -> list[Tensor] | Tensor:
if type(images) is list:
image_features: list[Tensor] | Tensor = []
for image in images:
outputs = self.model.vision_model(
image.unsqueeze(0),
output_hidden_states=True,
)
hidden_states: Tensor = outputs.hidden_states[self.output_layer].to(image.dtype)
if not self.config.vision_config.use_cls_token:
image_features.append(hidden_states[:, 1:])
else:
image_features.append(hidden_states)
else:
outputs = self.model.vision_model(
images,
output_hidden_states=True,
)
hidden_states = outputs.hidden_states[self.config.vision_config.output_layer].to(
images.dtype
)
if not self.config.vision_config.use_cls_token:
image_features = hidden_states[:, 1:]
else:
image_features = hidden_states
image_features = self.model.connector(image_features)
return image_features
def unpad_image(self: Any, tensor: Tensor, original_size: tuple[int, int]) -> Tensor:
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of PIL image (width, height).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
def prepare_inputs_labels_for_multimodal(
self: Any,
input_ids: Tensor | None = None,
position_ids: LongTensor | None = None,
attention_mask: Tensor | None = None,
past_key_values: list[FloatTensor] | None = None,
labels: LongTensor | None = None,
images: FloatTensor | None = None,
) -> tuple[
Tensor | None,
LongTensor | None,
Tensor | None,
list[FloatTensor] | None,
Tensor | None,
LongTensor | None,
]:
vision_model = self.model.vision_model
if vision_model is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if isinstance(images, list) or images.ndim == 5:
if isinstance(images, list):
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features: tuple[Tensor, ...] = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images)
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0,
input_ids.shape[1],
dtype=torch.long,
device=input_ids.device,
)
if labels is None:
labels = torch.full_like(input_ids, self.config.ignore_index)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask, strict=False)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask, strict=False)
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == self.config.image_token_index).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_input_embeddings()(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1]
+ torch.where(cur_input_ids == self.config.image_token_index)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(
cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_input_embeddings()(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
self.config.ignore_index,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = self.config.max_seq_length
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
self.config.ignore_index,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
)
position_ids = torch.zeros(
(batch_size, max_len),
dtype=position_ids.dtype,
device=position_ids.device,
)
for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels, strict=False)
):
cur_len = cur_new_embed.shape[0]
if self.config.padding_side == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0,
cur_len,
dtype=position_ids.dtype,
device=position_ids.device,
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0,
cur_len,
dtype=position_ids.dtype,
device=position_ids.device,
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
AutoModel.register(VLMConfig, VLMForCausalLM)