Kush26 commited on
Commit
d1799c9
·
verified ·
1 Parent(s): 8d45ab9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn as nn
5
+ import torchvision.models as models
6
+ from PIL import Image
7
+ import os
8
+ import nltk
9
+ import argparse
10
+ from collections import Counter # Needed for Vocabulary unpickling
11
+ from torch.serialization import safe_globals # For secure loading
12
+ import gradio as gr # Import Gradio
13
+
14
+ # --- 1. Define Classes EXACTLY as during training ---
15
+ # Paste the final versions of Vocabulary, EncoderCNN, DecoderRNN here.
16
+ # This is CRUCIAL for loading the model correctly.
17
+
18
+ class Vocabulary:
19
+ # --- Paste your final Vocabulary class definition here ---
20
+ def __init__(self, freq_threshold=5):
21
+ self.freq_threshold = freq_threshold
22
+ self.word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
23
+ self.idx2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
24
+ self.idx = 4
25
+ def build_vocabulary(self, sentence_list): # Needs to be present for unpickling
26
+ frequencies = Counter()
27
+ for sentence in sentence_list: tokens = nltk.tokenize.word_tokenize(sentence.lower()); frequencies.update(tokens)
28
+ filtered_freq = {word: freq for word, freq in frequencies.items() if freq >= self.freq_threshold}
29
+ for word in filtered_freq:
30
+ if word not in self.word2idx: self.word2idx[word] = self.idx; self.idx2word[self.idx] = word; self.idx += 1
31
+ def numericalize(self, text):
32
+ tokens = nltk.tokenize.word_tokenize(text.lower())
33
+ return [self.word2idx.get(token, self.word2idx["<unk>"]) for token in tokens]
34
+ def __len__(self): return self.idx
35
+
36
+ class EncoderCNN(nn.Module):
37
+ # --- Paste your final EncoderCNN class definition here ---
38
+ def __init__(self, embed_size, dropout_p=0.5, fine_tune=True):
39
+ super(EncoderCNN, self).__init__()
40
+ try: # Handle potential torchvision version differences
41
+ resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
42
+ except TypeError:
43
+ resnet = models.resnet101(pretrained=True)
44
+ for param in resnet.parameters(): param.requires_grad = False
45
+ # Fine-tune status doesn't matter for eval, but architecture must match
46
+ self.resnet = nn.Sequential(*list(resnet.children())[:-1])
47
+ self.fc = nn.Linear(resnet.fc.in_features, embed_size)
48
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
49
+ self.dropout = nn.Dropout(dropout_p)
50
+ def forward(self, images):
51
+ with torch.no_grad(): features = self.resnet(images)
52
+ features = features.squeeze(3).squeeze(2)
53
+ features = self.fc(features)
54
+ features = self.bn(features)
55
+ return features
56
+
57
+ class DecoderRNN(nn.Module):
58
+ # --- Paste your final DecoderRNN class definition here ---
59
+ # --- including forward_step and init_hidden_state ---
60
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout_p=0.5):
61
+ super().__init__()
62
+ self.embed = nn.Embedding(vocab_size, embed_size)
63
+ self.embed_dropout = nn.Dropout(dropout_p)
64
+ lstm_dropout = dropout_p if num_layers > 1 else 0
65
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
66
+ self.dropout = nn.Dropout(dropout_p)
67
+ self.linear = nn.Linear(hidden_size, vocab_size)
68
+ self.init_h = nn.Linear(embed_size, hidden_size)
69
+ self.init_c = nn.Linear(embed_size, hidden_size)
70
+ self.num_layers = num_layers
71
+ def init_hidden_state(self, features):
72
+ h0 = self.init_h(features).unsqueeze(0)
73
+ c0 = self.init_c(features).unsqueeze(0)
74
+ if self.num_layers > 1:
75
+ h0 = h0.repeat(self.num_layers, 1, 1)
76
+ c0 = c0.repeat(self.num_layers, 1, 1)
77
+ return (h0, c0)
78
+ def forward_step(self, embedded_input, hidden_state):
79
+ lstm_out, hidden_state = self.lstm(embedded_input, hidden_state)
80
+ outputs = self.linear(lstm_out.squeeze(1))
81
+ return outputs, hidden_state
82
+ # --- End Class Definitions ---
83
+
84
+
85
+ # --- Configuration ---
86
+ CHECKPOINT_PATH = 'best_model_improved.pth'
87
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CPU for typical Spaces hardware
88
+ MAX_LEN = 25
89
+
90
+ # --- Global variables for loaded model (load ONCE) ---
91
+ encoder_global = None
92
+ decoder_global = None
93
+ vocab_global = None
94
+ transform_global = None
95
+
96
+ # --- Model Loading Function ---
97
+ def load_model_and_vocab():
98
+ global encoder_global, decoder_global, vocab_global, transform_global
99
+ if encoder_global is not None: # Already loaded
100
+ print("Model already loaded.")
101
+ return
102
+
103
+ print(f"Loading checkpoint: {CHECKPOINT_PATH} onto device: {DEVICE}")
104
+ if not os.path.exists(CHECKPOINT_PATH):
105
+ raise FileNotFoundError(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")
106
+
107
+ try:
108
+ with safe_globals([Vocabulary, Counter]): # Allowlist custom classes
109
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
110
+ except Exception as e:
111
+ print(f"Error loading checkpoint with safe_globals: {e}. Trying weights_only=False...")
112
+ try:
113
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
114
+ except Exception as e2:
115
+ raise RuntimeError(f"Failed to load checkpoint: {e2}")
116
+
117
+
118
+ # Load vocabulary and hyperparameters
119
+ vocab_global = checkpoint['vocab']
120
+ embed_size = checkpoint.get('embed_size', 256)
121
+ hidden_size = checkpoint.get('hidden_size', 512)
122
+ num_layers = checkpoint.get('num_layers', 1)
123
+ dropout_prob = checkpoint.get('dropout_prob', 0.5)
124
+ fine_tune_encoder = checkpoint.get('fine_tune_encoder', True) # Match saved config
125
+ vocab_size = len(vocab_global)
126
+ print(f"Vocabulary loaded (size: {vocab_size}). Hyperparameters extracted.")
127
+
128
+ # Initialize models
129
+ encoder_global = EncoderCNN(embed_size, dropout_p=dropout_prob, fine_tune=fine_tune_encoder).to(DEVICE)
130
+ decoder_global = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout_p=dropout_prob).to(DEVICE)
131
+
132
+ encoder_global.load_state_dict(checkpoint['encoder_state_dict'])
133
+ decoder_global.load_state_dict(checkpoint['decoder_state_dict'])
134
+
135
+ # Set to evaluation mode
136
+ encoder_global.eval()
137
+ decoder_global.eval()
138
+ print("Models initialized, weights loaded, and set to eval mode.")
139
+
140
+ # Define image transformation (same as validation/inference)
141
+ transform_global = transforms.Compose([
142
+ transforms.Resize((224, 224)),
143
+ transforms.ToTensor(),
144
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
145
+ ])
146
+ print("Transforms defined.")
147
+
148
+ # --- Helper: Tokens to Sentence ---
149
+ def tokens_to_sentence(tokens, vocab):
150
+ words = [vocab.idx2word.get(token, "<unk>") for token in tokens]
151
+ words = [word for word in words if word not in ["<start>", "<end>", "<pad>"]]
152
+ return " ".join(words)
153
+
154
+ # --- Inference Function for Gradio ---
155
+ def predict(input_image):
156
+ """Generates caption for a PIL image input from Gradio."""
157
+ if encoder_global is None or decoder_global is None or vocab_global is None or transform_global is None:
158
+ print("Error: Model not loaded.")
159
+ # Optionally try loading here, but it's better to load upfront
160
+ # load_model_and_vocab()
161
+ # if encoder_global is None: # Check again
162
+ return "Error: Model components not loaded. Check logs."
163
+
164
+ # 1. Preprocess Image
165
+ try:
166
+ image_tensor = transform_global(input_image)
167
+ image_tensor = image_tensor.unsqueeze(0).to(DEVICE) # Add batch dim
168
+ except Exception as e:
169
+ print(f"Error transforming image: {e}")
170
+ return f"Error processing image: {e}"
171
+
172
+ # 2. Generate Caption (Greedy Search)
173
+ generated_indices = []
174
+ with torch.no_grad():
175
+ try:
176
+ features = encoder_global(image_tensor)
177
+ hidden_state = decoder_global.init_hidden_state(features)
178
+ start_token_idx = vocab_global.word2idx["<start>"]
179
+ inputs = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE)
180
+
181
+ for _ in range(MAX_LEN):
182
+ embedded = decoder_global.embed(inputs)
183
+ outputs, hidden_state = decoder_global.forward_step(embedded, hidden_state)
184
+ predicted_idx = outputs.argmax(1)
185
+ predicted_word_idx = predicted_idx.item()
186
+
187
+ if predicted_word_idx == vocab_global.word2idx["<end>"]:
188
+ break # Stop if <end> is predicted
189
+
190
+ generated_indices.append(predicted_word_idx)
191
+ inputs = predicted_idx.unsqueeze(1) # Prepare for next step
192
+
193
+ except Exception as e:
194
+ print(f"Error during caption generation: {e}")
195
+ return f"Error during generation: {e}"
196
+
197
+ # 3. Convert to Sentence
198
+ caption = tokens_to_sentence(generated_indices, vocab_global)
199
+ return caption
200
+
201
+ # --- Load Model when script starts ---
202
+ # Ensure NLTK data is available if needed by tokenizer within Vocab class
203
+ try:
204
+ nltk.data.find('tokenizers/punkt')
205
+ except LookupError:
206
+ print("NLTK 'punkt' tokenizer data not found. Downloading...")
207
+ nltk.download('punkt', quiet=True)
208
+
209
+ load_model_and_vocab() # Load model into global variables
210
+
211
+ # --- Create Gradio Interface ---
212
+ title = "Image Captioning Demo"
213
+ description = "Upload an image and this model (ResNet101 Encoder + LSTM Decoder) will generate a caption. Trained on COCO."
214
+
215
+ # Optional: Define example images (paths relative to the app.py file)
216
+ example_list = [["images/example1.jpg"], ["images/example2.jpg"]] if os.path.exists("images") else None
217
+
218
+
219
+ iface = gr.Interface(
220
+ fn=predict, # The function to call for inference
221
+ inputs=gr.Image(type="pil", label="Upload Image"), # Input: Image upload, provide PIL image to fn
222
+ outputs=gr.Textbox(label="Generated Caption"), # Output: Textbox
223
+ title=title,
224
+ description=description,
225
+ examples=example_list, # Optional: Provide examples
226
+ allow_flagging="never" # Optional: Disable flagging
227
+ )
228
+
229
+ # --- Launch the Gradio app ---
230
+ if __name__ == "__main__":
231
+ iface.launch() # Share=True is not needed for Spaces, it's handled automatically