Spaces:
Build error
Build error
File size: 7,057 Bytes
ee1d4aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import torch
from PIL import Image
import sys # Import sys for flushing stdout
# Import modules from your project structure
# Ensure utils is imported early to set up logging
from .utils import get_logger, get_eval_transform, visualize_attention
from .model import ImageCaptioningModel
from .data_preprocessing import COCOVocabulary
from .config import INFERENCE_CONFIG, update_config_with_latest_model # config is imported here
from .evaluation import calculate_bleu_scores_detailed # evaluation is imported here
# Get the module-specific logger. This logger will inherit from the root logger
# which is configured when `utils` is imported.
logger = get_logger(__name__)
def run_inference_example(model_path, image_path, config=None):
"""
Function to run inference on a single image and generate a caption.
Args:
model_path (str): Path to the saved model checkpoint (.pth file).
image_path (str): Path to the image file for captioning.
config (dict, optional): Configuration dictionary for inference parameters
(e.g., beam_size, max_caption_length).
Returns:
str: The generated caption for the image.
Raises:
FileNotFoundError: If the model checkpoint or image file is not found.
Exception: For other unexpected errors during inference.
"""
logger.info("Loading model for inference...")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model checkpoint not found at {model_path}. "
"Please train the model first or provide a valid path.")
# Load the complete checkpoint (model state, optimizer state, vocabulary, config)
# map_location='cpu' ensures it loads to CPU even if trained on GPU
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
# Extract configuration and vocabulary from the checkpoint
model_config_from_checkpoint = checkpoint.get('config', {})
vocabulary = checkpoint['vocabulary']
# Initialize the model structure with parameters saved in the checkpoint
# Ensure dropout is set to 0.0 for inference and fine_tune_encoder is False
model = ImageCaptioningModel(
vocab_size=vocabulary.vocab_size,
embed_dim=model_config_from_checkpoint.get('embed_dim', 256),
attention_dim=model_config_from_checkpoint.get('attention_dim', 256),
decoder_dim=model_config_from_checkpoint.get('decoder_dim', 256),
dropout=0.0, # Dropout should be off during inference
fine_tune_encoder=False, # Encoder should not be fine-tuned during inference
max_caption_length=config.get('max_caption_length', 20) if config else 20 # Use config's max length for inference
)
# Load the trained weights into the model
model.load_state_dict(checkpoint['model_state_dict'])
model.eval() # Set the model to evaluation mode (important for batch norm, dropout)
# Determine the device to run inference on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # Move model to GPU if available
logger.info(f"Model loaded successfully on device: {device}")
# Get the image transformation pipeline for evaluation/inference
transform = get_eval_transform()
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}. Please check the image path.")
# Load and preprocess the image
try:
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).to(device)
except Exception as e:
raise Exception(f"Error loading or transforming image {image_path}: {e}")
logger.info(f"Generating caption for {image_path} using beam search (beam_size="
f"{config.get('beam_size', 5) if config else 5})...")
# Generate the caption using the model's integrated method
generated_caption = model.generate_caption(
image_tensor,
vocabulary,
device,
beam_size=config.get('beam_size', 5) if config else 5,
max_length=config.get('max_caption_length', 20) if config else 20
)
# Optional: Visualize attention weights
visualize_attention(model, image_path, vocabulary, device,
save_path=os.path.join('output', 'attention_visualization.png'))
# These logs are now placed AFTER the point where the logger is definitely active.
logger.info("\n" + "="*50)
logger.info(" GENERATED CAPTION")
logger.info("="*50)
logger.info(f"Image: {image_path}")
logger.info(f"Caption: {generated_caption}")
logger.info("="*50 + "\n")
sys.stdout.flush() # Explicitly flush the standard output buffer
return generated_caption
if __name__ == '__main__':
# When `app.py` is run directly, it will run the inference example.
# Update INFERENCE_CONFIG with the latest model path if available
update_config_with_latest_model(INFERENCE_CONFIG)
# --- User Input/Configuration for Inference ---
# These values are now primarily controlled via INFERENCE_CONFIG in config.py
# You can override them here if you need to test specific scenarios immediately.
my_image_path = INFERENCE_CONFIG['example_image_path']
my_model_path = INFERENCE_CONFIG['model_path']
# You can also set a reference caption here if you know it for comparison
my_reference_caption = "Two riders on horses are performing a reining maneuver on a green grassy field surrounded by trees" # Example reference, replace or leave empty
# Use a copy of INFERENCE_CONFIG to avoid modifying the global config directly
inference_params = INFERENCE_CONFIG.copy()
logger.info("--- Running Inference Example ---")
try:
generated_caption = run_inference_example(my_model_path, my_image_path, config=inference_params)
# You can add evaluation of this single generated caption against its reference here if desired
if my_reference_caption:
# calculate_bleu_scores_detailed is already imported from evaluation
bleu_scores = calculate_bleu_scores_detailed([my_reference_caption], [generated_caption])
logger.info("\n--- Single Image Evaluation ---")
logger.info(f"Reference: {my_reference_caption}")
logger.info(f"Generated: {generated_caption}")
logger.info(f"BLEU-4 Score: {bleu_scores['BLEU-4']:.4f}")
logger.info("-------------------------------\n")
sys.stdout.flush() # Explicitly flush after single image evaluation too
except FileNotFoundError as e:
logger.error(f"Error: {e}")
logger.error("Please ensure your model, vocabulary, and image paths are correct "
"and data is downloaded as per README.md.")
sys.stdout.flush() # Flush errors too
except Exception as e:
logger.critical(f"An unexpected error occurred during inference: {e}", exc_info=True)
sys.stdout.flush() # Flush critical errors too
|