Spaces:
Sleeping
Sleeping
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| import torch | |
| from PIL import Image | |
| from typing import Optional, List, Tuple | |
| class ImageCaptioner: | |
| def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b"): | |
| """ | |
| Initialize BLIP-2 model for image captioning | |
| Args: | |
| model_name: HuggingFace model identifier for BLIP-2 | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.processor = Blip2Processor.from_pretrained(model_name) | |
| self.model = Blip2ForConditionalGeneration.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ).to(self.device) | |
| def generate_caption( | |
| self, | |
| image: Image.Image, | |
| prompt: Optional[str] = None, | |
| max_length: int = 50, | |
| num_beams: int = 5 | |
| ) -> str: | |
| """ | |
| Generate caption for an image | |
| Args: | |
| image: PIL Image to caption | |
| prompt: Optional prompt to guide caption generation | |
| max_length: Maximum length of generated caption | |
| num_beams: Number of beams for beam search | |
| Returns: | |
| Generated caption string | |
| """ | |
| # Prepare image for the model | |
| inputs = self.processor( | |
| images=image, | |
| text=prompt if prompt else "a photo of", | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate caption | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| min_length=10, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.5 | |
| ) | |
| # Decode the generated caption | |
| generated_text = self.processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| )[0].strip() | |
| return generated_text | |
| def batch_generate_captions( | |
| self, | |
| images: List[Image.Image], | |
| prompt: Optional[str] = None, | |
| batch_size: int = 4 | |
| ) -> List[str]: | |
| """ | |
| Generate captions for multiple images | |
| Args: | |
| images: List of PIL Images | |
| prompt: Optional prompt for all images | |
| batch_size: Number of images to process at once | |
| Returns: | |
| List of generated captions | |
| """ | |
| captions = [] | |
| # Process images in batches | |
| for i in range(0, len(images), batch_size): | |
| batch = images[i:i + batch_size] | |
| batch_captions = [ | |
| self.generate_caption(img, prompt) | |
| for img in batch | |
| ] | |
| captions.extend(batch_captions) | |
| return captions | |