Varsha Dewangan
Initial clean commit for project deployment
ee1d4aa
import os
import torch
from PIL import Image
import sys # Import sys for flushing stdout
# Import modules from your project structure
# Ensure utils is imported early to set up logging
from .utils import get_logger, get_eval_transform, visualize_attention
from .model import ImageCaptioningModel
from .data_preprocessing import COCOVocabulary
from .config import INFERENCE_CONFIG, update_config_with_latest_model # config is imported here
from .evaluation import calculate_bleu_scores_detailed # evaluation is imported here
# Get the module-specific logger. This logger will inherit from the root logger
# which is configured when `utils` is imported.
logger = get_logger(__name__)
def run_inference_example(model_path, image_path, config=None):
"""
Function to run inference on a single image and generate a caption.
Args:
model_path (str): Path to the saved model checkpoint (.pth file).
image_path (str): Path to the image file for captioning.
config (dict, optional): Configuration dictionary for inference parameters
(e.g., beam_size, max_caption_length).
Returns:
str: The generated caption for the image.
Raises:
FileNotFoundError: If the model checkpoint or image file is not found.
Exception: For other unexpected errors during inference.
"""
logger.info("Loading model for inference...")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model checkpoint not found at {model_path}. "
"Please train the model first or provide a valid path.")
# Load the complete checkpoint (model state, optimizer state, vocabulary, config)
# map_location='cpu' ensures it loads to CPU even if trained on GPU
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
# Extract configuration and vocabulary from the checkpoint
model_config_from_checkpoint = checkpoint.get('config', {})
vocabulary = checkpoint['vocabulary']
# Initialize the model structure with parameters saved in the checkpoint
# Ensure dropout is set to 0.0 for inference and fine_tune_encoder is False
model = ImageCaptioningModel(
vocab_size=vocabulary.vocab_size,
embed_dim=model_config_from_checkpoint.get('embed_dim', 256),
attention_dim=model_config_from_checkpoint.get('attention_dim', 256),
decoder_dim=model_config_from_checkpoint.get('decoder_dim', 256),
dropout=0.0, # Dropout should be off during inference
fine_tune_encoder=False, # Encoder should not be fine-tuned during inference
max_caption_length=config.get('max_caption_length', 20) if config else 20 # Use config's max length for inference
)
# Load the trained weights into the model
model.load_state_dict(checkpoint['model_state_dict'])
model.eval() # Set the model to evaluation mode (important for batch norm, dropout)
# Determine the device to run inference on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # Move model to GPU if available
logger.info(f"Model loaded successfully on device: {device}")
# Get the image transformation pipeline for evaluation/inference
transform = get_eval_transform()
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}. Please check the image path.")
# Load and preprocess the image
try:
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).to(device)
except Exception as e:
raise Exception(f"Error loading or transforming image {image_path}: {e}")
logger.info(f"Generating caption for {image_path} using beam search (beam_size="
f"{config.get('beam_size', 5) if config else 5})...")
# Generate the caption using the model's integrated method
generated_caption = model.generate_caption(
image_tensor,
vocabulary,
device,
beam_size=config.get('beam_size', 5) if config else 5,
max_length=config.get('max_caption_length', 20) if config else 20
)
# Optional: Visualize attention weights
visualize_attention(model, image_path, vocabulary, device,
save_path=os.path.join('output', 'attention_visualization.png'))
# These logs are now placed AFTER the point where the logger is definitely active.
logger.info("\n" + "="*50)
logger.info(" GENERATED CAPTION")
logger.info("="*50)
logger.info(f"Image: {image_path}")
logger.info(f"Caption: {generated_caption}")
logger.info("="*50 + "\n")
sys.stdout.flush() # Explicitly flush the standard output buffer
return generated_caption
if __name__ == '__main__':
# When `app.py` is run directly, it will run the inference example.
# Update INFERENCE_CONFIG with the latest model path if available
update_config_with_latest_model(INFERENCE_CONFIG)
# --- User Input/Configuration for Inference ---
# These values are now primarily controlled via INFERENCE_CONFIG in config.py
# You can override them here if you need to test specific scenarios immediately.
my_image_path = INFERENCE_CONFIG['example_image_path']
my_model_path = INFERENCE_CONFIG['model_path']
# You can also set a reference caption here if you know it for comparison
my_reference_caption = "Two riders on horses are performing a reining maneuver on a green grassy field surrounded by trees" # Example reference, replace or leave empty
# Use a copy of INFERENCE_CONFIG to avoid modifying the global config directly
inference_params = INFERENCE_CONFIG.copy()
logger.info("--- Running Inference Example ---")
try:
generated_caption = run_inference_example(my_model_path, my_image_path, config=inference_params)
# You can add evaluation of this single generated caption against its reference here if desired
if my_reference_caption:
# calculate_bleu_scores_detailed is already imported from evaluation
bleu_scores = calculate_bleu_scores_detailed([my_reference_caption], [generated_caption])
logger.info("\n--- Single Image Evaluation ---")
logger.info(f"Reference: {my_reference_caption}")
logger.info(f"Generated: {generated_caption}")
logger.info(f"BLEU-4 Score: {bleu_scores['BLEU-4']:.4f}")
logger.info("-------------------------------\n")
sys.stdout.flush() # Explicitly flush after single image evaluation too
except FileNotFoundError as e:
logger.error(f"Error: {e}")
logger.error("Please ensure your model, vocabulary, and image paths are correct "
"and data is downloaded as per README.md.")
sys.stdout.flush() # Flush errors too
except Exception as e:
logger.critical(f"An unexpected error occurred during inference: {e}", exc_info=True)
sys.stdout.flush() # Flush critical errors too