Spaces:
Runtime error
Runtime error
| from transformers import BlipProcessor, BlipForConditionalGeneration, AutoProcessor, AutoModelForCausalLM | |
| from PIL import Image | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| class ImageCaptioning: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load BLIP | |
| self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(self.device) | |
| # Load GIT | |
| self.git_processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
| self.git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to(self.device) | |
| self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | |
| def generate_caption_blip(self, image): | |
| inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| output = self.blip_model.generate(**inputs) | |
| caption = self.blip_processor.decode(output[0], skip_special_tokens=True) | |
| return caption, self.compute_logprob(self.blip_model, inputs, output, self.blip_processor) | |
| def generate_caption_git(self, image): | |
| inputs = self.git_processor(images=image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| generated_ids = self.git_model.generate(**inputs) | |
| caption = self.git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return caption, self.compute_logprob(self.git_model, inputs, generated_ids, self.git_processor) | |
| def generate_caption_clip(self, image): | |
| # Step 1: Generate caption candidates | |
| caption_blip = self.generate_caption_blip(image) | |
| caption_git = self.generate_caption_git(image) | |
| candidates = [caption_blip, caption_git] | |
| # Extract text-only for CLIP scoring | |
| captions_only = [c[0] for c in candidates] | |
| # Step 2: Score them with CLIP | |
| inputs = self.clip_processor(text=captions_only, images=image, return_tensors="pt", padding=True).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.clip_model(**inputs) | |
| scores = outputs.logits_per_image[0] # shape: (num_captions,) | |
| scores = scores.softmax(dim=0) # optional: normalize scores | |
| best_idx = scores.argmax().item() | |
| best_caption = candidates[best_idx] | |
| best_score = scores[best_idx].item() | |
| return best_caption[0], best_score # returning the caption text and score | |
| def compute_logprob(self, model, inputs, generated_ids, processor): | |
| # Decode the generated tokens to text | |
| caption_text = processor.decode(generated_ids[0], skip_special_tokens=True) | |
| # Tokenize the caption (text) to get labels and input_ids | |
| text_inputs = processor(text=caption_text, return_tensors="pt").to(self.device) | |
| labels = text_inputs["input_ids"] | |
| # Combine image inputs with the new input_ids (needed for loss computation) | |
| model_inputs = {**inputs, "input_ids": text_inputs["input_ids"]} | |
| # Compute the loss | |
| with torch.no_grad(): | |
| outputs = model(**model_inputs, labels=labels) | |
| return -outputs.loss.item() # Higher is better | |
| def get_best_caption(self, image): | |
| # This runs BLIP and GIT, then scores both with CLIP to pick the best caption | |
| caption, score = self.generate_caption_clip(image) | |
| return caption, score | |