|
|
""" |
|
|
Caption Generator Plugin |
|
|
|
|
|
Generates descriptive captions for images using BLIP-2. |
|
|
""" |
|
|
|
|
|
from typing import Dict, Any |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from loguru import logger |
|
|
|
|
|
from plugins.base import BasePlugin, PluginMetadata |
|
|
|
|
|
|
|
|
class CaptionGeneratorPlugin(BasePlugin): |
|
|
""" |
|
|
Generate captions for images using BLIP-2. |
|
|
|
|
|
Creates natural language descriptions of image content. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize CaptionGeneratorPlugin.""" |
|
|
super().__init__() |
|
|
self.model = None |
|
|
self.processor = None |
|
|
self.max_length = 50 |
|
|
|
|
|
@property |
|
|
def metadata(self) -> PluginMetadata: |
|
|
"""Return plugin metadata.""" |
|
|
return PluginMetadata( |
|
|
name="caption_generator", |
|
|
version="0.1.0", |
|
|
description="Generates image captions using BLIP-2", |
|
|
author="AI Dev Collective", |
|
|
requires=["transformers", "torch"], |
|
|
category="captioning", |
|
|
priority=20, |
|
|
) |
|
|
|
|
|
def initialize(self) -> None: |
|
|
"""Initialize the plugin and load BLIP-2 model.""" |
|
|
try: |
|
|
|
|
|
from transformers import ( |
|
|
Blip2Processor, |
|
|
Blip2ForConditionalGeneration |
|
|
) |
|
|
|
|
|
logger.info("Loading BLIP-2 model...") |
|
|
|
|
|
|
|
|
model_name = "Salesforce/blip2-opt-2.7b" |
|
|
|
|
|
|
|
|
self.processor = Blip2Processor.from_pretrained(model_name) |
|
|
self.model = Blip2ForConditionalGeneration.from_pretrained( |
|
|
model_name |
|
|
) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
self.model.to(device) |
|
|
|
|
|
self._initialized = True |
|
|
|
|
|
logger.info( |
|
|
f"BLIP-2 model loaded successfully on {device}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize CaptionGeneratorPlugin: {e}") |
|
|
|
|
|
try: |
|
|
logger.info("Trying smaller BLIP model...") |
|
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
|
|
|
|
model_name = "Salesforce/blip-image-captioning-base" |
|
|
self.processor = BlipProcessor.from_pretrained(model_name) |
|
|
self.model = BlipForConditionalGeneration.from_pretrained( |
|
|
model_name |
|
|
) |
|
|
self.model.eval() |
|
|
self.model.to("cpu") |
|
|
self._initialized = True |
|
|
|
|
|
logger.info("BLIP base model loaded successfully") |
|
|
|
|
|
except Exception as fallback_error: |
|
|
logger.error(f"Fallback also failed: {fallback_error}") |
|
|
raise |
|
|
|
|
|
def _generate_caption( |
|
|
self, |
|
|
image: Image.Image, |
|
|
max_length: int = 50 |
|
|
) -> str: |
|
|
""" |
|
|
Generate caption for image. |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
max_length: Maximum caption length |
|
|
|
|
|
Returns: |
|
|
Generated caption string |
|
|
""" |
|
|
import torch |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
images=image, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
num_beams=5, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
caption = self.processor.decode( |
|
|
generated_ids[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return caption.strip() |
|
|
|
|
|
def analyze( |
|
|
self, |
|
|
media: Any, |
|
|
media_path: Path |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate caption for the image. |
|
|
|
|
|
Args: |
|
|
media: PIL Image or numpy array |
|
|
media_path: Path to image file |
|
|
|
|
|
Returns: |
|
|
Dictionary with caption |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
|
|
|
|
|
|
if not self.validate_input(media): |
|
|
return {"error": "Invalid input type"} |
|
|
|
|
|
|
|
|
if isinstance(media, np.ndarray): |
|
|
image = Image.fromarray( |
|
|
(media * 255).astype(np.uint8) if media.max() <= 1 |
|
|
else media.astype(np.uint8) |
|
|
) |
|
|
else: |
|
|
image = media |
|
|
|
|
|
|
|
|
caption = self._generate_caption(image, self.max_length) |
|
|
|
|
|
|
|
|
word_count = len(caption.split()) |
|
|
|
|
|
result = { |
|
|
"caption": caption, |
|
|
"word_count": word_count, |
|
|
"character_count": len(caption), |
|
|
"max_length": self.max_length, |
|
|
"status": "success", |
|
|
} |
|
|
|
|
|
logger.debug(f"Caption generated: '{caption[:50]}...'") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Caption generation failed: {e}") |
|
|
return { |
|
|
"error": str(e), |
|
|
"status": "failed" |
|
|
} |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""Clean up model resources.""" |
|
|
if self.model is not None: |
|
|
del self.model |
|
|
self.model = None |
|
|
|
|
|
if self.processor is not None: |
|
|
del self.processor |
|
|
self.processor = None |
|
|
|
|
|
logger.info("CaptionGeneratorPlugin cleanup complete") |
|
|
|