xmutly's picture
Upload 294 files
e1aaaac verified
"""
Adapted from https://github.com/pytorch/vision/blob/main/torchvision/datasets/flickr.py
Thanks to the authors of torchvision
"""
from collections import defaultdict
import glob
import os
from collections import defaultdict
from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple
from PIL import Image
from torchvision.datasets import VisionDataset
class Flickr(VisionDataset):
def __init__(
self,
root: str,
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
data = defaultdict(list)
with open(ann_file) as fd:
fd.readline()
for line in fd:
line = line.strip()
if line:
# some lines have comma in the caption, se we make sure we do the split correctly
img, caption = line.strip().split(".jpg,")
img = img + ".jpg"
data[img].append(caption)
self.data = list(data.items())
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img, captions = self.data[index]
# Image
img = Image.open(os.path.join(self.root, img)).convert("RGB")
if self.transform is not None:
img = self.transform(img)
# Captions
target = captions
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)