model / modeling_litevit5.py
litvit5
init
244d6df
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, SiglipVisionModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from .config import LiteVit5Config
class LiteVit5ForConditionalGeneration(PreTrainedModel):
"""
LiteVit5 model for vision-to-text generation tasks.
Combines SigLIP vision encoder with T5 seq2seq decoder for image-to-text tasks.
"""
config_class = LiteVit5Config
base_model_prefix = "litevit5"
def __init__(self, config):
super().__init__(config)
self.config = config
# Vision model (frozen)
self.vision_model = SiglipVisionModel.from_pretrained(
"google/siglip2-base-patch16-512",
dtype=torch.float16
)
self.vision_model.eval()
for param in self.vision_model.parameters():
param.requires_grad = False
# Load seq2seq decoder and lm_head from CodeT5
seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained(
"Salesforce/codet5-base",
dtype=torch.float16
)
self.seq2seq_decoder = seq2seq_model.decoder
self.seq2seq_lm_head = seq2seq_model.lm_head
self._shift_right = seq2seq_model._shift_right
# Vision processing layers
self.downsampler = nn.Conv2d(768, 768, kernel_size=2, stride=2, bias=False, dtype=torch.float16)
self.fuse = nn.Linear(768 * 2, 768).half()
self.pos_embedding = nn.Parameter(torch.zeros(1, 1024, 768, dtype=torch.float16), requires_grad=True)
self.linear_projection = nn.Linear(768, 768).half()
self.post_init()
def get_encoder(self):
"""Return the vision encoder for the model."""
return self.vision_model
def get_decoder(self):
"""Return the seq2seq decoder."""
return self.seq2seq_decoder
def _encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Encode image inputs into vision features.
Args:
pixel_values: Input images of shape [B*5, 3, 512, 512] (5 views per sample)
Returns:
Encoded vision features of shape [B, 1024, 768]
"""
# Ensure pixel_values are float16
pixel_values = pixel_values.half()
batch_size = pixel_values.size(0) // 5
scale = 5 # Number of views (4 quarter views + 1 full view)
num_patches = 32
# Get vision embeddings
with torch.no_grad():
vision_model_outputs = self.vision_model(pixel_values=pixel_values)
vision_hidden_states = vision_model_outputs.last_hidden_state # [B*5, 1024, 768]
# Reshape to separate views
vision_hidden_states = vision_hidden_states.view(batch_size, scale, *vision_hidden_states.shape[1:]) # [B, 5, 1024, 768]
# Process quarter views
quarters = vision_hidden_states[:, :4] # [B, 4, 1024, 768]
quarters = quarters.view(batch_size, 4, num_patches, num_patches, -1) # [B, 4, 32, 32, 768]
# Combine quarter views into full image
upper = torch.cat([quarters[:, 0], quarters[:, 1]], dim=2) # [B, 32, 64, 768]
lower = torch.cat([quarters[:, 2], quarters[:, 3]], dim=2) # [B, 32, 64, 768]
pooled_image = torch.cat([upper, lower], dim=1) # [B, 64, 64, 768]
pooled_image = pooled_image.permute(0, 3, 1, 2) # [B, 768, 64, 64]
# Downsample
pooled32 = self.downsampler(pooled_image) # [B, 768, 32, 32]
pooled_tok = pooled32.flatten(2).transpose(1, 2) # [B, 1024, 768]
# Full image features
full_image = vision_hidden_states[:, 4] # [B, 1024, 768]
# Fuse quarter and full views
concat = torch.cat([pooled_tok, full_image], dim=-1) # [B, 1024, 1536]
fused = self.fuse(concat) # [B, 1024, 768]
# Add positional encoding and project
fused = fused + self.pos_embedding
vision_hidden_states = self.linear_projection(fused) # [B, 1024, 768]
return vision_hidden_states
def forward(
self,
pixel_values: torch.Tensor,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> Seq2SeqLMOutput:
"""
Forward pass for the model.
Args:
pixel_values: Vision input images
input_ids: Decoder input token IDs
labels: Target token IDs for training
decoder_input_ids: Decoder input IDs (used during generation)
past_key_values: Cached key values for efficient generation
attention_mask: Attention mask for decoder inputs
Returns:
Seq2SeqLMOutput with loss, logits, and generation-related outputs
"""
# Encode images
encoder_hidden_states = self._encode_vision(pixel_values)
# Prepare decoder input IDs
if decoder_input_ids is None and input_ids is None:
decoder_input_ids = self._get_decoder_start_token_id()
decoder_input_ids = torch.full(
(pixel_values.shape[0] // 5, 1),
decoder_input_ids,
dtype=torch.long,
device=pixel_values.device
)
if decoder_input_ids is None and input_ids is not None:
decoder_input_ids = self._shift_right(input_ids)
# Pass through decoder
decoder_outputs = self.seq2seq_decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
)
sequence_output = decoder_outputs[0]
lm_logits = self.seq2seq_lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
"""Prepare inputs for generation."""
# Cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is already defined
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache,
}
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
):
"""Encode pixel values to get encoder outputs."""
# Encode images if not already done
if "encoder_outputs" not in model_kwargs:
encoder_outputs = self._encode_vision(inputs_tensor)
model_kwargs["encoder_outputs"] = (encoder_outputs,)
return model_kwargs
def generate(
self,
pixel_values: torch.Tensor,
max_length: int = 1024,
num_beams: int = 1,
temperature: float = 1.0,
do_sample: bool = False,
**kwargs
) -> torch.LongTensor:
"""
Generate text from image inputs.
Args:
pixel_values: Input images [B*5, 3, 512, 512]
max_length: Maximum generation length
num_beams: Number of beams for beam search (1 = greedy) TODO: Not implemented
temperature: Sampling temperature
do_sample: Whether to use sampling
Returns:
Generated token sequences
"""
# Encode vision inputs
encoder_hidden_states = self._encode_vision(pixel_values)
batch_size = pixel_values.shape[0] // 5
# Start with decoder_start_token_id
decoder_input_ids = torch.full(
(batch_size, 1),
self._get_decoder_start_token_id(),
dtype=torch.long,
device=pixel_values.device
)
generated_tokens = []
past_key_values = None
for step in range(max_length):
with torch.no_grad():
# Get decoder outputs
decoder_outputs = self.seq2seq_decoder(
input_ids=decoder_input_ids if past_key_values is None else decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = decoder_outputs.past_key_values
# Get logits and generate next token
hidden_states = decoder_outputs[0][:, -1:, :]
lm_logits = self.seq2seq_lm_head(hidden_states)
# Apply temperature
if temperature != 1.0:
lm_logits = lm_logits / temperature
# Get next token
if do_sample:
probs = torch.softmax(lm_logits[:, -1, :], dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True)
# Append to generated tokens
generated_tokens.append(next_token)
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
# Check for EOS
if (next_token == self.config.eos_token_id).all():
break
return decoder_input_ids
def _get_decoder_start_token_id(self) -> int:
"""Get decoder start token ID."""
return self.config.decoder_start_token_id or self.config.pad_token_id