AsherKnight
changes
c2c5906
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
class ImageCaptioning:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load BLIP
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(self.device)
# Load GIT
self.git_processor = AutoProcessor.from_pretrained("microsoft/git-base")
self.git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
def generate_caption_blip(self, image):
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.blip_model.generate(**inputs)
caption = self.blip_processor.decode(output[0], skip_special_tokens=True)
return caption, self.compute_logprob(self.blip_model, inputs, output, self.blip_processor)
def generate_caption_git(self, image):
inputs = self.git_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
generated_ids = self.git_model.generate(**inputs)
caption = self.git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return caption, self.compute_logprob(self.git_model, inputs, generated_ids, self.git_processor)
def generate_caption_clip(self, image):
# Step 1: Generate caption candidates
caption_blip = self.generate_caption_blip(image)
caption_git = self.generate_caption_git(image)
candidates = [caption_blip, caption_git]
# Extract text-only for CLIP scoring
captions_only = [c[0] for c in candidates]
# Step 2: Score them with CLIP
inputs = self.clip_processor(text=captions_only, images=image, return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
outputs = self.clip_model(**inputs)
scores = outputs.logits_per_image[0] # shape: (num_captions,)
scores = scores.softmax(dim=0) # optional: normalize scores
best_idx = scores.argmax().item()
best_caption = candidates[best_idx]
best_score = scores[best_idx].item()
return best_caption[0], best_score # returning the caption text and score
def compute_logprob(self, model, inputs, generated_ids, processor):
# Decode the generated tokens to text
caption_text = processor.decode(generated_ids[0], skip_special_tokens=True)
# Tokenize the caption (text) to get labels and input_ids
text_inputs = processor(text=caption_text, return_tensors="pt").to(self.device)
labels = text_inputs["input_ids"]
# Combine image inputs with the new input_ids (needed for loss computation)
model_inputs = {**inputs, "input_ids": text_inputs["input_ids"]}
# Compute the loss
with torch.no_grad():
outputs = model(**model_inputs, labels=labels)
return -outputs.loss.item() # Higher is better
def get_best_caption(self, image):
# This runs BLIP and GIT, then scores both with CLIP to pick the best caption
caption, score = self.generate_caption_clip(image)
return caption, score