| import os | |
| from data import srdata | |
| class DIV2K(srdata.SRData): | |
| def __init__(self, args, name='DIV2K', train=True, benchmark=False): | |
| data_range = [r.split('-') for r in args.data_range.split('/')] | |
| if train: | |
| data_range = data_range[0] | |
| else: | |
| if args.test_only and len(data_range) == 1: | |
| data_range = data_range[0] | |
| else: | |
| data_range = data_range[1] | |
| self.begin, self.end = list(map(lambda x: int(x), data_range)) | |
| super(DIV2K, self).__init__( | |
| args, name=name, train=train, benchmark=benchmark | |
| ) | |
| def _scan(self): | |
| names_hr, names_lr = super(DIV2K, self)._scan() | |
| names_hr = names_hr[self.begin - 1:self.end] | |
| names_lr = [n[self.begin - 1:self.end] for n in names_lr] | |
| return names_hr, names_lr | |
| def _set_filesystem(self, dir_data): | |
| super(DIV2K, self)._set_filesystem(dir_data) | |
| self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') | |
| self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') | |
| if self.input_large: self.dir_lr += 'L' | |