import os import io from PIL import Image from torch.utils.data import Dataset def add_parser_arguments(parser): pass def task_dataset(): return TaskDataset class TaskDataset(Dataset): def __init__(self, args=None, is_train=True): super(TaskDataset, self).__init__() self.args = args self.is_train = is_train self.root_dir = None self.sample_list = [] self.idxs = [] self.im_loader = ImageLoader() if is_train: self.root_dir = list(self.args.trainset.values())[0] else: self.root_dir = list(self.args.valset.values())[0] def __len__(self): return len(self.sample_list) def __getitem__(self, idx): raise NotImplementedError class ImageLoader: def __init__(self): pass def load(self, name): image = Image.open(name) return image