File size: 3,307 Bytes
b1a427a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os 
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import torchvision.transforms as transforms

spacy_eng=spacy.load("en_core_web_sm")
class Vocabulary:
    def __init__(self,freq_threshold):
        self.itos={0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        self.stoi={"<PAD>":0,"<SOS>":1,"<EOS>":2,"<UNK>":3}
        self.freq_threshold=freq_threshold
        
    def __len__(self):
        return len(self.itos)
    
    def tokenizer_eng(self,text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self,sentence_list):
        frequencies={}
        idx=4
        
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word]=1
                    
                else:
                    frequencies[word]+=1
                    
                if frequencies[word]==self.freq_threshold:
                    self.stoi[word]=idx
                    self.itos[idx]=word
                    idx+=1
                    
    def numericalize(self,text):
        tokenized_text=self.tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
        
class FlickrDataset(Dataset):
    def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
        self.root_dir=root_dir
        self.df=pd.read_csv(captions_file)
        self.transform=transform
        
        self.imgs=self.df['image']
        self.captions=self.df['caption']
        
        self.vocab=Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)  
    
    def __getitem__(self,index):
        caption=self.captions[index]
        img_id=self.imgs[index]
        img=Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")
        
        if self.transform is not None:
            img=self.transform(img)
            
        numericalized_caption=[self.vocab.stoi["<SOS>"]]
        numericalized_caption+=self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img,torch.tensor(numericalized_caption)
    
class MyCollate:
    def __init__(self,pad_idx):
        self.pad_idx=pad_idx
        
    def __call__(self,batch):
        imgs=[item[0].unsqueeze(0) for item in batch]
        imgs=torch.cat(imgs,dim=0)
        targets=[item[1] for item in batch]
        targets=pad_sequence(targets,batch_first=False,padding_value=self.pad_idx)
        
        return imgs,targets
    
def get_loader(root_folder,annotation_file,transform,batch_size=32,shuffle=True,pin_memory=True,num_workers=8):
    dataset=FlickrDataset(root_folder,annotation_file,transform=transform)
    pad_idx=dataset.vocab.stoi["<PAD>"]
    loader=DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx)
    )
    
    return loader,dataset