|
|
""" |
|
|
Adapted from https://github.com/pytorch/vision/blob/main/torchvision/datasets/flickr.py |
|
|
Thanks to the authors of torchvision |
|
|
""" |
|
|
from collections import defaultdict |
|
|
import glob |
|
|
import os |
|
|
from collections import defaultdict |
|
|
from html.parser import HTMLParser |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
|
|
|
from PIL import Image |
|
|
from torchvision.datasets import VisionDataset |
|
|
|
|
|
class Flickr(VisionDataset): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
root: str, |
|
|
ann_file: str, |
|
|
transform: Optional[Callable] = None, |
|
|
target_transform: Optional[Callable] = None, |
|
|
) -> None: |
|
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
|
self.ann_file = os.path.expanduser(ann_file) |
|
|
data = defaultdict(list) |
|
|
with open(ann_file) as fd: |
|
|
fd.readline() |
|
|
for line in fd: |
|
|
line = line.strip() |
|
|
if line: |
|
|
|
|
|
img, caption = line.strip().split(".jpg,") |
|
|
img = img + ".jpg" |
|
|
data[img].append(caption) |
|
|
self.data = list(data.items()) |
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
|
|
|
Returns: |
|
|
tuple: Tuple (image, target). target is a list of captions for the image. |
|
|
""" |
|
|
img, captions = self.data[index] |
|
|
|
|
|
|
|
|
img = Image.open(os.path.join(self.root, img)).convert("RGB") |
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
|
|
|
target = captions |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |