Spaces:
Sleeping
Sleeping
FIX: Complete the Vocabulary class definition
Browse files- vocabulary.py +33 -1
vocabulary.py
CHANGED
|
@@ -1,7 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
class Vocabulary:
|
| 2 |
def __init__(self, freq_threshold):
|
| 3 |
self.itos = {0: "<PAD>", 1: "<START>", 2: "<END>", 3: "<UNK>"}
|
| 4 |
self.stoi = {"<PAD>": 0, "<START>": 1, "<END>": 2, "<UNK>": 3}
|
| 5 |
self.freq_threshold = freq_threshold
|
|
|
|
| 6 |
def __len__(self):
|
| 7 |
-
return len(self.itos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
class Vocabulary:
|
| 6 |
def __init__(self, freq_threshold):
|
| 7 |
self.itos = {0: "<PAD>", 1: "<START>", 2: "<END>", 3: "<UNK>"}
|
| 8 |
self.stoi = {"<PAD>": 0, "<START>": 1, "<END>": 2, "<UNK>": 3}
|
| 9 |
self.freq_threshold = freq_threshold
|
| 10 |
+
|
| 11 |
def __len__(self):
|
| 12 |
+
return len(self.itos)
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def get_all_captions(dataset):
|
| 16 |
+
all_captions = []
|
| 17 |
+
print("Gathering all captions from the training set...")
|
| 18 |
+
for item in tqdm(dataset):
|
| 19 |
+
all_captions.append(item['caption_0'])
|
| 20 |
+
all_captions.append(item['caption_1'])
|
| 21 |
+
all_captions.append(item['caption_2'])
|
| 22 |
+
all_captions.append(item['caption_3'])
|
| 23 |
+
all_captions.append(item['caption_4'])
|
| 24 |
+
return all_captions
|
| 25 |
+
|
| 26 |
+
def build_vocabulary(self, sentence_list):
|
| 27 |
+
frequencies = Counter()
|
| 28 |
+
idx = 4
|
| 29 |
+
print("Tokenizing and counting word frequencies...")
|
| 30 |
+
for sentence in tqdm(sentence_list):
|
| 31 |
+
for word in nltk.word_tokenize(sentence.lower()):
|
| 32 |
+
frequencies[word] += 1
|
| 33 |
+
|
| 34 |
+
print("Building word-to-index mapping...")
|
| 35 |
+
for word, count in tqdm(frequencies.items()):
|
| 36 |
+
if count >= self.freq_threshold:
|
| 37 |
+
self.stoi[word] = idx
|
| 38 |
+
self.itos[idx] = word
|
| 39 |
+
idx += 1
|