File size: 5,159 Bytes
ee1d4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
from PIL import Image

# Import core model components and utilities using ABSOLUTE IMPORTS
# Since 'src' is added to sys.path, we refer to modules directly under 'src'.
from src.model import ImageCaptioningModel
from src.data_preprocessing import COCOVocabulary
from src.utils import get_logger, get_eval_transform
from src.config import INFERENCE_CONFIG, update_config_with_latest_model # Import global config

logger = get_logger(__name__)

# --- Global variables to store the loaded model and vocabulary ---
# These will be loaded once when this module is first imported.
_model = None
_vocabulary = None
_device = None
_transform = None

def _load_model_and_vocabulary():
    """
    Loads the image captioning model and vocabulary.
    This function should be called only once during application startup.
    """
    global _model, _vocabulary, _device, _transform

    if _model is not None:
        logger.info("Model and vocabulary already loaded.")
        return

    logger.info("Initializing model and vocabulary for web inference...")

    # Update INFERENCE_CONFIG with the path to the latest best model
    # This ensures the web app uses the correct trained model.
    update_config_with_latest_model(INFERENCE_CONFIG)
    model_path = INFERENCE_CONFIG['model_path']
    example_image_path = INFERENCE_CONFIG['example_image_path'] # Not directly used for inference, but useful for context

    if not os.path.exists(model_path):
        logger.error(f"Model checkpoint not found at {model_path}. "
                     "Please ensure the model is trained and saved.")
        raise FileNotFoundError(f"Model checkpoint not found: {model_path}")

    try:
        # Load the complete checkpoint (model state, vocabulary, config)
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)

        # Extract configuration and vocabulary from the checkpoint
        model_config_from_checkpoint = checkpoint.get('config', {})
        _vocabulary = checkpoint['vocabulary']

        # Initialize the model structure with parameters saved in the checkpoint
        _model = ImageCaptioningModel(
            vocab_size=_vocabulary.vocab_size,
            embed_dim=model_config_from_checkpoint.get('embed_dim', 256),
            attention_dim=model_config_from_checkpoint.get('attention_dim', 256),
            decoder_dim=model_config_from_checkpoint.get('decoder_dim', 256),
            dropout=0.0, # Dropout should be off during inference
            fine_tune_encoder=False, # Encoder should not be fine-tuned during inference
            max_caption_length=INFERENCE_CONFIG.get('max_caption_length', 20)
        )
        # Load the trained weights into the model
        _model.load_state_dict(checkpoint['model_state_dict'])
        _model.eval() # Set the model to evaluation mode (important for batch norm, dropout)

        # Determine the device to run inference on
        _device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        _model = _model.to(_device) # Move model to GPU if available
        logger.info(f"Model loaded successfully on device: {_device}")

        # Get the image transformation pipeline for evaluation/inference
        _transform = get_eval_transform()

        logger.info("Model and vocabulary are ready for inference.")

    except Exception as e:
        logger.critical(f"Failed to load model or vocabulary: {e}", exc_info=True)
        # Reraise the exception to prevent the Flask app from starting without the model
        raise

# Call the loading function immediately when this module is imported
# This ensures the model is loaded only once when the Flask app starts
_load_model_and_vocabulary()


def generate_caption_for_image(image_path: str) -> str:
    """
    Generates a caption for a given image path using the pre-loaded model.

    Args:
        image_path (str): The full path to the image file.

    Returns:
        str: The generated caption.
    Raises:
        FileNotFoundError: If the image file does not exist.
        Exception: For errors during image loading or caption generation.
    """
    if _model is None or _vocabulary is None or _transform is None or _device is None:
        logger.error("Model or vocabulary not loaded. Cannot generate caption.")
        raise RuntimeError("Image captioning model is not initialized.")

    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at {image_path}.")

    logger.info(f"Processing image: {image_path}")

    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = _transform(image).to(_device)
    except Exception as e:
        raise Exception(f"Error loading or transforming image {image_path}: {e}")

    # Generate the caption using the model's integrated method
    generated_caption = _model.generate_caption(
        image_tensor,
        _vocabulary,
        _device,
        beam_size=INFERENCE_CONFIG.get('beam_size', 5),
        max_length=INFERENCE_CONFIG.get('max_caption_length', 20)
    )
    logger.info(f"Generated caption: {generated_caption}")
    return generated_caption