File size: 7,057 Bytes
ee1d4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
from PIL import Image
import sys # Import sys for flushing stdout

# Import modules from your project structure
# Ensure utils is imported early to set up logging
from .utils import get_logger, get_eval_transform, visualize_attention
from .model import ImageCaptioningModel
from .data_preprocessing import COCOVocabulary
from .config import INFERENCE_CONFIG, update_config_with_latest_model # config is imported here
from .evaluation import calculate_bleu_scores_detailed # evaluation is imported here

# Get the module-specific logger. This logger will inherit from the root logger
# which is configured when `utils` is imported.
logger = get_logger(__name__)

def run_inference_example(model_path, image_path, config=None):
    """
    Function to run inference on a single image and generate a caption.

    Args:
        model_path (str): Path to the saved model checkpoint (.pth file).
        image_path (str): Path to the image file for captioning.
        config (dict, optional): Configuration dictionary for inference parameters
                                 (e.g., beam_size, max_caption_length).
    Returns:
        str: The generated caption for the image.
    Raises:
        FileNotFoundError: If the model checkpoint or image file is not found.
        Exception: For other unexpected errors during inference.
    """
    logger.info("Loading model for inference...")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model checkpoint not found at {model_path}. "
                                "Please train the model first or provide a valid path.")

    # Load the complete checkpoint (model state, optimizer state, vocabulary, config)
    # map_location='cpu' ensures it loads to CPU even if trained on GPU
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)

    # Extract configuration and vocabulary from the checkpoint
    model_config_from_checkpoint = checkpoint.get('config', {})
    vocabulary = checkpoint['vocabulary']

    # Initialize the model structure with parameters saved in the checkpoint
    # Ensure dropout is set to 0.0 for inference and fine_tune_encoder is False
    model = ImageCaptioningModel(
        vocab_size=vocabulary.vocab_size,
        embed_dim=model_config_from_checkpoint.get('embed_dim', 256),
        attention_dim=model_config_from_checkpoint.get('attention_dim', 256),
        decoder_dim=model_config_from_checkpoint.get('decoder_dim', 256),
        dropout=0.0, # Dropout should be off during inference
        fine_tune_encoder=False, # Encoder should not be fine-tuned during inference
        max_caption_length=config.get('max_caption_length', 20) if config else 20 # Use config's max length for inference
    )
    # Load the trained weights into the model
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval() # Set the model to evaluation mode (important for batch norm, dropout)

    # Determine the device to run inference on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device) # Move model to GPU if available
    logger.info(f"Model loaded successfully on device: {device}")

    # Get the image transformation pipeline for evaluation/inference
    transform = get_eval_transform()

    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at {image_path}. Please check the image path.")

    # Load and preprocess the image
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).to(device)
    except Exception as e:
        raise Exception(f"Error loading or transforming image {image_path}: {e}")

    logger.info(f"Generating caption for {image_path} using beam search (beam_size="
                f"{config.get('beam_size', 5) if config else 5})...")

    # Generate the caption using the model's integrated method
    generated_caption = model.generate_caption(
        image_tensor,
        vocabulary,
        device,
        beam_size=config.get('beam_size', 5) if config else 5,
        max_length=config.get('max_caption_length', 20) if config else 20
    )

    # Optional: Visualize attention weights
    visualize_attention(model, image_path, vocabulary, device,
                        save_path=os.path.join('output', 'attention_visualization.png'))

    # These logs are now placed AFTER the point where the logger is definitely active.
    logger.info("\n" + "="*50)
    logger.info("             GENERATED CAPTION")
    logger.info("="*50)
    logger.info(f"Image: {image_path}")
    logger.info(f"Caption: {generated_caption}")
    logger.info("="*50 + "\n")
    sys.stdout.flush() # Explicitly flush the standard output buffer

    return generated_caption


if __name__ == '__main__':
    # When `app.py` is run directly, it will run the inference example.
    # Update INFERENCE_CONFIG with the latest model path if available
    update_config_with_latest_model(INFERENCE_CONFIG)

    # --- User Input/Configuration for Inference ---
    # These values are now primarily controlled via INFERENCE_CONFIG in config.py
    # You can override them here if you need to test specific scenarios immediately.
    my_image_path = INFERENCE_CONFIG['example_image_path']
    my_model_path = INFERENCE_CONFIG['model_path']
    # You can also set a reference caption here if you know it for comparison
    my_reference_caption = "Two riders on horses are performing a reining maneuver on a green grassy field surrounded by trees" # Example reference, replace or leave empty

    # Use a copy of INFERENCE_CONFIG to avoid modifying the global config directly
    inference_params = INFERENCE_CONFIG.copy()

    logger.info("--- Running Inference Example ---")
    try:
        generated_caption = run_inference_example(my_model_path, my_image_path, config=inference_params)

        # You can add evaluation of this single generated caption against its reference here if desired
        if my_reference_caption:
            # calculate_bleu_scores_detailed is already imported from evaluation
            bleu_scores = calculate_bleu_scores_detailed([my_reference_caption], [generated_caption])
            logger.info("\n--- Single Image Evaluation ---")
            logger.info(f"Reference: {my_reference_caption}")
            logger.info(f"Generated: {generated_caption}")
            logger.info(f"BLEU-4 Score: {bleu_scores['BLEU-4']:.4f}")
            logger.info("-------------------------------\n")
            sys.stdout.flush() # Explicitly flush after single image evaluation too

    except FileNotFoundError as e:
        logger.error(f"Error: {e}")
        logger.error("Please ensure your model, vocabulary, and image paths are correct "
                     "and data is downloaded as per README.md.")
        sys.stdout.flush() # Flush errors too
    except Exception as e:
        logger.critical(f"An unexpected error occurred during inference: {e}", exc_info=True)
        sys.stdout.flush() # Flush critical errors too