| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from utils import data_utils | |
| class InferenceDataset(Dataset): | |
| def __init__(self, root=None, paths_list=None, opts=None, transform=None, return_path=False): | |
| if paths_list is None: | |
| self.paths = sorted(data_utils.make_dataset(root)) | |
| else: | |
| self.paths = data_utils.make_dataset_from_paths_list(paths_list) | |
| self.transform = transform | |
| self.opts = opts | |
| self.return_path = return_path | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, index): | |
| from_path = self.paths[index] | |
| from_im = Image.open(from_path) | |
| from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') | |
| if self.transform: | |
| from_im = self.transform(from_im) | |
| if self.return_path: | |
| return from_im, from_path | |
| else: | |
| return from_im | |