Upload 14 files
Browse files- basicsr/data/__init__.py +101 -101
- basicsr/data/data_sampler.py +48 -0
- basicsr/data/data_util.py +315 -0
- basicsr/data/degradations.py +764 -0
- basicsr/data/ffhq_dataset.py +80 -0
- basicsr/data/paired_image_dataset.py +106 -0
- basicsr/data/prefetch_dataloader.py +122 -0
- basicsr/data/realesrgan_dataset.py +193 -0
- basicsr/data/realesrgan_paired_dataset.py +106 -0
- basicsr/data/reds_dataset.py +352 -0
- basicsr/data/single_image_dataset.py +68 -0
- basicsr/data/transforms.py +179 -0
- basicsr/data/video_test_dataset.py +283 -0
- basicsr/data/vimeo90k_dataset.py +199 -0
basicsr/data/__init__.py
CHANGED
|
@@ -1,101 +1,101 @@
|
|
| 1 |
-
import importlib
|
| 2 |
-
import numpy as np
|
| 3 |
-
import random
|
| 4 |
-
import torch
|
| 5 |
-
import torch.utils.data
|
| 6 |
-
from copy import deepcopy
|
| 7 |
-
from functools import partial
|
| 8 |
-
from os import path as osp
|
| 9 |
-
|
| 10 |
-
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
| 11 |
-
from basicsr.utils import get_root_logger, scandir
|
| 12 |
-
from basicsr.utils.dist_util import get_dist_info
|
| 13 |
-
from basicsr.utils.registry import DATASET_REGISTRY
|
| 14 |
-
|
| 15 |
-
__all__ = ['build_dataset', 'build_dataloader']
|
| 16 |
-
|
| 17 |
-
# automatically scan and import dataset modules for registry
|
| 18 |
-
# scan all the files under the data folder with '_dataset' in file names
|
| 19 |
-
data_folder = osp.dirname(osp.abspath(__file__))
|
| 20 |
-
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
| 21 |
-
# import all the dataset modules
|
| 22 |
-
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def build_dataset(dataset_opt):
|
| 26 |
-
"""Build dataset from options.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
dataset_opt (dict): Configuration for dataset. It must contain:
|
| 30 |
-
name (str): Dataset name.
|
| 31 |
-
type (str): Dataset type.
|
| 32 |
-
"""
|
| 33 |
-
dataset_opt = deepcopy(dataset_opt)
|
| 34 |
-
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
| 35 |
-
logger = get_root_logger()
|
| 36 |
-
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
| 37 |
-
return dataset
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
| 41 |
-
"""Build dataloader.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
dataset (torch.utils.data.Dataset): Dataset.
|
| 45 |
-
dataset_opt (dict): Dataset options. It contains the following keys:
|
| 46 |
-
phase (str): 'train' or 'val'.
|
| 47 |
-
num_worker_per_gpu (int): Number of workers for each GPU.
|
| 48 |
-
batch_size_per_gpu (int): Training batch size for each GPU.
|
| 49 |
-
num_gpu (int): Number of GPUs. Used only in the train phase.
|
| 50 |
-
Default: 1.
|
| 51 |
-
dist (bool): Whether in distributed training. Used only in the train
|
| 52 |
-
phase. Default: False.
|
| 53 |
-
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
| 54 |
-
seed (int | None): Seed. Default: None
|
| 55 |
-
"""
|
| 56 |
-
phase = dataset_opt['phase']
|
| 57 |
-
rank, _ = get_dist_info()
|
| 58 |
-
if phase == 'train':
|
| 59 |
-
if dist: # distributed training
|
| 60 |
-
batch_size = dataset_opt['batch_size_per_gpu']
|
| 61 |
-
num_workers = dataset_opt['num_worker_per_gpu']
|
| 62 |
-
else: # non-distributed training
|
| 63 |
-
multiplier = 1 if num_gpu == 0 else num_gpu
|
| 64 |
-
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
| 65 |
-
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
| 66 |
-
dataloader_args = dict(
|
| 67 |
-
dataset=dataset,
|
| 68 |
-
batch_size=batch_size,
|
| 69 |
-
shuffle=False,
|
| 70 |
-
num_workers=num_workers,
|
| 71 |
-
sampler=sampler,
|
| 72 |
-
drop_last=True)
|
| 73 |
-
if sampler is None:
|
| 74 |
-
dataloader_args['shuffle'] = True
|
| 75 |
-
dataloader_args['worker_init_fn'] = partial(
|
| 76 |
-
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
| 77 |
-
elif phase in ['val', 'test']: # validation
|
| 78 |
-
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 79 |
-
else:
|
| 80 |
-
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
| 81 |
-
|
| 82 |
-
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
| 83 |
-
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
| 84 |
-
|
| 85 |
-
prefetch_mode = dataset_opt.get('prefetch_mode')
|
| 86 |
-
if prefetch_mode == 'cpu': # CPUPrefetcher
|
| 87 |
-
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
| 88 |
-
logger = get_root_logger()
|
| 89 |
-
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
| 90 |
-
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
| 91 |
-
else:
|
| 92 |
-
# prefetch_mode=None: Normal dataloader
|
| 93 |
-
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
| 94 |
-
return torch.utils.data.DataLoader(**dataloader_args)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def worker_init_fn(worker_id, num_workers, rank, seed):
|
| 98 |
-
# Set the worker seed to num_workers * rank + worker_id + seed
|
| 99 |
-
worker_seed = num_workers * rank + worker_id + seed
|
| 100 |
-
np.random.seed(worker_seed)
|
| 101 |
-
random.seed(worker_seed)
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from functools import partial
|
| 8 |
+
from os import path as osp
|
| 9 |
+
|
| 10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
| 11 |
+
from basicsr.utils import get_root_logger, scandir
|
| 12 |
+
from basicsr.utils.dist_util import get_dist_info
|
| 13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 14 |
+
|
| 15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
| 16 |
+
|
| 17 |
+
# automatically scan and import dataset modules for registry
|
| 18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
| 19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
| 21 |
+
# import all the dataset modules
|
| 22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_dataset(dataset_opt):
|
| 26 |
+
"""Build dataset from options.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dataset_opt (dict): Configuration for dataset. It must contain:
|
| 30 |
+
name (str): Dataset name.
|
| 31 |
+
type (str): Dataset type.
|
| 32 |
+
"""
|
| 33 |
+
dataset_opt = deepcopy(dataset_opt)
|
| 34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
| 35 |
+
logger = get_root_logger()
|
| 36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
| 37 |
+
return dataset
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
| 41 |
+
"""Build dataloader.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
| 45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
| 46 |
+
phase (str): 'train' or 'val'.
|
| 47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
| 48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
| 49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
| 50 |
+
Default: 1.
|
| 51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
| 52 |
+
phase. Default: False.
|
| 53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
| 54 |
+
seed (int | None): Seed. Default: None
|
| 55 |
+
"""
|
| 56 |
+
phase = dataset_opt['phase']
|
| 57 |
+
rank, _ = get_dist_info()
|
| 58 |
+
if phase == 'train':
|
| 59 |
+
if dist: # distributed training
|
| 60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
| 61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
| 62 |
+
else: # non-distributed training
|
| 63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
| 64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
| 65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
| 66 |
+
dataloader_args = dict(
|
| 67 |
+
dataset=dataset,
|
| 68 |
+
batch_size=batch_size,
|
| 69 |
+
shuffle=False,
|
| 70 |
+
num_workers=num_workers,
|
| 71 |
+
sampler=sampler,
|
| 72 |
+
drop_last=True)
|
| 73 |
+
if sampler is None:
|
| 74 |
+
dataloader_args['shuffle'] = True
|
| 75 |
+
dataloader_args['worker_init_fn'] = partial(
|
| 76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
| 77 |
+
elif phase in ['val', 'test']: # validation
|
| 78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
| 81 |
+
|
| 82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
| 83 |
+
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
| 84 |
+
|
| 85 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
| 86 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
| 87 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
| 88 |
+
logger = get_root_logger()
|
| 89 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
| 90 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
| 91 |
+
else:
|
| 92 |
+
# prefetch_mode=None: Normal dataloader
|
| 93 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
| 94 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
| 98 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
| 99 |
+
worker_seed = num_workers * rank + worker_id + seed
|
| 100 |
+
np.random.seed(worker_seed)
|
| 101 |
+
random.seed(worker_seed)
|
basicsr/data/data_sampler.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.sampler import Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EnlargedSampler(Sampler):
|
| 7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 8 |
+
|
| 9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
| 10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
| 11 |
+
time when restart the dataloader after each epoch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
| 15 |
+
num_replicas (int | None): Number of processes participating in
|
| 16 |
+
the training. It is usually the world_size.
|
| 17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
| 18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.num_replicas = num_replicas
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.epoch = 0
|
| 26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
| 27 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 28 |
+
|
| 29 |
+
def __iter__(self):
|
| 30 |
+
# deterministically shuffle based on epoch
|
| 31 |
+
g = torch.Generator()
|
| 32 |
+
g.manual_seed(self.epoch)
|
| 33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
| 34 |
+
|
| 35 |
+
dataset_size = len(self.dataset)
|
| 36 |
+
indices = [v % dataset_size for v in indices]
|
| 37 |
+
|
| 38 |
+
# subsample
|
| 39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 40 |
+
assert len(indices) == self.num_samples
|
| 41 |
+
|
| 42 |
+
return iter(indices)
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return self.num_samples
|
| 46 |
+
|
| 47 |
+
def set_epoch(self, epoch):
|
| 48 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from os import path as osp
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from basicsr.data.transforms import mod_crop
|
| 8 |
+
from basicsr.utils import img2tensor, scandir
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
| 12 |
+
"""Read a sequence of images from a given folder path.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
path (list[str] | str): List of image paths or image folder path.
|
| 16 |
+
require_mod_crop (bool): Require mod crop for each image.
|
| 17 |
+
Default: False.
|
| 18 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
| 19 |
+
return_imgname(bool): Whether return image names. Default False.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
| 23 |
+
list[str]: Returned image name list.
|
| 24 |
+
"""
|
| 25 |
+
if isinstance(path, list):
|
| 26 |
+
img_paths = path
|
| 27 |
+
else:
|
| 28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
| 29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
| 30 |
+
|
| 31 |
+
if require_mod_crop:
|
| 32 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
| 33 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
| 34 |
+
imgs = torch.stack(imgs, dim=0)
|
| 35 |
+
|
| 36 |
+
if return_imgname:
|
| 37 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
| 38 |
+
return imgs, imgnames
|
| 39 |
+
else:
|
| 40 |
+
return imgs
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
| 44 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
| 45 |
+
of images.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
crt_idx (int): Current center index.
|
| 49 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
| 50 |
+
num_frames (int): Reading num_frames frames.
|
| 51 |
+
padding (str): Padding mode, one of
|
| 52 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
| 53 |
+
Examples: current_idx = 0, num_frames = 5
|
| 54 |
+
The generated frame indices under different padding mode:
|
| 55 |
+
replicate: [0, 0, 0, 1, 2]
|
| 56 |
+
reflection: [2, 1, 0, 1, 2]
|
| 57 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
| 58 |
+
circle: [3, 4, 0, 1, 2]
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
list[int]: A list of indices.
|
| 62 |
+
"""
|
| 63 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
| 64 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
| 65 |
+
|
| 66 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
| 67 |
+
num_pad = num_frames // 2
|
| 68 |
+
|
| 69 |
+
indices = []
|
| 70 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
| 71 |
+
if i < 0:
|
| 72 |
+
if padding == 'replicate':
|
| 73 |
+
pad_idx = 0
|
| 74 |
+
elif padding == 'reflection':
|
| 75 |
+
pad_idx = -i
|
| 76 |
+
elif padding == 'reflection_circle':
|
| 77 |
+
pad_idx = crt_idx + num_pad - i
|
| 78 |
+
else:
|
| 79 |
+
pad_idx = num_frames + i
|
| 80 |
+
elif i > max_frame_num:
|
| 81 |
+
if padding == 'replicate':
|
| 82 |
+
pad_idx = max_frame_num
|
| 83 |
+
elif padding == 'reflection':
|
| 84 |
+
pad_idx = max_frame_num * 2 - i
|
| 85 |
+
elif padding == 'reflection_circle':
|
| 86 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
| 87 |
+
else:
|
| 88 |
+
pad_idx = i - num_frames
|
| 89 |
+
else:
|
| 90 |
+
pad_idx = i
|
| 91 |
+
indices.append(pad_idx)
|
| 92 |
+
return indices
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def paired_paths_from_lmdb(folders, keys):
|
| 96 |
+
"""Generate paired paths from lmdb files.
|
| 97 |
+
|
| 98 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
| 99 |
+
|
| 100 |
+
::
|
| 101 |
+
|
| 102 |
+
lq.lmdb
|
| 103 |
+
βββ data.mdb
|
| 104 |
+
βββ lock.mdb
|
| 105 |
+
βββ meta_info.txt
|
| 106 |
+
|
| 107 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
| 108 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
| 109 |
+
|
| 110 |
+
The meta_info.txt is a specified txt file to record the meta information
|
| 111 |
+
of our datasets. It will be automatically created when preparing
|
| 112 |
+
datasets by our provided dataset tools.
|
| 113 |
+
Each line in the txt file records
|
| 114 |
+
1)image name (with extension),
|
| 115 |
+
2)image shape,
|
| 116 |
+
3)compression level, separated by a white space.
|
| 117 |
+
Example: `baboon.png (120,125,3) 1`
|
| 118 |
+
|
| 119 |
+
We use the image name without extension as the lmdb key.
|
| 120 |
+
Note that we use the same key for the corresponding lq and gt images.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 124 |
+
be [input_folder, gt_folder].
|
| 125 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 126 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 127 |
+
Note that this key is different from lmdb keys.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
list[str]: Returned path list.
|
| 131 |
+
"""
|
| 132 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 133 |
+
f'But got {len(folders)}')
|
| 134 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
| 135 |
+
input_folder, gt_folder = folders
|
| 136 |
+
input_key, gt_key = keys
|
| 137 |
+
|
| 138 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
| 139 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
| 140 |
+
f'formats. But received {input_key}: {input_folder}; '
|
| 141 |
+
f'{gt_key}: {gt_folder}')
|
| 142 |
+
# ensure that the two meta_info files are the same
|
| 143 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
| 144 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 145 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
| 146 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 147 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
| 148 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
| 149 |
+
else:
|
| 150 |
+
paths = []
|
| 151 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
| 152 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
| 153 |
+
return paths
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
| 157 |
+
"""Generate paired paths from an meta information file.
|
| 158 |
+
|
| 159 |
+
Each line in the meta information file contains the image names and
|
| 160 |
+
image shape (usually for gt), separated by a white space.
|
| 161 |
+
|
| 162 |
+
Example of an meta information file:
|
| 163 |
+
```
|
| 164 |
+
0001_s001.png (480,480,3)
|
| 165 |
+
0001_s002.png (480,480,3)
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 170 |
+
be [input_folder, gt_folder].
|
| 171 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 172 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 173 |
+
meta_info_file (str): Path to the meta information file.
|
| 174 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 175 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 176 |
+
for files in the input folder.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
list[str]: Returned path list.
|
| 180 |
+
"""
|
| 181 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 182 |
+
f'But got {len(folders)}')
|
| 183 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
| 184 |
+
input_folder, gt_folder = folders
|
| 185 |
+
input_key, gt_key = keys
|
| 186 |
+
|
| 187 |
+
with open(meta_info_file, 'r') as fin:
|
| 188 |
+
gt_names = [line.strip().split(' ')[0] for line in fin]
|
| 189 |
+
|
| 190 |
+
paths = []
|
| 191 |
+
for gt_name in gt_names:
|
| 192 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
| 193 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 194 |
+
input_path = osp.join(input_folder, input_name)
|
| 195 |
+
gt_path = osp.join(gt_folder, gt_name)
|
| 196 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
| 197 |
+
return paths
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
| 201 |
+
"""Generate paired paths from folders.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 205 |
+
be [input_folder, gt_folder].
|
| 206 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 207 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 208 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 209 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 210 |
+
for files in the input folder.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
list[str]: Returned path list.
|
| 214 |
+
"""
|
| 215 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 216 |
+
f'But got {len(folders)}')
|
| 217 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
| 218 |
+
input_folder, gt_folder = folders
|
| 219 |
+
input_key, gt_key = keys
|
| 220 |
+
|
| 221 |
+
input_paths = list(scandir(input_folder))
|
| 222 |
+
gt_paths = list(scandir(gt_folder))
|
| 223 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
| 224 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
| 225 |
+
paths = []
|
| 226 |
+
for gt_path in gt_paths:
|
| 227 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
| 228 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 229 |
+
input_path = osp.join(input_folder, input_name)
|
| 230 |
+
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
| 231 |
+
gt_path = osp.join(gt_folder, gt_path)
|
| 232 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
| 233 |
+
return paths
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def paths_from_folder(folder):
|
| 237 |
+
"""Generate paths from folder.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
folder (str): Folder path.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
list[str]: Returned path list.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
paths = list(scandir(folder))
|
| 247 |
+
paths = [osp.join(folder, path) for path in paths]
|
| 248 |
+
return paths
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def paths_from_lmdb(folder):
|
| 252 |
+
"""Generate paths from lmdb.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
folder (str): Folder path.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
list[str]: Returned path list.
|
| 259 |
+
"""
|
| 260 |
+
if not folder.endswith('.lmdb'):
|
| 261 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
| 262 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
| 263 |
+
paths = [line.split('.')[0] for line in fin]
|
| 264 |
+
return paths
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
| 268 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 272 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
np.array: The Gaussian kernel.
|
| 276 |
+
"""
|
| 277 |
+
from scipy.ndimage import filters as filters
|
| 278 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
| 279 |
+
# set element at the middle to one, a dirac delta
|
| 280 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
| 281 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
| 282 |
+
return filters.gaussian_filter(kernel, sigma)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
| 286 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
| 290 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 291 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
| 292 |
+
Default: 4.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Tensor: DUF downsampled frames.
|
| 296 |
+
"""
|
| 297 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
| 298 |
+
|
| 299 |
+
squeeze_flag = False
|
| 300 |
+
if x.ndim == 4:
|
| 301 |
+
squeeze_flag = True
|
| 302 |
+
x = x.unsqueeze(0)
|
| 303 |
+
b, t, c, h, w = x.size()
|
| 304 |
+
x = x.view(-1, 1, h, w)
|
| 305 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
| 306 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
| 307 |
+
|
| 308 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
| 309 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
| 310 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
| 311 |
+
x = x[:, :, 2:-2, 2:-2]
|
| 312 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
| 313 |
+
if squeeze_flag:
|
| 314 |
+
x = x.squeeze(0)
|
| 315 |
+
return x
|
basicsr/data/degradations.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
from scipy import special
|
| 7 |
+
from scipy.stats import multivariate_normal
|
| 8 |
+
from torchvision.transforms.functional import rgb_to_grayscale
|
| 9 |
+
|
| 10 |
+
# -------------------------------------------------------------------- #
|
| 11 |
+
# --------------------------- blur kernels --------------------------- #
|
| 12 |
+
# -------------------------------------------------------------------- #
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# --------------------------- util functions --------------------------- #
|
| 16 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
| 17 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
sig_x (float):
|
| 21 |
+
sig_y (float):
|
| 22 |
+
theta (float): Radian measurement.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
ndarray: Rotated sigma matrix.
|
| 26 |
+
"""
|
| 27 |
+
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
| 28 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
| 29 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def mesh_grid(kernel_size):
|
| 33 |
+
"""Generate the mesh grid, centering at zero.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
kernel_size (int):
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
| 40 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
| 41 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
| 42 |
+
"""
|
| 43 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
| 44 |
+
xx, yy = np.meshgrid(ax, ax)
|
| 45 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
| 46 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
| 47 |
+
return xy, xx, yy
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def pdf2(sigma_matrix, grid):
|
| 51 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
| 55 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
| 56 |
+
with the shape (K, K, 2), K is the kernel size.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
kernel (ndarrray): un-normalized kernel.
|
| 60 |
+
"""
|
| 61 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 62 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
| 63 |
+
return kernel
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def cdf2(d_matrix, grid):
|
| 67 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
| 68 |
+
Used in skewed Gaussian distribution.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
d_matrix (ndarrasy): skew matrix.
|
| 72 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
| 73 |
+
with the shape (K, K, 2), K is the kernel size.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
cdf (ndarray): skewed cdf.
|
| 77 |
+
"""
|
| 78 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
| 79 |
+
grid = np.dot(grid, d_matrix)
|
| 80 |
+
cdf = rv.cdf(grid)
|
| 81 |
+
return cdf
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
| 85 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
| 86 |
+
|
| 87 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
kernel_size (int):
|
| 91 |
+
sig_x (float):
|
| 92 |
+
sig_y (float):
|
| 93 |
+
theta (float): Radian measurement.
|
| 94 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 95 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 96 |
+
isotropic (bool):
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
kernel (ndarray): normalized kernel.
|
| 100 |
+
"""
|
| 101 |
+
if grid is None:
|
| 102 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 103 |
+
if isotropic:
|
| 104 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
| 105 |
+
else:
|
| 106 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 107 |
+
kernel = pdf2(sigma_matrix, grid)
|
| 108 |
+
kernel = kernel / np.sum(kernel)
|
| 109 |
+
return kernel
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
| 113 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
| 114 |
+
|
| 115 |
+
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
|
| 116 |
+
|
| 117 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
kernel_size (int):
|
| 121 |
+
sig_x (float):
|
| 122 |
+
sig_y (float):
|
| 123 |
+
theta (float): Radian measurement.
|
| 124 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
| 125 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 126 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
kernel (ndarray): normalized kernel.
|
| 130 |
+
"""
|
| 131 |
+
if grid is None:
|
| 132 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 133 |
+
if isotropic:
|
| 134 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
| 135 |
+
else:
|
| 136 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 137 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 138 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
| 139 |
+
kernel = kernel / np.sum(kernel)
|
| 140 |
+
return kernel
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
| 144 |
+
"""Generate a plateau-like anisotropic kernel.
|
| 145 |
+
|
| 146 |
+
1 / (1+x^(beta))
|
| 147 |
+
|
| 148 |
+
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
| 149 |
+
|
| 150 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
kernel_size (int):
|
| 154 |
+
sig_x (float):
|
| 155 |
+
sig_y (float):
|
| 156 |
+
theta (float): Radian measurement.
|
| 157 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
| 158 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 159 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
kernel (ndarray): normalized kernel.
|
| 163 |
+
"""
|
| 164 |
+
if grid is None:
|
| 165 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 166 |
+
if isotropic:
|
| 167 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
| 168 |
+
else:
|
| 169 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 170 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 171 |
+
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
| 172 |
+
kernel = kernel / np.sum(kernel)
|
| 173 |
+
return kernel
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def random_bivariate_Gaussian(kernel_size,
|
| 177 |
+
sigma_x_range,
|
| 178 |
+
sigma_y_range,
|
| 179 |
+
rotation_range,
|
| 180 |
+
noise_range=None,
|
| 181 |
+
isotropic=True):
|
| 182 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
| 183 |
+
|
| 184 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
kernel_size (int):
|
| 188 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 189 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 190 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 191 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
| 192 |
+
[0.75, 1.25]. Default: None
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
kernel (ndarray):
|
| 196 |
+
"""
|
| 197 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 198 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 199 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 200 |
+
if isotropic is False:
|
| 201 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 202 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 203 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 204 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 205 |
+
else:
|
| 206 |
+
sigma_y = sigma_x
|
| 207 |
+
rotation = 0
|
| 208 |
+
|
| 209 |
+
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
| 210 |
+
|
| 211 |
+
# add multiplicative noise
|
| 212 |
+
if noise_range is not None:
|
| 213 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 214 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
| 215 |
+
kernel = kernel * noise
|
| 216 |
+
kernel = kernel / np.sum(kernel)
|
| 217 |
+
return kernel
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
| 221 |
+
sigma_x_range,
|
| 222 |
+
sigma_y_range,
|
| 223 |
+
rotation_range,
|
| 224 |
+
beta_range,
|
| 225 |
+
noise_range=None,
|
| 226 |
+
isotropic=True):
|
| 227 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
| 228 |
+
|
| 229 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
kernel_size (int):
|
| 233 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 234 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 235 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 236 |
+
beta_range (tuple): [0.5, 8]
|
| 237 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
| 238 |
+
[0.75, 1.25]. Default: None
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
kernel (ndarray):
|
| 242 |
+
"""
|
| 243 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 244 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 245 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 246 |
+
if isotropic is False:
|
| 247 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 248 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 249 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 250 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 251 |
+
else:
|
| 252 |
+
sigma_y = sigma_x
|
| 253 |
+
rotation = 0
|
| 254 |
+
|
| 255 |
+
# assume beta_range[0] < 1 < beta_range[1]
|
| 256 |
+
if np.random.uniform() < 0.5:
|
| 257 |
+
beta = np.random.uniform(beta_range[0], 1)
|
| 258 |
+
else:
|
| 259 |
+
beta = np.random.uniform(1, beta_range[1])
|
| 260 |
+
|
| 261 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
| 262 |
+
|
| 263 |
+
# add multiplicative noise
|
| 264 |
+
if noise_range is not None:
|
| 265 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 266 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
| 267 |
+
kernel = kernel * noise
|
| 268 |
+
kernel = kernel / np.sum(kernel)
|
| 269 |
+
return kernel
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def random_bivariate_plateau(kernel_size,
|
| 273 |
+
sigma_x_range,
|
| 274 |
+
sigma_y_range,
|
| 275 |
+
rotation_range,
|
| 276 |
+
beta_range,
|
| 277 |
+
noise_range=None,
|
| 278 |
+
isotropic=True):
|
| 279 |
+
"""Randomly generate bivariate plateau kernels.
|
| 280 |
+
|
| 281 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
kernel_size (int):
|
| 285 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 286 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 287 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
| 288 |
+
beta_range (tuple): [1, 4]
|
| 289 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
| 290 |
+
[0.75, 1.25]. Default: None
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
kernel (ndarray):
|
| 294 |
+
"""
|
| 295 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 296 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 297 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 298 |
+
if isotropic is False:
|
| 299 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 300 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 301 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 302 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 303 |
+
else:
|
| 304 |
+
sigma_y = sigma_x
|
| 305 |
+
rotation = 0
|
| 306 |
+
|
| 307 |
+
# TODO: this may be not proper
|
| 308 |
+
if np.random.uniform() < 0.5:
|
| 309 |
+
beta = np.random.uniform(beta_range[0], 1)
|
| 310 |
+
else:
|
| 311 |
+
beta = np.random.uniform(1, beta_range[1])
|
| 312 |
+
|
| 313 |
+
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
| 314 |
+
# add multiplicative noise
|
| 315 |
+
if noise_range is not None:
|
| 316 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 317 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
| 318 |
+
kernel = kernel * noise
|
| 319 |
+
kernel = kernel / np.sum(kernel)
|
| 320 |
+
|
| 321 |
+
return kernel
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def random_mixed_kernels(kernel_list,
|
| 325 |
+
kernel_prob,
|
| 326 |
+
kernel_size=21,
|
| 327 |
+
sigma_x_range=(0.6, 5),
|
| 328 |
+
sigma_y_range=(0.6, 5),
|
| 329 |
+
rotation_range=(-math.pi, math.pi),
|
| 330 |
+
betag_range=(0.5, 8),
|
| 331 |
+
betap_range=(0.5, 8),
|
| 332 |
+
noise_range=None):
|
| 333 |
+
"""Randomly generate mixed kernels.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
kernel_list (tuple): a list name of kernel types,
|
| 337 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
| 338 |
+
'plateau_aniso']
|
| 339 |
+
kernel_prob (tuple): corresponding kernel probability for each
|
| 340 |
+
kernel type
|
| 341 |
+
kernel_size (int):
|
| 342 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 343 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 344 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 345 |
+
beta_range (tuple): [0.5, 8]
|
| 346 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
| 347 |
+
[0.75, 1.25]. Default: None
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
kernel (ndarray):
|
| 351 |
+
"""
|
| 352 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
| 353 |
+
if kernel_type == 'iso':
|
| 354 |
+
kernel = random_bivariate_Gaussian(
|
| 355 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
| 356 |
+
elif kernel_type == 'aniso':
|
| 357 |
+
kernel = random_bivariate_Gaussian(
|
| 358 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
| 359 |
+
elif kernel_type == 'generalized_iso':
|
| 360 |
+
kernel = random_bivariate_generalized_Gaussian(
|
| 361 |
+
kernel_size,
|
| 362 |
+
sigma_x_range,
|
| 363 |
+
sigma_y_range,
|
| 364 |
+
rotation_range,
|
| 365 |
+
betag_range,
|
| 366 |
+
noise_range=noise_range,
|
| 367 |
+
isotropic=True)
|
| 368 |
+
elif kernel_type == 'generalized_aniso':
|
| 369 |
+
kernel = random_bivariate_generalized_Gaussian(
|
| 370 |
+
kernel_size,
|
| 371 |
+
sigma_x_range,
|
| 372 |
+
sigma_y_range,
|
| 373 |
+
rotation_range,
|
| 374 |
+
betag_range,
|
| 375 |
+
noise_range=noise_range,
|
| 376 |
+
isotropic=False)
|
| 377 |
+
elif kernel_type == 'plateau_iso':
|
| 378 |
+
kernel = random_bivariate_plateau(
|
| 379 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
| 380 |
+
elif kernel_type == 'plateau_aniso':
|
| 381 |
+
kernel = random_bivariate_plateau(
|
| 382 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
| 383 |
+
return kernel
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
np.seterr(divide='ignore', invalid='ignore')
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
| 390 |
+
"""2D sinc filter
|
| 391 |
+
|
| 392 |
+
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
cutoff (float): cutoff frequency in radians (pi is max)
|
| 396 |
+
kernel_size (int): horizontal and vertical size, must be odd.
|
| 397 |
+
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
| 398 |
+
"""
|
| 399 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 400 |
+
kernel = np.fromfunction(
|
| 401 |
+
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
| 402 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
| 403 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
| 404 |
+
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
| 405 |
+
kernel = kernel / np.sum(kernel)
|
| 406 |
+
if pad_to > kernel_size:
|
| 407 |
+
pad_size = (pad_to - kernel_size) // 2
|
| 408 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
| 409 |
+
return kernel
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# ------------------------------------------------------------- #
|
| 413 |
+
# --------------------------- noise --------------------------- #
|
| 414 |
+
# ------------------------------------------------------------- #
|
| 415 |
+
|
| 416 |
+
# ----------------------- Gaussian Noise ----------------------- #
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
| 420 |
+
"""Generate Gaussian noise.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
| 424 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
| 428 |
+
float32.
|
| 429 |
+
"""
|
| 430 |
+
if gray_noise:
|
| 431 |
+
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
| 432 |
+
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
| 433 |
+
else:
|
| 434 |
+
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
| 435 |
+
return noise
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
| 439 |
+
"""Add Gaussian noise.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
| 443 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
| 447 |
+
float32.
|
| 448 |
+
"""
|
| 449 |
+
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
| 450 |
+
out = img + noise
|
| 451 |
+
if clip and rounds:
|
| 452 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
| 453 |
+
elif clip:
|
| 454 |
+
out = np.clip(out, 0, 1)
|
| 455 |
+
elif rounds:
|
| 456 |
+
out = (out * 255.0).round() / 255.
|
| 457 |
+
return out
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
| 461 |
+
"""Add Gaussian noise (PyTorch version).
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
| 465 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
| 469 |
+
float32.
|
| 470 |
+
"""
|
| 471 |
+
b, _, h, w = img.size()
|
| 472 |
+
if not isinstance(sigma, (float, int)):
|
| 473 |
+
sigma = sigma.view(img.size(0), 1, 1, 1)
|
| 474 |
+
if isinstance(gray_noise, (float, int)):
|
| 475 |
+
cal_gray_noise = gray_noise > 0
|
| 476 |
+
else:
|
| 477 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
| 478 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
| 479 |
+
|
| 480 |
+
if cal_gray_noise:
|
| 481 |
+
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
| 482 |
+
noise_gray = noise_gray.view(b, 1, h, w)
|
| 483 |
+
|
| 484 |
+
# always calculate color noise
|
| 485 |
+
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
| 486 |
+
|
| 487 |
+
if cal_gray_noise:
|
| 488 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
| 489 |
+
return noise
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
| 493 |
+
"""Add Gaussian noise (PyTorch version).
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
| 497 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
| 501 |
+
float32.
|
| 502 |
+
"""
|
| 503 |
+
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
| 504 |
+
out = img + noise
|
| 505 |
+
if clip and rounds:
|
| 506 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
| 507 |
+
elif clip:
|
| 508 |
+
out = torch.clamp(out, 0, 1)
|
| 509 |
+
elif rounds:
|
| 510 |
+
out = (out * 255.0).round() / 255.
|
| 511 |
+
return out
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# ----------------------- Random Gaussian Noise ----------------------- #
|
| 515 |
+
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
| 516 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
| 517 |
+
if np.random.uniform() < gray_prob:
|
| 518 |
+
gray_noise = True
|
| 519 |
+
else:
|
| 520 |
+
gray_noise = False
|
| 521 |
+
return generate_gaussian_noise(img, sigma, gray_noise)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
| 525 |
+
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
| 526 |
+
out = img + noise
|
| 527 |
+
if clip and rounds:
|
| 528 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
| 529 |
+
elif clip:
|
| 530 |
+
out = np.clip(out, 0, 1)
|
| 531 |
+
elif rounds:
|
| 532 |
+
out = (out * 255.0).round() / 255.
|
| 533 |
+
return out
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
| 537 |
+
sigma = torch.rand(
|
| 538 |
+
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
| 539 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
| 540 |
+
gray_noise = (gray_noise < gray_prob).float()
|
| 541 |
+
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
| 545 |
+
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
| 546 |
+
out = img + noise
|
| 547 |
+
if clip and rounds:
|
| 548 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
| 549 |
+
elif clip:
|
| 550 |
+
out = torch.clamp(out, 0, 1)
|
| 551 |
+
elif rounds:
|
| 552 |
+
out = (out * 255.0).round() / 255.
|
| 553 |
+
return out
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
| 560 |
+
"""Generate poisson noise.
|
| 561 |
+
|
| 562 |
+
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
| 566 |
+
scale (float): Noise scale. Default: 1.0.
|
| 567 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
| 568 |
+
|
| 569 |
+
Returns:
|
| 570 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
| 571 |
+
float32.
|
| 572 |
+
"""
|
| 573 |
+
if gray_noise:
|
| 574 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 575 |
+
# round and clip image for counting vals correctly
|
| 576 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
| 577 |
+
vals = len(np.unique(img))
|
| 578 |
+
vals = 2**np.ceil(np.log2(vals))
|
| 579 |
+
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
| 580 |
+
noise = out - img
|
| 581 |
+
if gray_noise:
|
| 582 |
+
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
| 583 |
+
return noise * scale
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
| 587 |
+
"""Add poisson noise.
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
| 591 |
+
scale (float): Noise scale. Default: 1.0.
|
| 592 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
| 596 |
+
float32.
|
| 597 |
+
"""
|
| 598 |
+
noise = generate_poisson_noise(img, scale, gray_noise)
|
| 599 |
+
out = img + noise
|
| 600 |
+
if clip and rounds:
|
| 601 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
| 602 |
+
elif clip:
|
| 603 |
+
out = np.clip(out, 0, 1)
|
| 604 |
+
elif rounds:
|
| 605 |
+
out = (out * 255.0).round() / 255.
|
| 606 |
+
return out
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
| 610 |
+
"""Generate a batch of poisson noise (PyTorch version)
|
| 611 |
+
|
| 612 |
+
Args:
|
| 613 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
| 614 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
| 615 |
+
Default: 1.0.
|
| 616 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
| 617 |
+
0 for False, 1 for True. Default: 0.
|
| 618 |
+
|
| 619 |
+
Returns:
|
| 620 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
| 621 |
+
float32.
|
| 622 |
+
"""
|
| 623 |
+
b, _, h, w = img.size()
|
| 624 |
+
if isinstance(gray_noise, (float, int)):
|
| 625 |
+
cal_gray_noise = gray_noise > 0
|
| 626 |
+
else:
|
| 627 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
| 628 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
| 629 |
+
if cal_gray_noise:
|
| 630 |
+
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
| 631 |
+
# round and clip image for counting vals correctly
|
| 632 |
+
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
| 633 |
+
# use for-loop to get the unique values for each sample
|
| 634 |
+
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
| 635 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
| 636 |
+
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
| 637 |
+
out = torch.poisson(img_gray * vals) / vals
|
| 638 |
+
noise_gray = out - img_gray
|
| 639 |
+
noise_gray = noise_gray.expand(b, 3, h, w)
|
| 640 |
+
|
| 641 |
+
# always calculate color noise
|
| 642 |
+
# round and clip image for counting vals correctly
|
| 643 |
+
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
| 644 |
+
# use for-loop to get the unique values for each sample
|
| 645 |
+
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
| 646 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
| 647 |
+
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
| 648 |
+
out = torch.poisson(img * vals) / vals
|
| 649 |
+
noise = out - img
|
| 650 |
+
if cal_gray_noise:
|
| 651 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
| 652 |
+
if not isinstance(scale, (float, int)):
|
| 653 |
+
scale = scale.view(b, 1, 1, 1)
|
| 654 |
+
return noise * scale
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
| 658 |
+
"""Add poisson noise to a batch of images (PyTorch version).
|
| 659 |
+
|
| 660 |
+
Args:
|
| 661 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
| 662 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
| 663 |
+
Default: 1.0.
|
| 664 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
| 665 |
+
0 for False, 1 for True. Default: 0.
|
| 666 |
+
|
| 667 |
+
Returns:
|
| 668 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
| 669 |
+
float32.
|
| 670 |
+
"""
|
| 671 |
+
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
| 672 |
+
out = img + noise
|
| 673 |
+
if clip and rounds:
|
| 674 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
| 675 |
+
elif clip:
|
| 676 |
+
out = torch.clamp(out, 0, 1)
|
| 677 |
+
elif rounds:
|
| 678 |
+
out = (out * 255.0).round() / 255.
|
| 679 |
+
return out
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
| 686 |
+
scale = np.random.uniform(scale_range[0], scale_range[1])
|
| 687 |
+
if np.random.uniform() < gray_prob:
|
| 688 |
+
gray_noise = True
|
| 689 |
+
else:
|
| 690 |
+
gray_noise = False
|
| 691 |
+
return generate_poisson_noise(img, scale, gray_noise)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
| 695 |
+
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
| 696 |
+
out = img + noise
|
| 697 |
+
if clip and rounds:
|
| 698 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
| 699 |
+
elif clip:
|
| 700 |
+
out = np.clip(out, 0, 1)
|
| 701 |
+
elif rounds:
|
| 702 |
+
out = (out * 255.0).round() / 255.
|
| 703 |
+
return out
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
| 707 |
+
scale = torch.rand(
|
| 708 |
+
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
| 709 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
| 710 |
+
gray_noise = (gray_noise < gray_prob).float()
|
| 711 |
+
return generate_poisson_noise_pt(img, scale, gray_noise)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
| 715 |
+
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
| 716 |
+
out = img + noise
|
| 717 |
+
if clip and rounds:
|
| 718 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
| 719 |
+
elif clip:
|
| 720 |
+
out = torch.clamp(out, 0, 1)
|
| 721 |
+
elif rounds:
|
| 722 |
+
out = (out * 255.0).round() / 255.
|
| 723 |
+
return out
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
# ------------------------------------------------------------------------ #
|
| 727 |
+
# --------------------------- JPEG compression --------------------------- #
|
| 728 |
+
# ------------------------------------------------------------------------ #
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def add_jpg_compression(img, quality=90):
|
| 732 |
+
"""Add JPG compression artifacts.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
| 736 |
+
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
| 737 |
+
best quality. Default: 90.
|
| 738 |
+
|
| 739 |
+
Returns:
|
| 740 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
| 741 |
+
float32.
|
| 742 |
+
"""
|
| 743 |
+
img = np.clip(img, 0, 1)
|
| 744 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
| 745 |
+
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
| 746 |
+
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
| 747 |
+
return img
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
| 751 |
+
"""Randomly add JPG compression artifacts.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
| 755 |
+
quality_range (tuple[float] | list[float]): JPG compression quality
|
| 756 |
+
range. 0 for lowest quality, 100 for best quality.
|
| 757 |
+
Default: (90, 100).
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
| 761 |
+
float32.
|
| 762 |
+
"""
|
| 763 |
+
quality = np.random.uniform(quality_range[0], quality_range[1])
|
| 764 |
+
return add_jpg_compression(img, quality)
|
basicsr/data/ffhq_dataset.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import time
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from torch.utils import data as data
|
| 5 |
+
from torchvision.transforms.functional import normalize
|
| 6 |
+
|
| 7 |
+
from basicsr.data.transforms import augment
|
| 8 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 9 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@DATASET_REGISTRY.register()
|
| 13 |
+
class FFHQDataset(data.Dataset):
|
| 14 |
+
"""FFHQ dataset for StyleGAN.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 18 |
+
dataroot_gt (str): Data root path for gt.
|
| 19 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 20 |
+
mean (list | tuple): Image mean.
|
| 21 |
+
std (list | tuple): Image std.
|
| 22 |
+
use_hflip (bool): Whether to horizontally flip.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, opt):
|
| 27 |
+
super(FFHQDataset, self).__init__()
|
| 28 |
+
self.opt = opt
|
| 29 |
+
# file client (io backend)
|
| 30 |
+
self.file_client = None
|
| 31 |
+
self.io_backend_opt = opt['io_backend']
|
| 32 |
+
|
| 33 |
+
self.gt_folder = opt['dataroot_gt']
|
| 34 |
+
self.mean = opt['mean']
|
| 35 |
+
self.std = opt['std']
|
| 36 |
+
|
| 37 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 38 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
| 39 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 40 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
| 41 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 42 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 43 |
+
else:
|
| 44 |
+
# FFHQ has 70000 images in total
|
| 45 |
+
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, index):
|
| 48 |
+
if self.file_client is None:
|
| 49 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 50 |
+
|
| 51 |
+
# load gt image
|
| 52 |
+
gt_path = self.paths[index]
|
| 53 |
+
# avoid errors caused by high latency in reading files
|
| 54 |
+
retry = 3
|
| 55 |
+
while retry > 0:
|
| 56 |
+
try:
|
| 57 |
+
img_bytes = self.file_client.get(gt_path)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger = get_root_logger()
|
| 60 |
+
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
|
| 61 |
+
# change another file to read
|
| 62 |
+
index = random.randint(0, self.__len__())
|
| 63 |
+
gt_path = self.paths[index]
|
| 64 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
| 65 |
+
else:
|
| 66 |
+
break
|
| 67 |
+
finally:
|
| 68 |
+
retry -= 1
|
| 69 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 70 |
+
|
| 71 |
+
# random horizontal flip
|
| 72 |
+
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
|
| 73 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 74 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
| 75 |
+
# normalize
|
| 76 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 77 |
+
return {'gt': img_gt, 'gt_path': gt_path}
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.paths)
|
basicsr/data/paired_image_dataset.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils import data as data
|
| 2 |
+
from torchvision.transforms.functional import normalize
|
| 3 |
+
|
| 4 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
|
| 5 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 6 |
+
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
|
| 7 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DATASET_REGISTRY.register()
|
| 11 |
+
class PairedImageDataset(data.Dataset):
|
| 12 |
+
"""Paired image dataset for image restoration.
|
| 13 |
+
|
| 14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
| 15 |
+
|
| 16 |
+
There are three modes:
|
| 17 |
+
|
| 18 |
+
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
| 19 |
+
2. **meta_info_file**: Use meta information file to generate paths. \
|
| 20 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 21 |
+
3. **folder**: Scan folders to generate paths. The rest.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 25 |
+
dataroot_gt (str): Data root path for gt.
|
| 26 |
+
dataroot_lq (str): Data root path for lq.
|
| 27 |
+
meta_info_file (str): Path for meta information file.
|
| 28 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 29 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
| 30 |
+
Default: '{}'.
|
| 31 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 32 |
+
use_hflip (bool): Use horizontal flips.
|
| 33 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 34 |
+
scale (bool): Scale, which will be added automatically.
|
| 35 |
+
phase (str): 'train' or 'val'.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, opt):
|
| 39 |
+
super(PairedImageDataset, self).__init__()
|
| 40 |
+
self.opt = opt
|
| 41 |
+
# file client (io backend)
|
| 42 |
+
self.file_client = None
|
| 43 |
+
self.io_backend_opt = opt['io_backend']
|
| 44 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 45 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 46 |
+
|
| 47 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 48 |
+
if 'filename_tmpl' in opt:
|
| 49 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 50 |
+
else:
|
| 51 |
+
self.filename_tmpl = '{}'
|
| 52 |
+
|
| 53 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 54 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 55 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 56 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 57 |
+
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
| 58 |
+
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 59 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
| 60 |
+
else:
|
| 61 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, index):
|
| 64 |
+
if self.file_client is None:
|
| 65 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 66 |
+
|
| 67 |
+
scale = self.opt['scale']
|
| 68 |
+
|
| 69 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 70 |
+
# image range: [0, 1], float32.
|
| 71 |
+
gt_path = self.paths[index]['gt_path']
|
| 72 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 73 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 74 |
+
lq_path = self.paths[index]['lq_path']
|
| 75 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 76 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 77 |
+
|
| 78 |
+
# augmentation for training
|
| 79 |
+
if self.opt['phase'] == 'train':
|
| 80 |
+
gt_size = self.opt['gt_size']
|
| 81 |
+
# random crop
|
| 82 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
| 83 |
+
# flip, rotation
|
| 84 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
| 85 |
+
|
| 86 |
+
# color space transform
|
| 87 |
+
if 'color' in self.opt and self.opt['color'] == 'y':
|
| 88 |
+
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
|
| 89 |
+
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
|
| 90 |
+
|
| 91 |
+
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
|
| 92 |
+
# TODO: It is better to update the datasets, rather than force to crop
|
| 93 |
+
if self.opt['phase'] != 'train':
|
| 94 |
+
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
|
| 95 |
+
|
| 96 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 97 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
| 98 |
+
# normalize
|
| 99 |
+
if self.mean is not None or self.std is not None:
|
| 100 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 101 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 102 |
+
|
| 103 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return len(self.paths)
|
basicsr/data/prefetch_dataloader.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue as Queue
|
| 2 |
+
import threading
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PrefetchGenerator(threading.Thread):
|
| 8 |
+
"""A general prefetch generator.
|
| 9 |
+
|
| 10 |
+
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
generator: Python generator.
|
| 14 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, generator, num_prefetch_queue):
|
| 18 |
+
threading.Thread.__init__(self)
|
| 19 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
| 20 |
+
self.generator = generator
|
| 21 |
+
self.daemon = True
|
| 22 |
+
self.start()
|
| 23 |
+
|
| 24 |
+
def run(self):
|
| 25 |
+
for item in self.generator:
|
| 26 |
+
self.queue.put(item)
|
| 27 |
+
self.queue.put(None)
|
| 28 |
+
|
| 29 |
+
def __next__(self):
|
| 30 |
+
next_item = self.queue.get()
|
| 31 |
+
if next_item is None:
|
| 32 |
+
raise StopIteration
|
| 33 |
+
return next_item
|
| 34 |
+
|
| 35 |
+
def __iter__(self):
|
| 36 |
+
return self
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class PrefetchDataLoader(DataLoader):
|
| 40 |
+
"""Prefetch version of dataloader.
|
| 41 |
+
|
| 42 |
+
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
| 43 |
+
|
| 44 |
+
TODO:
|
| 45 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
| 46 |
+
ddp.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 50 |
+
kwargs (dict): Other arguments for dataloader.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
| 54 |
+
self.num_prefetch_queue = num_prefetch_queue
|
| 55 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
| 56 |
+
|
| 57 |
+
def __iter__(self):
|
| 58 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CPUPrefetcher():
|
| 62 |
+
"""CPU prefetcher.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
loader: Dataloader.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, loader):
|
| 69 |
+
self.ori_loader = loader
|
| 70 |
+
self.loader = iter(loader)
|
| 71 |
+
|
| 72 |
+
def next(self):
|
| 73 |
+
try:
|
| 74 |
+
return next(self.loader)
|
| 75 |
+
except StopIteration:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def reset(self):
|
| 79 |
+
self.loader = iter(self.ori_loader)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class CUDAPrefetcher():
|
| 83 |
+
"""CUDA prefetcher.
|
| 84 |
+
|
| 85 |
+
Reference: https://github.com/NVIDIA/apex/issues/304#
|
| 86 |
+
|
| 87 |
+
It may consume more GPU memory.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
loader: Dataloader.
|
| 91 |
+
opt (dict): Options.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, loader, opt):
|
| 95 |
+
self.ori_loader = loader
|
| 96 |
+
self.loader = iter(loader)
|
| 97 |
+
self.opt = opt
|
| 98 |
+
self.stream = torch.cuda.Stream()
|
| 99 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 100 |
+
self.preload()
|
| 101 |
+
|
| 102 |
+
def preload(self):
|
| 103 |
+
try:
|
| 104 |
+
self.batch = next(self.loader) # self.batch is a dict
|
| 105 |
+
except StopIteration:
|
| 106 |
+
self.batch = None
|
| 107 |
+
return None
|
| 108 |
+
# put tensors to gpu
|
| 109 |
+
with torch.cuda.stream(self.stream):
|
| 110 |
+
for k, v in self.batch.items():
|
| 111 |
+
if torch.is_tensor(v):
|
| 112 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
| 113 |
+
|
| 114 |
+
def next(self):
|
| 115 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
| 116 |
+
batch = self.batch
|
| 117 |
+
self.preload()
|
| 118 |
+
return batch
|
| 119 |
+
|
| 120 |
+
def reset(self):
|
| 121 |
+
self.loader = iter(self.ori_loader)
|
| 122 |
+
self.preload()
|
basicsr/data/realesrgan_dataset.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils import data as data
|
| 10 |
+
|
| 11 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
| 12 |
+
from basicsr.data.transforms import augment
|
| 13 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 14 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@DATASET_REGISTRY.register(suffix='basicsr')
|
| 18 |
+
class RealESRGANDataset(data.Dataset):
|
| 19 |
+
"""Dataset used for Real-ESRGAN model:
|
| 20 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
| 21 |
+
|
| 22 |
+
It loads gt (Ground-Truth) images, and augments them.
|
| 23 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
| 24 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 28 |
+
dataroot_gt (str): Data root path for gt.
|
| 29 |
+
meta_info (str): Path for meta information file.
|
| 30 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 31 |
+
use_hflip (bool): Use horizontal flips.
|
| 32 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 33 |
+
Please see more options in the codes.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, opt):
|
| 37 |
+
super(RealESRGANDataset, self).__init__()
|
| 38 |
+
self.opt = opt
|
| 39 |
+
self.file_client = None
|
| 40 |
+
self.io_backend_opt = opt['io_backend']
|
| 41 |
+
self.gt_folder = opt['dataroot_gt']
|
| 42 |
+
|
| 43 |
+
# file client (lmdb io backend)
|
| 44 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 45 |
+
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
| 46 |
+
self.io_backend_opt['client_keys'] = ['gt']
|
| 47 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 48 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
| 49 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 50 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 51 |
+
else:
|
| 52 |
+
# disk backend with meta_info
|
| 53 |
+
# Each line in the meta_info describes the relative path to an image
|
| 54 |
+
with open(self.opt['meta_info']) as fin:
|
| 55 |
+
paths = [line.strip().split(' ')[0] for line in fin]
|
| 56 |
+
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
| 57 |
+
|
| 58 |
+
# blur settings for the first degradation
|
| 59 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
| 60 |
+
self.kernel_list = opt['kernel_list']
|
| 61 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
| 62 |
+
self.blur_sigma = opt['blur_sigma']
|
| 63 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
| 64 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
| 65 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
| 66 |
+
|
| 67 |
+
# blur settings for the second degradation
|
| 68 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
| 69 |
+
self.kernel_list2 = opt['kernel_list2']
|
| 70 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
| 71 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
| 72 |
+
self.betag_range2 = opt['betag_range2']
|
| 73 |
+
self.betap_range2 = opt['betap_range2']
|
| 74 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
| 75 |
+
|
| 76 |
+
# a final sinc filter
|
| 77 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
| 78 |
+
|
| 79 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
| 80 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
| 81 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
| 82 |
+
self.pulse_tensor[10, 10] = 1
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, index):
|
| 85 |
+
if self.file_client is None:
|
| 86 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 87 |
+
|
| 88 |
+
# -------------------------------- Load gt images -------------------------------- #
|
| 89 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
| 90 |
+
gt_path = self.paths[index]
|
| 91 |
+
# avoid errors caused by high latency in reading files
|
| 92 |
+
retry = 3
|
| 93 |
+
while retry > 0:
|
| 94 |
+
try:
|
| 95 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 96 |
+
except (IOError, OSError) as e:
|
| 97 |
+
logger = get_root_logger()
|
| 98 |
+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
| 99 |
+
# change another file to read
|
| 100 |
+
index = random.randint(0, self.__len__())
|
| 101 |
+
gt_path = self.paths[index]
|
| 102 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
| 103 |
+
else:
|
| 104 |
+
break
|
| 105 |
+
finally:
|
| 106 |
+
retry -= 1
|
| 107 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 108 |
+
|
| 109 |
+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
| 110 |
+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
| 111 |
+
|
| 112 |
+
# crop or pad to 400
|
| 113 |
+
# TODO: 400 is hard-coded. You may change it accordingly
|
| 114 |
+
h, w = img_gt.shape[0:2]
|
| 115 |
+
crop_pad_size = 400
|
| 116 |
+
# pad
|
| 117 |
+
if h < crop_pad_size or w < crop_pad_size:
|
| 118 |
+
pad_h = max(0, crop_pad_size - h)
|
| 119 |
+
pad_w = max(0, crop_pad_size - w)
|
| 120 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
| 121 |
+
# crop
|
| 122 |
+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
| 123 |
+
h, w = img_gt.shape[0:2]
|
| 124 |
+
# randomly choose top and left coordinates
|
| 125 |
+
top = random.randint(0, h - crop_pad_size)
|
| 126 |
+
left = random.randint(0, w - crop_pad_size)
|
| 127 |
+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
| 128 |
+
|
| 129 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
| 130 |
+
kernel_size = random.choice(self.kernel_range)
|
| 131 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
| 132 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
| 133 |
+
if kernel_size < 13:
|
| 134 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
| 135 |
+
else:
|
| 136 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
| 137 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
| 138 |
+
else:
|
| 139 |
+
kernel = random_mixed_kernels(
|
| 140 |
+
self.kernel_list,
|
| 141 |
+
self.kernel_prob,
|
| 142 |
+
kernel_size,
|
| 143 |
+
self.blur_sigma,
|
| 144 |
+
self.blur_sigma, [-math.pi, math.pi],
|
| 145 |
+
self.betag_range,
|
| 146 |
+
self.betap_range,
|
| 147 |
+
noise_range=None)
|
| 148 |
+
# pad kernel
|
| 149 |
+
pad_size = (21 - kernel_size) // 2
|
| 150 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
| 151 |
+
|
| 152 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
| 153 |
+
kernel_size = random.choice(self.kernel_range)
|
| 154 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
| 155 |
+
if kernel_size < 13:
|
| 156 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
| 157 |
+
else:
|
| 158 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
| 159 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
| 160 |
+
else:
|
| 161 |
+
kernel2 = random_mixed_kernels(
|
| 162 |
+
self.kernel_list2,
|
| 163 |
+
self.kernel_prob2,
|
| 164 |
+
kernel_size,
|
| 165 |
+
self.blur_sigma2,
|
| 166 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
| 167 |
+
self.betag_range2,
|
| 168 |
+
self.betap_range2,
|
| 169 |
+
noise_range=None)
|
| 170 |
+
|
| 171 |
+
# pad kernel
|
| 172 |
+
pad_size = (21 - kernel_size) // 2
|
| 173 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
| 174 |
+
|
| 175 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
| 176 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
| 177 |
+
kernel_size = random.choice(self.kernel_range)
|
| 178 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
| 179 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
| 180 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
| 181 |
+
else:
|
| 182 |
+
sinc_kernel = self.pulse_tensor
|
| 183 |
+
|
| 184 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 185 |
+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
| 186 |
+
kernel = torch.FloatTensor(kernel)
|
| 187 |
+
kernel2 = torch.FloatTensor(kernel2)
|
| 188 |
+
|
| 189 |
+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
| 190 |
+
return return_d
|
| 191 |
+
|
| 192 |
+
def __len__(self):
|
| 193 |
+
return len(self.paths)
|
basicsr/data/realesrgan_paired_dataset.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torch.utils import data as data
|
| 3 |
+
from torchvision.transforms.functional import normalize
|
| 4 |
+
|
| 5 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
| 6 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 7 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
| 8 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DATASET_REGISTRY.register(suffix='basicsr')
|
| 12 |
+
class RealESRGANPairedDataset(data.Dataset):
|
| 13 |
+
"""Paired image dataset for image restoration.
|
| 14 |
+
|
| 15 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
| 16 |
+
|
| 17 |
+
There are three modes:
|
| 18 |
+
|
| 19 |
+
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
| 20 |
+
2. **meta_info_file**: Use meta information file to generate paths. \
|
| 21 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 22 |
+
3. **folder**: Scan folders to generate paths. The rest.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 26 |
+
dataroot_gt (str): Data root path for gt.
|
| 27 |
+
dataroot_lq (str): Data root path for lq.
|
| 28 |
+
meta_info (str): Path for meta information file.
|
| 29 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 30 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
| 31 |
+
Default: '{}'.
|
| 32 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 33 |
+
use_hflip (bool): Use horizontal flips.
|
| 34 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 35 |
+
scale (bool): Scale, which will be added automatically.
|
| 36 |
+
phase (str): 'train' or 'val'.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, opt):
|
| 40 |
+
super(RealESRGANPairedDataset, self).__init__()
|
| 41 |
+
self.opt = opt
|
| 42 |
+
self.file_client = None
|
| 43 |
+
self.io_backend_opt = opt['io_backend']
|
| 44 |
+
# mean and std for normalizing the input images
|
| 45 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 46 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 47 |
+
|
| 48 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 49 |
+
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
| 50 |
+
|
| 51 |
+
# file client (lmdb io backend)
|
| 52 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 53 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 54 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 55 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 56 |
+
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
| 57 |
+
# disk backend with meta_info
|
| 58 |
+
# Each line in the meta_info describes the relative path to an image
|
| 59 |
+
with open(self.opt['meta_info']) as fin:
|
| 60 |
+
paths = [line.strip() for line in fin]
|
| 61 |
+
self.paths = []
|
| 62 |
+
for path in paths:
|
| 63 |
+
gt_path, lq_path = path.split(', ')
|
| 64 |
+
gt_path = os.path.join(self.gt_folder, gt_path)
|
| 65 |
+
lq_path = os.path.join(self.lq_folder, lq_path)
|
| 66 |
+
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
| 67 |
+
else:
|
| 68 |
+
# disk backend
|
| 69 |
+
# it will scan the whole folder to get meta info
|
| 70 |
+
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
| 71 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
| 72 |
+
|
| 73 |
+
def __getitem__(self, index):
|
| 74 |
+
if self.file_client is None:
|
| 75 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 76 |
+
|
| 77 |
+
scale = self.opt['scale']
|
| 78 |
+
|
| 79 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 80 |
+
# image range: [0, 1], float32.
|
| 81 |
+
gt_path = self.paths[index]['gt_path']
|
| 82 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 83 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 84 |
+
lq_path = self.paths[index]['lq_path']
|
| 85 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 86 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 87 |
+
|
| 88 |
+
# augmentation for training
|
| 89 |
+
if self.opt['phase'] == 'train':
|
| 90 |
+
gt_size = self.opt['gt_size']
|
| 91 |
+
# random crop
|
| 92 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
| 93 |
+
# flip, rotation
|
| 94 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
| 95 |
+
|
| 96 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 97 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
| 98 |
+
# normalize
|
| 99 |
+
if self.mean is not None or self.std is not None:
|
| 100 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 101 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 102 |
+
|
| 103 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return len(self.paths)
|
basicsr/data/reds_dataset.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from torch.utils import data as data
|
| 6 |
+
|
| 7 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 8 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 9 |
+
from basicsr.utils.flow_util import dequantize_flow
|
| 10 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@DATASET_REGISTRY.register()
|
| 14 |
+
class REDSDataset(data.Dataset):
|
| 15 |
+
"""REDS dataset for training.
|
| 16 |
+
|
| 17 |
+
The keys are generated from a meta info txt file.
|
| 18 |
+
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
| 19 |
+
|
| 20 |
+
Each line contains:
|
| 21 |
+
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
| 22 |
+
a white space.
|
| 23 |
+
Examples:
|
| 24 |
+
000 100 (720,1280,3)
|
| 25 |
+
001 100 (720,1280,3)
|
| 26 |
+
...
|
| 27 |
+
|
| 28 |
+
Key examples: "000/00000000"
|
| 29 |
+
GT (gt): Ground-Truth;
|
| 30 |
+
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 34 |
+
dataroot_gt (str): Data root path for gt.
|
| 35 |
+
dataroot_lq (str): Data root path for lq.
|
| 36 |
+
dataroot_flow (str, optional): Data root path for flow.
|
| 37 |
+
meta_info_file (str): Path for meta information file.
|
| 38 |
+
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
| 39 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 40 |
+
num_frame (int): Window size for input frames.
|
| 41 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 42 |
+
interval_list (list): Interval list for temporal augmentation.
|
| 43 |
+
random_reverse (bool): Random reverse input frames.
|
| 44 |
+
use_hflip (bool): Use horizontal flips.
|
| 45 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 46 |
+
scale (bool): Scale, which will be added automatically.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, opt):
|
| 50 |
+
super(REDSDataset, self).__init__()
|
| 51 |
+
self.opt = opt
|
| 52 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
| 53 |
+
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
|
| 54 |
+
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
|
| 55 |
+
self.num_frame = opt['num_frame']
|
| 56 |
+
self.num_half_frames = opt['num_frame'] // 2
|
| 57 |
+
|
| 58 |
+
self.keys = []
|
| 59 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 60 |
+
for line in fin:
|
| 61 |
+
folder, frame_num, _ = line.split(' ')
|
| 62 |
+
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
| 63 |
+
|
| 64 |
+
# remove the video clips used in validation
|
| 65 |
+
if opt['val_partition'] == 'REDS4':
|
| 66 |
+
val_partition = ['000', '011', '015', '020']
|
| 67 |
+
elif opt['val_partition'] == 'official':
|
| 68 |
+
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
| 71 |
+
f"Supported ones are ['official', 'REDS4'].")
|
| 72 |
+
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
| 73 |
+
|
| 74 |
+
# file client (io backend)
|
| 75 |
+
self.file_client = None
|
| 76 |
+
self.io_backend_opt = opt['io_backend']
|
| 77 |
+
self.is_lmdb = False
|
| 78 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 79 |
+
self.is_lmdb = True
|
| 80 |
+
if self.flow_root is not None:
|
| 81 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
| 82 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
| 83 |
+
else:
|
| 84 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
| 85 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 86 |
+
|
| 87 |
+
# temporal augmentation configs
|
| 88 |
+
self.interval_list = opt['interval_list']
|
| 89 |
+
self.random_reverse = opt['random_reverse']
|
| 90 |
+
interval_str = ','.join(str(x) for x in opt['interval_list'])
|
| 91 |
+
logger = get_root_logger()
|
| 92 |
+
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
| 93 |
+
f'random reverse is {self.random_reverse}.')
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, index):
|
| 96 |
+
if self.file_client is None:
|
| 97 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 98 |
+
|
| 99 |
+
scale = self.opt['scale']
|
| 100 |
+
gt_size = self.opt['gt_size']
|
| 101 |
+
key = self.keys[index]
|
| 102 |
+
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
| 103 |
+
center_frame_idx = int(frame_name)
|
| 104 |
+
|
| 105 |
+
# determine the neighboring frames
|
| 106 |
+
interval = random.choice(self.interval_list)
|
| 107 |
+
|
| 108 |
+
# ensure not exceeding the borders
|
| 109 |
+
start_frame_idx = center_frame_idx - self.num_half_frames * interval
|
| 110 |
+
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
| 111 |
+
# each clip has 100 frames starting from 0 to 99
|
| 112 |
+
while (start_frame_idx < 0) or (end_frame_idx > 99):
|
| 113 |
+
center_frame_idx = random.randint(0, 99)
|
| 114 |
+
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
|
| 115 |
+
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
| 116 |
+
frame_name = f'{center_frame_idx:08d}'
|
| 117 |
+
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
|
| 118 |
+
# random reverse
|
| 119 |
+
if self.random_reverse and random.random() < 0.5:
|
| 120 |
+
neighbor_list.reverse()
|
| 121 |
+
|
| 122 |
+
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
|
| 123 |
+
|
| 124 |
+
# get the GT frame (as the center frame)
|
| 125 |
+
if self.is_lmdb:
|
| 126 |
+
img_gt_path = f'{clip_name}/{frame_name}'
|
| 127 |
+
else:
|
| 128 |
+
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
|
| 129 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| 130 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 131 |
+
|
| 132 |
+
# get the neighboring LQ frames
|
| 133 |
+
img_lqs = []
|
| 134 |
+
for neighbor in neighbor_list:
|
| 135 |
+
if self.is_lmdb:
|
| 136 |
+
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
| 137 |
+
else:
|
| 138 |
+
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
| 139 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| 140 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 141 |
+
img_lqs.append(img_lq)
|
| 142 |
+
|
| 143 |
+
# get flows
|
| 144 |
+
if self.flow_root is not None:
|
| 145 |
+
img_flows = []
|
| 146 |
+
# read previous flows
|
| 147 |
+
for i in range(self.num_half_frames, 0, -1):
|
| 148 |
+
if self.is_lmdb:
|
| 149 |
+
flow_path = f'{clip_name}/{frame_name}_p{i}'
|
| 150 |
+
else:
|
| 151 |
+
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
|
| 152 |
+
img_bytes = self.file_client.get(flow_path, 'flow')
|
| 153 |
+
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
| 154 |
+
dx, dy = np.split(cat_flow, 2, axis=0)
|
| 155 |
+
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
| 156 |
+
img_flows.append(flow)
|
| 157 |
+
# read next flows
|
| 158 |
+
for i in range(1, self.num_half_frames + 1):
|
| 159 |
+
if self.is_lmdb:
|
| 160 |
+
flow_path = f'{clip_name}/{frame_name}_n{i}'
|
| 161 |
+
else:
|
| 162 |
+
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
|
| 163 |
+
img_bytes = self.file_client.get(flow_path, 'flow')
|
| 164 |
+
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
| 165 |
+
dx, dy = np.split(cat_flow, 2, axis=0)
|
| 166 |
+
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
| 167 |
+
img_flows.append(flow)
|
| 168 |
+
|
| 169 |
+
# for random crop, here, img_flows and img_lqs have the same
|
| 170 |
+
# spatial size
|
| 171 |
+
img_lqs.extend(img_flows)
|
| 172 |
+
|
| 173 |
+
# randomly crop
|
| 174 |
+
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
| 175 |
+
if self.flow_root is not None:
|
| 176 |
+
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
|
| 177 |
+
|
| 178 |
+
# augmentation - flip, rotate
|
| 179 |
+
img_lqs.append(img_gt)
|
| 180 |
+
if self.flow_root is not None:
|
| 181 |
+
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
|
| 182 |
+
else:
|
| 183 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
| 184 |
+
|
| 185 |
+
img_results = img2tensor(img_results)
|
| 186 |
+
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
| 187 |
+
img_gt = img_results[-1]
|
| 188 |
+
|
| 189 |
+
if self.flow_root is not None:
|
| 190 |
+
img_flows = img2tensor(img_flows)
|
| 191 |
+
# add the zero center flow
|
| 192 |
+
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
|
| 193 |
+
img_flows = torch.stack(img_flows, dim=0)
|
| 194 |
+
|
| 195 |
+
# img_lqs: (t, c, h, w)
|
| 196 |
+
# img_flows: (t, 2, h, w)
|
| 197 |
+
# img_gt: (c, h, w)
|
| 198 |
+
# key: str
|
| 199 |
+
if self.flow_root is not None:
|
| 200 |
+
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
|
| 201 |
+
else:
|
| 202 |
+
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
| 203 |
+
|
| 204 |
+
def __len__(self):
|
| 205 |
+
return len(self.keys)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@DATASET_REGISTRY.register()
|
| 209 |
+
class REDSRecurrentDataset(data.Dataset):
|
| 210 |
+
"""REDS dataset for training recurrent networks.
|
| 211 |
+
|
| 212 |
+
The keys are generated from a meta info txt file.
|
| 213 |
+
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
| 214 |
+
|
| 215 |
+
Each line contains:
|
| 216 |
+
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
| 217 |
+
a white space.
|
| 218 |
+
Examples:
|
| 219 |
+
000 100 (720,1280,3)
|
| 220 |
+
001 100 (720,1280,3)
|
| 221 |
+
...
|
| 222 |
+
|
| 223 |
+
Key examples: "000/00000000"
|
| 224 |
+
GT (gt): Ground-Truth;
|
| 225 |
+
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 229 |
+
dataroot_gt (str): Data root path for gt.
|
| 230 |
+
dataroot_lq (str): Data root path for lq.
|
| 231 |
+
dataroot_flow (str, optional): Data root path for flow.
|
| 232 |
+
meta_info_file (str): Path for meta information file.
|
| 233 |
+
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
| 234 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 235 |
+
num_frame (int): Window size for input frames.
|
| 236 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 237 |
+
interval_list (list): Interval list for temporal augmentation.
|
| 238 |
+
random_reverse (bool): Random reverse input frames.
|
| 239 |
+
use_hflip (bool): Use horizontal flips.
|
| 240 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 241 |
+
scale (bool): Scale, which will be added automatically.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, opt):
|
| 245 |
+
super(REDSRecurrentDataset, self).__init__()
|
| 246 |
+
self.opt = opt
|
| 247 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
| 248 |
+
self.num_frame = opt['num_frame']
|
| 249 |
+
|
| 250 |
+
self.keys = []
|
| 251 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 252 |
+
for line in fin:
|
| 253 |
+
folder, frame_num, _ = line.split(' ')
|
| 254 |
+
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
| 255 |
+
|
| 256 |
+
# remove the video clips used in validation
|
| 257 |
+
if opt['val_partition'] == 'REDS4':
|
| 258 |
+
val_partition = ['000', '011', '015', '020']
|
| 259 |
+
elif opt['val_partition'] == 'official':
|
| 260 |
+
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
| 263 |
+
f"Supported ones are ['official', 'REDS4'].")
|
| 264 |
+
if opt['test_mode']:
|
| 265 |
+
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
|
| 266 |
+
else:
|
| 267 |
+
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
| 268 |
+
|
| 269 |
+
# file client (io backend)
|
| 270 |
+
self.file_client = None
|
| 271 |
+
self.io_backend_opt = opt['io_backend']
|
| 272 |
+
self.is_lmdb = False
|
| 273 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 274 |
+
self.is_lmdb = True
|
| 275 |
+
if hasattr(self, 'flow_root') and self.flow_root is not None:
|
| 276 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
| 277 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
| 278 |
+
else:
|
| 279 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
| 280 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 281 |
+
|
| 282 |
+
# temporal augmentation configs
|
| 283 |
+
self.interval_list = opt.get('interval_list', [1])
|
| 284 |
+
self.random_reverse = opt.get('random_reverse', False)
|
| 285 |
+
interval_str = ','.join(str(x) for x in self.interval_list)
|
| 286 |
+
logger = get_root_logger()
|
| 287 |
+
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
| 288 |
+
f'random reverse is {self.random_reverse}.')
|
| 289 |
+
|
| 290 |
+
def __getitem__(self, index):
|
| 291 |
+
if self.file_client is None:
|
| 292 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 293 |
+
|
| 294 |
+
scale = self.opt['scale']
|
| 295 |
+
gt_size = self.opt['gt_size']
|
| 296 |
+
key = self.keys[index]
|
| 297 |
+
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
| 298 |
+
|
| 299 |
+
# determine the neighboring frames
|
| 300 |
+
interval = random.choice(self.interval_list)
|
| 301 |
+
|
| 302 |
+
# ensure not exceeding the borders
|
| 303 |
+
start_frame_idx = int(frame_name)
|
| 304 |
+
if start_frame_idx > 100 - self.num_frame * interval:
|
| 305 |
+
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
|
| 306 |
+
end_frame_idx = start_frame_idx + self.num_frame * interval
|
| 307 |
+
|
| 308 |
+
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
|
| 309 |
+
|
| 310 |
+
# random reverse
|
| 311 |
+
if self.random_reverse and random.random() < 0.5:
|
| 312 |
+
neighbor_list.reverse()
|
| 313 |
+
|
| 314 |
+
# get the neighboring LQ and GT frames
|
| 315 |
+
img_lqs = []
|
| 316 |
+
img_gts = []
|
| 317 |
+
for neighbor in neighbor_list:
|
| 318 |
+
if self.is_lmdb:
|
| 319 |
+
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
| 320 |
+
img_gt_path = f'{clip_name}/{neighbor:08d}'
|
| 321 |
+
else:
|
| 322 |
+
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
| 323 |
+
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
|
| 324 |
+
|
| 325 |
+
# get LQ
|
| 326 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| 327 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 328 |
+
img_lqs.append(img_lq)
|
| 329 |
+
|
| 330 |
+
# get GT
|
| 331 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| 332 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 333 |
+
img_gts.append(img_gt)
|
| 334 |
+
|
| 335 |
+
# randomly crop
|
| 336 |
+
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
| 337 |
+
|
| 338 |
+
# augmentation - flip, rotate
|
| 339 |
+
img_lqs.extend(img_gts)
|
| 340 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
| 341 |
+
|
| 342 |
+
img_results = img2tensor(img_results)
|
| 343 |
+
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
|
| 344 |
+
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
|
| 345 |
+
|
| 346 |
+
# img_lqs: (t, c, h, w)
|
| 347 |
+
# img_gts: (t, c, h, w)
|
| 348 |
+
# key: str
|
| 349 |
+
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
| 350 |
+
|
| 351 |
+
def __len__(self):
|
| 352 |
+
return len(self.keys)
|
basicsr/data/single_image_dataset.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import path as osp
|
| 2 |
+
from torch.utils import data as data
|
| 3 |
+
from torchvision.transforms.functional import normalize
|
| 4 |
+
|
| 5 |
+
from basicsr.data.data_util import paths_from_lmdb
|
| 6 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
|
| 7 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DATASET_REGISTRY.register()
|
| 11 |
+
class SingleImageDataset(data.Dataset):
|
| 12 |
+
"""Read only lq images in the test phase.
|
| 13 |
+
|
| 14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
|
| 15 |
+
|
| 16 |
+
There are two modes:
|
| 17 |
+
1. 'meta_info_file': Use meta information file to generate paths.
|
| 18 |
+
2. 'folder': Scan folders to generate paths.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 22 |
+
dataroot_lq (str): Data root path for lq.
|
| 23 |
+
meta_info_file (str): Path for meta information file.
|
| 24 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, opt):
|
| 28 |
+
super(SingleImageDataset, self).__init__()
|
| 29 |
+
self.opt = opt
|
| 30 |
+
# file client (io backend)
|
| 31 |
+
self.file_client = None
|
| 32 |
+
self.io_backend_opt = opt['io_backend']
|
| 33 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 34 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 35 |
+
self.lq_folder = opt['dataroot_lq']
|
| 36 |
+
|
| 37 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 38 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder]
|
| 39 |
+
self.io_backend_opt['client_keys'] = ['lq']
|
| 40 |
+
self.paths = paths_from_lmdb(self.lq_folder)
|
| 41 |
+
elif 'meta_info_file' in self.opt:
|
| 42 |
+
with open(self.opt['meta_info_file'], 'r') as fin:
|
| 43 |
+
self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
|
| 44 |
+
else:
|
| 45 |
+
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, index):
|
| 48 |
+
if self.file_client is None:
|
| 49 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 50 |
+
|
| 51 |
+
# load lq image
|
| 52 |
+
lq_path = self.paths[index]
|
| 53 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 54 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 55 |
+
|
| 56 |
+
# color space transform
|
| 57 |
+
if 'color' in self.opt and self.opt['color'] == 'y':
|
| 58 |
+
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
|
| 59 |
+
|
| 60 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 61 |
+
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
| 62 |
+
# normalize
|
| 63 |
+
if self.mean is not None or self.std is not None:
|
| 64 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 65 |
+
return {'lq': img_lq, 'lq_path': lq_path}
|
| 66 |
+
|
| 67 |
+
def __len__(self):
|
| 68 |
+
return len(self.paths)
|
basicsr/data/transforms.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def mod_crop(img, scale):
|
| 7 |
+
"""Mod crop images, used during testing.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
img (ndarray): Input image.
|
| 11 |
+
scale (int): Scale factor.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
ndarray: Result image.
|
| 15 |
+
"""
|
| 16 |
+
img = img.copy()
|
| 17 |
+
if img.ndim in (2, 3):
|
| 18 |
+
h, w = img.shape[0], img.shape[1]
|
| 19 |
+
h_remainder, w_remainder = h % scale, w % scale
|
| 20 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
| 27 |
+
"""Paired random crop. Support Numpy array and Tensor inputs.
|
| 28 |
+
|
| 29 |
+
It crops lists of lq and gt images with corresponding locations.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
| 33 |
+
should have the same shape. If the input is an ndarray, it will
|
| 34 |
+
be transformed to a list containing itself.
|
| 35 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 36 |
+
should have the same shape. If the input is an ndarray, it will
|
| 37 |
+
be transformed to a list containing itself.
|
| 38 |
+
gt_patch_size (int): GT patch size.
|
| 39 |
+
scale (int): Scale factor.
|
| 40 |
+
gt_path (str): Path to ground-truth. Default: None.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 44 |
+
only have one element, just return ndarray.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if not isinstance(img_gts, list):
|
| 48 |
+
img_gts = [img_gts]
|
| 49 |
+
if not isinstance(img_lqs, list):
|
| 50 |
+
img_lqs = [img_lqs]
|
| 51 |
+
|
| 52 |
+
# determine input type: Numpy array or Tensor
|
| 53 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
| 54 |
+
|
| 55 |
+
if input_type == 'Tensor':
|
| 56 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
| 57 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
| 58 |
+
else:
|
| 59 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
| 60 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
| 61 |
+
lq_patch_size = gt_patch_size // scale
|
| 62 |
+
|
| 63 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 64 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 65 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 66 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 67 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 68 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
| 69 |
+
f'Please remove {gt_path}.')
|
| 70 |
+
|
| 71 |
+
# randomly choose top and left coordinates for lq patch
|
| 72 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
| 73 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
| 74 |
+
|
| 75 |
+
# crop lq patch
|
| 76 |
+
if input_type == 'Tensor':
|
| 77 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
| 78 |
+
else:
|
| 79 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
| 80 |
+
|
| 81 |
+
# crop corresponding gt patch
|
| 82 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
| 83 |
+
if input_type == 'Tensor':
|
| 84 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
| 85 |
+
else:
|
| 86 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
| 87 |
+
if len(img_gts) == 1:
|
| 88 |
+
img_gts = img_gts[0]
|
| 89 |
+
if len(img_lqs) == 1:
|
| 90 |
+
img_lqs = img_lqs[0]
|
| 91 |
+
return img_gts, img_lqs
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
| 95 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
| 96 |
+
|
| 97 |
+
We use vertical flip and transpose for rotation implementation.
|
| 98 |
+
All the images in the list use the same augmentation.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
| 102 |
+
is an ndarray, it will be transformed to a list.
|
| 103 |
+
hflip (bool): Horizontal flip. Default: True.
|
| 104 |
+
rotation (bool): Ratotation. Default: True.
|
| 105 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
| 106 |
+
ndarray, it will be transformed to a list.
|
| 107 |
+
Dimension is (h, w, 2). Default: None.
|
| 108 |
+
return_status (bool): Return the status of flip and rotation.
|
| 109 |
+
Default: False.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
| 113 |
+
results only have one element, just return ndarray.
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
hflip = hflip and random.random() < 0.5
|
| 117 |
+
vflip = rotation and random.random() < 0.5
|
| 118 |
+
rot90 = rotation and random.random() < 0.5
|
| 119 |
+
|
| 120 |
+
def _augment(img):
|
| 121 |
+
if hflip: # horizontal
|
| 122 |
+
cv2.flip(img, 1, img)
|
| 123 |
+
if vflip: # vertical
|
| 124 |
+
cv2.flip(img, 0, img)
|
| 125 |
+
if rot90:
|
| 126 |
+
img = img.transpose(1, 0, 2)
|
| 127 |
+
return img
|
| 128 |
+
|
| 129 |
+
def _augment_flow(flow):
|
| 130 |
+
if hflip: # horizontal
|
| 131 |
+
cv2.flip(flow, 1, flow)
|
| 132 |
+
flow[:, :, 0] *= -1
|
| 133 |
+
if vflip: # vertical
|
| 134 |
+
cv2.flip(flow, 0, flow)
|
| 135 |
+
flow[:, :, 1] *= -1
|
| 136 |
+
if rot90:
|
| 137 |
+
flow = flow.transpose(1, 0, 2)
|
| 138 |
+
flow = flow[:, :, [1, 0]]
|
| 139 |
+
return flow
|
| 140 |
+
|
| 141 |
+
if not isinstance(imgs, list):
|
| 142 |
+
imgs = [imgs]
|
| 143 |
+
imgs = [_augment(img) for img in imgs]
|
| 144 |
+
if len(imgs) == 1:
|
| 145 |
+
imgs = imgs[0]
|
| 146 |
+
|
| 147 |
+
if flows is not None:
|
| 148 |
+
if not isinstance(flows, list):
|
| 149 |
+
flows = [flows]
|
| 150 |
+
flows = [_augment_flow(flow) for flow in flows]
|
| 151 |
+
if len(flows) == 1:
|
| 152 |
+
flows = flows[0]
|
| 153 |
+
return imgs, flows
|
| 154 |
+
else:
|
| 155 |
+
if return_status:
|
| 156 |
+
return imgs, (hflip, vflip, rot90)
|
| 157 |
+
else:
|
| 158 |
+
return imgs
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
| 162 |
+
"""Rotate image.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
img (ndarray): Image to be rotated.
|
| 166 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
| 167 |
+
counter-clockwise rotation.
|
| 168 |
+
center (tuple[int]): Rotation center. If the center is None,
|
| 169 |
+
initialize it as the center of the image. Default: None.
|
| 170 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
| 171 |
+
"""
|
| 172 |
+
(h, w) = img.shape[:2]
|
| 173 |
+
|
| 174 |
+
if center is None:
|
| 175 |
+
center = (w // 2, h // 2)
|
| 176 |
+
|
| 177 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 178 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
| 179 |
+
return rotated_img
|
basicsr/data/video_test_dataset.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import torch
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from torch.utils import data as data
|
| 5 |
+
|
| 6 |
+
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
|
| 7 |
+
from basicsr.utils import get_root_logger, scandir
|
| 8 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DATASET_REGISTRY.register()
|
| 12 |
+
class VideoTestDataset(data.Dataset):
|
| 13 |
+
"""Video test dataset.
|
| 14 |
+
|
| 15 |
+
Supported datasets: Vid4, REDS4, REDSofficial.
|
| 16 |
+
More generally, it supports testing dataset with following structures:
|
| 17 |
+
|
| 18 |
+
::
|
| 19 |
+
|
| 20 |
+
dataroot
|
| 21 |
+
βββ subfolder1
|
| 22 |
+
βββ frame000
|
| 23 |
+
βββ frame001
|
| 24 |
+
βββ ...
|
| 25 |
+
βββ subfolder2
|
| 26 |
+
βββ frame000
|
| 27 |
+
βββ frame001
|
| 28 |
+
βββ ...
|
| 29 |
+
βββ ...
|
| 30 |
+
|
| 31 |
+
For testing datasets, there is no need to prepare LMDB files.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 35 |
+
dataroot_gt (str): Data root path for gt.
|
| 36 |
+
dataroot_lq (str): Data root path for lq.
|
| 37 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 38 |
+
cache_data (bool): Whether to cache testing datasets.
|
| 39 |
+
name (str): Dataset name.
|
| 40 |
+
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
| 41 |
+
in the dataroot will be used.
|
| 42 |
+
num_frame (int): Window size for input frames.
|
| 43 |
+
padding (str): Padding mode.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, opt):
|
| 47 |
+
super(VideoTestDataset, self).__init__()
|
| 48 |
+
self.opt = opt
|
| 49 |
+
self.cache_data = opt['cache_data']
|
| 50 |
+
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
| 51 |
+
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
| 52 |
+
# file client (io backend)
|
| 53 |
+
self.file_client = None
|
| 54 |
+
self.io_backend_opt = opt['io_backend']
|
| 55 |
+
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
| 56 |
+
|
| 57 |
+
logger = get_root_logger()
|
| 58 |
+
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
| 59 |
+
self.imgs_lq, self.imgs_gt = {}, {}
|
| 60 |
+
if 'meta_info_file' in opt:
|
| 61 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 62 |
+
subfolders = [line.split(' ')[0] for line in fin]
|
| 63 |
+
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
|
| 64 |
+
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
|
| 65 |
+
else:
|
| 66 |
+
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
|
| 67 |
+
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
|
| 68 |
+
|
| 69 |
+
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
|
| 70 |
+
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
|
| 71 |
+
# get frame list for lq and gt
|
| 72 |
+
subfolder_name = osp.basename(subfolder_lq)
|
| 73 |
+
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
|
| 74 |
+
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
|
| 75 |
+
|
| 76 |
+
max_idx = len(img_paths_lq)
|
| 77 |
+
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
|
| 78 |
+
f' and gt folders ({len(img_paths_gt)})')
|
| 79 |
+
|
| 80 |
+
self.data_info['lq_path'].extend(img_paths_lq)
|
| 81 |
+
self.data_info['gt_path'].extend(img_paths_gt)
|
| 82 |
+
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
| 83 |
+
for i in range(max_idx):
|
| 84 |
+
self.data_info['idx'].append(f'{i}/{max_idx}')
|
| 85 |
+
border_l = [0] * max_idx
|
| 86 |
+
for i in range(self.opt['num_frame'] // 2):
|
| 87 |
+
border_l[i] = 1
|
| 88 |
+
border_l[max_idx - i - 1] = 1
|
| 89 |
+
self.data_info['border'].extend(border_l)
|
| 90 |
+
|
| 91 |
+
# cache data or save the frame list
|
| 92 |
+
if self.cache_data:
|
| 93 |
+
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
|
| 94 |
+
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
|
| 95 |
+
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
|
| 96 |
+
else:
|
| 97 |
+
self.imgs_lq[subfolder_name] = img_paths_lq
|
| 98 |
+
self.imgs_gt[subfolder_name] = img_paths_gt
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, index):
|
| 103 |
+
folder = self.data_info['folder'][index]
|
| 104 |
+
idx, max_idx = self.data_info['idx'][index].split('/')
|
| 105 |
+
idx, max_idx = int(idx), int(max_idx)
|
| 106 |
+
border = self.data_info['border'][index]
|
| 107 |
+
lq_path = self.data_info['lq_path'][index]
|
| 108 |
+
|
| 109 |
+
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
| 110 |
+
|
| 111 |
+
if self.cache_data:
|
| 112 |
+
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
| 113 |
+
img_gt = self.imgs_gt[folder][idx]
|
| 114 |
+
else:
|
| 115 |
+
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
| 116 |
+
imgs_lq = read_img_seq(img_paths_lq)
|
| 117 |
+
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
|
| 118 |
+
img_gt.squeeze_(0)
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
'lq': imgs_lq, # (t, c, h, w)
|
| 122 |
+
'gt': img_gt, # (c, h, w)
|
| 123 |
+
'folder': folder, # folder name
|
| 124 |
+
'idx': self.data_info['idx'][index], # e.g., 0/99
|
| 125 |
+
'border': border, # 1 for border, 0 for non-border
|
| 126 |
+
'lq_path': lq_path # center frame
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def __len__(self):
|
| 130 |
+
return len(self.data_info['gt_path'])
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@DATASET_REGISTRY.register()
|
| 134 |
+
class VideoTestVimeo90KDataset(data.Dataset):
|
| 135 |
+
"""Video test dataset for Vimeo90k-Test dataset.
|
| 136 |
+
|
| 137 |
+
It only keeps the center frame for testing.
|
| 138 |
+
For testing datasets, there is no need to prepare LMDB files.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 142 |
+
dataroot_gt (str): Data root path for gt.
|
| 143 |
+
dataroot_lq (str): Data root path for lq.
|
| 144 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 145 |
+
cache_data (bool): Whether to cache testing datasets.
|
| 146 |
+
name (str): Dataset name.
|
| 147 |
+
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
| 148 |
+
in the dataroot will be used.
|
| 149 |
+
num_frame (int): Window size for input frames.
|
| 150 |
+
padding (str): Padding mode.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, opt):
|
| 154 |
+
super(VideoTestVimeo90KDataset, self).__init__()
|
| 155 |
+
self.opt = opt
|
| 156 |
+
self.cache_data = opt['cache_data']
|
| 157 |
+
if self.cache_data:
|
| 158 |
+
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
|
| 159 |
+
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
| 160 |
+
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
| 161 |
+
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
| 162 |
+
|
| 163 |
+
# file client (io backend)
|
| 164 |
+
self.file_client = None
|
| 165 |
+
self.io_backend_opt = opt['io_backend']
|
| 166 |
+
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
| 167 |
+
|
| 168 |
+
logger = get_root_logger()
|
| 169 |
+
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
| 170 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 171 |
+
subfolders = [line.split(' ')[0] for line in fin]
|
| 172 |
+
for idx, subfolder in enumerate(subfolders):
|
| 173 |
+
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
|
| 174 |
+
self.data_info['gt_path'].append(gt_path)
|
| 175 |
+
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
|
| 176 |
+
self.data_info['lq_path'].append(lq_paths)
|
| 177 |
+
self.data_info['folder'].append('vimeo90k')
|
| 178 |
+
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
|
| 179 |
+
self.data_info['border'].append(0)
|
| 180 |
+
|
| 181 |
+
def __getitem__(self, index):
|
| 182 |
+
lq_path = self.data_info['lq_path'][index]
|
| 183 |
+
gt_path = self.data_info['gt_path'][index]
|
| 184 |
+
imgs_lq = read_img_seq(lq_path)
|
| 185 |
+
img_gt = read_img_seq([gt_path])
|
| 186 |
+
img_gt.squeeze_(0)
|
| 187 |
+
|
| 188 |
+
return {
|
| 189 |
+
'lq': imgs_lq, # (t, c, h, w)
|
| 190 |
+
'gt': img_gt, # (c, h, w)
|
| 191 |
+
'folder': self.data_info['folder'][index], # folder name
|
| 192 |
+
'idx': self.data_info['idx'][index], # e.g., 0/843
|
| 193 |
+
'border': self.data_info['border'][index], # 0 for non-border
|
| 194 |
+
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
def __len__(self):
|
| 198 |
+
return len(self.data_info['gt_path'])
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@DATASET_REGISTRY.register()
|
| 202 |
+
class VideoTestDUFDataset(VideoTestDataset):
|
| 203 |
+
""" Video test dataset for DUF dataset.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
|
| 207 |
+
It has the following extra keys:
|
| 208 |
+
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
|
| 209 |
+
scale (bool): Scale, which will be added automatically.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __getitem__(self, index):
|
| 213 |
+
folder = self.data_info['folder'][index]
|
| 214 |
+
idx, max_idx = self.data_info['idx'][index].split('/')
|
| 215 |
+
idx, max_idx = int(idx), int(max_idx)
|
| 216 |
+
border = self.data_info['border'][index]
|
| 217 |
+
lq_path = self.data_info['lq_path'][index]
|
| 218 |
+
|
| 219 |
+
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
| 220 |
+
|
| 221 |
+
if self.cache_data:
|
| 222 |
+
if self.opt['use_duf_downsampling']:
|
| 223 |
+
# read imgs_gt to generate low-resolution frames
|
| 224 |
+
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
|
| 225 |
+
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
| 226 |
+
else:
|
| 227 |
+
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
| 228 |
+
img_gt = self.imgs_gt[folder][idx]
|
| 229 |
+
else:
|
| 230 |
+
if self.opt['use_duf_downsampling']:
|
| 231 |
+
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
|
| 232 |
+
# read imgs_gt to generate low-resolution frames
|
| 233 |
+
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
|
| 234 |
+
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
| 235 |
+
else:
|
| 236 |
+
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
| 237 |
+
imgs_lq = read_img_seq(img_paths_lq)
|
| 238 |
+
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
|
| 239 |
+
img_gt.squeeze_(0)
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
'lq': imgs_lq, # (t, c, h, w)
|
| 243 |
+
'gt': img_gt, # (c, h, w)
|
| 244 |
+
'folder': folder, # folder name
|
| 245 |
+
'idx': self.data_info['idx'][index], # e.g., 0/99
|
| 246 |
+
'border': border, # 1 for border, 0 for non-border
|
| 247 |
+
'lq_path': lq_path # center frame
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@DATASET_REGISTRY.register()
|
| 252 |
+
class VideoRecurrentTestDataset(VideoTestDataset):
|
| 253 |
+
"""Video test dataset for recurrent architectures, which takes LR video
|
| 254 |
+
frames as input and output corresponding HR video frames.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
opt (dict): Same as VideoTestDataset. Unused opt:
|
| 258 |
+
padding (str): Padding mode.
|
| 259 |
+
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, opt):
|
| 263 |
+
super(VideoRecurrentTestDataset, self).__init__(opt)
|
| 264 |
+
# Find unique folder strings
|
| 265 |
+
self.folders = sorted(list(set(self.data_info['folder'])))
|
| 266 |
+
|
| 267 |
+
def __getitem__(self, index):
|
| 268 |
+
folder = self.folders[index]
|
| 269 |
+
|
| 270 |
+
if self.cache_data:
|
| 271 |
+
imgs_lq = self.imgs_lq[folder]
|
| 272 |
+
imgs_gt = self.imgs_gt[folder]
|
| 273 |
+
else:
|
| 274 |
+
raise NotImplementedError('Without cache_data is not implemented.')
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
'lq': imgs_lq,
|
| 278 |
+
'gt': imgs_gt,
|
| 279 |
+
'folder': folder,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
def __len__(self):
|
| 283 |
+
return len(self.folders)
|
basicsr/data/vimeo90k_dataset.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from torch.utils import data as data
|
| 5 |
+
|
| 6 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 7 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 8 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DATASET_REGISTRY.register()
|
| 12 |
+
class Vimeo90KDataset(data.Dataset):
|
| 13 |
+
"""Vimeo90K dataset for training.
|
| 14 |
+
|
| 15 |
+
The keys are generated from a meta info txt file.
|
| 16 |
+
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
|
| 17 |
+
|
| 18 |
+
Each line contains the following items, separated by a white space.
|
| 19 |
+
|
| 20 |
+
1. clip name;
|
| 21 |
+
2. frame number;
|
| 22 |
+
3. image shape
|
| 23 |
+
|
| 24 |
+
Examples:
|
| 25 |
+
|
| 26 |
+
::
|
| 27 |
+
|
| 28 |
+
00001/0001 7 (256,448,3)
|
| 29 |
+
00001/0002 7 (256,448,3)
|
| 30 |
+
|
| 31 |
+
- Key examples: "00001/0001"
|
| 32 |
+
- GT (gt): Ground-Truth;
|
| 33 |
+
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
| 34 |
+
|
| 35 |
+
The neighboring frame list for different num_frame:
|
| 36 |
+
|
| 37 |
+
::
|
| 38 |
+
|
| 39 |
+
num_frame | frame list
|
| 40 |
+
1 | 4
|
| 41 |
+
3 | 3,4,5
|
| 42 |
+
5 | 2,3,4,5,6
|
| 43 |
+
7 | 1,2,3,4,5,6,7
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 47 |
+
dataroot_gt (str): Data root path for gt.
|
| 48 |
+
dataroot_lq (str): Data root path for lq.
|
| 49 |
+
meta_info_file (str): Path for meta information file.
|
| 50 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 51 |
+
num_frame (int): Window size for input frames.
|
| 52 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 53 |
+
random_reverse (bool): Random reverse input frames.
|
| 54 |
+
use_hflip (bool): Use horizontal flips.
|
| 55 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 56 |
+
scale (bool): Scale, which will be added automatically.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, opt):
|
| 60 |
+
super(Vimeo90KDataset, self).__init__()
|
| 61 |
+
self.opt = opt
|
| 62 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
| 63 |
+
|
| 64 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 65 |
+
self.keys = [line.split(' ')[0] for line in fin]
|
| 66 |
+
|
| 67 |
+
# file client (io backend)
|
| 68 |
+
self.file_client = None
|
| 69 |
+
self.io_backend_opt = opt['io_backend']
|
| 70 |
+
self.is_lmdb = False
|
| 71 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 72 |
+
self.is_lmdb = True
|
| 73 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
| 74 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 75 |
+
|
| 76 |
+
# indices of input images
|
| 77 |
+
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
| 78 |
+
|
| 79 |
+
# temporal augmentation configs
|
| 80 |
+
self.random_reverse = opt['random_reverse']
|
| 81 |
+
logger = get_root_logger()
|
| 82 |
+
logger.info(f'Random reverse is {self.random_reverse}.')
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, index):
|
| 85 |
+
if self.file_client is None:
|
| 86 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 87 |
+
|
| 88 |
+
# random reverse
|
| 89 |
+
if self.random_reverse and random.random() < 0.5:
|
| 90 |
+
self.neighbor_list.reverse()
|
| 91 |
+
|
| 92 |
+
scale = self.opt['scale']
|
| 93 |
+
gt_size = self.opt['gt_size']
|
| 94 |
+
key = self.keys[index]
|
| 95 |
+
clip, seq = key.split('/') # key example: 00001/0001
|
| 96 |
+
|
| 97 |
+
# get the GT frame (im4.png)
|
| 98 |
+
if self.is_lmdb:
|
| 99 |
+
img_gt_path = f'{key}/im4'
|
| 100 |
+
else:
|
| 101 |
+
img_gt_path = self.gt_root / clip / seq / 'im4.png'
|
| 102 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| 103 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 104 |
+
|
| 105 |
+
# get the neighboring LQ frames
|
| 106 |
+
img_lqs = []
|
| 107 |
+
for neighbor in self.neighbor_list:
|
| 108 |
+
if self.is_lmdb:
|
| 109 |
+
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
| 110 |
+
else:
|
| 111 |
+
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
| 112 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| 113 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 114 |
+
img_lqs.append(img_lq)
|
| 115 |
+
|
| 116 |
+
# randomly crop
|
| 117 |
+
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
| 118 |
+
|
| 119 |
+
# augmentation - flip, rotate
|
| 120 |
+
img_lqs.append(img_gt)
|
| 121 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
| 122 |
+
|
| 123 |
+
img_results = img2tensor(img_results)
|
| 124 |
+
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
| 125 |
+
img_gt = img_results[-1]
|
| 126 |
+
|
| 127 |
+
# img_lqs: (t, c, h, w)
|
| 128 |
+
# img_gt: (c, h, w)
|
| 129 |
+
# key: str
|
| 130 |
+
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return len(self.keys)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@DATASET_REGISTRY.register()
|
| 137 |
+
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
|
| 138 |
+
|
| 139 |
+
def __init__(self, opt):
|
| 140 |
+
super(Vimeo90KRecurrentDataset, self).__init__(opt)
|
| 141 |
+
|
| 142 |
+
self.flip_sequence = opt['flip_sequence']
|
| 143 |
+
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, index):
|
| 146 |
+
if self.file_client is None:
|
| 147 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 148 |
+
|
| 149 |
+
# random reverse
|
| 150 |
+
if self.random_reverse and random.random() < 0.5:
|
| 151 |
+
self.neighbor_list.reverse()
|
| 152 |
+
|
| 153 |
+
scale = self.opt['scale']
|
| 154 |
+
gt_size = self.opt['gt_size']
|
| 155 |
+
key = self.keys[index]
|
| 156 |
+
clip, seq = key.split('/') # key example: 00001/0001
|
| 157 |
+
|
| 158 |
+
# get the neighboring LQ and GT frames
|
| 159 |
+
img_lqs = []
|
| 160 |
+
img_gts = []
|
| 161 |
+
for neighbor in self.neighbor_list:
|
| 162 |
+
if self.is_lmdb:
|
| 163 |
+
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
| 164 |
+
img_gt_path = f'{clip}/{seq}/im{neighbor}'
|
| 165 |
+
else:
|
| 166 |
+
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
| 167 |
+
img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
|
| 168 |
+
# LQ
|
| 169 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| 170 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 171 |
+
# GT
|
| 172 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| 173 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 174 |
+
|
| 175 |
+
img_lqs.append(img_lq)
|
| 176 |
+
img_gts.append(img_gt)
|
| 177 |
+
|
| 178 |
+
# randomly crop
|
| 179 |
+
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
| 180 |
+
|
| 181 |
+
# augmentation - flip, rotate
|
| 182 |
+
img_lqs.extend(img_gts)
|
| 183 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
| 184 |
+
|
| 185 |
+
img_results = img2tensor(img_results)
|
| 186 |
+
img_lqs = torch.stack(img_results[:7], dim=0)
|
| 187 |
+
img_gts = torch.stack(img_results[7:], dim=0)
|
| 188 |
+
|
| 189 |
+
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
|
| 190 |
+
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
|
| 191 |
+
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
|
| 192 |
+
|
| 193 |
+
# img_lqs: (t, c, h, w)
|
| 194 |
+
# img_gt: (c, h, w)
|
| 195 |
+
# key: str
|
| 196 |
+
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
| 197 |
+
|
| 198 |
+
def __len__(self):
|
| 199 |
+
return len(self.keys)
|