semantic_image_search / modules /image_captioner.py
Chamin09's picture
Update modules/image_captioner.py
dad7590 verified
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
from PIL import Image
from typing import Optional, List, Tuple
class ImageCaptioner:
def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b"):
"""
Initialize BLIP-2 model for image captioning
Args:
model_name: HuggingFace model identifier for BLIP-2
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = Blip2Processor.from_pretrained(model_name)
self.model = Blip2ForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
def generate_caption(
self,
image: Image.Image,
prompt: Optional[str] = None,
max_length: int = 50,
num_beams: int = 5
) -> str:
"""
Generate caption for an image
Args:
image: PIL Image to caption
prompt: Optional prompt to guide caption generation
max_length: Maximum length of generated caption
num_beams: Number of beams for beam search
Returns:
Generated caption string
"""
# Prepare image for the model
inputs = self.processor(
images=image,
text=prompt if prompt else "a photo of",
return_tensors="pt"
).to(self.device)
# Generate caption
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
min_length=10,
do_sample=True,
top_p=0.9,
repetition_penalty=1.5
)
# Decode the generated caption
generated_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0].strip()
return generated_text
def batch_generate_captions(
self,
images: List[Image.Image],
prompt: Optional[str] = None,
batch_size: int = 4
) -> List[str]:
"""
Generate captions for multiple images
Args:
images: List of PIL Images
prompt: Optional prompt for all images
batch_size: Number of images to process at once
Returns:
List of generated captions
"""
captions = []
# Process images in batches
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
batch_captions = [
self.generate_caption(img, prompt)
for img in batch
]
captions.extend(batch_captions)
return captions