Kush26's picture
Create app.py
d1799c9 verified
# app.py
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models
from PIL import Image
import os
import nltk
import argparse
from collections import Counter # Needed for Vocabulary unpickling
from torch.serialization import safe_globals # For secure loading
import gradio as gr # Import Gradio
# --- 1. Define Classes EXACTLY as during training ---
# Paste the final versions of Vocabulary, EncoderCNN, DecoderRNN here.
# This is CRUCIAL for loading the model correctly.
class Vocabulary:
# --- Paste your final Vocabulary class definition here ---
def __init__(self, freq_threshold=5):
self.freq_threshold = freq_threshold
self.word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
self.idx2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
self.idx = 4
def build_vocabulary(self, sentence_list): # Needs to be present for unpickling
frequencies = Counter()
for sentence in sentence_list: tokens = nltk.tokenize.word_tokenize(sentence.lower()); frequencies.update(tokens)
filtered_freq = {word: freq for word, freq in frequencies.items() if freq >= self.freq_threshold}
for word in filtered_freq:
if word not in self.word2idx: self.word2idx[word] = self.idx; self.idx2word[self.idx] = word; self.idx += 1
def numericalize(self, text):
tokens = nltk.tokenize.word_tokenize(text.lower())
return [self.word2idx.get(token, self.word2idx["<unk>"]) for token in tokens]
def __len__(self): return self.idx
class EncoderCNN(nn.Module):
# --- Paste your final EncoderCNN class definition here ---
def __init__(self, embed_size, dropout_p=0.5, fine_tune=True):
super(EncoderCNN, self).__init__()
try: # Handle potential torchvision version differences
resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
except TypeError:
resnet = models.resnet101(pretrained=True)
for param in resnet.parameters(): param.requires_grad = False
# Fine-tune status doesn't matter for eval, but architecture must match
self.resnet = nn.Sequential(*list(resnet.children())[:-1])
self.fc = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
self.dropout = nn.Dropout(dropout_p)
def forward(self, images):
with torch.no_grad(): features = self.resnet(images)
features = features.squeeze(3).squeeze(2)
features = self.fc(features)
features = self.bn(features)
return features
class DecoderRNN(nn.Module):
# --- Paste your final DecoderRNN class definition here ---
# --- including forward_step and init_hidden_state ---
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout_p=0.5):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.embed_dropout = nn.Dropout(dropout_p)
lstm_dropout = dropout_p if num_layers > 1 else 0
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
self.dropout = nn.Dropout(dropout_p)
self.linear = nn.Linear(hidden_size, vocab_size)
self.init_h = nn.Linear(embed_size, hidden_size)
self.init_c = nn.Linear(embed_size, hidden_size)
self.num_layers = num_layers
def init_hidden_state(self, features):
h0 = self.init_h(features).unsqueeze(0)
c0 = self.init_c(features).unsqueeze(0)
if self.num_layers > 1:
h0 = h0.repeat(self.num_layers, 1, 1)
c0 = c0.repeat(self.num_layers, 1, 1)
return (h0, c0)
def forward_step(self, embedded_input, hidden_state):
lstm_out, hidden_state = self.lstm(embedded_input, hidden_state)
outputs = self.linear(lstm_out.squeeze(1))
return outputs, hidden_state
# --- End Class Definitions ---
# --- Configuration ---
CHECKPOINT_PATH = 'best_model_improved.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CPU for typical Spaces hardware
MAX_LEN = 25
# --- Global variables for loaded model (load ONCE) ---
encoder_global = None
decoder_global = None
vocab_global = None
transform_global = None
# --- Model Loading Function ---
def load_model_and_vocab():
global encoder_global, decoder_global, vocab_global, transform_global
if encoder_global is not None: # Already loaded
print("Model already loaded.")
return
print(f"Loading checkpoint: {CHECKPOINT_PATH} onto device: {DEVICE}")
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")
try:
with safe_globals([Vocabulary, Counter]): # Allowlist custom classes
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
except Exception as e:
print(f"Error loading checkpoint with safe_globals: {e}. Trying weights_only=False...")
try:
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
except Exception as e2:
raise RuntimeError(f"Failed to load checkpoint: {e2}")
# Load vocabulary and hyperparameters
vocab_global = checkpoint['vocab']
embed_size = checkpoint.get('embed_size', 256)
hidden_size = checkpoint.get('hidden_size', 512)
num_layers = checkpoint.get('num_layers', 1)
dropout_prob = checkpoint.get('dropout_prob', 0.5)
fine_tune_encoder = checkpoint.get('fine_tune_encoder', True) # Match saved config
vocab_size = len(vocab_global)
print(f"Vocabulary loaded (size: {vocab_size}). Hyperparameters extracted.")
# Initialize models
encoder_global = EncoderCNN(embed_size, dropout_p=dropout_prob, fine_tune=fine_tune_encoder).to(DEVICE)
decoder_global = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout_p=dropout_prob).to(DEVICE)
encoder_global.load_state_dict(checkpoint['encoder_state_dict'])
decoder_global.load_state_dict(checkpoint['decoder_state_dict'])
# Set to evaluation mode
encoder_global.eval()
decoder_global.eval()
print("Models initialized, weights loaded, and set to eval mode.")
# Define image transformation (same as validation/inference)
transform_global = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("Transforms defined.")
# --- Helper: Tokens to Sentence ---
def tokens_to_sentence(tokens, vocab):
words = [vocab.idx2word.get(token, "<unk>") for token in tokens]
words = [word for word in words if word not in ["<start>", "<end>", "<pad>"]]
return " ".join(words)
# --- Inference Function for Gradio ---
def predict(input_image):
"""Generates caption for a PIL image input from Gradio."""
if encoder_global is None or decoder_global is None or vocab_global is None or transform_global is None:
print("Error: Model not loaded.")
# Optionally try loading here, but it's better to load upfront
# load_model_and_vocab()
# if encoder_global is None: # Check again
return "Error: Model components not loaded. Check logs."
# 1. Preprocess Image
try:
image_tensor = transform_global(input_image)
image_tensor = image_tensor.unsqueeze(0).to(DEVICE) # Add batch dim
except Exception as e:
print(f"Error transforming image: {e}")
return f"Error processing image: {e}"
# 2. Generate Caption (Greedy Search)
generated_indices = []
with torch.no_grad():
try:
features = encoder_global(image_tensor)
hidden_state = decoder_global.init_hidden_state(features)
start_token_idx = vocab_global.word2idx["<start>"]
inputs = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE)
for _ in range(MAX_LEN):
embedded = decoder_global.embed(inputs)
outputs, hidden_state = decoder_global.forward_step(embedded, hidden_state)
predicted_idx = outputs.argmax(1)
predicted_word_idx = predicted_idx.item()
if predicted_word_idx == vocab_global.word2idx["<end>"]:
break # Stop if <end> is predicted
generated_indices.append(predicted_word_idx)
inputs = predicted_idx.unsqueeze(1) # Prepare for next step
except Exception as e:
print(f"Error during caption generation: {e}")
return f"Error during generation: {e}"
# 3. Convert to Sentence
caption = tokens_to_sentence(generated_indices, vocab_global)
return caption
# --- Load Model when script starts ---
# Ensure NLTK data is available if needed by tokenizer within Vocab class
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
print("NLTK 'punkt' tokenizer data not found. Downloading...")
nltk.download('punkt', quiet=True)
load_model_and_vocab() # Load model into global variables
# --- Create Gradio Interface ---
title = "Image Captioning Demo"
description = "Upload an image and this model (ResNet101 Encoder + LSTM Decoder) will generate a caption. Trained on COCO."
# Optional: Define example images (paths relative to the app.py file)
example_list = [["images/example1.jpg"], ["images/example2.jpg"]] if os.path.exists("images") else None
iface = gr.Interface(
fn=predict, # The function to call for inference
inputs=gr.Image(type="pil", label="Upload Image"), # Input: Image upload, provide PIL image to fn
outputs=gr.Textbox(label="Generated Caption"), # Output: Textbox
title=title,
description=description,
examples=example_list, # Optional: Provide examples
allow_flagging="never" # Optional: Disable flagging
)
# --- Launch the Gradio app ---
if __name__ == "__main__":
iface.launch() # Share=True is not needed for Spaces, it's handled automatically