Spaces:
Build error
Build error
| import os | |
| import torch | |
| from PIL import Image | |
| import sys # Import sys for flushing stdout | |
| # Import modules from your project structure | |
| # Ensure utils is imported early to set up logging | |
| from .utils import get_logger, get_eval_transform, visualize_attention | |
| from .model import ImageCaptioningModel | |
| from .data_preprocessing import COCOVocabulary | |
| from .config import INFERENCE_CONFIG, update_config_with_latest_model # config is imported here | |
| from .evaluation import calculate_bleu_scores_detailed # evaluation is imported here | |
| # Get the module-specific logger. This logger will inherit from the root logger | |
| # which is configured when `utils` is imported. | |
| logger = get_logger(__name__) | |
| def run_inference_example(model_path, image_path, config=None): | |
| """ | |
| Function to run inference on a single image and generate a caption. | |
| Args: | |
| model_path (str): Path to the saved model checkpoint (.pth file). | |
| image_path (str): Path to the image file for captioning. | |
| config (dict, optional): Configuration dictionary for inference parameters | |
| (e.g., beam_size, max_caption_length). | |
| Returns: | |
| str: The generated caption for the image. | |
| Raises: | |
| FileNotFoundError: If the model checkpoint or image file is not found. | |
| Exception: For other unexpected errors during inference. | |
| """ | |
| logger.info("Loading model for inference...") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model checkpoint not found at {model_path}. " | |
| "Please train the model first or provide a valid path.") | |
| # Load the complete checkpoint (model state, optimizer state, vocabulary, config) | |
| # map_location='cpu' ensures it loads to CPU even if trained on GPU | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| # Extract configuration and vocabulary from the checkpoint | |
| model_config_from_checkpoint = checkpoint.get('config', {}) | |
| vocabulary = checkpoint['vocabulary'] | |
| # Initialize the model structure with parameters saved in the checkpoint | |
| # Ensure dropout is set to 0.0 for inference and fine_tune_encoder is False | |
| model = ImageCaptioningModel( | |
| vocab_size=vocabulary.vocab_size, | |
| embed_dim=model_config_from_checkpoint.get('embed_dim', 256), | |
| attention_dim=model_config_from_checkpoint.get('attention_dim', 256), | |
| decoder_dim=model_config_from_checkpoint.get('decoder_dim', 256), | |
| dropout=0.0, # Dropout should be off during inference | |
| fine_tune_encoder=False, # Encoder should not be fine-tuned during inference | |
| max_caption_length=config.get('max_caption_length', 20) if config else 20 # Use config's max length for inference | |
| ) | |
| # Load the trained weights into the model | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() # Set the model to evaluation mode (important for batch norm, dropout) | |
| # Determine the device to run inference on | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) # Move model to GPU if available | |
| logger.info(f"Model loaded successfully on device: {device}") | |
| # Get the image transformation pipeline for evaluation/inference | |
| transform = get_eval_transform() | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found at {image_path}. Please check the image path.") | |
| # Load and preprocess the image | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| image_tensor = transform(image).to(device) | |
| except Exception as e: | |
| raise Exception(f"Error loading or transforming image {image_path}: {e}") | |
| logger.info(f"Generating caption for {image_path} using beam search (beam_size=" | |
| f"{config.get('beam_size', 5) if config else 5})...") | |
| # Generate the caption using the model's integrated method | |
| generated_caption = model.generate_caption( | |
| image_tensor, | |
| vocabulary, | |
| device, | |
| beam_size=config.get('beam_size', 5) if config else 5, | |
| max_length=config.get('max_caption_length', 20) if config else 20 | |
| ) | |
| # Optional: Visualize attention weights | |
| visualize_attention(model, image_path, vocabulary, device, | |
| save_path=os.path.join('output', 'attention_visualization.png')) | |
| # These logs are now placed AFTER the point where the logger is definitely active. | |
| logger.info("\n" + "="*50) | |
| logger.info(" GENERATED CAPTION") | |
| logger.info("="*50) | |
| logger.info(f"Image: {image_path}") | |
| logger.info(f"Caption: {generated_caption}") | |
| logger.info("="*50 + "\n") | |
| sys.stdout.flush() # Explicitly flush the standard output buffer | |
| return generated_caption | |
| if __name__ == '__main__': | |
| # When `app.py` is run directly, it will run the inference example. | |
| # Update INFERENCE_CONFIG with the latest model path if available | |
| update_config_with_latest_model(INFERENCE_CONFIG) | |
| # --- User Input/Configuration for Inference --- | |
| # These values are now primarily controlled via INFERENCE_CONFIG in config.py | |
| # You can override them here if you need to test specific scenarios immediately. | |
| my_image_path = INFERENCE_CONFIG['example_image_path'] | |
| my_model_path = INFERENCE_CONFIG['model_path'] | |
| # You can also set a reference caption here if you know it for comparison | |
| my_reference_caption = "Two riders on horses are performing a reining maneuver on a green grassy field surrounded by trees" # Example reference, replace or leave empty | |
| # Use a copy of INFERENCE_CONFIG to avoid modifying the global config directly | |
| inference_params = INFERENCE_CONFIG.copy() | |
| logger.info("--- Running Inference Example ---") | |
| try: | |
| generated_caption = run_inference_example(my_model_path, my_image_path, config=inference_params) | |
| # You can add evaluation of this single generated caption against its reference here if desired | |
| if my_reference_caption: | |
| # calculate_bleu_scores_detailed is already imported from evaluation | |
| bleu_scores = calculate_bleu_scores_detailed([my_reference_caption], [generated_caption]) | |
| logger.info("\n--- Single Image Evaluation ---") | |
| logger.info(f"Reference: {my_reference_caption}") | |
| logger.info(f"Generated: {generated_caption}") | |
| logger.info(f"BLEU-4 Score: {bleu_scores['BLEU-4']:.4f}") | |
| logger.info("-------------------------------\n") | |
| sys.stdout.flush() # Explicitly flush after single image evaluation too | |
| except FileNotFoundError as e: | |
| logger.error(f"Error: {e}") | |
| logger.error("Please ensure your model, vocabulary, and image paths are correct " | |
| "and data is downloaded as per README.md.") | |
| sys.stdout.flush() # Flush errors too | |
| except Exception as e: | |
| logger.critical(f"An unexpected error occurred during inference: {e}", exc_info=True) | |
| sys.stdout.flush() # Flush critical errors too | |