|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, SiglipVisionModel |
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
from .config import LiteVit5Config |
|
|
|
|
|
|
|
|
class LiteVit5ForConditionalGeneration(PreTrainedModel): |
|
|
""" |
|
|
LiteVit5 model for vision-to-text generation tasks. |
|
|
Combines SigLIP vision encoder with T5 seq2seq decoder for image-to-text tasks. |
|
|
""" |
|
|
|
|
|
config_class = LiteVit5Config |
|
|
base_model_prefix = "litevit5" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.vision_model = SiglipVisionModel.from_pretrained( |
|
|
"google/siglip2-base-patch16-512", |
|
|
dtype=torch.float16 |
|
|
) |
|
|
self.vision_model.eval() |
|
|
for param in self.vision_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
"Salesforce/codet5-base", |
|
|
dtype=torch.float16 |
|
|
) |
|
|
self.seq2seq_decoder = seq2seq_model.decoder |
|
|
self.seq2seq_lm_head = seq2seq_model.lm_head |
|
|
self._shift_right = seq2seq_model._shift_right |
|
|
|
|
|
|
|
|
self.downsampler = nn.Conv2d(768, 768, kernel_size=2, stride=2, bias=False, dtype=torch.float16) |
|
|
self.fuse = nn.Linear(768 * 2, 768).half() |
|
|
self.pos_embedding = nn.Parameter(torch.zeros(1, 1024, 768, dtype=torch.float16), requires_grad=True) |
|
|
self.linear_projection = nn.Linear(768, 768).half() |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_encoder(self): |
|
|
"""Return the vision encoder for the model.""" |
|
|
return self.vision_model |
|
|
|
|
|
def get_decoder(self): |
|
|
"""Return the seq2seq decoder.""" |
|
|
return self.seq2seq_decoder |
|
|
|
|
|
def _encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode image inputs into vision features. |
|
|
|
|
|
Args: |
|
|
pixel_values: Input images of shape [B*5, 3, 512, 512] (5 views per sample) |
|
|
|
|
|
Returns: |
|
|
Encoded vision features of shape [B, 1024, 768] |
|
|
""" |
|
|
|
|
|
pixel_values = pixel_values.half() |
|
|
|
|
|
batch_size = pixel_values.size(0) // 5 |
|
|
scale = 5 |
|
|
num_patches = 32 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
vision_model_outputs = self.vision_model(pixel_values=pixel_values) |
|
|
vision_hidden_states = vision_model_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
vision_hidden_states = vision_hidden_states.view(batch_size, scale, *vision_hidden_states.shape[1:]) |
|
|
|
|
|
|
|
|
quarters = vision_hidden_states[:, :4] |
|
|
quarters = quarters.view(batch_size, 4, num_patches, num_patches, -1) |
|
|
|
|
|
|
|
|
upper = torch.cat([quarters[:, 0], quarters[:, 1]], dim=2) |
|
|
lower = torch.cat([quarters[:, 2], quarters[:, 3]], dim=2) |
|
|
pooled_image = torch.cat([upper, lower], dim=1) |
|
|
pooled_image = pooled_image.permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
pooled32 = self.downsampler(pooled_image) |
|
|
pooled_tok = pooled32.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
full_image = vision_hidden_states[:, 4] |
|
|
|
|
|
|
|
|
concat = torch.cat([pooled_tok, full_image], dim=-1) |
|
|
fused = self.fuse(concat) |
|
|
|
|
|
|
|
|
fused = fused + self.pos_embedding |
|
|
vision_hidden_states = self.linear_projection(fused) |
|
|
|
|
|
return vision_hidden_states |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: torch.Tensor, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Tuple] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs |
|
|
) -> Seq2SeqLMOutput: |
|
|
""" |
|
|
Forward pass for the model. |
|
|
|
|
|
Args: |
|
|
pixel_values: Vision input images |
|
|
input_ids: Decoder input token IDs |
|
|
labels: Target token IDs for training |
|
|
decoder_input_ids: Decoder input IDs (used during generation) |
|
|
past_key_values: Cached key values for efficient generation |
|
|
attention_mask: Attention mask for decoder inputs |
|
|
|
|
|
Returns: |
|
|
Seq2SeqLMOutput with loss, logits, and generation-related outputs |
|
|
""" |
|
|
|
|
|
encoder_hidden_states = self._encode_vision(pixel_values) |
|
|
|
|
|
|
|
|
if decoder_input_ids is None and input_ids is None: |
|
|
decoder_input_ids = self._get_decoder_start_token_id() |
|
|
decoder_input_ids = torch.full( |
|
|
(pixel_values.shape[0] // 5, 1), |
|
|
decoder_input_ids, |
|
|
dtype=torch.long, |
|
|
device=pixel_values.device |
|
|
) |
|
|
|
|
|
if decoder_input_ids is None and input_ids is not None: |
|
|
decoder_input_ids = self._shift_right(input_ids) |
|
|
|
|
|
|
|
|
decoder_outputs = self.seq2seq_decoder( |
|
|
input_ids=decoder_input_ids, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
sequence_output = decoder_outputs[0] |
|
|
lm_logits = self.seq2seq_lm_head(sequence_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
labels = labels.to(lm_logits.device) |
|
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
|
|
|
|
|
return Seq2SeqLMOutput( |
|
|
loss=loss, |
|
|
logits=lm_logits, |
|
|
past_key_values=decoder_outputs.past_key_values, |
|
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
|
decoder_attentions=decoder_outputs.attentions, |
|
|
cross_attentions=decoder_outputs.cross_attentions, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
decoder_input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
use_cache=None, |
|
|
encoder_outputs=None, |
|
|
**kwargs |
|
|
): |
|
|
"""Prepare inputs for generation.""" |
|
|
|
|
|
if past_key_values is not None: |
|
|
decoder_input_ids = decoder_input_ids[:, -1:] |
|
|
|
|
|
return { |
|
|
"input_ids": None, |
|
|
"encoder_outputs": encoder_outputs, |
|
|
"past_key_values": past_key_values, |
|
|
"decoder_input_ids": decoder_input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"use_cache": use_cache, |
|
|
} |
|
|
|
|
|
def _prepare_encoder_decoder_kwargs_for_generation( |
|
|
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None |
|
|
): |
|
|
"""Encode pixel values to get encoder outputs.""" |
|
|
|
|
|
if "encoder_outputs" not in model_kwargs: |
|
|
encoder_outputs = self._encode_vision(inputs_tensor) |
|
|
model_kwargs["encoder_outputs"] = (encoder_outputs,) |
|
|
|
|
|
return model_kwargs |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
pixel_values: torch.Tensor, |
|
|
max_length: int = 1024, |
|
|
num_beams: int = 1, |
|
|
temperature: float = 1.0, |
|
|
do_sample: bool = False, |
|
|
**kwargs |
|
|
) -> torch.LongTensor: |
|
|
""" |
|
|
Generate text from image inputs. |
|
|
|
|
|
Args: |
|
|
pixel_values: Input images [B*5, 3, 512, 512] |
|
|
max_length: Maximum generation length |
|
|
num_beams: Number of beams for beam search (1 = greedy) TODO: Not implemented |
|
|
temperature: Sampling temperature |
|
|
do_sample: Whether to use sampling |
|
|
|
|
|
Returns: |
|
|
Generated token sequences |
|
|
""" |
|
|
|
|
|
encoder_hidden_states = self._encode_vision(pixel_values) |
|
|
batch_size = pixel_values.shape[0] // 5 |
|
|
|
|
|
|
|
|
decoder_input_ids = torch.full( |
|
|
(batch_size, 1), |
|
|
self._get_decoder_start_token_id(), |
|
|
dtype=torch.long, |
|
|
device=pixel_values.device |
|
|
) |
|
|
|
|
|
generated_tokens = [] |
|
|
past_key_values = None |
|
|
|
|
|
for step in range(max_length): |
|
|
with torch.no_grad(): |
|
|
|
|
|
decoder_outputs = self.seq2seq_decoder( |
|
|
input_ids=decoder_input_ids if past_key_values is None else decoder_input_ids[:, -1:], |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
past_key_values = decoder_outputs.past_key_values |
|
|
|
|
|
|
|
|
hidden_states = decoder_outputs[0][:, -1:, :] |
|
|
lm_logits = self.seq2seq_lm_head(hidden_states) |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
lm_logits = lm_logits / temperature |
|
|
|
|
|
|
|
|
if do_sample: |
|
|
probs = torch.softmax(lm_logits[:, -1, :], dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1) |
|
|
|
|
|
|
|
|
if (next_token == self.config.eos_token_id).all(): |
|
|
break |
|
|
|
|
|
return decoder_input_ids |
|
|
|
|
|
def _get_decoder_start_token_id(self) -> int: |
|
|
"""Get decoder start token ID.""" |
|
|
return self.config.decoder_start_token_id or self.config.pad_token_id |
|
|
|