AhsanAftab commited on
Commit
4d0c7ca
·
verified ·
1 Parent(s): 30dbc69

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +45 -65
model_loader.py CHANGED
@@ -4,72 +4,51 @@ from torchvision import models
4
  import pickle
5
  from pathlib import Path
6
  import sys
 
7
 
8
- # ==========================================
9
- # 1. DEFINE THE VOCABULARY CLASS
10
- # (This allows pickle to reconstruct the object)
11
- # ==========================================
12
- class Vocabulary:
13
- def __init__(self, freq_threshold):
14
- self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
15
- self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
16
- self.freq_threshold = freq_threshold
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def __len__(self):
19
- return len(self.itos)
20
-
21
- @staticmethod
22
- def tokenizer_eng(text):
23
- return text.lower().split()
24
-
25
- def build_vocabulary(self, sentence_list):
26
- frequencies = {}
27
- idx = 4
28
- for sentence in sentence_list:
29
- for word in self.tokenizer_eng(sentence):
30
- if word not in frequencies:
31
- frequencies[word] = 1
32
- else:
33
- frequencies[word] += 1
34
-
35
- if frequencies[word] == self.freq_threshold:
36
- self.stoi[word] = idx
37
- self.itos[idx] = word
38
- idx += 1
39
-
40
- def numericalize(self, text):
41
- tokenized_text = self.tokenizer_eng(text)
42
- return [
43
- self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
44
- for token in tokenized_text
45
- ]
46
-
47
- # Helper for inference
48
- def decode(self, tokens):
49
- return [self.itos[token] if token in self.itos else "<UNK>" for token in tokens]
50
-
51
- @property
52
- def start_token(self):
53
- return "<SOS>"
54
-
55
- @property
56
- def end_token(self):
57
- return "<EOS>"
58
 
59
- @property
60
- def pad_token(self):
61
- return "<PAD>"
62
 
63
- # ==========================================
64
- # 2. REDIRECT __main__.Vocabulary
65
- # (Crucial step for pickle loading)
66
- # ==========================================
67
  import __main__
68
  setattr(__main__, "Vocabulary", Vocabulary)
69
 
70
- # ==========================================
71
- # MODEL CLASSES
72
- # ==========================================
73
 
74
  class EncoderCNN(nn.Module):
75
  def __init__(self, embed_size):
@@ -153,9 +132,6 @@ class ActionRecognitionModel(nn.Module):
153
  def forward(self, x):
154
  return self.backbone(x)
155
 
156
- # ==========================================
157
- # LOADER FUNCTIONS
158
- # ==========================================
159
 
160
  def load_caption_model(device, model_dir=None):
161
  if model_dir is None:
@@ -168,10 +144,14 @@ def load_caption_model(device, model_dir=None):
168
  config = pickle.load(f)
169
 
170
  # Load vocabulary
171
- # The 'setattr' fix above allows this line to work
172
- with open(model_dir / 'vocab.pkl', 'rb') as f:
173
- vocab = pickle.load(f)
174
-
 
 
 
 
175
  # Create model
176
  model = ImageCaptioningModel(
177
  embed_size=config['embed_size'],
 
4
  import pickle
5
  from pathlib import Path
6
  import sys
7
+ import logging
8
 
9
+ # Configure logger
10
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
11
 
12
+ class Vocabulary:
13
+ def __init__(self, freq_threshold=5):
14
+ self.freq_threshold = freq_threshold
15
+ self.word2idx = {}
16
+ self.idx2word = {}
17
+ self.idx = 0
18
+
19
+ # Special tokens
20
+ self.pad_token = "<PAD>"
21
+ self.start_token = "<SOS>"
22
+ self.end_token = "<EOS>"
23
+ self.unk_token = "<UNK>"
24
+
25
+ # Add special tokens
26
+ for token in [self.pad_token, self.start_token, self.end_token, self.unk_token]:
27
+ self.add_word(token)
28
+
29
+ def add_word(self, word):
30
+ """Add a word to the vocabulary"""
31
+ if word not in self.word2idx:
32
+ self.word2idx[word] = self.idx
33
+ self.idx2word[self.idx] = word
34
+ self.idx += 1
35
+
36
  def __len__(self):
37
+ return len(self.word2idx)
38
+
39
+ def __call__(self, word):
40
+ """Convert word to index"""
41
+ if word not in self.word2idx:
42
+ return self.word2idx[self.unk_token]
43
+ return self.word2idx[word]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def decode(self, indices):
46
+ """Convert indices back to words"""
47
+ return [self.idx2word[idx] for idx in indices if idx in self.idx2word]
48
 
 
 
 
 
49
  import __main__
50
  setattr(__main__, "Vocabulary", Vocabulary)
51
 
 
 
 
52
 
53
  class EncoderCNN(nn.Module):
54
  def __init__(self, embed_size):
 
132
  def forward(self, x):
133
  return self.backbone(x)
134
 
 
 
 
135
 
136
  def load_caption_model(device, model_dir=None):
137
  if model_dir is None:
 
144
  config = pickle.load(f)
145
 
146
  # Load vocabulary
147
+ try:
148
+ with open(model_dir / 'vocab.pkl', 'rb') as f:
149
+ vocab = pickle.load(f)
150
+ logger.info(f"Vocabulary loaded successfully. Size: {len(vocab)}")
151
+ except Exception as e:
152
+ logger.error(f"Failed to load vocabulary: {e}")
153
+ raise e
154
+
155
  # Create model
156
  model = ImageCaptioningModel(
157
  embed_size=config['embed_size'],