|
|
import os |
|
|
from data import srdata |
|
|
|
|
|
class DIV2K(srdata.SRData): |
|
|
def __init__(self, args, name='DIV2K', train=True, benchmark=False): |
|
|
data_range = [r.split('-') for r in args.data_range.split('/')] |
|
|
if train: |
|
|
data_range = data_range[0] |
|
|
else: |
|
|
if args.test_only and len(data_range) == 1: |
|
|
data_range = data_range[0] |
|
|
else: |
|
|
data_range = data_range[1] |
|
|
|
|
|
self.begin, self.end = list(map(lambda x: int(x), data_range)) |
|
|
super(DIV2K, self).__init__( |
|
|
args, name=name, train=train, benchmark=benchmark |
|
|
) |
|
|
|
|
|
def _scan(self): |
|
|
names_hr, names_edge, names_lr = super(DIV2K, self)._scan() |
|
|
names_hr = names_hr[self.begin - 1:self.end] |
|
|
names_edge = names_edge[self.begin - 1:self.end] |
|
|
names_lr = [n[self.begin - 1:self.end] for n in names_lr] |
|
|
|
|
|
return names_hr, names_edge, names_lr |
|
|
|
|
|
def _set_filesystem(self, dir_data): |
|
|
super(DIV2K, self)._set_filesystem(dir_data) |
|
|
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') |
|
|
self.dir_edge = os.path.join(self.apath, 'DIV2K_train_EDGE_disturbed') |
|
|
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') |
|
|
if self.input_large: self.dir_lr += 'L' |
|
|
|
|
|
|