deepvision-prompt-builder / plugins /caption_generator.py
Salman Abjam
Initial deployment: DeepVision Prompt Builder v0.1.0
eb5a9e1
"""
Caption Generator Plugin
Generates descriptive captions for images using BLIP-2.
"""
from typing import Dict, Any
from pathlib import Path
import numpy as np
from PIL import Image
from loguru import logger
from plugins.base import BasePlugin, PluginMetadata
class CaptionGeneratorPlugin(BasePlugin):
"""
Generate captions for images using BLIP-2.
Creates natural language descriptions of image content.
"""
def __init__(self):
"""Initialize CaptionGeneratorPlugin."""
super().__init__()
self.model = None
self.processor = None
self.max_length = 50
@property
def metadata(self) -> PluginMetadata:
"""Return plugin metadata."""
return PluginMetadata(
name="caption_generator",
version="0.1.0",
description="Generates image captions using BLIP-2",
author="AI Dev Collective",
requires=["transformers", "torch"],
category="captioning",
priority=20,
)
def initialize(self) -> None:
"""Initialize the plugin and load BLIP-2 model."""
try:
# Import here to avoid loading if plugin is not used
from transformers import (
Blip2Processor,
Blip2ForConditionalGeneration
)
logger.info("Loading BLIP-2 model...")
# Use smaller BLIP-2 model for faster inference
model_name = "Salesforce/blip2-opt-2.7b"
# Load processor and model
self.processor = Blip2Processor.from_pretrained(model_name)
self.model = Blip2ForConditionalGeneration.from_pretrained(
model_name
)
# Set to eval mode
self.model.eval()
# Move to CPU (GPU support can be added later)
device = "cpu"
self.model.to(device)
self._initialized = True
logger.info(
f"BLIP-2 model loaded successfully on {device}"
)
except Exception as e:
logger.error(f"Failed to initialize CaptionGeneratorPlugin: {e}")
# Fallback: try smaller BLIP model
try:
logger.info("Trying smaller BLIP model...")
from transformers import BlipProcessor, BlipForConditionalGeneration
model_name = "Salesforce/blip-image-captioning-base"
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(
model_name
)
self.model.eval()
self.model.to("cpu")
self._initialized = True
logger.info("BLIP base model loaded successfully")
except Exception as fallback_error:
logger.error(f"Fallback also failed: {fallback_error}")
raise
def _generate_caption(
self,
image: Image.Image,
max_length: int = 50
) -> str:
"""
Generate caption for image.
Args:
image: PIL Image
max_length: Maximum caption length
Returns:
Generated caption string
"""
import torch
# Prepare inputs
inputs = self.processor(
images=image,
return_tensors="pt"
)
# Generate caption
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_length=max_length,
num_beams=5,
early_stopping=True
)
# Decode caption
caption = self.processor.decode(
generated_ids[0],
skip_special_tokens=True
)
return caption.strip()
def analyze(
self,
media: Any,
media_path: Path
) -> Dict[str, Any]:
"""
Generate caption for the image.
Args:
media: PIL Image or numpy array
media_path: Path to image file
Returns:
Dictionary with caption
"""
try:
# Check if initialized
if not self._initialized:
self.initialize()
# Validate input
if not self.validate_input(media):
return {"error": "Invalid input type"}
# Convert to PIL Image if numpy array
if isinstance(media, np.ndarray):
image = Image.fromarray(
(media * 255).astype(np.uint8) if media.max() <= 1
else media.astype(np.uint8)
)
else:
image = media
# Generate caption
caption = self._generate_caption(image, self.max_length)
# Analyze caption
word_count = len(caption.split())
result = {
"caption": caption,
"word_count": word_count,
"character_count": len(caption),
"max_length": self.max_length,
"status": "success",
}
logger.debug(f"Caption generated: '{caption[:50]}...'")
return result
except Exception as e:
logger.error(f"Caption generation failed: {e}")
return {
"error": str(e),
"status": "failed"
}
def cleanup(self) -> None:
"""Clean up model resources."""
if self.model is not None:
del self.model
self.model = None
if self.processor is not None:
del self.processor
self.processor = None
logger.info("CaptionGeneratorPlugin cleanup complete")