| | import os
|
| | import random
|
| | import imageio
|
| | import numpy as np
|
| | import torch.utils.data as data
|
| |
|
| | from data import common
|
| |
|
| | from utils import interact
|
| |
|
| | class Dataset(data.Dataset):
|
| | """Basic dataloader class
|
| | """
|
| | def __init__(self, args, mode='train'):
|
| | super(Dataset, self).__init__()
|
| | self.args = args
|
| | self.mode = mode
|
| |
|
| | self.modes = ()
|
| | self.set_modes()
|
| | self._check_mode()
|
| |
|
| | self.set_keys()
|
| |
|
| | if self.mode == 'train':
|
| | dataset = args.data_train
|
| | elif self.mode == 'val':
|
| | dataset = args.data_val
|
| | elif self.mode == 'test':
|
| | dataset = args.data_test
|
| | elif self.mode == 'demo':
|
| | pass
|
| | else:
|
| | raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode))
|
| |
|
| | if self.mode == 'demo':
|
| | self.subset_root = args.demo_input_dir
|
| | else:
|
| | self.subset_root = os.path.join(args.data_root, dataset, self.mode)
|
| |
|
| | self.blur_list = []
|
| | self.sharp_list = []
|
| |
|
| | self._scan()
|
| |
|
| | def set_modes(self):
|
| | self.modes = ('train', 'val', 'test', 'demo')
|
| |
|
| | def _check_mode(self):
|
| | """Should be called in the child class __init__() after super
|
| | """
|
| | if self.mode not in self.modes:
|
| | raise NotImplementedError('mode error: not for {}'.format(self.mode))
|
| |
|
| | return
|
| |
|
| | def set_keys(self):
|
| | self.blur_key = 'blur'
|
| | self.sharp_key = 'sharp'
|
| |
|
| | self.non_blur_keys = []
|
| | self.non_sharp_keys = []
|
| |
|
| | return
|
| |
|
| | def _scan(self, root=None):
|
| | """Should be called in the child class __init__() after super
|
| | """
|
| | if root is None:
|
| | root = self.subset_root
|
| |
|
| | if self.blur_key in self.non_blur_keys:
|
| | self.non_blur_keys.remove(self.blur_key)
|
| | if self.sharp_key in self.non_sharp_keys:
|
| | self.non_sharp_keys.remove(self.sharp_key)
|
| |
|
| | def _key_check(path, true_key, false_keys):
|
| | path = os.path.join(path, '')
|
| | if path.find(true_key) >= 0:
|
| | for false_key in false_keys:
|
| | if path.find(false_key) >= 0:
|
| | return False
|
| |
|
| | return True
|
| | else:
|
| | return False
|
| |
|
| | def _get_list_by_key(root, true_key, false_keys):
|
| | data_list = []
|
| | for sub, dirs, files in os.walk(root):
|
| | if not dirs:
|
| | file_list = [os.path.join(sub, f) for f in files]
|
| | if _key_check(sub, true_key, false_keys):
|
| | data_list += file_list
|
| |
|
| | data_list.sort()
|
| |
|
| | return data_list
|
| |
|
| | def _rectify_keys():
|
| | self.blur_key = os.path.join(self.blur_key, '')
|
| | self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys]
|
| | self.sharp_key = os.path.join(self.sharp_key, '')
|
| | self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys]
|
| |
|
| | _rectify_keys()
|
| |
|
| | self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys)
|
| | self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys)
|
| |
|
| | if len(self.sharp_list) > 0:
|
| | assert(len(self.blur_list) == len(self.sharp_list))
|
| |
|
| | return
|
| |
|
| | def __getitem__(self, idx):
|
| |
|
| | blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
|
| | if len(self.sharp_list) > 0:
|
| | sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
|
| | imgs = [blur, sharp]
|
| | else:
|
| | imgs = [blur]
|
| |
|
| | pad_width = 0
|
| | if self.mode == 'train':
|
| | imgs = common.crop(*imgs, ps=self.args.patch_size)
|
| | if self.args.augment:
|
| | imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
|
| | imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
|
| | elif self.mode == 'demo':
|
| | imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1))
|
| | else:
|
| | pass
|
| |
|
| | if self.args.gaussian_pyramid:
|
| | imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)
|
| |
|
| | imgs = common.np2tensor(*imgs, rgb_range=self.args.rgb_range)
|
| | relpath = os.path.relpath(self.blur_list[idx], self.subset_root)
|
| |
|
| | blur = imgs[0]
|
| | sharp = imgs[1] if len(imgs) > 1 else False
|
| |
|
| | return blur, sharp, pad_width, idx, relpath
|
| |
|
| | def __len__(self):
|
| | return len(self.blur_list)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|