import torch import torch.nn as nn import torchvision.models as models import sys import os import pickle import re from collections import Counter DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EMBED_DIM = 512 HIDDEN_DIM = 512 MAX_LEN = 25 # Vocabulary class class Vocabulary: def __init__(self, freq_threshold=5): self.freq_threshold = freq_threshold self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"} self.stoi = {v: k for k, v in self.itos.items()} self.index = 4 def __len__(self): return len(self.itos) def tokenizer(self, text): text = text.lower() tokens = re.findall(r"\w+", text) return tokens def build_vocabulary(self, sentence_list): frequencies = Counter() for sentence in sentence_list: tokens = self.tokenizer(sentence) frequencies.update(tokens) for word, freq in frequencies.items(): if freq >= self.freq_threshold: self.stoi[word] = self.index self.itos[self.index] = word self.index += 1 def numericalize(self, text): tokens = self.tokenizer(text) numericalized = [] for token in tokens: if token in self.stoi: numericalized.append(self.stoi[token]) else: numericalized.append(self.stoi["unk"]) return numericalized class Encoder(nn.Module): def __init__(self, embed_dim): super().__init__() resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.fc = nn.Linear(resnet.fc.in_features, embed_dim) self.bn = nn.BatchNorm1d(embed_dim) def forward(self, x): with torch.no_grad(): features = self.backbone(x) features = features.reshape(features.size(0), -1) features = self.bn(self.fc(features)) return features class Decoder(nn.Module): def __init__(self, embed_dim, hidden_dim, vocab_size): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM( embed_dim, hidden_dim, batch_first=True ) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x, states=None): emb = self.embedding(x) outputs, states = self.lstm(emb, states) logits = self.fc(outputs) return logits, states class CaptionModel(nn.Module): def __init__(self, embed_dim, hidden_dim, vocab_size): super().__init__() self.encoder = Encoder(embed_dim) self.decoder = Decoder(embed_dim, hidden_dim, vocab_size) # Main debug script_dir = os.path.dirname(os.path.abspath(__file__)) CHECKPOINT_PATH = os.path.join(script_dir, "best_checkpoint.pth") VOCAB_PATH = os.path.join(script_dir, "vocab.pkl") print("=" * 80) print("LOADING CHECKPOINT") print("=" * 80) checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) print(f"\nCheckpoint keys: {list(checkpoint.keys())}") print("\nCheckpoint model_state_dict keys:") checkpoint_keys = set(checkpoint["model_state_dict"].keys()) for key in sorted(checkpoint_keys): shape = checkpoint["model_state_dict"][key].shape print(f" {key}: {shape}") # Load vocab with open(VOCAB_PATH, "rb") as f: vocab = pickle.load(f) vocab_size = len(vocab) print(f"\nVocab size: {vocab_size}") # Create model model = CaptionModel( EMBED_DIM, HIDDEN_DIM, vocab_size ).to(DEVICE) print("\n" + "=" * 80) print("MODEL STATE DICT KEYS") print("=" * 80) model_keys = set(model.state_dict().keys()) for key in sorted(model_keys): shape = model.state_dict()[key].shape print(f" {key}: {shape}") # Check differences print("\n" + "=" * 80) print("COMPARISON") print("=" * 80) print("\nKeys in checkpoint but NOT in model:") for key in sorted(checkpoint_keys - model_keys): print(f" {key}") print("\nKeys in model but NOT in checkpoint:") for key in sorted(model_keys - checkpoint_keys): print(f" {key}") print("\nKeys in both but with different shapes:") for key in sorted(checkpoint_keys & model_keys): cp_shape = checkpoint["model_state_dict"][key].shape model_shape = model.state_dict()[key].shape if cp_shape != model_shape: print(f" {key}") print(f" Checkpoint: {cp_shape}") print(f" Model: {model_shape}") print("\n" + "=" * 80) print("ATTEMPTING TO LOAD WEIGHTS") print("=" * 80) try: model.load_state_dict(checkpoint["model_state_dict"]) print("SUCCESS: Weights loaded successfully!") except Exception as e: print(f"ERROR: {e}")