import math from typing import Optional, Tuple import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, SiglipVisionModel from transformers.modeling_outputs import Seq2SeqLMOutput from .config import LiteVit5Config class LiteVit5ForConditionalGeneration(PreTrainedModel): """ LiteVit5 model for vision-to-text generation tasks. Combines SigLIP vision encoder with T5 seq2seq decoder for image-to-text tasks. """ config_class = LiteVit5Config base_model_prefix = "litevit5" def __init__(self, config): super().__init__(config) self.config = config # Vision model (frozen) self.vision_model = SiglipVisionModel.from_pretrained( "google/siglip2-base-patch16-512", dtype=torch.float16 ) self.vision_model.eval() for param in self.vision_model.parameters(): param.requires_grad = False # Load seq2seq decoder and lm_head from CodeT5 seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained( "Salesforce/codet5-base", dtype=torch.float16 ) self.seq2seq_decoder = seq2seq_model.decoder self.seq2seq_lm_head = seq2seq_model.lm_head self._shift_right = seq2seq_model._shift_right # Vision processing layers self.downsampler = nn.Conv2d(768, 768, kernel_size=2, stride=2, bias=False, dtype=torch.float16) self.fuse = nn.Linear(768 * 2, 768).half() self.pos_embedding = nn.Parameter(torch.zeros(1, 1024, 768, dtype=torch.float16), requires_grad=True) self.linear_projection = nn.Linear(768, 768).half() self.post_init() def get_encoder(self): """Return the vision encoder for the model.""" return self.vision_model def get_decoder(self): """Return the seq2seq decoder.""" return self.seq2seq_decoder def _encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Encode image inputs into vision features. Args: pixel_values: Input images of shape [B*5, 3, 512, 512] (5 views per sample) Returns: Encoded vision features of shape [B, 1024, 768] """ # Ensure pixel_values are float16 pixel_values = pixel_values.half() batch_size = pixel_values.size(0) // 5 scale = 5 # Number of views (4 quarter views + 1 full view) num_patches = 32 # Get vision embeddings with torch.no_grad(): vision_model_outputs = self.vision_model(pixel_values=pixel_values) vision_hidden_states = vision_model_outputs.last_hidden_state # [B*5, 1024, 768] # Reshape to separate views vision_hidden_states = vision_hidden_states.view(batch_size, scale, *vision_hidden_states.shape[1:]) # [B, 5, 1024, 768] # Process quarter views quarters = vision_hidden_states[:, :4] # [B, 4, 1024, 768] quarters = quarters.view(batch_size, 4, num_patches, num_patches, -1) # [B, 4, 32, 32, 768] # Combine quarter views into full image upper = torch.cat([quarters[:, 0], quarters[:, 1]], dim=2) # [B, 32, 64, 768] lower = torch.cat([quarters[:, 2], quarters[:, 3]], dim=2) # [B, 32, 64, 768] pooled_image = torch.cat([upper, lower], dim=1) # [B, 64, 64, 768] pooled_image = pooled_image.permute(0, 3, 1, 2) # [B, 768, 64, 64] # Downsample pooled32 = self.downsampler(pooled_image) # [B, 768, 32, 32] pooled_tok = pooled32.flatten(2).transpose(1, 2) # [B, 1024, 768] # Full image features full_image = vision_hidden_states[:, 4] # [B, 1024, 768] # Fuse quarter and full views concat = torch.cat([pooled_tok, full_image], dim=-1) # [B, 1024, 1536] fused = self.fuse(concat) # [B, 1024, 768] # Add positional encoding and project fused = fused + self.pos_embedding vision_hidden_states = self.linear_projection(fused) # [B, 1024, 768] return vision_hidden_states def forward( self, pixel_values: torch.Tensor, input_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs ) -> Seq2SeqLMOutput: """ Forward pass for the model. Args: pixel_values: Vision input images input_ids: Decoder input token IDs labels: Target token IDs for training decoder_input_ids: Decoder input IDs (used during generation) past_key_values: Cached key values for efficient generation attention_mask: Attention mask for decoder inputs Returns: Seq2SeqLMOutput with loss, logits, and generation-related outputs """ # Encode images encoder_hidden_states = self._encode_vision(pixel_values) # Prepare decoder input IDs if decoder_input_ids is None and input_ids is None: decoder_input_ids = self._get_decoder_start_token_id() decoder_input_ids = torch.full( (pixel_values.shape[0] // 5, 1), decoder_input_ids, dtype=torch.long, device=pixel_values.device ) if decoder_input_ids is None and input_ids is not None: decoder_input_ids = self._shift_right(input_ids) # Pass through decoder decoder_outputs = self.seq2seq_decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, ) sequence_output = decoder_outputs[0] lm_logits = self.seq2seq_lm_head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) labels = labels.to(lm_logits.device) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) return Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): """Prepare inputs for generation.""" # Cut decoder_input_ids if past is used if past_key_values is not None: decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, # encoder_outputs is already defined "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, } def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ): """Encode pixel values to get encoder outputs.""" # Encode images if not already done if "encoder_outputs" not in model_kwargs: encoder_outputs = self._encode_vision(inputs_tensor) model_kwargs["encoder_outputs"] = (encoder_outputs,) return model_kwargs def generate( self, pixel_values: torch.Tensor, max_length: int = 1024, num_beams: int = 1, temperature: float = 1.0, do_sample: bool = False, **kwargs ) -> torch.LongTensor: """ Generate text from image inputs. Args: pixel_values: Input images [B*5, 3, 512, 512] max_length: Maximum generation length num_beams: Number of beams for beam search (1 = greedy) TODO: Not implemented temperature: Sampling temperature do_sample: Whether to use sampling Returns: Generated token sequences """ # Encode vision inputs encoder_hidden_states = self._encode_vision(pixel_values) batch_size = pixel_values.shape[0] // 5 # Start with decoder_start_token_id decoder_input_ids = torch.full( (batch_size, 1), self._get_decoder_start_token_id(), dtype=torch.long, device=pixel_values.device ) generated_tokens = [] past_key_values = None for step in range(max_length): with torch.no_grad(): # Get decoder outputs decoder_outputs = self.seq2seq_decoder( input_ids=decoder_input_ids if past_key_values is None else decoder_input_ids[:, -1:], encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values, use_cache=True, ) past_key_values = decoder_outputs.past_key_values # Get logits and generate next token hidden_states = decoder_outputs[0][:, -1:, :] lm_logits = self.seq2seq_lm_head(hidden_states) # Apply temperature if temperature != 1.0: lm_logits = lm_logits / temperature # Get next token if do_sample: probs = torch.softmax(lm_logits[:, -1, :], dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True) # Append to generated tokens generated_tokens.append(next_token) decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1) # Check for EOS if (next_token == self.config.eos_token_id).all(): break return decoder_input_ids def _get_decoder_start_token_id(self) -> int: """Get decoder start token ID.""" return self.config.decoder_start_token_id or self.config.pad_token_id