Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.models as models
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import os
|
| 8 |
+
import nltk
|
| 9 |
+
import argparse
|
| 10 |
+
from collections import Counter # Needed for Vocabulary unpickling
|
| 11 |
+
from torch.serialization import safe_globals # For secure loading
|
| 12 |
+
import gradio as gr # Import Gradio
|
| 13 |
+
|
| 14 |
+
# --- 1. Define Classes EXACTLY as during training ---
|
| 15 |
+
# Paste the final versions of Vocabulary, EncoderCNN, DecoderRNN here.
|
| 16 |
+
# This is CRUCIAL for loading the model correctly.
|
| 17 |
+
|
| 18 |
+
class Vocabulary:
|
| 19 |
+
# --- Paste your final Vocabulary class definition here ---
|
| 20 |
+
def __init__(self, freq_threshold=5):
|
| 21 |
+
self.freq_threshold = freq_threshold
|
| 22 |
+
self.word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
|
| 23 |
+
self.idx2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
|
| 24 |
+
self.idx = 4
|
| 25 |
+
def build_vocabulary(self, sentence_list): # Needs to be present for unpickling
|
| 26 |
+
frequencies = Counter()
|
| 27 |
+
for sentence in sentence_list: tokens = nltk.tokenize.word_tokenize(sentence.lower()); frequencies.update(tokens)
|
| 28 |
+
filtered_freq = {word: freq for word, freq in frequencies.items() if freq >= self.freq_threshold}
|
| 29 |
+
for word in filtered_freq:
|
| 30 |
+
if word not in self.word2idx: self.word2idx[word] = self.idx; self.idx2word[self.idx] = word; self.idx += 1
|
| 31 |
+
def numericalize(self, text):
|
| 32 |
+
tokens = nltk.tokenize.word_tokenize(text.lower())
|
| 33 |
+
return [self.word2idx.get(token, self.word2idx["<unk>"]) for token in tokens]
|
| 34 |
+
def __len__(self): return self.idx
|
| 35 |
+
|
| 36 |
+
class EncoderCNN(nn.Module):
|
| 37 |
+
# --- Paste your final EncoderCNN class definition here ---
|
| 38 |
+
def __init__(self, embed_size, dropout_p=0.5, fine_tune=True):
|
| 39 |
+
super(EncoderCNN, self).__init__()
|
| 40 |
+
try: # Handle potential torchvision version differences
|
| 41 |
+
resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
|
| 42 |
+
except TypeError:
|
| 43 |
+
resnet = models.resnet101(pretrained=True)
|
| 44 |
+
for param in resnet.parameters(): param.requires_grad = False
|
| 45 |
+
# Fine-tune status doesn't matter for eval, but architecture must match
|
| 46 |
+
self.resnet = nn.Sequential(*list(resnet.children())[:-1])
|
| 47 |
+
self.fc = nn.Linear(resnet.fc.in_features, embed_size)
|
| 48 |
+
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
|
| 49 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 50 |
+
def forward(self, images):
|
| 51 |
+
with torch.no_grad(): features = self.resnet(images)
|
| 52 |
+
features = features.squeeze(3).squeeze(2)
|
| 53 |
+
features = self.fc(features)
|
| 54 |
+
features = self.bn(features)
|
| 55 |
+
return features
|
| 56 |
+
|
| 57 |
+
class DecoderRNN(nn.Module):
|
| 58 |
+
# --- Paste your final DecoderRNN class definition here ---
|
| 59 |
+
# --- including forward_step and init_hidden_state ---
|
| 60 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout_p=0.5):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
| 63 |
+
self.embed_dropout = nn.Dropout(dropout_p)
|
| 64 |
+
lstm_dropout = dropout_p if num_layers > 1 else 0
|
| 65 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
|
| 66 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 67 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
| 68 |
+
self.init_h = nn.Linear(embed_size, hidden_size)
|
| 69 |
+
self.init_c = nn.Linear(embed_size, hidden_size)
|
| 70 |
+
self.num_layers = num_layers
|
| 71 |
+
def init_hidden_state(self, features):
|
| 72 |
+
h0 = self.init_h(features).unsqueeze(0)
|
| 73 |
+
c0 = self.init_c(features).unsqueeze(0)
|
| 74 |
+
if self.num_layers > 1:
|
| 75 |
+
h0 = h0.repeat(self.num_layers, 1, 1)
|
| 76 |
+
c0 = c0.repeat(self.num_layers, 1, 1)
|
| 77 |
+
return (h0, c0)
|
| 78 |
+
def forward_step(self, embedded_input, hidden_state):
|
| 79 |
+
lstm_out, hidden_state = self.lstm(embedded_input, hidden_state)
|
| 80 |
+
outputs = self.linear(lstm_out.squeeze(1))
|
| 81 |
+
return outputs, hidden_state
|
| 82 |
+
# --- End Class Definitions ---
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# --- Configuration ---
|
| 86 |
+
CHECKPOINT_PATH = 'best_model_improved.pth'
|
| 87 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CPU for typical Spaces hardware
|
| 88 |
+
MAX_LEN = 25
|
| 89 |
+
|
| 90 |
+
# --- Global variables for loaded model (load ONCE) ---
|
| 91 |
+
encoder_global = None
|
| 92 |
+
decoder_global = None
|
| 93 |
+
vocab_global = None
|
| 94 |
+
transform_global = None
|
| 95 |
+
|
| 96 |
+
# --- Model Loading Function ---
|
| 97 |
+
def load_model_and_vocab():
|
| 98 |
+
global encoder_global, decoder_global, vocab_global, transform_global
|
| 99 |
+
if encoder_global is not None: # Already loaded
|
| 100 |
+
print("Model already loaded.")
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
print(f"Loading checkpoint: {CHECKPOINT_PATH} onto device: {DEVICE}")
|
| 104 |
+
if not os.path.exists(CHECKPOINT_PATH):
|
| 105 |
+
raise FileNotFoundError(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
with safe_globals([Vocabulary, Counter]): # Allowlist custom classes
|
| 109 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error loading checkpoint with safe_globals: {e}. Trying weights_only=False...")
|
| 112 |
+
try:
|
| 113 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
|
| 114 |
+
except Exception as e2:
|
| 115 |
+
raise RuntimeError(f"Failed to load checkpoint: {e2}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Load vocabulary and hyperparameters
|
| 119 |
+
vocab_global = checkpoint['vocab']
|
| 120 |
+
embed_size = checkpoint.get('embed_size', 256)
|
| 121 |
+
hidden_size = checkpoint.get('hidden_size', 512)
|
| 122 |
+
num_layers = checkpoint.get('num_layers', 1)
|
| 123 |
+
dropout_prob = checkpoint.get('dropout_prob', 0.5)
|
| 124 |
+
fine_tune_encoder = checkpoint.get('fine_tune_encoder', True) # Match saved config
|
| 125 |
+
vocab_size = len(vocab_global)
|
| 126 |
+
print(f"Vocabulary loaded (size: {vocab_size}). Hyperparameters extracted.")
|
| 127 |
+
|
| 128 |
+
# Initialize models
|
| 129 |
+
encoder_global = EncoderCNN(embed_size, dropout_p=dropout_prob, fine_tune=fine_tune_encoder).to(DEVICE)
|
| 130 |
+
decoder_global = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout_p=dropout_prob).to(DEVICE)
|
| 131 |
+
|
| 132 |
+
encoder_global.load_state_dict(checkpoint['encoder_state_dict'])
|
| 133 |
+
decoder_global.load_state_dict(checkpoint['decoder_state_dict'])
|
| 134 |
+
|
| 135 |
+
# Set to evaluation mode
|
| 136 |
+
encoder_global.eval()
|
| 137 |
+
decoder_global.eval()
|
| 138 |
+
print("Models initialized, weights loaded, and set to eval mode.")
|
| 139 |
+
|
| 140 |
+
# Define image transformation (same as validation/inference)
|
| 141 |
+
transform_global = transforms.Compose([
|
| 142 |
+
transforms.Resize((224, 224)),
|
| 143 |
+
transforms.ToTensor(),
|
| 144 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 145 |
+
])
|
| 146 |
+
print("Transforms defined.")
|
| 147 |
+
|
| 148 |
+
# --- Helper: Tokens to Sentence ---
|
| 149 |
+
def tokens_to_sentence(tokens, vocab):
|
| 150 |
+
words = [vocab.idx2word.get(token, "<unk>") for token in tokens]
|
| 151 |
+
words = [word for word in words if word not in ["<start>", "<end>", "<pad>"]]
|
| 152 |
+
return " ".join(words)
|
| 153 |
+
|
| 154 |
+
# --- Inference Function for Gradio ---
|
| 155 |
+
def predict(input_image):
|
| 156 |
+
"""Generates caption for a PIL image input from Gradio."""
|
| 157 |
+
if encoder_global is None or decoder_global is None or vocab_global is None or transform_global is None:
|
| 158 |
+
print("Error: Model not loaded.")
|
| 159 |
+
# Optionally try loading here, but it's better to load upfront
|
| 160 |
+
# load_model_and_vocab()
|
| 161 |
+
# if encoder_global is None: # Check again
|
| 162 |
+
return "Error: Model components not loaded. Check logs."
|
| 163 |
+
|
| 164 |
+
# 1. Preprocess Image
|
| 165 |
+
try:
|
| 166 |
+
image_tensor = transform_global(input_image)
|
| 167 |
+
image_tensor = image_tensor.unsqueeze(0).to(DEVICE) # Add batch dim
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"Error transforming image: {e}")
|
| 170 |
+
return f"Error processing image: {e}"
|
| 171 |
+
|
| 172 |
+
# 2. Generate Caption (Greedy Search)
|
| 173 |
+
generated_indices = []
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
try:
|
| 176 |
+
features = encoder_global(image_tensor)
|
| 177 |
+
hidden_state = decoder_global.init_hidden_state(features)
|
| 178 |
+
start_token_idx = vocab_global.word2idx["<start>"]
|
| 179 |
+
inputs = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE)
|
| 180 |
+
|
| 181 |
+
for _ in range(MAX_LEN):
|
| 182 |
+
embedded = decoder_global.embed(inputs)
|
| 183 |
+
outputs, hidden_state = decoder_global.forward_step(embedded, hidden_state)
|
| 184 |
+
predicted_idx = outputs.argmax(1)
|
| 185 |
+
predicted_word_idx = predicted_idx.item()
|
| 186 |
+
|
| 187 |
+
if predicted_word_idx == vocab_global.word2idx["<end>"]:
|
| 188 |
+
break # Stop if <end> is predicted
|
| 189 |
+
|
| 190 |
+
generated_indices.append(predicted_word_idx)
|
| 191 |
+
inputs = predicted_idx.unsqueeze(1) # Prepare for next step
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
print(f"Error during caption generation: {e}")
|
| 195 |
+
return f"Error during generation: {e}"
|
| 196 |
+
|
| 197 |
+
# 3. Convert to Sentence
|
| 198 |
+
caption = tokens_to_sentence(generated_indices, vocab_global)
|
| 199 |
+
return caption
|
| 200 |
+
|
| 201 |
+
# --- Load Model when script starts ---
|
| 202 |
+
# Ensure NLTK data is available if needed by tokenizer within Vocab class
|
| 203 |
+
try:
|
| 204 |
+
nltk.data.find('tokenizers/punkt')
|
| 205 |
+
except LookupError:
|
| 206 |
+
print("NLTK 'punkt' tokenizer data not found. Downloading...")
|
| 207 |
+
nltk.download('punkt', quiet=True)
|
| 208 |
+
|
| 209 |
+
load_model_and_vocab() # Load model into global variables
|
| 210 |
+
|
| 211 |
+
# --- Create Gradio Interface ---
|
| 212 |
+
title = "Image Captioning Demo"
|
| 213 |
+
description = "Upload an image and this model (ResNet101 Encoder + LSTM Decoder) will generate a caption. Trained on COCO."
|
| 214 |
+
|
| 215 |
+
# Optional: Define example images (paths relative to the app.py file)
|
| 216 |
+
example_list = [["images/example1.jpg"], ["images/example2.jpg"]] if os.path.exists("images") else None
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
iface = gr.Interface(
|
| 220 |
+
fn=predict, # The function to call for inference
|
| 221 |
+
inputs=gr.Image(type="pil", label="Upload Image"), # Input: Image upload, provide PIL image to fn
|
| 222 |
+
outputs=gr.Textbox(label="Generated Caption"), # Output: Textbox
|
| 223 |
+
title=title,
|
| 224 |
+
description=description,
|
| 225 |
+
examples=example_list, # Optional: Provide examples
|
| 226 |
+
allow_flagging="never" # Optional: Disable flagging
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# --- Launch the Gradio app ---
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
iface.launch() # Share=True is not needed for Spaces, it's handled automatically
|