Spaces:
Runtime error
Runtime error
| """ | |
| Caption Generator Plugin | |
| Generates descriptive captions for images using BLIP-2. | |
| """ | |
| from typing import Dict, Any | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| from loguru import logger | |
| from plugins.base import BasePlugin, PluginMetadata | |
| class CaptionGeneratorPlugin(BasePlugin): | |
| """ | |
| Generate captions for images using BLIP-2. | |
| Creates natural language descriptions of image content. | |
| """ | |
| def __init__(self): | |
| """Initialize CaptionGeneratorPlugin.""" | |
| super().__init__() | |
| self.model = None | |
| self.processor = None | |
| self.max_length = 50 | |
| def metadata(self) -> PluginMetadata: | |
| """Return plugin metadata.""" | |
| return PluginMetadata( | |
| name="caption_generator", | |
| version="0.1.0", | |
| description="Generates image captions using BLIP-2", | |
| author="AI Dev Collective", | |
| requires=["transformers", "torch"], | |
| category="captioning", | |
| priority=20, | |
| ) | |
| def initialize(self) -> None: | |
| """Initialize the plugin and load BLIP-2 model.""" | |
| try: | |
| # Import here to avoid loading if plugin is not used | |
| from transformers import ( | |
| Blip2Processor, | |
| Blip2ForConditionalGeneration | |
| ) | |
| logger.info("Loading BLIP-2 model...") | |
| # Use smaller BLIP-2 model for faster inference | |
| model_name = "Salesforce/blip2-opt-2.7b" | |
| # Load processor and model | |
| self.processor = Blip2Processor.from_pretrained(model_name) | |
| self.model = Blip2ForConditionalGeneration.from_pretrained( | |
| model_name | |
| ) | |
| # Set to eval mode | |
| self.model.eval() | |
| # Move to CPU (GPU support can be added later) | |
| device = "cpu" | |
| self.model.to(device) | |
| self._initialized = True | |
| logger.info( | |
| f"BLIP-2 model loaded successfully on {device}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize CaptionGeneratorPlugin: {e}") | |
| # Fallback: try smaller BLIP model | |
| try: | |
| logger.info("Trying smaller BLIP model...") | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| model_name = "Salesforce/blip-image-captioning-base" | |
| self.processor = BlipProcessor.from_pretrained(model_name) | |
| self.model = BlipForConditionalGeneration.from_pretrained( | |
| model_name | |
| ) | |
| self.model.eval() | |
| self.model.to("cpu") | |
| self._initialized = True | |
| logger.info("BLIP base model loaded successfully") | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback also failed: {fallback_error}") | |
| raise | |
| def _generate_caption( | |
| self, | |
| image: Image.Image, | |
| max_length: int = 50 | |
| ) -> str: | |
| """ | |
| Generate caption for image. | |
| Args: | |
| image: PIL Image | |
| max_length: Maximum caption length | |
| Returns: | |
| Generated caption string | |
| """ | |
| import torch | |
| # Prepare inputs | |
| inputs = self.processor( | |
| images=image, | |
| return_tensors="pt" | |
| ) | |
| # Generate caption | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| # Decode caption | |
| caption = self.processor.decode( | |
| generated_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| return caption.strip() | |
| def analyze( | |
| self, | |
| media: Any, | |
| media_path: Path | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate caption for the image. | |
| Args: | |
| media: PIL Image or numpy array | |
| media_path: Path to image file | |
| Returns: | |
| Dictionary with caption | |
| """ | |
| try: | |
| # Check if initialized | |
| if not self._initialized: | |
| self.initialize() | |
| # Validate input | |
| if not self.validate_input(media): | |
| return {"error": "Invalid input type"} | |
| # Convert to PIL Image if numpy array | |
| if isinstance(media, np.ndarray): | |
| image = Image.fromarray( | |
| (media * 255).astype(np.uint8) if media.max() <= 1 | |
| else media.astype(np.uint8) | |
| ) | |
| else: | |
| image = media | |
| # Generate caption | |
| caption = self._generate_caption(image, self.max_length) | |
| # Analyze caption | |
| word_count = len(caption.split()) | |
| result = { | |
| "caption": caption, | |
| "word_count": word_count, | |
| "character_count": len(caption), | |
| "max_length": self.max_length, | |
| "status": "success", | |
| } | |
| logger.debug(f"Caption generated: '{caption[:50]}...'") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Caption generation failed: {e}") | |
| return { | |
| "error": str(e), | |
| "status": "failed" | |
| } | |
| def cleanup(self) -> None: | |
| """Clean up model resources.""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.processor is not None: | |
| del self.processor | |
| self.processor = None | |
| logger.info("CaptionGeneratorPlugin cleanup complete") | |