Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from PIL import Image | |
| import pickle | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from datasets import register | |
| class HRImgLoader(Dataset): | |
| def __init__(self, root_path, split_file, split_key, first_k=None, cache='none'): | |
| self.cache = cache | |
| with open(split_file, 'r') as f: | |
| filenames = json.load(f)[split_key] | |
| if first_k is not None: | |
| filenames = filenames[:first_k] | |
| self.files = [] | |
| for filename in filenames: | |
| file = os.path.join(root_path, filename) | |
| if cache == 'none': | |
| self.files.append(file) | |
| elif cache == 'bin': | |
| bin_root = os.path.join(os.path.dirname(root_path), | |
| '_bin_' + os.path.basename(root_path)) | |
| if not os.path.exists(bin_root): | |
| os.mkdir(bin_root) | |
| print('mkdir', bin_root) | |
| bin_file = os.path.join( | |
| bin_root, filename.split('.')[0] + '.pkl') | |
| if not os.path.exists(bin_file): | |
| with open(bin_file, 'wb') as f: | |
| pickle.dump(imageio.imread(file), f) | |
| print('dump', bin_file) | |
| self.files.append(bin_file) | |
| elif cache == 'in_memory': | |
| self.files.append(transforms.ToTensor()( | |
| Image.open(file).convert('RGB'))) | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| x = self.files[idx] | |
| file_name = x | |
| if self.cache == 'none': | |
| return transforms.ToTensor()(Image.open(x).convert('RGB')), file_name | |
| elif self.cache == 'bin': | |
| with open(x, 'rb') as f: | |
| x = pickle.load(f) | |
| x = np.ascontiguousarray(x.transpose(2, 0, 1)) | |
| x = torch.from_numpy(x).float() / 255 | |
| return x, file_name | |
| elif self.cache == 'in_memory': | |
| return x, file_name | |