| import numpy as np
|
| import random
|
| import torch
|
| from pathlib import Path
|
| from torch.utils import data as data
|
|
|
| from basicsr.data.transforms import augment, paired_random_crop
|
| from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| from basicsr.utils.flow_util import dequantize_flow
|
|
|
|
|
| class REDSDataset(data.Dataset):
|
| """REDS dataset for training.
|
|
|
| The keys are generated from a meta info txt file.
|
| basicsr/data/meta_info/meta_info_REDS_GT.txt
|
|
|
| Each line contains:
|
| 1. subfolder (clip) name; 2. frame number; 3. image shape, seperated by
|
| a white space.
|
| Examples:
|
| 000 100 (720,1280,3)
|
| 001 100 (720,1280,3)
|
| ...
|
|
|
| Key examples: "000/00000000"
|
| GT (gt): Ground-Truth;
|
| LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
|
|
| Args:
|
| opt (dict): Config for train dataset. It contains the following keys:
|
| dataroot_gt (str): Data root path for gt.
|
| dataroot_lq (str): Data root path for lq.
|
| dataroot_flow (str, optional): Data root path for flow.
|
| meta_info_file (str): Path for meta information file.
|
| val_partition (str): Validation partition types. 'REDS4' or
|
| 'official'.
|
| io_backend (dict): IO backend type and other kwarg.
|
|
|
| num_frame (int): Window size for input frames.
|
| gt_size (int): Cropped patched size for gt patches.
|
| interval_list (list): Interval list for temporal augmentation.
|
| random_reverse (bool): Random reverse input frames.
|
| use_flip (bool): Use horizontal flips.
|
| use_rot (bool): Use rotation (use vertical flip and transposing h
|
| and w for implementation).
|
|
|
| scale (bool): Scale, which will be added automatically.
|
| """
|
|
|
| def __init__(self, opt):
|
| super(REDSDataset, self).__init__()
|
| self.opt = opt
|
| self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
|
| opt['dataroot_lq'])
|
| self.flow_root = Path(
|
| opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
|
| assert opt['num_frame'] % 2 == 1, (
|
| f'num_frame should be odd number, but got {opt["num_frame"]}')
|
| self.num_frame = opt['num_frame']
|
| self.num_half_frames = opt['num_frame'] // 2
|
|
|
| self.keys = []
|
| with open(opt['meta_info_file'], 'r') as fin:
|
| for line in fin:
|
| folder, frame_num, _ = line.split(' ')
|
| self.keys.extend(
|
| [f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
|
|
|
|
| if opt['val_partition'] == 'REDS4':
|
| val_partition = ['000', '011', '015', '020']
|
| elif opt['val_partition'] == 'official':
|
| val_partition = [f'{v:03d}' for v in range(240, 270)]
|
| else:
|
| raise ValueError(
|
| f'Wrong validation partition {opt["val_partition"]}.'
|
| f"Supported ones are ['official', 'REDS4'].")
|
| self.keys = [
|
| v for v in self.keys if v.split('/')[0] not in val_partition
|
| ]
|
|
|
|
|
| self.file_client = None
|
| self.io_backend_opt = opt['io_backend']
|
| self.is_lmdb = False
|
| if self.io_backend_opt['type'] == 'lmdb':
|
| self.is_lmdb = True
|
| if self.flow_root is not None:
|
| self.io_backend_opt['db_paths'] = [
|
| self.lq_root, self.gt_root, self.flow_root
|
| ]
|
| self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
| else:
|
| self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
| self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
|
|
|
|
| self.interval_list = opt['interval_list']
|
| self.random_reverse = opt['random_reverse']
|
| interval_str = ','.join(str(x) for x in opt['interval_list'])
|
| logger = get_root_logger()
|
| logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
| f'random reverse is {self.random_reverse}.')
|
|
|
| def __getitem__(self, index):
|
| if self.file_client is None:
|
| self.file_client = FileClient(
|
| self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
|
|
| scale = self.opt['scale']
|
| gt_size = self.opt['gt_size']
|
| key = self.keys[index]
|
| clip_name, frame_name = key.split('/')
|
| center_frame_idx = int(frame_name)
|
|
|
|
|
| interval = random.choice(self.interval_list)
|
|
|
|
|
| start_frame_idx = center_frame_idx - self.num_half_frames * interval
|
| end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
|
|
| while (start_frame_idx < 0) or (end_frame_idx > 99):
|
| center_frame_idx = random.randint(0, 99)
|
| start_frame_idx = (
|
| center_frame_idx - self.num_half_frames * interval)
|
| end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
| frame_name = f'{center_frame_idx:08d}'
|
| neighbor_list = list(
|
| range(center_frame_idx - self.num_half_frames * interval,
|
| center_frame_idx + self.num_half_frames * interval + 1,
|
| interval))
|
|
|
| if self.random_reverse and random.random() < 0.5:
|
| neighbor_list.reverse()
|
|
|
| assert len(neighbor_list) == self.num_frame, (
|
| f'Wrong length of neighbor list: {len(neighbor_list)}')
|
|
|
|
|
| if self.is_lmdb:
|
| img_gt_path = f'{clip_name}/{frame_name}'
|
| else:
|
| img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
|
| img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| img_gt = imfrombytes(img_bytes, float32=True)
|
|
|
|
|
| img_lqs = []
|
| for neighbor in neighbor_list:
|
| if self.is_lmdb:
|
| img_lq_path = f'{clip_name}/{neighbor:08d}'
|
| else:
|
| img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
| img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| img_lq = imfrombytes(img_bytes, float32=True)
|
| img_lqs.append(img_lq)
|
|
|
|
|
| if self.flow_root is not None:
|
| img_flows = []
|
|
|
| for i in range(self.num_half_frames, 0, -1):
|
| if self.is_lmdb:
|
| flow_path = f'{clip_name}/{frame_name}_p{i}'
|
| else:
|
| flow_path = (
|
| self.flow_root / clip_name / f'{frame_name}_p{i}.png')
|
| img_bytes = self.file_client.get(flow_path, 'flow')
|
| cat_flow = imfrombytes(
|
| img_bytes, flag='grayscale',
|
| float32=False)
|
| dx, dy = np.split(cat_flow, 2, axis=0)
|
| flow = dequantize_flow(
|
| dx, dy, max_val=20,
|
| denorm=False)
|
| img_flows.append(flow)
|
|
|
| for i in range(1, self.num_half_frames + 1):
|
| if self.is_lmdb:
|
| flow_path = f'{clip_name}/{frame_name}_n{i}'
|
| else:
|
| flow_path = (
|
| self.flow_root / clip_name / f'{frame_name}_n{i}.png')
|
| img_bytes = self.file_client.get(flow_path, 'flow')
|
| cat_flow = imfrombytes(
|
| img_bytes, flag='grayscale',
|
| float32=False)
|
| dx, dy = np.split(cat_flow, 2, axis=0)
|
| flow = dequantize_flow(
|
| dx, dy, max_val=20,
|
| denorm=False)
|
| img_flows.append(flow)
|
|
|
|
|
|
|
| img_lqs.extend(img_flows)
|
|
|
|
|
| img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
|
| img_gt_path)
|
| if self.flow_root is not None:
|
| img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.
|
| num_frame:]
|
|
|
|
|
| img_lqs.append(img_gt)
|
| if self.flow_root is not None:
|
| img_results, img_flows = augment(img_lqs, self.opt['use_flip'],
|
| self.opt['use_rot'], img_flows)
|
| else:
|
| img_results = augment(img_lqs, self.opt['use_flip'],
|
| self.opt['use_rot'])
|
|
|
| img_results = img2tensor(img_results)
|
| img_lqs = torch.stack(img_results[0:-1], dim=0)
|
| img_gt = img_results[-1]
|
|
|
| if self.flow_root is not None:
|
| img_flows = img2tensor(img_flows)
|
|
|
| img_flows.insert(self.num_half_frames,
|
| torch.zeros_like(img_flows[0]))
|
| img_flows = torch.stack(img_flows, dim=0)
|
|
|
|
|
|
|
|
|
|
|
| if self.flow_root is not None:
|
| return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
|
| else:
|
| return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
|
|
| def __len__(self):
|
| return len(self.keys)
|
|
|