Spaces:
Sleeping
Sleeping
| # app.py | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from PIL import Image | |
| import os | |
| import nltk | |
| import argparse | |
| from collections import Counter # Needed for Vocabulary unpickling | |
| from torch.serialization import safe_globals # For secure loading | |
| import gradio as gr # Import Gradio | |
| # --- 1. Define Classes EXACTLY as during training --- | |
| # Paste the final versions of Vocabulary, EncoderCNN, DecoderRNN here. | |
| # This is CRUCIAL for loading the model correctly. | |
| class Vocabulary: | |
| # --- Paste your final Vocabulary class definition here --- | |
| def __init__(self, freq_threshold=5): | |
| self.freq_threshold = freq_threshold | |
| self.word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3} | |
| self.idx2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"} | |
| self.idx = 4 | |
| def build_vocabulary(self, sentence_list): # Needs to be present for unpickling | |
| frequencies = Counter() | |
| for sentence in sentence_list: tokens = nltk.tokenize.word_tokenize(sentence.lower()); frequencies.update(tokens) | |
| filtered_freq = {word: freq for word, freq in frequencies.items() if freq >= self.freq_threshold} | |
| for word in filtered_freq: | |
| if word not in self.word2idx: self.word2idx[word] = self.idx; self.idx2word[self.idx] = word; self.idx += 1 | |
| def numericalize(self, text): | |
| tokens = nltk.tokenize.word_tokenize(text.lower()) | |
| return [self.word2idx.get(token, self.word2idx["<unk>"]) for token in tokens] | |
| def __len__(self): return self.idx | |
| class EncoderCNN(nn.Module): | |
| # --- Paste your final EncoderCNN class definition here --- | |
| def __init__(self, embed_size, dropout_p=0.5, fine_tune=True): | |
| super(EncoderCNN, self).__init__() | |
| try: # Handle potential torchvision version differences | |
| resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1) | |
| except TypeError: | |
| resnet = models.resnet101(pretrained=True) | |
| for param in resnet.parameters(): param.requires_grad = False | |
| # Fine-tune status doesn't matter for eval, but architecture must match | |
| self.resnet = nn.Sequential(*list(resnet.children())[:-1]) | |
| self.fc = nn.Linear(resnet.fc.in_features, embed_size) | |
| self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) | |
| self.dropout = nn.Dropout(dropout_p) | |
| def forward(self, images): | |
| with torch.no_grad(): features = self.resnet(images) | |
| features = features.squeeze(3).squeeze(2) | |
| features = self.fc(features) | |
| features = self.bn(features) | |
| return features | |
| class DecoderRNN(nn.Module): | |
| # --- Paste your final DecoderRNN class definition here --- | |
| # --- including forward_step and init_hidden_state --- | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout_p=0.5): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, embed_size) | |
| self.embed_dropout = nn.Dropout(dropout_p) | |
| lstm_dropout = dropout_p if num_layers > 1 else 0 | |
| self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout) | |
| self.dropout = nn.Dropout(dropout_p) | |
| self.linear = nn.Linear(hidden_size, vocab_size) | |
| self.init_h = nn.Linear(embed_size, hidden_size) | |
| self.init_c = nn.Linear(embed_size, hidden_size) | |
| self.num_layers = num_layers | |
| def init_hidden_state(self, features): | |
| h0 = self.init_h(features).unsqueeze(0) | |
| c0 = self.init_c(features).unsqueeze(0) | |
| if self.num_layers > 1: | |
| h0 = h0.repeat(self.num_layers, 1, 1) | |
| c0 = c0.repeat(self.num_layers, 1, 1) | |
| return (h0, c0) | |
| def forward_step(self, embedded_input, hidden_state): | |
| lstm_out, hidden_state = self.lstm(embedded_input, hidden_state) | |
| outputs = self.linear(lstm_out.squeeze(1)) | |
| return outputs, hidden_state | |
| # --- End Class Definitions --- | |
| # --- Configuration --- | |
| CHECKPOINT_PATH = 'best_model_improved.pth' | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CPU for typical Spaces hardware | |
| MAX_LEN = 25 | |
| # --- Global variables for loaded model (load ONCE) --- | |
| encoder_global = None | |
| decoder_global = None | |
| vocab_global = None | |
| transform_global = None | |
| # --- Model Loading Function --- | |
| def load_model_and_vocab(): | |
| global encoder_global, decoder_global, vocab_global, transform_global | |
| if encoder_global is not None: # Already loaded | |
| print("Model already loaded.") | |
| return | |
| print(f"Loading checkpoint: {CHECKPOINT_PATH} onto device: {DEVICE}") | |
| if not os.path.exists(CHECKPOINT_PATH): | |
| raise FileNotFoundError(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}") | |
| try: | |
| with safe_globals([Vocabulary, Counter]): # Allowlist custom classes | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| except Exception as e: | |
| print(f"Error loading checkpoint with safe_globals: {e}. Trying weights_only=False...") | |
| try: | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) | |
| except Exception as e2: | |
| raise RuntimeError(f"Failed to load checkpoint: {e2}") | |
| # Load vocabulary and hyperparameters | |
| vocab_global = checkpoint['vocab'] | |
| embed_size = checkpoint.get('embed_size', 256) | |
| hidden_size = checkpoint.get('hidden_size', 512) | |
| num_layers = checkpoint.get('num_layers', 1) | |
| dropout_prob = checkpoint.get('dropout_prob', 0.5) | |
| fine_tune_encoder = checkpoint.get('fine_tune_encoder', True) # Match saved config | |
| vocab_size = len(vocab_global) | |
| print(f"Vocabulary loaded (size: {vocab_size}). Hyperparameters extracted.") | |
| # Initialize models | |
| encoder_global = EncoderCNN(embed_size, dropout_p=dropout_prob, fine_tune=fine_tune_encoder).to(DEVICE) | |
| decoder_global = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout_p=dropout_prob).to(DEVICE) | |
| encoder_global.load_state_dict(checkpoint['encoder_state_dict']) | |
| decoder_global.load_state_dict(checkpoint['decoder_state_dict']) | |
| # Set to evaluation mode | |
| encoder_global.eval() | |
| decoder_global.eval() | |
| print("Models initialized, weights loaded, and set to eval mode.") | |
| # Define image transformation (same as validation/inference) | |
| transform_global = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print("Transforms defined.") | |
| # --- Helper: Tokens to Sentence --- | |
| def tokens_to_sentence(tokens, vocab): | |
| words = [vocab.idx2word.get(token, "<unk>") for token in tokens] | |
| words = [word for word in words if word not in ["<start>", "<end>", "<pad>"]] | |
| return " ".join(words) | |
| # --- Inference Function for Gradio --- | |
| def predict(input_image): | |
| """Generates caption for a PIL image input from Gradio.""" | |
| if encoder_global is None or decoder_global is None or vocab_global is None or transform_global is None: | |
| print("Error: Model not loaded.") | |
| # Optionally try loading here, but it's better to load upfront | |
| # load_model_and_vocab() | |
| # if encoder_global is None: # Check again | |
| return "Error: Model components not loaded. Check logs." | |
| # 1. Preprocess Image | |
| try: | |
| image_tensor = transform_global(input_image) | |
| image_tensor = image_tensor.unsqueeze(0).to(DEVICE) # Add batch dim | |
| except Exception as e: | |
| print(f"Error transforming image: {e}") | |
| return f"Error processing image: {e}" | |
| # 2. Generate Caption (Greedy Search) | |
| generated_indices = [] | |
| with torch.no_grad(): | |
| try: | |
| features = encoder_global(image_tensor) | |
| hidden_state = decoder_global.init_hidden_state(features) | |
| start_token_idx = vocab_global.word2idx["<start>"] | |
| inputs = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE) | |
| for _ in range(MAX_LEN): | |
| embedded = decoder_global.embed(inputs) | |
| outputs, hidden_state = decoder_global.forward_step(embedded, hidden_state) | |
| predicted_idx = outputs.argmax(1) | |
| predicted_word_idx = predicted_idx.item() | |
| if predicted_word_idx == vocab_global.word2idx["<end>"]: | |
| break # Stop if <end> is predicted | |
| generated_indices.append(predicted_word_idx) | |
| inputs = predicted_idx.unsqueeze(1) # Prepare for next step | |
| except Exception as e: | |
| print(f"Error during caption generation: {e}") | |
| return f"Error during generation: {e}" | |
| # 3. Convert to Sentence | |
| caption = tokens_to_sentence(generated_indices, vocab_global) | |
| return caption | |
| # --- Load Model when script starts --- | |
| # Ensure NLTK data is available if needed by tokenizer within Vocab class | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| print("NLTK 'punkt' tokenizer data not found. Downloading...") | |
| nltk.download('punkt', quiet=True) | |
| load_model_and_vocab() # Load model into global variables | |
| # --- Create Gradio Interface --- | |
| title = "Image Captioning Demo" | |
| description = "Upload an image and this model (ResNet101 Encoder + LSTM Decoder) will generate a caption. Trained on COCO." | |
| # Optional: Define example images (paths relative to the app.py file) | |
| example_list = [["images/example1.jpg"], ["images/example2.jpg"]] if os.path.exists("images") else None | |
| iface = gr.Interface( | |
| fn=predict, # The function to call for inference | |
| inputs=gr.Image(type="pil", label="Upload Image"), # Input: Image upload, provide PIL image to fn | |
| outputs=gr.Textbox(label="Generated Caption"), # Output: Textbox | |
| title=title, | |
| description=description, | |
| examples=example_list, # Optional: Provide examples | |
| allow_flagging="never" # Optional: Disable flagging | |
| ) | |
| # --- Launch the Gradio app --- | |
| if __name__ == "__main__": | |
| iface.launch() # Share=True is not needed for Spaces, it's handled automatically |