Captioning / src /models /model.py
Mokhtar
Deploying backend code
e4721a6
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import (
GPT2LMHeadModel,
VisionEncoderDecoderModel,
ViTImageProcessor,
AutoTokenizer,
BlipProcessor,
BlipForConditionalGeneration
)
# -----------------------------------------------------------------------------
# 1. Custom ResNet + GPT-2 (Training from Scratch)
# -----------------------------------------------------------------------------
class ResNetEncoder(nn.Module):
def __init__(self, embed_dim=768):
super(ResNetEncoder, self).__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
for param in self.resnet.parameters():
param.requires_grad = False
self.projection = nn.Linear(2048, embed_dim)
self.bn = nn.BatchNorm1d(embed_dim, momentum=0.01)
def forward(self, images):
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.projection(features)
features = self.bn(features)
return features
class ResNetGPT2(nn.Module):
def __init__(self, max_seq_len=40):
super(ResNetGPT2, self).__init__()
self.encoder = ResNetEncoder(embed_dim=768)
self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
self.max_seq_len = max_seq_len
def forward(self, images, input_ids, attention_mask):
image_embeds = self.encoder(images)
token_embeds = self.gpt2.transformer.wte(input_ids)
inputs_embeds = torch.cat((image_embeds.unsqueeze(1), token_embeds), dim=1)
batch_size = images.shape[0]
ones = torch.ones(batch_size, 1).to(images.device)
attention_mask = torch.cat((ones, attention_mask), dim=1)
outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
return outputs.logits
def generate_caption(self, image, tokenizer, max_length=20, temperature=1.0):
self.eval()
with torch.no_grad():
image_embed = self.encoder(image.unsqueeze(0))
inputs_embeds = image_embed.unsqueeze(1)
generated_tokens = []
for _ in range(max_length):
outputs = self.gpt2(inputs_embeds=inputs_embeds)
logits = outputs.logits[:, -1, :] / temperature
next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
if next_token.item() == tokenizer.eos_token_id:
break
generated_tokens.append(next_token.item())
next_token_embed = self.gpt2.transformer.wte(next_token)
inputs_embeds = torch.cat((inputs_embeds, next_token_embed), dim=1)
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
# -----------------------------------------------------------------------------
# 2. ViT + GPT-2 (Pre-trained SOTA 1)
# -----------------------------------------------------------------------------
class ViTGPT2Captioner(nn.Module):
def __init__(self):
super().__init__()
print("Loading ViT-GPT2 model...")
self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
self.feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
self.tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def generate_caption(self, image, **kwargs):
self.eval()
with torch.no_grad():
pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.model.device)
output_ids = self.model.generate(pixel_values, max_length=20, num_beams=4)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return preds[0].strip()
# -----------------------------------------------------------------------------
# 3. BLIP (Pre-trained SOTA 2 - Best)
# -----------------------------------------------------------------------------
class BLIPCaptioner(nn.Module):
def __init__(self):
super().__init__()
print("Loading BLIP model (Salesforce/blip-image-captioning-large)...")
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
def generate_caption(self, image, **kwargs):
self.eval()
with torch.no_grad():
inputs = self.processor(images=image, return_tensors="pt").to(self.model.device)
output_ids = self.model.generate(**inputs, max_length=50, num_beams=5, repetition_penalty=1.2, min_length=5)
caption = self.processor.decode(output_ids[0], skip_special_tokens=True)
return caption
# -----------------------------------------------------------------------------
# Factory
# -----------------------------------------------------------------------------
def get_model(config):
if config.MODEL_TYPE == "resnet_gpt2":
return ResNetGPT2()
elif config.MODEL_TYPE == "vit_gpt2":
return ViTGPT2Captioner()
elif config.MODEL_TYPE == "blip":
return BLIPCaptioner()
else:
raise ValueError(f"Unknown model type: {config.MODEL_TYPE}")