| | import glob |
| | import torch |
| | from os import path as osp |
| | from torch.utils import data as data |
| |
|
| | from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq |
| | from basicsr.utils import get_root_logger, scandir |
| | from basicsr.utils.registry import DATASET_REGISTRY |
| |
|
| |
|
| | @DATASET_REGISTRY.register() |
| | class VideoTestDataset(data.Dataset): |
| | """Video test dataset. |
| | |
| | Supported datasets: Vid4, REDS4, REDSofficial. |
| | More generally, it supports testing dataset with following structures: |
| | |
| | dataroot |
| | βββ subfolder1 |
| | βββ frame000 |
| | βββ frame001 |
| | βββ ... |
| | βββ subfolder1 |
| | βββ frame000 |
| | βββ frame001 |
| | βββ ... |
| | βββ ... |
| | |
| | For testing datasets, there is no need to prepare LMDB files. |
| | |
| | Args: |
| | opt (dict): Config for train dataset. It contains the following keys: |
| | dataroot_gt (str): Data root path for gt. |
| | dataroot_lq (str): Data root path for lq. |
| | io_backend (dict): IO backend type and other kwarg. |
| | cache_data (bool): Whether to cache testing datasets. |
| | name (str): Dataset name. |
| | meta_info_file (str): The path to the file storing the list of test |
| | folders. If not provided, all the folders in the dataroot will |
| | be used. |
| | num_frame (int): Window size for input frames. |
| | padding (str): Padding mode. |
| | """ |
| |
|
| | def __init__(self, opt): |
| | super(VideoTestDataset, self).__init__() |
| | self.opt = opt |
| | self.cache_data = opt['cache_data'] |
| | self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] |
| | self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} |
| | |
| | self.file_client = None |
| | self.io_backend_opt = opt['io_backend'] |
| | assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.' |
| |
|
| | logger = get_root_logger() |
| | logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') |
| | self.imgs_lq, self.imgs_gt = {}, {} |
| | if 'meta_info_file' in opt: |
| | with open(opt['meta_info_file'], 'r') as fin: |
| | subfolders = [line.split(' ')[0] for line in fin] |
| | subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders] |
| | subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders] |
| | else: |
| | subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) |
| | subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*'))) |
| |
|
| | if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']: |
| | for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt): |
| | |
| | subfolder_name = osp.basename(subfolder_lq) |
| | img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True))) |
| | img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True))) |
| |
|
| | max_idx = len(img_paths_lq) |
| | assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})' |
| | f' and gt folders ({len(img_paths_gt)})') |
| |
|
| | self.data_info['lq_path'].extend(img_paths_lq) |
| | self.data_info['gt_path'].extend(img_paths_gt) |
| | self.data_info['folder'].extend([subfolder_name] * max_idx) |
| | for i in range(max_idx): |
| | self.data_info['idx'].append(f'{i}/{max_idx}') |
| | border_l = [0] * max_idx |
| | for i in range(self.opt['num_frame'] // 2): |
| | border_l[i] = 1 |
| | border_l[max_idx - i - 1] = 1 |
| | self.data_info['border'].extend(border_l) |
| |
|
| | |
| | if self.cache_data: |
| | logger.info(f'Cache {subfolder_name} for VideoTestDataset...') |
| | self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq) |
| | self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt) |
| | else: |
| | self.imgs_lq[subfolder_name] = img_paths_lq |
| | self.imgs_gt[subfolder_name] = img_paths_gt |
| | else: |
| | raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}') |
| |
|
| | def __getitem__(self, index): |
| | folder = self.data_info['folder'][index] |
| | idx, max_idx = self.data_info['idx'][index].split('/') |
| | idx, max_idx = int(idx), int(max_idx) |
| | border = self.data_info['border'][index] |
| | lq_path = self.data_info['lq_path'][index] |
| |
|
| | select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) |
| |
|
| | if self.cache_data: |
| | imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx)) |
| | img_gt = self.imgs_gt[folder][idx] |
| | else: |
| | img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] |
| | imgs_lq = read_img_seq(img_paths_lq) |
| | img_gt = read_img_seq([self.imgs_gt[folder][idx]]) |
| | img_gt.squeeze_(0) |
| |
|
| | return { |
| | 'lq': imgs_lq, |
| | 'gt': img_gt, |
| | 'folder': folder, |
| | 'idx': self.data_info['idx'][index], |
| | 'border': border, |
| | 'lq_path': lq_path |
| | } |
| |
|
| | def __len__(self): |
| | return len(self.data_info['gt_path']) |
| |
|
| |
|
| | @DATASET_REGISTRY.register() |
| | class VideoTestVimeo90KDataset(data.Dataset): |
| | """Video test dataset for Vimeo90k-Test dataset. |
| | |
| | It only keeps the center frame for testing. |
| | For testing datasets, there is no need to prepare LMDB files. |
| | |
| | Args: |
| | opt (dict): Config for train dataset. It contains the following keys: |
| | dataroot_gt (str): Data root path for gt. |
| | dataroot_lq (str): Data root path for lq. |
| | io_backend (dict): IO backend type and other kwarg. |
| | cache_data (bool): Whether to cache testing datasets. |
| | name (str): Dataset name. |
| | meta_info_file (str): The path to the file storing the list of test |
| | folders. If not provided, all the folders in the dataroot will |
| | be used. |
| | num_frame (int): Window size for input frames. |
| | padding (str): Padding mode. |
| | """ |
| |
|
| | def __init__(self, opt): |
| | super(VideoTestVimeo90KDataset, self).__init__() |
| | self.opt = opt |
| | self.cache_data = opt['cache_data'] |
| | if self.cache_data: |
| | raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.') |
| | self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] |
| | self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} |
| | neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] |
| |
|
| | |
| | self.file_client = None |
| | self.io_backend_opt = opt['io_backend'] |
| | assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.' |
| |
|
| | logger = get_root_logger() |
| | logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') |
| | with open(opt['meta_info_file'], 'r') as fin: |
| | subfolders = [line.split(' ')[0] for line in fin] |
| | for idx, subfolder in enumerate(subfolders): |
| | gt_path = osp.join(self.gt_root, subfolder, 'im4.png') |
| | self.data_info['gt_path'].append(gt_path) |
| | lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list] |
| | self.data_info['lq_path'].append(lq_paths) |
| | self.data_info['folder'].append('vimeo90k') |
| | self.data_info['idx'].append(f'{idx}/{len(subfolders)}') |
| | self.data_info['border'].append(0) |
| |
|
| | def __getitem__(self, index): |
| | lq_path = self.data_info['lq_path'][index] |
| | gt_path = self.data_info['gt_path'][index] |
| | imgs_lq = read_img_seq(lq_path) |
| | img_gt = read_img_seq([gt_path]) |
| | img_gt.squeeze_(0) |
| |
|
| | return { |
| | 'lq': imgs_lq, |
| | 'gt': img_gt, |
| | 'folder': self.data_info['folder'][index], |
| | 'idx': self.data_info['idx'][index], |
| | 'border': self.data_info['border'][index], |
| | 'lq_path': lq_path[self.opt['num_frame'] // 2] |
| | } |
| |
|
| | def __len__(self): |
| | return len(self.data_info['gt_path']) |
| |
|
| |
|
| | @DATASET_REGISTRY.register() |
| | class VideoTestDUFDataset(VideoTestDataset): |
| | """ Video test dataset for DUF dataset. |
| | |
| | Args: |
| | opt (dict): Config for train dataset. |
| | Most of keys are the same as VideoTestDataset. |
| | It has the follwing extra keys: |
| | |
| | use_duf_downsampling (bool): Whether to use duf downsampling to |
| | generate low-resolution frames. |
| | scale (bool): Scale, which will be added automatically. |
| | """ |
| |
|
| | def __getitem__(self, index): |
| | folder = self.data_info['folder'][index] |
| | idx, max_idx = self.data_info['idx'][index].split('/') |
| | idx, max_idx = int(idx), int(max_idx) |
| | border = self.data_info['border'][index] |
| | lq_path = self.data_info['lq_path'][index] |
| |
|
| | select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) |
| |
|
| | if self.cache_data: |
| | if self.opt['use_duf_downsampling']: |
| | |
| | imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx)) |
| | imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) |
| | else: |
| | imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx)) |
| | img_gt = self.imgs_gt[folder][idx] |
| | else: |
| | if self.opt['use_duf_downsampling']: |
| | img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx] |
| | |
| | imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale']) |
| | imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) |
| | else: |
| | img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] |
| | imgs_lq = read_img_seq(img_paths_lq) |
| | img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale']) |
| | img_gt.squeeze_(0) |
| |
|
| | return { |
| | 'lq': imgs_lq, |
| | 'gt': img_gt, |
| | 'folder': folder, |
| | 'idx': self.data_info['idx'][index], |
| | 'border': border, |
| | 'lq_path': lq_path |
| | } |
| |
|
| |
|
| | @DATASET_REGISTRY.register() |
| | class VideoRecurrentTestDataset(VideoTestDataset): |
| | """Video test dataset for recurrent architectures, which takes LR video |
| | frames as input and output corresponding HR video frames. |
| | |
| | Args: |
| | Same as VideoTestDataset. |
| | Unused opt: |
| | padding (str): Padding mode. |
| | |
| | """ |
| |
|
| | def __init__(self, opt): |
| | super(VideoRecurrentTestDataset, self).__init__(opt) |
| | |
| | self.folders = sorted(list(set(self.data_info['folder']))) |
| |
|
| | def __getitem__(self, index): |
| | folder = self.folders[index] |
| |
|
| | if self.cache_data: |
| | imgs_lq = self.imgs_lq[folder] |
| | imgs_gt = self.imgs_gt[folder] |
| | else: |
| | raise NotImplementedError('Without cache_data is not implemented.') |
| |
|
| | return { |
| | 'lq': imgs_lq, |
| | 'gt': imgs_gt, |
| | 'folder': folder, |
| | } |
| |
|
| | def __len__(self): |
| | return len(self.folders) |
| |
|