Spaces:
Sleeping
Sleeping
| import os | |
| import torch.utils.data as data | |
| from os import listdir | |
| from os.path import join | |
| from data.util import * | |
| import torch.nn.functional as F | |
| class SICEDatasetFromFolderEval(data.Dataset): | |
| def __init__(self, data_dir, transform=None): | |
| super(SICEDatasetFromFolderEval, self).__init__() | |
| data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] | |
| data_filenames.sort() | |
| self.data_filenames = data_filenames | |
| self.transform = transform | |
| def __getitem__(self, index): | |
| input = load_img(self.data_filenames[index]) | |
| _, file = os.path.split(self.data_filenames[index]) | |
| if self.transform: | |
| input = self.transform(input) | |
| factor = 8 | |
| h, w = input.shape[1], input.shape[2] | |
| H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor | |
| padh = H - h if h % factor != 0 else 0 | |
| padw = W - w if w % factor != 0 else 0 | |
| input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect').squeeze(0) | |
| return input, file, h, w | |
| def __len__(self): | |
| return len(self.data_filenames) | |
| class DatasetFromFolderEval(data.Dataset): | |
| def __init__(self, data_dir, transform=None): | |
| super(DatasetFromFolderEval, self).__init__() | |
| data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] | |
| data_filenames.sort() | |
| self.data_filenames = data_filenames | |
| self.transform = transform | |
| def __getitem__(self, index): | |
| input = load_img(self.data_filenames[index]) | |
| _, file = os.path.split(self.data_filenames[index]) | |
| if self.transform: | |
| input = self.transform(input) | |
| return input, file | |
| def __len__(self): | |
| return len(self.data_filenames) |