Image_Captioning / loader.py
karan99300's picture
Upload 5 files
b1a427a
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import torchvision.transforms as transforms
spacy_eng=spacy.load("en_core_web_sm")
class Vocabulary:
def __init__(self,freq_threshold):
self.itos={0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
self.stoi={"<PAD>":0,"<SOS>":1,"<EOS>":2,"<UNK>":3}
self.freq_threshold=freq_threshold
def __len__(self):
return len(self.itos)
def tokenizer_eng(self,text):
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
def build_vocabulary(self,sentence_list):
frequencies={}
idx=4
for sentence in sentence_list:
for word in self.tokenizer_eng(sentence):
if word not in frequencies:
frequencies[word]=1
else:
frequencies[word]+=1
if frequencies[word]==self.freq_threshold:
self.stoi[word]=idx
self.itos[idx]=word
idx+=1
def numericalize(self,text):
tokenized_text=self.tokenizer_eng(text)
return [
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
for token in tokenized_text
]
class FlickrDataset(Dataset):
def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
self.root_dir=root_dir
self.df=pd.read_csv(captions_file)
self.transform=transform
self.imgs=self.df['image']
self.captions=self.df['caption']
self.vocab=Vocabulary(freq_threshold)
self.vocab.build_vocabulary(self.captions.tolist())
def __len__(self):
return len(self.df)
def __getitem__(self,index):
caption=self.captions[index]
img_id=self.imgs[index]
img=Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")
if self.transform is not None:
img=self.transform(img)
numericalized_caption=[self.vocab.stoi["<SOS>"]]
numericalized_caption+=self.vocab.numericalize(caption)
numericalized_caption.append(self.vocab.stoi["<EOS>"])
return img,torch.tensor(numericalized_caption)
class MyCollate:
def __init__(self,pad_idx):
self.pad_idx=pad_idx
def __call__(self,batch):
imgs=[item[0].unsqueeze(0) for item in batch]
imgs=torch.cat(imgs,dim=0)
targets=[item[1] for item in batch]
targets=pad_sequence(targets,batch_first=False,padding_value=self.pad_idx)
return imgs,targets
def get_loader(root_folder,annotation_file,transform,batch_size=32,shuffle=True,pin_memory=True,num_workers=8):
dataset=FlickrDataset(root_folder,annotation_file,transform=transform)
pad_idx=dataset.vocab.stoi["<PAD>"]
loader=DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
pin_memory=pin_memory,
collate_fn=MyCollate(pad_idx=pad_idx)
)
return loader,dataset