Image-Captioning / debug_weights.py
VIKRAM989's picture
Add application file
40243b5
import torch
import torch.nn as nn
import torchvision.models as models
import sys
import os
import pickle
import re
from collections import Counter
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512
HIDDEN_DIM = 512
MAX_LEN = 25
# Vocabulary class
class Vocabulary:
def __init__(self, freq_threshold=5):
self.freq_threshold = freq_threshold
self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
self.stoi = {v: k for k, v in self.itos.items()}
self.index = 4
def __len__(self):
return len(self.itos)
def tokenizer(self, text):
text = text.lower()
tokens = re.findall(r"\w+", text)
return tokens
def build_vocabulary(self, sentence_list):
frequencies = Counter()
for sentence in sentence_list:
tokens = self.tokenizer(sentence)
frequencies.update(tokens)
for word, freq in frequencies.items():
if freq >= self.freq_threshold:
self.stoi[word] = self.index
self.itos[self.index] = word
self.index += 1
def numericalize(self, text):
tokens = self.tokenizer(text)
numericalized = []
for token in tokens:
if token in self.stoi:
numericalized.append(self.stoi[token])
else:
numericalized.append(self.stoi["unk"])
return numericalized
class Encoder(nn.Module):
def __init__(self, embed_dim):
super().__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
self.bn = nn.BatchNorm1d(embed_dim)
def forward(self, x):
with torch.no_grad():
features = self.backbone(x)
features = features.reshape(features.size(0), -1)
features = self.bn(self.fc(features))
return features
class Decoder(nn.Module):
def __init__(self, embed_dim, hidden_dim, vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
batch_first=True
)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, states=None):
emb = self.embedding(x)
outputs, states = self.lstm(emb, states)
logits = self.fc(outputs)
return logits, states
class CaptionModel(nn.Module):
def __init__(self, embed_dim, hidden_dim, vocab_size):
super().__init__()
self.encoder = Encoder(embed_dim)
self.decoder = Decoder(embed_dim, hidden_dim, vocab_size)
# Main debug
script_dir = os.path.dirname(os.path.abspath(__file__))
CHECKPOINT_PATH = os.path.join(script_dir, "best_checkpoint.pth")
VOCAB_PATH = os.path.join(script_dir, "vocab.pkl")
print("=" * 80)
print("LOADING CHECKPOINT")
print("=" * 80)
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
print(f"\nCheckpoint keys: {list(checkpoint.keys())}")
print("\nCheckpoint model_state_dict keys:")
checkpoint_keys = set(checkpoint["model_state_dict"].keys())
for key in sorted(checkpoint_keys):
shape = checkpoint["model_state_dict"][key].shape
print(f" {key}: {shape}")
# Load vocab
with open(VOCAB_PATH, "rb") as f:
vocab = pickle.load(f)
vocab_size = len(vocab)
print(f"\nVocab size: {vocab_size}")
# Create model
model = CaptionModel(
EMBED_DIM,
HIDDEN_DIM,
vocab_size
).to(DEVICE)
print("\n" + "=" * 80)
print("MODEL STATE DICT KEYS")
print("=" * 80)
model_keys = set(model.state_dict().keys())
for key in sorted(model_keys):
shape = model.state_dict()[key].shape
print(f" {key}: {shape}")
# Check differences
print("\n" + "=" * 80)
print("COMPARISON")
print("=" * 80)
print("\nKeys in checkpoint but NOT in model:")
for key in sorted(checkpoint_keys - model_keys):
print(f" {key}")
print("\nKeys in model but NOT in checkpoint:")
for key in sorted(model_keys - checkpoint_keys):
print(f" {key}")
print("\nKeys in both but with different shapes:")
for key in sorted(checkpoint_keys & model_keys):
cp_shape = checkpoint["model_state_dict"][key].shape
model_shape = model.state_dict()[key].shape
if cp_shape != model_shape:
print(f" {key}")
print(f" Checkpoint: {cp_shape}")
print(f" Model: {model_shape}")
print("\n" + "=" * 80)
print("ATTEMPTING TO LOAD WEIGHTS")
print("=" * 80)
try:
model.load_state_dict(checkpoint["model_state_dict"])
print("SUCCESS: Weights loaded successfully!")
except Exception as e:
print(f"ERROR: {e}")