File size: 10,147 Bytes
d1799c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# app.py
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models
from PIL import Image
import os
import nltk
import argparse
from collections import Counter # Needed for Vocabulary unpickling
from torch.serialization import safe_globals # For secure loading
import gradio as gr # Import Gradio

# --- 1. Define Classes EXACTLY as during training ---
# Paste the final versions of Vocabulary, EncoderCNN, DecoderRNN here.
# This is CRUCIAL for loading the model correctly.

class Vocabulary:
    # --- Paste your final Vocabulary class definition here ---
    def __init__(self, freq_threshold=5):
        self.freq_threshold = freq_threshold
        self.word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
        self.idx2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.idx = 4
    def build_vocabulary(self, sentence_list): # Needs to be present for unpickling
        frequencies = Counter()
        for sentence in sentence_list: tokens = nltk.tokenize.word_tokenize(sentence.lower()); frequencies.update(tokens)
        filtered_freq = {word: freq for word, freq in frequencies.items() if freq >= self.freq_threshold}
        for word in filtered_freq:
            if word not in self.word2idx: self.word2idx[word] = self.idx; self.idx2word[self.idx] = word; self.idx += 1
    def numericalize(self, text):
        tokens = nltk.tokenize.word_tokenize(text.lower())
        return [self.word2idx.get(token, self.word2idx["<unk>"]) for token in tokens]
    def __len__(self): return self.idx

class EncoderCNN(nn.Module):
    # --- Paste your final EncoderCNN class definition here ---
    def __init__(self, embed_size, dropout_p=0.5, fine_tune=True):
        super(EncoderCNN, self).__init__()
        try: # Handle potential torchvision version differences
            resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
        except TypeError:
             resnet = models.resnet101(pretrained=True)
        for param in resnet.parameters(): param.requires_grad = False
        # Fine-tune status doesn't matter for eval, but architecture must match
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        self.dropout = nn.Dropout(dropout_p)
    def forward(self, images):
        with torch.no_grad(): features = self.resnet(images)
        features = features.squeeze(3).squeeze(2)
        features = self.fc(features)
        features = self.bn(features)
        return features

class DecoderRNN(nn.Module):
    # --- Paste your final DecoderRNN class definition here ---
    # --- including forward_step and init_hidden_state ---
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout_p=0.5):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.embed_dropout = nn.Dropout(dropout_p)
        lstm_dropout = dropout_p if num_layers > 1 else 0
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
        self.dropout = nn.Dropout(dropout_p)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)
        self.num_layers = num_layers
    def init_hidden_state(self, features):
        h0 = self.init_h(features).unsqueeze(0)
        c0 = self.init_c(features).unsqueeze(0)
        if self.num_layers > 1:
             h0 = h0.repeat(self.num_layers, 1, 1)
             c0 = c0.repeat(self.num_layers, 1, 1)
        return (h0, c0)
    def forward_step(self, embedded_input, hidden_state):
        lstm_out, hidden_state = self.lstm(embedded_input, hidden_state)
        outputs = self.linear(lstm_out.squeeze(1))
        return outputs, hidden_state
# --- End Class Definitions ---


# --- Configuration ---
CHECKPOINT_PATH = 'best_model_improved.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CPU for typical Spaces hardware
MAX_LEN = 25

# --- Global variables for loaded model (load ONCE) ---
encoder_global = None
decoder_global = None
vocab_global = None
transform_global = None

# --- Model Loading Function ---
def load_model_and_vocab():
    global encoder_global, decoder_global, vocab_global, transform_global
    if encoder_global is not None: # Already loaded
        print("Model already loaded.")
        return

    print(f"Loading checkpoint: {CHECKPOINT_PATH} onto device: {DEVICE}")
    if not os.path.exists(CHECKPOINT_PATH):
        raise FileNotFoundError(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")

    try:
        with safe_globals([Vocabulary, Counter]): # Allowlist custom classes
             checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    except Exception as e:
        print(f"Error loading checkpoint with safe_globals: {e}. Trying weights_only=False...")
        try:
            checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
        except Exception as e2:
             raise RuntimeError(f"Failed to load checkpoint: {e2}")


    # Load vocabulary and hyperparameters
    vocab_global = checkpoint['vocab']
    embed_size = checkpoint.get('embed_size', 256)
    hidden_size = checkpoint.get('hidden_size', 512)
    num_layers = checkpoint.get('num_layers', 1)
    dropout_prob = checkpoint.get('dropout_prob', 0.5)
    fine_tune_encoder = checkpoint.get('fine_tune_encoder', True) # Match saved config
    vocab_size = len(vocab_global)
    print(f"Vocabulary loaded (size: {vocab_size}). Hyperparameters extracted.")

    # Initialize models
    encoder_global = EncoderCNN(embed_size, dropout_p=dropout_prob, fine_tune=fine_tune_encoder).to(DEVICE)
    decoder_global = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout_p=dropout_prob).to(DEVICE)

    encoder_global.load_state_dict(checkpoint['encoder_state_dict'])
    decoder_global.load_state_dict(checkpoint['decoder_state_dict'])

    # Set to evaluation mode
    encoder_global.eval()
    decoder_global.eval()
    print("Models initialized, weights loaded, and set to eval mode.")

    # Define image transformation (same as validation/inference)
    transform_global = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    print("Transforms defined.")

# --- Helper: Tokens to Sentence ---
def tokens_to_sentence(tokens, vocab):
    words = [vocab.idx2word.get(token, "<unk>") for token in tokens]
    words = [word for word in words if word not in ["<start>", "<end>", "<pad>"]]
    return " ".join(words)

# --- Inference Function for Gradio ---
def predict(input_image):
    """Generates caption for a PIL image input from Gradio."""
    if encoder_global is None or decoder_global is None or vocab_global is None or transform_global is None:
         print("Error: Model not loaded.")
         # Optionally try loading here, but it's better to load upfront
         # load_model_and_vocab()
         # if encoder_global is None: # Check again
         return "Error: Model components not loaded. Check logs."

    # 1. Preprocess Image
    try:
        image_tensor = transform_global(input_image)
        image_tensor = image_tensor.unsqueeze(0).to(DEVICE) # Add batch dim
    except Exception as e:
        print(f"Error transforming image: {e}")
        return f"Error processing image: {e}"

    # 2. Generate Caption (Greedy Search)
    generated_indices = []
    with torch.no_grad():
        try:
            features = encoder_global(image_tensor)
            hidden_state = decoder_global.init_hidden_state(features)
            start_token_idx = vocab_global.word2idx["<start>"]
            inputs = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE)

            for _ in range(MAX_LEN):
                embedded = decoder_global.embed(inputs)
                outputs, hidden_state = decoder_global.forward_step(embedded, hidden_state)
                predicted_idx = outputs.argmax(1)
                predicted_word_idx = predicted_idx.item()

                if predicted_word_idx == vocab_global.word2idx["<end>"]:
                    break # Stop if <end> is predicted

                generated_indices.append(predicted_word_idx)
                inputs = predicted_idx.unsqueeze(1) # Prepare for next step

        except Exception as e:
             print(f"Error during caption generation: {e}")
             return f"Error during generation: {e}"

    # 3. Convert to Sentence
    caption = tokens_to_sentence(generated_indices, vocab_global)
    return caption

# --- Load Model when script starts ---
# Ensure NLTK data is available if needed by tokenizer within Vocab class
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("NLTK 'punkt' tokenizer data not found. Downloading...")
    nltk.download('punkt', quiet=True)

load_model_and_vocab() # Load model into global variables

# --- Create Gradio Interface ---
title = "Image Captioning Demo"
description = "Upload an image and this model (ResNet101 Encoder + LSTM Decoder) will generate a caption. Trained on COCO."

# Optional: Define example images (paths relative to the app.py file)
example_list = [["images/example1.jpg"], ["images/example2.jpg"]] if os.path.exists("images") else None


iface = gr.Interface(
    fn=predict,                 # The function to call for inference
    inputs=gr.Image(type="pil", label="Upload Image"), # Input: Image upload, provide PIL image to fn
    outputs=gr.Textbox(label="Generated Caption"),    # Output: Textbox
    title=title,
    description=description,
    examples=example_list,      # Optional: Provide examples
    allow_flagging="never"      # Optional: Disable flagging
)

# --- Launch the Gradio app ---
if __name__ == "__main__":
    iface.launch() # Share=True is not needed for Spaces, it's handled automatically