|
|
import os |
|
|
import glob |
|
|
import random |
|
|
import pickle |
|
|
|
|
|
from data import common |
|
|
|
|
|
import numpy as np |
|
|
import imageio |
|
|
import torch |
|
|
import torch.utils.data as data |
|
|
|
|
|
class SRData(data.Dataset): |
|
|
def __init__(self, args, name='', train=True, benchmark=False): |
|
|
self.args = args |
|
|
self.name = name |
|
|
self.train = train |
|
|
self.split = 'train' if train else 'test' |
|
|
self.do_eval = True |
|
|
self.benchmark = benchmark |
|
|
self.input_large = (args.model == 'VDSR') |
|
|
self.scale = args.scale |
|
|
self.idx_scale = 0 |
|
|
|
|
|
self._set_filesystem(args.dir_data) |
|
|
if args.ext.find('img') < 0: |
|
|
path_bin = os.path.join(self.apath, 'bin') |
|
|
os.makedirs(path_bin, exist_ok=True) |
|
|
|
|
|
list_hr, list_edge, list_lr = self._scan() |
|
|
if args.ext.find('img') >= 0 or benchmark: |
|
|
self.images_hr, self.images_edge, self.images_lr = list_hr, list_edge, list_lr |
|
|
elif args.ext.find('sep') >= 0: |
|
|
os.makedirs( |
|
|
self.dir_hr.replace(self.apath, path_bin), |
|
|
exist_ok=True |
|
|
) |
|
|
os.makedirs( |
|
|
self.dir_edge.replace(self.apath, path_bin), |
|
|
exist_ok=True |
|
|
) |
|
|
for s in self.scale: |
|
|
os.makedirs( |
|
|
os.path.join( |
|
|
self.dir_lr.replace(self.apath, path_bin), |
|
|
'X{}'.format(s) |
|
|
), |
|
|
exist_ok=True |
|
|
) |
|
|
|
|
|
self.images_hr, self.images_edge, self.images_lr = [], [], [[] for _ in self.scale] |
|
|
for h in list_hr: |
|
|
b = h.replace(self.apath, path_bin) |
|
|
b = b.replace(self.ext[0], '.pt') |
|
|
self.images_hr.append(b) |
|
|
self._check_and_load(args.ext, h, b, verbose=True) |
|
|
|
|
|
for e in list_edge: |
|
|
g = e.replace(self.apath, path_bin) |
|
|
g = g.replace(self.ext[0], '.pt') |
|
|
self.images_edge.append(g) |
|
|
self._check_and_load( |
|
|
args.ext, e, g, verbose=True) |
|
|
|
|
|
for i, ll in enumerate(list_lr): |
|
|
for l in ll: |
|
|
b = l.replace(self.apath, path_bin) |
|
|
b = b.replace(self.ext[1], '.pt') |
|
|
self.images_lr[i].append(b) |
|
|
self._check_and_load(args.ext, l, b, verbose=True) |
|
|
if train: |
|
|
n_patches = args.batch_size * args.test_every |
|
|
n_images = len(args.data_train) * len(self.images_hr) |
|
|
if n_images == 0: |
|
|
self.repeat = 0 |
|
|
else: |
|
|
self.repeat = max(n_patches // n_images, 1) |
|
|
|
|
|
|
|
|
def _scan(self): |
|
|
names_hr = sorted( |
|
|
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) |
|
|
) |
|
|
names_edge = sorted( |
|
|
glob.glob(os.path.join(self.dir_edge, '*' + self.ext[0])) |
|
|
) |
|
|
names_lr = [[] for _ in self.scale] |
|
|
for f in names_hr: |
|
|
filename, _ = os.path.splitext(os.path.basename(f)) |
|
|
for si, s in enumerate(self.scale): |
|
|
names_lr[si].append(os.path.join( |
|
|
self.dir_lr, 'X{}/{}{}'.format( |
|
|
s, filename, self.ext[1] |
|
|
) |
|
|
)) |
|
|
|
|
|
return names_hr, names_edge, names_lr |
|
|
|
|
|
def _set_filesystem(self, dir_data): |
|
|
self.apath = os.path.join(dir_data, self.name) |
|
|
self.dir_hr = os.path.join(self.apath, 'HR') |
|
|
self.dir_edge = os.path.join(self.apath, 'EDGE') |
|
|
self.dir_lr = os.path.join(self.apath, 'LR_bicubic') |
|
|
if self.input_large: self.dir_lr += 'L' |
|
|
self.ext = ('.jpg', '.jpg') |
|
|
|
|
|
|
|
|
def _check_and_load(self, ext, img, f, verbose=True): |
|
|
if not os.path.isfile(f) or ext.find('reset') >= 0: |
|
|
if verbose: |
|
|
print('Making a binary: {}'.format(f)) |
|
|
with open(f, 'wb') as _f: |
|
|
pickle.dump(imageio.imread(img), _f) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
lr, edge, hr, filename = self._load_file(idx) |
|
|
lr, edge, hr = self.get_patch(lr, edge, hr) |
|
|
lr, edge, hr = common.set_channel(lr, edge, hr, n_channels=self.args.n_colors) |
|
|
lr_tensor, edge_tensor, hr_tensor = common.np2Tensor(lr, edge, hr, rgb_range=self.args.rgb_range) |
|
|
|
|
|
return lr_tensor, edge_tensor, hr_tensor, filename |
|
|
|
|
|
def __len__(self): |
|
|
if self.train: |
|
|
return len(self.images_hr) * self.repeat |
|
|
else: |
|
|
return len(self.images_hr) |
|
|
|
|
|
def _get_index(self, idx): |
|
|
if self.train: |
|
|
return idx % len(self.images_hr) |
|
|
else: |
|
|
return idx |
|
|
|
|
|
def _load_file(self, idx): |
|
|
idx = self._get_index(idx) |
|
|
f_hr = self.images_hr[idx] |
|
|
f_edge = self.images_edge[idx] |
|
|
f_lr = self.images_lr[self.idx_scale][idx] |
|
|
|
|
|
filename, _ = os.path.splitext(os.path.basename(f_hr)) |
|
|
if self.args.ext == 'img' or self.benchmark: |
|
|
hr = imageio.imread(f_hr) |
|
|
edge = imageio.imread(f_edge) |
|
|
lr = imageio.imread(f_lr) |
|
|
elif self.args.ext.find('sep') >= 0: |
|
|
with open(f_hr, 'rb') as _f: |
|
|
hr = pickle.load(_f) |
|
|
with open(f_edge, 'rb') as _f: |
|
|
edge = pickle.load(_f) |
|
|
with open(f_lr, 'rb') as _f: |
|
|
lr = pickle.load(_f) |
|
|
|
|
|
return lr, edge, hr, filename |
|
|
|
|
|
def get_patch(self, lr, edge, hr): |
|
|
scale = self.scale[self.idx_scale] |
|
|
if self.train: |
|
|
lr, edge, hr = common.get_patch( |
|
|
lr, edge, hr, |
|
|
patch_size=self.args.patch_size, |
|
|
scale=scale, |
|
|
multi=(len(self.scale) > 1), |
|
|
input_large=self.input_large |
|
|
) |
|
|
if not self.args.no_augment: |
|
|
lr, edge, hr = common.augment(lr, edge, hr) |
|
|
else: |
|
|
ih, iw = lr.shape[:2] |
|
|
hr = hr[0:ih * scale, 0:iw * scale] |
|
|
edge = edge[0:ih * scale, 0:iw * scale] |
|
|
|
|
|
return lr, edge, hr |
|
|
|
|
|
def set_scale(self, idx_scale): |
|
|
if not self.input_large: |
|
|
self.idx_scale = idx_scale |
|
|
else: |
|
|
self.idx_scale = random.randint(0, len(self.scale) - 1) |
|
|
|
|
|
|