File size: 1,413 Bytes
e1aaaac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import json
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
class COCOFlickrDataset(Dataset):
def __init__(
self,
image_dir_path,
annotations_path,
transform=None,
is_flickr=False,
prefix=None,
):
self.image_dir_path = image_dir_path
self.annotations = json.load(open(annotations_path))["annotations"]
self.is_flickr = is_flickr
self.transform = transform
self.prefix = prefix
def __len__(self):
return len(self.annotations)
def get_img_path(self, idx):
if self.is_flickr:
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
else:
return f"{self.image_dir_path}/{self.prefix}{self.annotations[idx]['image_id']:012d}.jpg"
def __getitem__(self, idx):
image = Image.open(self.get_img_path(idx))
caption = self.annotations[idx]["caption"]
return self.transform(image), caption
class ImageNetDataset(ImageFolder):
"""Class to represent the ImageNet1k dataset."""
def __init__(self, root, **kwargs):
super().__init__(root=root, **kwargs)
def __getitem__(self, idx):
sample, target = super().__getitem__(idx)
# target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
return sample, target
|