from transformers import Blip2Processor, Blip2ForConditionalGeneration import torch from PIL import Image from typing import Optional, List, Tuple class ImageCaptioner: def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b"): """ Initialize BLIP-2 model for image captioning Args: model_name: HuggingFace model identifier for BLIP-2 """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = Blip2Processor.from_pretrained(model_name) self.model = Blip2ForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) def generate_caption( self, image: Image.Image, prompt: Optional[str] = None, max_length: int = 50, num_beams: int = 5 ) -> str: """ Generate caption for an image Args: image: PIL Image to caption prompt: Optional prompt to guide caption generation max_length: Maximum length of generated caption num_beams: Number of beams for beam search Returns: Generated caption string """ # Prepare image for the model inputs = self.processor( images=image, text=prompt if prompt else "a photo of", return_tensors="pt" ).to(self.device) # Generate caption with torch.no_grad(): generated_ids = self.model.generate( **inputs, max_length=max_length, num_beams=num_beams, min_length=10, do_sample=True, top_p=0.9, repetition_penalty=1.5 ) # Decode the generated caption generated_text = self.processor.batch_decode( generated_ids, skip_special_tokens=True )[0].strip() return generated_text def batch_generate_captions( self, images: List[Image.Image], prompt: Optional[str] = None, batch_size: int = 4 ) -> List[str]: """ Generate captions for multiple images Args: images: List of PIL Images prompt: Optional prompt for all images batch_size: Number of images to process at once Returns: List of generated captions """ captions = [] # Process images in batches for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] batch_captions = [ self.generate_caption(img, prompt) for img in batch ] captions.extend(batch_captions) return captions