""" Caption Model Module Manages BLIP and GIT models for image caption generation. Handles model loading, inference, and memory management. """ import torch from PIL import Image from typing import Optional, Dict, Tuple from transformers import ( BlipProcessor, BlipForConditionalGeneration, AutoProcessor, AutoModelForCausalLM ) import gc from config import model_config class CaptionModelError(Exception): """Custom exception for caption model errors""" pass class CaptionModel: """ Base class for caption generation models Provides common interface for BLIP and GIT models """ def __init__(self, model_name: str, device: str = "cuda"): """ Initialize caption model Args: model_name: HuggingFace model identifier device: Device to load model on (cuda/cpu) """ self.model_name = model_name self.device = self._get_device(device) self.processor = None self.model = None self._is_loaded = False def _get_device(self, requested_device: str) -> str: """ Determine available device Args: requested_device: Requested device (cuda/cpu) Returns: str: Available device """ if requested_device == "cuda" and torch.cuda.is_available(): return "cuda" return "cpu" def load(self) -> bool: """ Load model into memory Returns: bool: True if successful """ raise NotImplementedError("Subclass must implement load()") def generate_caption( self, image: Image.Image, max_length: int = 50, num_beams: int = 3 ) -> str: """ Generate caption for image Args: image: PIL Image max_length: Maximum caption length num_beams: Number of beams for beam search Returns: str: Generated caption """ raise NotImplementedError("Subclass must implement generate_caption()") def unload(self) -> None: """Unload model from memory""" if self.model is not None: del self.model self.model = None if self.processor is not None: del self.processor self.processor = None gc.collect() if self.device == "cuda": torch.cuda.empty_cache() self._is_loaded = False def is_loaded(self) -> bool: """Check if model is loaded""" return self._is_loaded def get_info(self) -> dict: """Get model information""" return { "model_name": self.model_name, "device": self.device, "is_loaded": self._is_loaded } class BLIPModel(CaptionModel): """ BLIP (Bootstrapping Language-Image Pre-training) model Fast and efficient model for image captioning """ def __init__(self, device: str = "cuda"): """Initialize BLIP model""" super().__init__(model_config.BLIP_MODEL_NAME, device) self.max_length = model_config.BLIP_MAX_LENGTH self.num_beams = model_config.BLIP_NUM_BEAMS def load(self) -> bool: """ Load BLIP model and processor Returns: bool: True if successful """ try: print(f"Loading BLIP model on {self.device}...") # Load processor self.processor = BlipProcessor.from_pretrained( self.model_name, cache_dir=model_config.MODEL_CACHE_DIR ) # Load model self.model = BlipForConditionalGeneration.from_pretrained( self.model_name, cache_dir=model_config.MODEL_CACHE_DIR, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # Set to evaluation mode self.model.eval() self._is_loaded = True print(f"✓ BLIP model loaded successfully on {self.device}") return True except Exception as e: print(f"Error loading BLIP model: {e}") self._is_loaded = False return False def generate_caption( self, image: Image.Image, max_length: Optional[int] = None, num_beams: Optional[int] = None ) -> str: """ Generate caption using BLIP Args: image: PIL Image max_length: Maximum caption length num_beams: Number of beams for beam search Returns: str: Generated caption Raises: CaptionModelError: If generation fails """ if not self._is_loaded: raise CaptionModelError("BLIP model not loaded") try: # Use default values if not provided max_length = max_length or self.max_length num_beams = num_beams or self.num_beams # Preprocess image inputs = self.processor( images=image, return_tensors="pt" ).to(self.device) # Generate caption with torch.no_grad(): output_ids = self.model.generate( **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True ) # Decode caption caption = self.processor.decode( output_ids[0], skip_special_tokens=True ) return caption.strip() except Exception as e: raise CaptionModelError(f"BLIP caption generation failed: {e}") class GITModel(CaptionModel): """ GIT (Generative Image-to-text Transformer) model More detailed and accurate captions compared to BLIP """ def __init__(self, device: str = "cuda"): """Initialize GIT model""" super().__init__(model_config.GIT_MODEL_NAME, device) self.max_length = model_config.GIT_MAX_LENGTH self.num_beams = model_config.GIT_NUM_BEAMS def load(self) -> bool: """ Load GIT model and processor Returns: bool: True if successful """ try: print(f"Loading GIT model on {self.device}...") # Load processor self.processor = AutoProcessor.from_pretrained( self.model_name, cache_dir=model_config.MODEL_CACHE_DIR ) # Load model self.model = AutoModelForCausalLM.from_pretrained( self.model_name, cache_dir=model_config.MODEL_CACHE_DIR, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # Set to evaluation mode self.model.eval() self._is_loaded = True print(f"✓ GIT model loaded successfully on {self.device}") return True except Exception as e: print(f"Error loading GIT model: {e}") self._is_loaded = False return False def generate_caption( self, image: Image.Image, max_length: Optional[int] = None, num_beams: Optional[int] = None ) -> str: """ Generate caption using GIT Args: image: PIL Image max_length: Maximum caption length num_beams: Number of beams for beam search Returns: str: Generated caption Raises: CaptionModelError: If generation fails """ if not self._is_loaded: raise CaptionModelError("GIT model not loaded") try: # Use default values if not provided max_length = max_length or self.max_length num_beams = num_beams or self.num_beams # Preprocess image inputs = self.processor( images=image, return_tensors="pt" ).to(self.device) # Generate caption with torch.no_grad(): output_ids = self.model.generate( pixel_values=inputs.pixel_values, max_length=max_length, num_beams=num_beams, early_stopping=True ) # Decode caption caption = self.processor.batch_decode( output_ids, skip_special_tokens=True )[0] return caption.strip() except Exception as e: raise CaptionModelError(f"GIT caption generation failed: {e}") class CaptionModelManager: """ Manager for both BLIP and GIT models Provides unified interface and handles model lifecycle """ def __init__(self, device: Optional[str] = None): """ Initialize model manager Args: device: Device to use (cuda/cpu), auto-detects if None """ self.device = device or model_config.DEVICE # Initialize models self.blip_model = BLIPModel(self.device) self.git_model = GITModel(self.device) # Track which models are loaded self._loaded_models = set() def load_all_models(self) -> Tuple[bool, bool]: """ Load both models Returns: Tuple[bool, bool]: (blip_success, git_success) """ blip_success = self.blip_model.load() if blip_success: self._loaded_models.add("blip") git_success = self.git_model.load() if git_success: self._loaded_models.add("git") return blip_success, git_success def load_model(self, model_name: str) -> bool: """ Load specific model Args: model_name: Model to load ("blip" or "git") Returns: bool: True if successful """ if model_name.lower() == "blip": success = self.blip_model.load() if success: self._loaded_models.add("blip") return success elif model_name.lower() == "git": success = self.git_model.load() if success: self._loaded_models.add("git") return success else: raise ValueError(f"Unknown model: {model_name}") def generate_captions( self, image: Image.Image ) -> Dict[str, str]: """ Generate captions from all loaded models Args: image: PIL Image Returns: Dict[str, str]: Captions from each model """ captions = {} if "blip" in self._loaded_models: try: captions["blip"] = self.blip_model.generate_caption(image) except Exception as e: captions["blip"] = f"Error: {str(e)}" if "git" in self._loaded_models: try: captions["git"] = self.git_model.generate_caption(image) except Exception as e: captions["git"] = f"Error: {str(e)}" return captions def unload_all_models(self) -> None: """Unload all models from memory""" self.blip_model.unload() self.git_model.unload() self._loaded_models.clear() def get_status(self) -> dict: """Get status of all models""" return { "device": self.device, "blip": { "loaded": self.blip_model.is_loaded(), "info": self.blip_model.get_info() }, "git": { "loaded": self.git_model.is_loaded(), "info": self.git_model.get_info() }, "loaded_models": list(self._loaded_models) } # Singleton instance _model_manager = None def get_model_manager() -> CaptionModelManager: """Get singleton CaptionModelManager instance""" global _model_manager if _model_manager is None: _model_manager = CaptionModelManager() return _model_manager if __name__ == "__main__": # Test the caption models print("=" * 60) print("CAPTION MODELS - TEST MODE") print("=" * 60) # Initialize manager manager = CaptionModelManager() print(f"\n✓ Model manager initialized") print(f" Device: {manager.device}") print("\n" + "=" * 60) print("Loading models (this may take a few minutes)...") print("=" * 60) # Load models blip_success, git_success = manager.load_all_models() print(f"\nBLIP: {'✓ Loaded' if blip_success else '✗ Failed'}") print(f"GIT: {'✓ Loaded' if git_success else '✗ Failed'}") print("\n" + "=" * 60) print("Model Status:") print("=" * 60) status = manager.get_status() for key, value in status.items(): if isinstance(value, dict): print(f"{key}:") for k, v in value.items(): print(f" {k}: {v}") else: print(f"{key}: {value}") print("\n" + "=" * 60) print("✓ Caption models test complete") print("=" * 60) print("\nTo test caption generation, provide a test image:") print(" from PIL import Image") print(" img = Image.open('your_image.jpg')") print(" captions = manager.generate_captions(img)") print(" print(captions)")