Spaces:
Build error
Build error
| import os | |
| import torch | |
| from PIL import Image | |
| # Import core model components and utilities using ABSOLUTE IMPORTS | |
| # Since 'src' is added to sys.path, we refer to modules directly under 'src'. | |
| from src.model import ImageCaptioningModel | |
| from src.data_preprocessing import COCOVocabulary | |
| from src.utils import get_logger, get_eval_transform | |
| from src.config import INFERENCE_CONFIG, update_config_with_latest_model # Import global config | |
| logger = get_logger(__name__) | |
| # --- Global variables to store the loaded model and vocabulary --- | |
| # These will be loaded once when this module is first imported. | |
| _model = None | |
| _vocabulary = None | |
| _device = None | |
| _transform = None | |
| def _load_model_and_vocabulary(): | |
| """ | |
| Loads the image captioning model and vocabulary. | |
| This function should be called only once during application startup. | |
| """ | |
| global _model, _vocabulary, _device, _transform | |
| if _model is not None: | |
| logger.info("Model and vocabulary already loaded.") | |
| return | |
| logger.info("Initializing model and vocabulary for web inference...") | |
| # Update INFERENCE_CONFIG with the path to the latest best model | |
| # This ensures the web app uses the correct trained model. | |
| update_config_with_latest_model(INFERENCE_CONFIG) | |
| model_path = INFERENCE_CONFIG['model_path'] | |
| example_image_path = INFERENCE_CONFIG['example_image_path'] # Not directly used for inference, but useful for context | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model checkpoint not found at {model_path}. " | |
| "Please ensure the model is trained and saved.") | |
| raise FileNotFoundError(f"Model checkpoint not found: {model_path}") | |
| try: | |
| # Load the complete checkpoint (model state, vocabulary, config) | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| # Extract configuration and vocabulary from the checkpoint | |
| model_config_from_checkpoint = checkpoint.get('config', {}) | |
| _vocabulary = checkpoint['vocabulary'] | |
| # Initialize the model structure with parameters saved in the checkpoint | |
| _model = ImageCaptioningModel( | |
| vocab_size=_vocabulary.vocab_size, | |
| embed_dim=model_config_from_checkpoint.get('embed_dim', 256), | |
| attention_dim=model_config_from_checkpoint.get('attention_dim', 256), | |
| decoder_dim=model_config_from_checkpoint.get('decoder_dim', 256), | |
| dropout=0.0, # Dropout should be off during inference | |
| fine_tune_encoder=False, # Encoder should not be fine-tuned during inference | |
| max_caption_length=INFERENCE_CONFIG.get('max_caption_length', 20) | |
| ) | |
| # Load the trained weights into the model | |
| _model.load_state_dict(checkpoint['model_state_dict']) | |
| _model.eval() # Set the model to evaluation mode (important for batch norm, dropout) | |
| # Determine the device to run inference on | |
| _device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| _model = _model.to(_device) # Move model to GPU if available | |
| logger.info(f"Model loaded successfully on device: {_device}") | |
| # Get the image transformation pipeline for evaluation/inference | |
| _transform = get_eval_transform() | |
| logger.info("Model and vocabulary are ready for inference.") | |
| except Exception as e: | |
| logger.critical(f"Failed to load model or vocabulary: {e}", exc_info=True) | |
| # Reraise the exception to prevent the Flask app from starting without the model | |
| raise | |
| # Call the loading function immediately when this module is imported | |
| # This ensures the model is loaded only once when the Flask app starts | |
| _load_model_and_vocabulary() | |
| def generate_caption_for_image(image_path: str) -> str: | |
| """ | |
| Generates a caption for a given image path using the pre-loaded model. | |
| Args: | |
| image_path (str): The full path to the image file. | |
| Returns: | |
| str: The generated caption. | |
| Raises: | |
| FileNotFoundError: If the image file does not exist. | |
| Exception: For errors during image loading or caption generation. | |
| """ | |
| if _model is None or _vocabulary is None or _transform is None or _device is None: | |
| logger.error("Model or vocabulary not loaded. Cannot generate caption.") | |
| raise RuntimeError("Image captioning model is not initialized.") | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found at {image_path}.") | |
| logger.info(f"Processing image: {image_path}") | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| image_tensor = _transform(image).to(_device) | |
| except Exception as e: | |
| raise Exception(f"Error loading or transforming image {image_path}: {e}") | |
| # Generate the caption using the model's integrated method | |
| generated_caption = _model.generate_caption( | |
| image_tensor, | |
| _vocabulary, | |
| _device, | |
| beam_size=INFERENCE_CONFIG.get('beam_size', 5), | |
| max_length=INFERENCE_CONFIG.get('max_caption_length', 20) | |
| ) | |
| logger.info(f"Generated caption: {generated_caption}") | |
| return generated_caption | |