File size: 10,931 Bytes
244d6df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, SiglipVisionModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from .config import LiteVit5Config
class LiteVit5ForConditionalGeneration(PreTrainedModel):
"""
LiteVit5 model for vision-to-text generation tasks.
Combines SigLIP vision encoder with T5 seq2seq decoder for image-to-text tasks.
"""
config_class = LiteVit5Config
base_model_prefix = "litevit5"
def __init__(self, config):
super().__init__(config)
self.config = config
# Vision model (frozen)
self.vision_model = SiglipVisionModel.from_pretrained(
"google/siglip2-base-patch16-512",
dtype=torch.float16
)
self.vision_model.eval()
for param in self.vision_model.parameters():
param.requires_grad = False
# Load seq2seq decoder and lm_head from CodeT5
seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained(
"Salesforce/codet5-base",
dtype=torch.float16
)
self.seq2seq_decoder = seq2seq_model.decoder
self.seq2seq_lm_head = seq2seq_model.lm_head
self._shift_right = seq2seq_model._shift_right
# Vision processing layers
self.downsampler = nn.Conv2d(768, 768, kernel_size=2, stride=2, bias=False, dtype=torch.float16)
self.fuse = nn.Linear(768 * 2, 768).half()
self.pos_embedding = nn.Parameter(torch.zeros(1, 1024, 768, dtype=torch.float16), requires_grad=True)
self.linear_projection = nn.Linear(768, 768).half()
self.post_init()
def get_encoder(self):
"""Return the vision encoder for the model."""
return self.vision_model
def get_decoder(self):
"""Return the seq2seq decoder."""
return self.seq2seq_decoder
def _encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Encode image inputs into vision features.
Args:
pixel_values: Input images of shape [B*5, 3, 512, 512] (5 views per sample)
Returns:
Encoded vision features of shape [B, 1024, 768]
"""
# Ensure pixel_values are float16
pixel_values = pixel_values.half()
batch_size = pixel_values.size(0) // 5
scale = 5 # Number of views (4 quarter views + 1 full view)
num_patches = 32
# Get vision embeddings
with torch.no_grad():
vision_model_outputs = self.vision_model(pixel_values=pixel_values)
vision_hidden_states = vision_model_outputs.last_hidden_state # [B*5, 1024, 768]
# Reshape to separate views
vision_hidden_states = vision_hidden_states.view(batch_size, scale, *vision_hidden_states.shape[1:]) # [B, 5, 1024, 768]
# Process quarter views
quarters = vision_hidden_states[:, :4] # [B, 4, 1024, 768]
quarters = quarters.view(batch_size, 4, num_patches, num_patches, -1) # [B, 4, 32, 32, 768]
# Combine quarter views into full image
upper = torch.cat([quarters[:, 0], quarters[:, 1]], dim=2) # [B, 32, 64, 768]
lower = torch.cat([quarters[:, 2], quarters[:, 3]], dim=2) # [B, 32, 64, 768]
pooled_image = torch.cat([upper, lower], dim=1) # [B, 64, 64, 768]
pooled_image = pooled_image.permute(0, 3, 1, 2) # [B, 768, 64, 64]
# Downsample
pooled32 = self.downsampler(pooled_image) # [B, 768, 32, 32]
pooled_tok = pooled32.flatten(2).transpose(1, 2) # [B, 1024, 768]
# Full image features
full_image = vision_hidden_states[:, 4] # [B, 1024, 768]
# Fuse quarter and full views
concat = torch.cat([pooled_tok, full_image], dim=-1) # [B, 1024, 1536]
fused = self.fuse(concat) # [B, 1024, 768]
# Add positional encoding and project
fused = fused + self.pos_embedding
vision_hidden_states = self.linear_projection(fused) # [B, 1024, 768]
return vision_hidden_states
def forward(
self,
pixel_values: torch.Tensor,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> Seq2SeqLMOutput:
"""
Forward pass for the model.
Args:
pixel_values: Vision input images
input_ids: Decoder input token IDs
labels: Target token IDs for training
decoder_input_ids: Decoder input IDs (used during generation)
past_key_values: Cached key values for efficient generation
attention_mask: Attention mask for decoder inputs
Returns:
Seq2SeqLMOutput with loss, logits, and generation-related outputs
"""
# Encode images
encoder_hidden_states = self._encode_vision(pixel_values)
# Prepare decoder input IDs
if decoder_input_ids is None and input_ids is None:
decoder_input_ids = self._get_decoder_start_token_id()
decoder_input_ids = torch.full(
(pixel_values.shape[0] // 5, 1),
decoder_input_ids,
dtype=torch.long,
device=pixel_values.device
)
if decoder_input_ids is None and input_ids is not None:
decoder_input_ids = self._shift_right(input_ids)
# Pass through decoder
decoder_outputs = self.seq2seq_decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
)
sequence_output = decoder_outputs[0]
lm_logits = self.seq2seq_lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
"""Prepare inputs for generation."""
# Cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is already defined
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache,
}
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
):
"""Encode pixel values to get encoder outputs."""
# Encode images if not already done
if "encoder_outputs" not in model_kwargs:
encoder_outputs = self._encode_vision(inputs_tensor)
model_kwargs["encoder_outputs"] = (encoder_outputs,)
return model_kwargs
def generate(
self,
pixel_values: torch.Tensor,
max_length: int = 1024,
num_beams: int = 1,
temperature: float = 1.0,
do_sample: bool = False,
**kwargs
) -> torch.LongTensor:
"""
Generate text from image inputs.
Args:
pixel_values: Input images [B*5, 3, 512, 512]
max_length: Maximum generation length
num_beams: Number of beams for beam search (1 = greedy) TODO: Not implemented
temperature: Sampling temperature
do_sample: Whether to use sampling
Returns:
Generated token sequences
"""
# Encode vision inputs
encoder_hidden_states = self._encode_vision(pixel_values)
batch_size = pixel_values.shape[0] // 5
# Start with decoder_start_token_id
decoder_input_ids = torch.full(
(batch_size, 1),
self._get_decoder_start_token_id(),
dtype=torch.long,
device=pixel_values.device
)
generated_tokens = []
past_key_values = None
for step in range(max_length):
with torch.no_grad():
# Get decoder outputs
decoder_outputs = self.seq2seq_decoder(
input_ids=decoder_input_ids if past_key_values is None else decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = decoder_outputs.past_key_values
# Get logits and generate next token
hidden_states = decoder_outputs[0][:, -1:, :]
lm_logits = self.seq2seq_lm_head(hidden_states)
# Apply temperature
if temperature != 1.0:
lm_logits = lm_logits / temperature
# Get next token
if do_sample:
probs = torch.softmax(lm_logits[:, -1, :], dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True)
# Append to generated tokens
generated_tokens.append(next_token)
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
# Check for EOS
if (next_token == self.config.eos_token_id).all():
break
return decoder_input_ids
def _get_decoder_start_token_id(self) -> int:
"""Get decoder start token ID."""
return self.config.decoder_start_token_id or self.config.pad_token_id
|