Update model_loader.py
Browse files- model_loader.py +45 -65
model_loader.py
CHANGED
|
@@ -4,72 +4,51 @@ from torchvision import models
|
|
| 4 |
import pickle
|
| 5 |
from pathlib import Path
|
| 6 |
import sys
|
|
|
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
# (This allows pickle to reconstruct the object)
|
| 11 |
-
# ==========================================
|
| 12 |
-
class Vocabulary:
|
| 13 |
-
def __init__(self, freq_threshold):
|
| 14 |
-
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
|
| 15 |
-
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
|
| 16 |
-
self.freq_threshold = freq_threshold
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def __len__(self):
|
| 19 |
-
return len(self.
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
frequencies = {}
|
| 27 |
-
idx = 4
|
| 28 |
-
for sentence in sentence_list:
|
| 29 |
-
for word in self.tokenizer_eng(sentence):
|
| 30 |
-
if word not in frequencies:
|
| 31 |
-
frequencies[word] = 1
|
| 32 |
-
else:
|
| 33 |
-
frequencies[word] += 1
|
| 34 |
-
|
| 35 |
-
if frequencies[word] == self.freq_threshold:
|
| 36 |
-
self.stoi[word] = idx
|
| 37 |
-
self.itos[idx] = word
|
| 38 |
-
idx += 1
|
| 39 |
-
|
| 40 |
-
def numericalize(self, text):
|
| 41 |
-
tokenized_text = self.tokenizer_eng(text)
|
| 42 |
-
return [
|
| 43 |
-
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
| 44 |
-
for token in tokenized_text
|
| 45 |
-
]
|
| 46 |
-
|
| 47 |
-
# Helper for inference
|
| 48 |
-
def decode(self, tokens):
|
| 49 |
-
return [self.itos[token] if token in self.itos else "<UNK>" for token in tokens]
|
| 50 |
-
|
| 51 |
-
@property
|
| 52 |
-
def start_token(self):
|
| 53 |
-
return "<SOS>"
|
| 54 |
-
|
| 55 |
-
@property
|
| 56 |
-
def end_token(self):
|
| 57 |
-
return "<EOS>"
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
return
|
| 62 |
|
| 63 |
-
# ==========================================
|
| 64 |
-
# 2. REDIRECT __main__.Vocabulary
|
| 65 |
-
# (Crucial step for pickle loading)
|
| 66 |
-
# ==========================================
|
| 67 |
import __main__
|
| 68 |
setattr(__main__, "Vocabulary", Vocabulary)
|
| 69 |
|
| 70 |
-
# ==========================================
|
| 71 |
-
# MODEL CLASSES
|
| 72 |
-
# ==========================================
|
| 73 |
|
| 74 |
class EncoderCNN(nn.Module):
|
| 75 |
def __init__(self, embed_size):
|
|
@@ -153,9 +132,6 @@ class ActionRecognitionModel(nn.Module):
|
|
| 153 |
def forward(self, x):
|
| 154 |
return self.backbone(x)
|
| 155 |
|
| 156 |
-
# ==========================================
|
| 157 |
-
# LOADER FUNCTIONS
|
| 158 |
-
# ==========================================
|
| 159 |
|
| 160 |
def load_caption_model(device, model_dir=None):
|
| 161 |
if model_dir is None:
|
|
@@ -168,10 +144,14 @@ def load_caption_model(device, model_dir=None):
|
|
| 168 |
config = pickle.load(f)
|
| 169 |
|
| 170 |
# Load vocabulary
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# Create model
|
| 176 |
model = ImageCaptioningModel(
|
| 177 |
embed_size=config['embed_size'],
|
|
|
|
| 4 |
import pickle
|
| 5 |
from pathlib import Path
|
| 6 |
import sys
|
| 7 |
+
import logging
|
| 8 |
|
| 9 |
+
# Configure logger
|
| 10 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
class Vocabulary:
|
| 13 |
+
def __init__(self, freq_threshold=5):
|
| 14 |
+
self.freq_threshold = freq_threshold
|
| 15 |
+
self.word2idx = {}
|
| 16 |
+
self.idx2word = {}
|
| 17 |
+
self.idx = 0
|
| 18 |
+
|
| 19 |
+
# Special tokens
|
| 20 |
+
self.pad_token = "<PAD>"
|
| 21 |
+
self.start_token = "<SOS>"
|
| 22 |
+
self.end_token = "<EOS>"
|
| 23 |
+
self.unk_token = "<UNK>"
|
| 24 |
+
|
| 25 |
+
# Add special tokens
|
| 26 |
+
for token in [self.pad_token, self.start_token, self.end_token, self.unk_token]:
|
| 27 |
+
self.add_word(token)
|
| 28 |
+
|
| 29 |
+
def add_word(self, word):
|
| 30 |
+
"""Add a word to the vocabulary"""
|
| 31 |
+
if word not in self.word2idx:
|
| 32 |
+
self.word2idx[word] = self.idx
|
| 33 |
+
self.idx2word[self.idx] = word
|
| 34 |
+
self.idx += 1
|
| 35 |
+
|
| 36 |
def __len__(self):
|
| 37 |
+
return len(self.word2idx)
|
| 38 |
+
|
| 39 |
+
def __call__(self, word):
|
| 40 |
+
"""Convert word to index"""
|
| 41 |
+
if word not in self.word2idx:
|
| 42 |
+
return self.word2idx[self.unk_token]
|
| 43 |
+
return self.word2idx[word]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
def decode(self, indices):
|
| 46 |
+
"""Convert indices back to words"""
|
| 47 |
+
return [self.idx2word[idx] for idx in indices if idx in self.idx2word]
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
import __main__
|
| 50 |
setattr(__main__, "Vocabulary", Vocabulary)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
class EncoderCNN(nn.Module):
|
| 54 |
def __init__(self, embed_size):
|
|
|
|
| 132 |
def forward(self, x):
|
| 133 |
return self.backbone(x)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
def load_caption_model(device, model_dir=None):
|
| 137 |
if model_dir is None:
|
|
|
|
| 144 |
config = pickle.load(f)
|
| 145 |
|
| 146 |
# Load vocabulary
|
| 147 |
+
try:
|
| 148 |
+
with open(model_dir / 'vocab.pkl', 'rb') as f:
|
| 149 |
+
vocab = pickle.load(f)
|
| 150 |
+
logger.info(f"Vocabulary loaded successfully. Size: {len(vocab)}")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"Failed to load vocabulary: {e}")
|
| 153 |
+
raise e
|
| 154 |
+
|
| 155 |
# Create model
|
| 156 |
model = ImageCaptioningModel(
|
| 157 |
embed_size=config['embed_size'],
|