Ishu8904 commited on
Commit
822afc8
·
1 Parent(s): 60f1f9f

FIX: Complete the Vocabulary class definition

Browse files
Files changed (1) hide show
  1. vocabulary.py +33 -1
vocabulary.py CHANGED
@@ -1,7 +1,39 @@
 
 
 
 
1
  class Vocabulary:
2
  def __init__(self, freq_threshold):
3
  self.itos = {0: "<PAD>", 1: "<START>", 2: "<END>", 3: "<UNK>"}
4
  self.stoi = {"<PAD>": 0, "<START>": 1, "<END>": 2, "<UNK>": 3}
5
  self.freq_threshold = freq_threshold
 
6
  def __len__(self):
7
- return len(self.itos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from collections import Counter
3
+ from tqdm import tqdm
4
+
5
  class Vocabulary:
6
  def __init__(self, freq_threshold):
7
  self.itos = {0: "<PAD>", 1: "<START>", 2: "<END>", 3: "<UNK>"}
8
  self.stoi = {"<PAD>": 0, "<START>": 1, "<END>": 2, "<UNK>": 3}
9
  self.freq_threshold = freq_threshold
10
+
11
  def __len__(self):
12
+ return len(self.itos)
13
+
14
+ @staticmethod
15
+ def get_all_captions(dataset):
16
+ all_captions = []
17
+ print("Gathering all captions from the training set...")
18
+ for item in tqdm(dataset):
19
+ all_captions.append(item['caption_0'])
20
+ all_captions.append(item['caption_1'])
21
+ all_captions.append(item['caption_2'])
22
+ all_captions.append(item['caption_3'])
23
+ all_captions.append(item['caption_4'])
24
+ return all_captions
25
+
26
+ def build_vocabulary(self, sentence_list):
27
+ frequencies = Counter()
28
+ idx = 4
29
+ print("Tokenizing and counting word frequencies...")
30
+ for sentence in tqdm(sentence_list):
31
+ for word in nltk.word_tokenize(sentence.lower()):
32
+ frequencies[word] += 1
33
+
34
+ print("Building word-to-index mapping...")
35
+ for word, count in tqdm(frequencies.items()):
36
+ if count >= self.freq_threshold:
37
+ self.stoi[word] = idx
38
+ self.itos[idx] = word
39
+ idx += 1