Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import sys | |
| import os | |
| import pickle | |
| import re | |
| from collections import Counter | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| EMBED_DIM = 512 | |
| HIDDEN_DIM = 512 | |
| MAX_LEN = 25 | |
| # Vocabulary class | |
| class Vocabulary: | |
| def __init__(self, freq_threshold=5): | |
| self.freq_threshold = freq_threshold | |
| self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"} | |
| self.stoi = {v: k for k, v in self.itos.items()} | |
| self.index = 4 | |
| def __len__(self): | |
| return len(self.itos) | |
| def tokenizer(self, text): | |
| text = text.lower() | |
| tokens = re.findall(r"\w+", text) | |
| return tokens | |
| def build_vocabulary(self, sentence_list): | |
| frequencies = Counter() | |
| for sentence in sentence_list: | |
| tokens = self.tokenizer(sentence) | |
| frequencies.update(tokens) | |
| for word, freq in frequencies.items(): | |
| if freq >= self.freq_threshold: | |
| self.stoi[word] = self.index | |
| self.itos[self.index] = word | |
| self.index += 1 | |
| def numericalize(self, text): | |
| tokens = self.tokenizer(text) | |
| numericalized = [] | |
| for token in tokens: | |
| if token in self.stoi: | |
| numericalized.append(self.stoi[token]) | |
| else: | |
| numericalized.append(self.stoi["unk"]) | |
| return numericalized | |
| class Encoder(nn.Module): | |
| def __init__(self, embed_dim): | |
| super().__init__() | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| self.backbone = nn.Sequential(*list(resnet.children())[:-1]) | |
| self.fc = nn.Linear(resnet.fc.in_features, embed_dim) | |
| self.bn = nn.BatchNorm1d(embed_dim) | |
| def forward(self, x): | |
| with torch.no_grad(): | |
| features = self.backbone(x) | |
| features = features.reshape(features.size(0), -1) | |
| features = self.bn(self.fc(features)) | |
| return features | |
| class Decoder(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, vocab_size): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.lstm = nn.LSTM( | |
| embed_dim, | |
| hidden_dim, | |
| batch_first=True | |
| ) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| def forward(self, x, states=None): | |
| emb = self.embedding(x) | |
| outputs, states = self.lstm(emb, states) | |
| logits = self.fc(outputs) | |
| return logits, states | |
| class CaptionModel(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, vocab_size): | |
| super().__init__() | |
| self.encoder = Encoder(embed_dim) | |
| self.decoder = Decoder(embed_dim, hidden_dim, vocab_size) | |
| # Main debug | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| CHECKPOINT_PATH = os.path.join(script_dir, "best_checkpoint.pth") | |
| VOCAB_PATH = os.path.join(script_dir, "vocab.pkl") | |
| print("=" * 80) | |
| print("LOADING CHECKPOINT") | |
| print("=" * 80) | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| print(f"\nCheckpoint keys: {list(checkpoint.keys())}") | |
| print("\nCheckpoint model_state_dict keys:") | |
| checkpoint_keys = set(checkpoint["model_state_dict"].keys()) | |
| for key in sorted(checkpoint_keys): | |
| shape = checkpoint["model_state_dict"][key].shape | |
| print(f" {key}: {shape}") | |
| # Load vocab | |
| with open(VOCAB_PATH, "rb") as f: | |
| vocab = pickle.load(f) | |
| vocab_size = len(vocab) | |
| print(f"\nVocab size: {vocab_size}") | |
| # Create model | |
| model = CaptionModel( | |
| EMBED_DIM, | |
| HIDDEN_DIM, | |
| vocab_size | |
| ).to(DEVICE) | |
| print("\n" + "=" * 80) | |
| print("MODEL STATE DICT KEYS") | |
| print("=" * 80) | |
| model_keys = set(model.state_dict().keys()) | |
| for key in sorted(model_keys): | |
| shape = model.state_dict()[key].shape | |
| print(f" {key}: {shape}") | |
| # Check differences | |
| print("\n" + "=" * 80) | |
| print("COMPARISON") | |
| print("=" * 80) | |
| print("\nKeys in checkpoint but NOT in model:") | |
| for key in sorted(checkpoint_keys - model_keys): | |
| print(f" {key}") | |
| print("\nKeys in model but NOT in checkpoint:") | |
| for key in sorted(model_keys - checkpoint_keys): | |
| print(f" {key}") | |
| print("\nKeys in both but with different shapes:") | |
| for key in sorted(checkpoint_keys & model_keys): | |
| cp_shape = checkpoint["model_state_dict"][key].shape | |
| model_shape = model.state_dict()[key].shape | |
| if cp_shape != model_shape: | |
| print(f" {key}") | |
| print(f" Checkpoint: {cp_shape}") | |
| print(f" Model: {model_shape}") | |
| print("\n" + "=" * 80) | |
| print("ATTEMPTING TO LOAD WEIGHTS") | |
| print("=" * 80) | |
| try: | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| print("SUCCESS: Weights loaded successfully!") | |
| except Exception as e: | |
| print(f"ERROR: {e}") | |