Commit ·
8e79984
1
Parent(s): dac2323
Initial commit of FPro dehazing model
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- basicsr/.DS_Store +0 -0
- basicsr/__pycache__/version.cpython-37.pyc +0 -0
- basicsr/data/.DS_Store +0 -0
- basicsr/data/__init__.py +126 -0
- basicsr/data/__pycache__/__init__.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/data_sampler.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/data_util.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/reds_dataset.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/transforms.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc +0 -0
- basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc +0 -0
- basicsr/data/data_sampler.py +49 -0
- basicsr/data/data_util.py +388 -0
- basicsr/data/ffhq_dataset.py +65 -0
- basicsr/data/paired_image_dataset.py +824 -0
- basicsr/data/prefetch_dataloader.py +126 -0
- basicsr/data/reds_dataset.py +237 -0
- basicsr/data/single_image_dataset.py +67 -0
- basicsr/data/transforms.py +480 -0
- basicsr/data/video_test_dataset.py +325 -0
- basicsr/data/vimeo90k_dataset.py +130 -0
- basicsr/metrics/__init__.py +4 -0
- basicsr/metrics/__pycache__/__init__.cpython-37.pyc +0 -0
- basicsr/metrics/__pycache__/metric_util.cpython-37.pyc +0 -0
- basicsr/metrics/__pycache__/niqe.cpython-37.pyc +0 -0
- basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc +0 -0
- basicsr/metrics/fid.py +102 -0
- basicsr/metrics/metric_util.py +47 -0
- basicsr/metrics/niqe.py +205 -0
- basicsr/metrics/niqe_pris_params.npz +3 -0
- basicsr/metrics/psnr_ssim.py +303 -0
- basicsr/models/.DS_Store +0 -0
- basicsr/models/__init__.py +42 -0
- basicsr/models/__pycache__/__init__.cpython-37.pyc +0 -0
- basicsr/models/__pycache__/base_model.cpython-37.pyc +0 -0
- basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc +0 -0
- basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc +0 -0
- basicsr/models/archs/FPro_arch.py +545 -0
- basicsr/models/archs/__init__.py +46 -0
- basicsr/models/archs/__pycache__/__init__.cpython-37.pyc +0 -0
- basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc +0 -0
- basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc +0 -0
- basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc +0 -0
- basicsr/models/archs/arch_util.py +255 -0
- basicsr/models/base_model.py +378 -0
- basicsr/models/image_restoration_model.py +361 -0
basicsr/.DS_Store
ADDED
|
Binary file (10.2 kB). View file
|
|
|
basicsr/__pycache__/version.cpython-37.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
basicsr/data/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
basicsr/data/__init__.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from functools import partial
|
| 7 |
+
from os import path as osp
|
| 8 |
+
|
| 9 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
| 10 |
+
from basicsr.utils import get_root_logger, scandir
|
| 11 |
+
from basicsr.utils.dist_util import get_dist_info
|
| 12 |
+
|
| 13 |
+
__all__ = ['create_dataset', 'create_dataloader']
|
| 14 |
+
|
| 15 |
+
# automatically scan and import dataset modules
|
| 16 |
+
# scan all the files under the data folder with '_dataset' in file names
|
| 17 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 18 |
+
dataset_filenames = [
|
| 19 |
+
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
|
| 20 |
+
if v.endswith('_dataset.py')
|
| 21 |
+
]
|
| 22 |
+
# import all the dataset modules
|
| 23 |
+
_dataset_modules = [
|
| 24 |
+
importlib.import_module(f'basicsr.data.{file_name}')
|
| 25 |
+
for file_name in dataset_filenames
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_dataset(dataset_opt):
|
| 30 |
+
"""Create dataset.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
dataset_opt (dict): Configuration for dataset. It constains:
|
| 34 |
+
name (str): Dataset name.
|
| 35 |
+
type (str): Dataset type.
|
| 36 |
+
"""
|
| 37 |
+
dataset_type = dataset_opt['type']
|
| 38 |
+
|
| 39 |
+
# dynamic instantiation
|
| 40 |
+
for module in _dataset_modules:
|
| 41 |
+
dataset_cls = getattr(module, dataset_type, None)
|
| 42 |
+
if dataset_cls is not None:
|
| 43 |
+
break
|
| 44 |
+
if dataset_cls is None:
|
| 45 |
+
raise ValueError(f'Dataset {dataset_type} is not found.')
|
| 46 |
+
|
| 47 |
+
dataset = dataset_cls(dataset_opt)
|
| 48 |
+
|
| 49 |
+
logger = get_root_logger()
|
| 50 |
+
logger.info(
|
| 51 |
+
f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
|
| 52 |
+
'is created.')
|
| 53 |
+
return dataset
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def create_dataloader(dataset,
|
| 57 |
+
dataset_opt,
|
| 58 |
+
num_gpu=1,
|
| 59 |
+
dist=False,
|
| 60 |
+
sampler=None,
|
| 61 |
+
seed=None):
|
| 62 |
+
"""Create dataloader.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
| 66 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
| 67 |
+
phase (str): 'train' or 'val'.
|
| 68 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
| 69 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
| 70 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
| 71 |
+
Default: 1.
|
| 72 |
+
dist (bool): Whether in distributed training. Used only in the train
|
| 73 |
+
phase. Default: False.
|
| 74 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
| 75 |
+
seed (int | None): Seed. Default: None
|
| 76 |
+
"""
|
| 77 |
+
phase = dataset_opt['phase']
|
| 78 |
+
rank, _ = get_dist_info()
|
| 79 |
+
if phase == 'train':
|
| 80 |
+
if dist: # distributed training
|
| 81 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
| 82 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
| 83 |
+
else: # non-distributed training
|
| 84 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
| 85 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
| 86 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
| 87 |
+
dataloader_args = dict(
|
| 88 |
+
dataset=dataset,
|
| 89 |
+
batch_size=batch_size,
|
| 90 |
+
shuffle=False,
|
| 91 |
+
num_workers=num_workers,
|
| 92 |
+
sampler=sampler,
|
| 93 |
+
drop_last=True)
|
| 94 |
+
if sampler is None:
|
| 95 |
+
dataloader_args['shuffle'] = True
|
| 96 |
+
dataloader_args['worker_init_fn'] = partial(
|
| 97 |
+
worker_init_fn, num_workers=num_workers, rank=rank,
|
| 98 |
+
seed=seed) if seed is not None else None
|
| 99 |
+
elif phase in ['val', 'test']: # validation
|
| 100 |
+
dataloader_args = dict(
|
| 101 |
+
dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f'Wrong dataset phase: {phase}. '
|
| 104 |
+
"Supported ones are 'train', 'val' and 'test'.")
|
| 105 |
+
|
| 106 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
| 107 |
+
|
| 108 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
| 109 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
| 110 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
| 111 |
+
logger = get_root_logger()
|
| 112 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: '
|
| 113 |
+
f'num_prefetch_queue = {num_prefetch_queue}')
|
| 114 |
+
return PrefetchDataLoader(
|
| 115 |
+
num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
| 116 |
+
else:
|
| 117 |
+
# prefetch_mode=None: Normal dataloader
|
| 118 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
| 119 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
| 123 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
| 124 |
+
worker_seed = num_workers * rank + worker_id + seed
|
| 125 |
+
np.random.seed(worker_seed)
|
| 126 |
+
random.seed(worker_seed)
|
basicsr/data/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
basicsr/data/__pycache__/data_sampler.cpython-37.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
basicsr/data/__pycache__/data_util.cpython-37.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc
ADDED
|
Binary file (2.54 kB). View file
|
|
|
basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
basicsr/data/__pycache__/reds_dataset.cpython-37.pyc
ADDED
|
Binary file (6.44 kB). View file
|
|
|
basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc
ADDED
|
Binary file (2.61 kB). View file
|
|
|
basicsr/data/__pycache__/transforms.cpython-37.pyc
ADDED
|
Binary file (9.85 kB). View file
|
|
|
basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc
ADDED
|
Binary file (4.16 kB). View file
|
|
|
basicsr/data/data_sampler.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.sampler import Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EnlargedSampler(Sampler):
|
| 7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 8 |
+
|
| 9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
| 10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
| 11 |
+
time when restart the dataloader after each epoch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
| 15 |
+
num_replicas (int | None): Number of processes participating in
|
| 16 |
+
the training. It is usually the world_size.
|
| 17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
| 18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.num_replicas = num_replicas
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.epoch = 0
|
| 26 |
+
self.num_samples = math.ceil(
|
| 27 |
+
len(self.dataset) * ratio / self.num_replicas)
|
| 28 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 29 |
+
|
| 30 |
+
def __iter__(self):
|
| 31 |
+
# deterministically shuffle based on epoch
|
| 32 |
+
g = torch.Generator()
|
| 33 |
+
g.manual_seed(self.epoch)
|
| 34 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
| 35 |
+
|
| 36 |
+
dataset_size = len(self.dataset)
|
| 37 |
+
indices = [v % dataset_size for v in indices]
|
| 38 |
+
|
| 39 |
+
# subsample
|
| 40 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 41 |
+
assert len(indices) == self.num_samples
|
| 42 |
+
|
| 43 |
+
return iter(indices)
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return self.num_samples
|
| 47 |
+
|
| 48 |
+
def set_epoch(self, epoch):
|
| 49 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
cv2.setNumThreads(1)
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from os import path as osp
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from basicsr.data.transforms import mod_crop
|
| 9 |
+
from basicsr.utils import img2tensor, scandir
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def read_img_seq(path, require_mod_crop=False, scale=1):
|
| 13 |
+
"""Read a sequence of images from a given folder path.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
path (list[str] | str): List of image paths or image folder path.
|
| 17 |
+
require_mod_crop (bool): Require mod crop for each image.
|
| 18 |
+
Default: False.
|
| 19 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
| 23 |
+
"""
|
| 24 |
+
if isinstance(path, list):
|
| 25 |
+
img_paths = path
|
| 26 |
+
else:
|
| 27 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
| 28 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
| 29 |
+
if require_mod_crop:
|
| 30 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
| 31 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
| 32 |
+
imgs = torch.stack(imgs, dim=0)
|
| 33 |
+
return imgs
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_frame_indices(crt_idx,
|
| 37 |
+
max_frame_num,
|
| 38 |
+
num_frames,
|
| 39 |
+
padding='reflection'):
|
| 40 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
| 41 |
+
of images.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
crt_idx (int): Current center index.
|
| 45 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
| 46 |
+
num_frames (int): Reading num_frames frames.
|
| 47 |
+
padding (str): Padding mode, one of
|
| 48 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
| 49 |
+
Examples: current_idx = 0, num_frames = 5
|
| 50 |
+
The generated frame indices under different padding mode:
|
| 51 |
+
replicate: [0, 0, 0, 1, 2]
|
| 52 |
+
reflection: [2, 1, 0, 1, 2]
|
| 53 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
| 54 |
+
circle: [3, 4, 0, 1, 2]
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
list[int]: A list of indices.
|
| 58 |
+
"""
|
| 59 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
| 60 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle',
|
| 61 |
+
'circle'), f'Wrong padding mode: {padding}.'
|
| 62 |
+
|
| 63 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
| 64 |
+
num_pad = num_frames // 2
|
| 65 |
+
|
| 66 |
+
indices = []
|
| 67 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
| 68 |
+
if i < 0:
|
| 69 |
+
if padding == 'replicate':
|
| 70 |
+
pad_idx = 0
|
| 71 |
+
elif padding == 'reflection':
|
| 72 |
+
pad_idx = -i
|
| 73 |
+
elif padding == 'reflection_circle':
|
| 74 |
+
pad_idx = crt_idx + num_pad - i
|
| 75 |
+
else:
|
| 76 |
+
pad_idx = num_frames + i
|
| 77 |
+
elif i > max_frame_num:
|
| 78 |
+
if padding == 'replicate':
|
| 79 |
+
pad_idx = max_frame_num
|
| 80 |
+
elif padding == 'reflection':
|
| 81 |
+
pad_idx = max_frame_num * 2 - i
|
| 82 |
+
elif padding == 'reflection_circle':
|
| 83 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
| 84 |
+
else:
|
| 85 |
+
pad_idx = i - num_frames
|
| 86 |
+
else:
|
| 87 |
+
pad_idx = i
|
| 88 |
+
indices.append(pad_idx)
|
| 89 |
+
return indices
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def paired_paths_from_lmdb(folders, keys):
|
| 93 |
+
"""Generate paired paths from lmdb files.
|
| 94 |
+
|
| 95 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
| 96 |
+
|
| 97 |
+
lq.lmdb
|
| 98 |
+
├── data.mdb
|
| 99 |
+
├── lock.mdb
|
| 100 |
+
├── meta_info.txt
|
| 101 |
+
|
| 102 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
| 103 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
| 104 |
+
|
| 105 |
+
The meta_info.txt is a specified txt file to record the meta information
|
| 106 |
+
of our datasets. It will be automatically created when preparing
|
| 107 |
+
datasets by our provided dataset tools.
|
| 108 |
+
Each line in the txt file records
|
| 109 |
+
1)image name (with extension),
|
| 110 |
+
2)image shape,
|
| 111 |
+
3)compression level, separated by a white space.
|
| 112 |
+
Example: `baboon.png (120,125,3) 1`
|
| 113 |
+
|
| 114 |
+
We use the image name without extension as the lmdb key.
|
| 115 |
+
Note that we use the same key for the corresponding lq and gt images.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 119 |
+
be [input_folder, gt_folder].
|
| 120 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 121 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 122 |
+
Note that this key is different from lmdb keys.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
list[str]: Returned path list.
|
| 126 |
+
"""
|
| 127 |
+
assert len(folders) == 2, (
|
| 128 |
+
'The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 129 |
+
f'But got {len(folders)}')
|
| 130 |
+
assert len(keys) == 2, (
|
| 131 |
+
'The len of keys should be 2 with [input_key, gt_key]. '
|
| 132 |
+
f'But got {len(keys)}')
|
| 133 |
+
input_folder, gt_folder = folders
|
| 134 |
+
input_key, gt_key = keys
|
| 135 |
+
|
| 136 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f'{input_key} folder and {gt_key} folder should both in lmdb '
|
| 139 |
+
f'formats. But received {input_key}: {input_folder}; '
|
| 140 |
+
f'{gt_key}: {gt_folder}')
|
| 141 |
+
# ensure that the two meta_info files are the same
|
| 142 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
| 143 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 144 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
| 145 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 146 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
| 149 |
+
else:
|
| 150 |
+
paths = []
|
| 151 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
| 152 |
+
paths.append(
|
| 153 |
+
dict([(f'{input_key}_path', lmdb_key),
|
| 154 |
+
(f'{gt_key}_path', lmdb_key)]))
|
| 155 |
+
return paths
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file,
|
| 159 |
+
filename_tmpl):
|
| 160 |
+
"""Generate paired paths from an meta information file.
|
| 161 |
+
|
| 162 |
+
Each line in the meta information file contains the image names and
|
| 163 |
+
image shape (usually for gt), separated by a white space.
|
| 164 |
+
|
| 165 |
+
Example of an meta information file:
|
| 166 |
+
```
|
| 167 |
+
0001_s001.png (480,480,3)
|
| 168 |
+
0001_s002.png (480,480,3)
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 173 |
+
be [input_folder, gt_folder].
|
| 174 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 175 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 176 |
+
meta_info_file (str): Path to the meta information file.
|
| 177 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 178 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 179 |
+
for files in the input folder.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
list[str]: Returned path list.
|
| 183 |
+
"""
|
| 184 |
+
assert len(folders) == 2, (
|
| 185 |
+
'The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 186 |
+
f'But got {len(folders)}')
|
| 187 |
+
assert len(keys) == 2, (
|
| 188 |
+
'The len of keys should be 2 with [input_key, gt_key]. '
|
| 189 |
+
f'But got {len(keys)}')
|
| 190 |
+
input_folder, gt_folder = folders
|
| 191 |
+
input_key, gt_key = keys
|
| 192 |
+
|
| 193 |
+
with open(meta_info_file, 'r') as fin:
|
| 194 |
+
gt_names = [line.split(' ')[0] for line in fin]
|
| 195 |
+
|
| 196 |
+
paths = []
|
| 197 |
+
for gt_name in gt_names:
|
| 198 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
| 199 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 200 |
+
input_path = osp.join(input_folder, input_name)
|
| 201 |
+
gt_path = osp.join(gt_folder, gt_name)
|
| 202 |
+
paths.append(
|
| 203 |
+
dict([(f'{input_key}_path', input_path),
|
| 204 |
+
(f'{gt_key}_path', gt_path)]))
|
| 205 |
+
return paths
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
| 209 |
+
"""Generate paired paths from folders.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 213 |
+
be [input_folder, gt_folder].
|
| 214 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 215 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 216 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 217 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 218 |
+
for files in the input folder.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
list[str]: Returned path list.
|
| 222 |
+
"""
|
| 223 |
+
assert len(folders) == 2, (
|
| 224 |
+
'The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 225 |
+
f'But got {len(folders)}')
|
| 226 |
+
assert len(keys) == 2, (
|
| 227 |
+
'The len of keys should be 2 with [input_key, gt_key]. '
|
| 228 |
+
f'But got {len(keys)}')
|
| 229 |
+
input_folder, gt_folder = folders
|
| 230 |
+
input_key, gt_key = keys
|
| 231 |
+
|
| 232 |
+
input_paths = list(scandir(input_folder))
|
| 233 |
+
gt_paths = list(scandir(gt_folder))
|
| 234 |
+
assert len(input_paths) == len(gt_paths), (
|
| 235 |
+
f'{input_key} and {gt_key} datasets have different number of images: '
|
| 236 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
| 237 |
+
paths = []
|
| 238 |
+
for idx in range(len(gt_paths)):
|
| 239 |
+
gt_path = gt_paths[idx]
|
| 240 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
| 241 |
+
input_path = input_paths[idx]
|
| 242 |
+
basename_input, ext_input = osp.splitext(osp.basename(input_path))
|
| 243 |
+
input_name = f'{filename_tmpl.format(basename)}{ext_input}'
|
| 244 |
+
input_path = osp.join(input_folder, input_name)
|
| 245 |
+
assert input_name in input_paths, (f'{input_name} is not in '
|
| 246 |
+
f'{input_key}_paths.')
|
| 247 |
+
gt_path = osp.join(gt_folder, gt_path)
|
| 248 |
+
paths.append(
|
| 249 |
+
dict([(f'{input_key}_path', input_path),
|
| 250 |
+
(f'{gt_key}_path', gt_path)]))
|
| 251 |
+
return paths
|
| 252 |
+
|
| 253 |
+
def paired_DP_paths_from_folder(folders, keys, filename_tmpl):
|
| 254 |
+
"""Generate paired paths from folders.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 258 |
+
be [input_folder, gt_folder].
|
| 259 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 260 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 261 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 262 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 263 |
+
for files in the input folder.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
list[str]: Returned path list.
|
| 267 |
+
"""
|
| 268 |
+
assert len(folders) == 3, (
|
| 269 |
+
'The len of folders should be 3 with [inputL_folder, inputR_folder, gt_folder]. '
|
| 270 |
+
f'But got {len(folders)}')
|
| 271 |
+
assert len(keys) == 3, (
|
| 272 |
+
'The len of keys should be 2 with [inputL_key, inputR_key, gt_key]. '
|
| 273 |
+
f'But got {len(keys)}')
|
| 274 |
+
inputL_folder, inputR_folder, gt_folder = folders
|
| 275 |
+
inputL_key, inputR_key, gt_key = keys
|
| 276 |
+
|
| 277 |
+
inputL_paths = list(scandir(inputL_folder))
|
| 278 |
+
inputR_paths = list(scandir(inputR_folder))
|
| 279 |
+
gt_paths = list(scandir(gt_folder))
|
| 280 |
+
assert len(inputL_paths) == len(inputR_paths) == len(gt_paths), (
|
| 281 |
+
f'{inputL_key} and {inputR_key} and {gt_key} datasets have different number of images: '
|
| 282 |
+
f'{len(inputL_paths)}, {len(inputR_paths)}, {len(gt_paths)}.')
|
| 283 |
+
paths = []
|
| 284 |
+
for idx in range(len(gt_paths)):
|
| 285 |
+
gt_path = gt_paths[idx]
|
| 286 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
| 287 |
+
inputL_path = inputL_paths[idx]
|
| 288 |
+
basename_input, ext_input = osp.splitext(osp.basename(inputL_path))
|
| 289 |
+
inputL_name = f'{filename_tmpl.format(basename)}{ext_input}'
|
| 290 |
+
inputL_path = osp.join(inputL_folder, inputL_name)
|
| 291 |
+
assert inputL_name in inputL_paths, (f'{inputL_name} is not in '
|
| 292 |
+
f'{inputL_key}_paths.')
|
| 293 |
+
inputR_path = inputR_paths[idx]
|
| 294 |
+
basename_input, ext_input = osp.splitext(osp.basename(inputR_path))
|
| 295 |
+
inputR_name = f'{filename_tmpl.format(basename)}{ext_input}'
|
| 296 |
+
inputR_path = osp.join(inputR_folder, inputR_name)
|
| 297 |
+
assert inputR_name in inputR_paths, (f'{inputR_name} is not in '
|
| 298 |
+
f'{inputR_key}_paths.')
|
| 299 |
+
gt_path = osp.join(gt_folder, gt_path)
|
| 300 |
+
paths.append(
|
| 301 |
+
dict([(f'{inputL_key}_path', inputL_path),
|
| 302 |
+
(f'{inputR_key}_path', inputR_path),
|
| 303 |
+
(f'{gt_key}_path', gt_path)]))
|
| 304 |
+
return paths
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def paths_from_folder(folder):
|
| 308 |
+
"""Generate paths from folder.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
folder (str): Folder path.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
list[str]: Returned path list.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
paths = list(scandir(folder))
|
| 318 |
+
paths = [osp.join(folder, path) for path in paths]
|
| 319 |
+
return paths
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def paths_from_lmdb(folder):
|
| 323 |
+
"""Generate paths from lmdb.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
folder (str): Folder path.
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
list[str]: Returned path list.
|
| 330 |
+
"""
|
| 331 |
+
if not folder.endswith('.lmdb'):
|
| 332 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
| 333 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
| 334 |
+
paths = [line.split('.')[0] for line in fin]
|
| 335 |
+
return paths
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
| 339 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 343 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
np.array: The Gaussian kernel.
|
| 347 |
+
"""
|
| 348 |
+
from scipy.ndimage import filters as filters
|
| 349 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
| 350 |
+
# set element at the middle to one, a dirac delta
|
| 351 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
| 352 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
| 353 |
+
return filters.gaussian_filter(kernel, sigma)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
| 357 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
| 361 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 362 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
| 363 |
+
Default: 4.
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Tensor: DUF downsampled frames.
|
| 367 |
+
"""
|
| 368 |
+
assert scale in (2, 3,
|
| 369 |
+
4), f'Only support scale (2, 3, 4), but got {scale}.'
|
| 370 |
+
|
| 371 |
+
squeeze_flag = False
|
| 372 |
+
if x.ndim == 4:
|
| 373 |
+
squeeze_flag = True
|
| 374 |
+
x = x.unsqueeze(0)
|
| 375 |
+
b, t, c, h, w = x.size()
|
| 376 |
+
x = x.view(-1, 1, h, w)
|
| 377 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
| 378 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
| 379 |
+
|
| 380 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
| 381 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(
|
| 382 |
+
0).unsqueeze(0)
|
| 383 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
| 384 |
+
x = x[:, :, 2:-2, 2:-2]
|
| 385 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
| 386 |
+
if squeeze_flag:
|
| 387 |
+
x = x.squeeze(0)
|
| 388 |
+
return x
|
basicsr/data/ffhq_dataset.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import path as osp
|
| 2 |
+
from torch.utils import data as data
|
| 3 |
+
from torchvision.transforms.functional import normalize
|
| 4 |
+
|
| 5 |
+
from basicsr.data.transforms import augment
|
| 6 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FFHQDataset(data.Dataset):
|
| 10 |
+
"""FFHQ dataset for StyleGAN.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 14 |
+
dataroot_gt (str): Data root path for gt.
|
| 15 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 16 |
+
mean (list | tuple): Image mean.
|
| 17 |
+
std (list | tuple): Image std.
|
| 18 |
+
use_hflip (bool): Whether to horizontally flip.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, opt):
|
| 23 |
+
super(FFHQDataset, self).__init__()
|
| 24 |
+
self.opt = opt
|
| 25 |
+
# file client (io backend)
|
| 26 |
+
self.file_client = None
|
| 27 |
+
self.io_backend_opt = opt['io_backend']
|
| 28 |
+
|
| 29 |
+
self.gt_folder = opt['dataroot_gt']
|
| 30 |
+
self.mean = opt['mean']
|
| 31 |
+
self.std = opt['std']
|
| 32 |
+
|
| 33 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 34 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
| 35 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 36 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', "
|
| 37 |
+
f'but received {self.gt_folder}')
|
| 38 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 39 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 40 |
+
else:
|
| 41 |
+
# FFHQ has 70000 images in total
|
| 42 |
+
self.paths = [
|
| 43 |
+
osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, index):
|
| 47 |
+
if self.file_client is None:
|
| 48 |
+
self.file_client = FileClient(
|
| 49 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 50 |
+
|
| 51 |
+
# load gt image
|
| 52 |
+
gt_path = self.paths[index]
|
| 53 |
+
img_bytes = self.file_client.get(gt_path)
|
| 54 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 55 |
+
|
| 56 |
+
# random horizontal flip
|
| 57 |
+
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
|
| 58 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 59 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
| 60 |
+
# normalize
|
| 61 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 62 |
+
return {'gt': img_gt, 'gt_path': gt_path}
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.paths)
|
basicsr/data/paired_image_dataset.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils import data as data
|
| 2 |
+
from torchvision.transforms.functional import normalize
|
| 3 |
+
|
| 4 |
+
from basicsr.data.data_util import (paired_paths_from_folder,
|
| 5 |
+
paired_DP_paths_from_folder,
|
| 6 |
+
paired_paths_from_lmdb,
|
| 7 |
+
paired_paths_from_meta_info_file)
|
| 8 |
+
from basicsr.data.transforms import augment, paired_random_crop, paired_random_crop_DP, random_augmentation, paired_center_crop
|
| 9 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, padding_DP, imfrombytesDP
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import cv2
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from scandir import scandir
|
| 18 |
+
|
| 19 |
+
class Dataset_PairedImage_dehazeSOT(data.Dataset):
|
| 20 |
+
"""Paired image dataset for image restoration.
|
| 21 |
+
|
| 22 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
| 23 |
+
GT image pairs.
|
| 24 |
+
|
| 25 |
+
There are three modes:
|
| 26 |
+
1. 'lmdb': Use lmdb files.
|
| 27 |
+
If opt['io_backend'] == lmdb.
|
| 28 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
| 29 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 30 |
+
3. 'folder': Scan folders to generate paths.
|
| 31 |
+
The rest.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 35 |
+
dataroot_gt (str): Data root path for gt.
|
| 36 |
+
dataroot_lq (str): Data root path for lq.
|
| 37 |
+
meta_info_file (str): Path for meta information file.
|
| 38 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 39 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 40 |
+
template excludes the file extension. Default: '{}'.
|
| 41 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 42 |
+
geometric_augs (bool): Use geometric augmentations.
|
| 43 |
+
|
| 44 |
+
scale (bool): Scale, which will be added automatically.
|
| 45 |
+
phase (str): 'train' or 'val'.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, opt):
|
| 49 |
+
super(Dataset_PairedImage_dehazeSOT, self).__init__()
|
| 50 |
+
self.opt = opt
|
| 51 |
+
# file client (io backend)
|
| 52 |
+
self.file_client = None
|
| 53 |
+
self.io_backend_opt = opt['io_backend']
|
| 54 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 55 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 56 |
+
|
| 57 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 58 |
+
if 'filename_tmpl' in opt:
|
| 59 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 60 |
+
else:
|
| 61 |
+
self.filename_tmpl = '{123}'
|
| 62 |
+
|
| 63 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 64 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 65 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 66 |
+
self.paths = paired_paths_from_lmdb(
|
| 67 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 68 |
+
elif 'meta_info_file' in self.opt and self.opt[
|
| 69 |
+
'meta_info_file'] is not None:
|
| 70 |
+
self.paths = paired_paths_from_meta_info_file(
|
| 71 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 72 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
| 73 |
+
else:
|
| 74 |
+
# self.paths = paired_paths_from_folder(
|
| 75 |
+
# [self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 76 |
+
# self.filename_tmpl)
|
| 77 |
+
basename = '/mnt/sda/zsh/dataset/haze/promptIR'
|
| 78 |
+
name = ''
|
| 79 |
+
if self.opt['phase'] == 'train':
|
| 80 |
+
name = 'hazy_outside.txt'
|
| 81 |
+
else:
|
| 82 |
+
name = 'haze_test.txt'
|
| 83 |
+
dataset = os.path.join(basename, name)
|
| 84 |
+
paths = []
|
| 85 |
+
if self.opt['phase'] == 'train':
|
| 86 |
+
gt_dir = basename + '/Dehaze/original'
|
| 87 |
+
lq = basename + '/Dehaze'
|
| 88 |
+
with open(dataset, 'r') as fin:
|
| 89 |
+
#synthetic/part4/8961_0.95_0.08.jpg
|
| 90 |
+
for line in fin:
|
| 91 |
+
gt_path = os.path.join(gt_dir, line.split('/')[-1].split('_')[0]+ '.jpg')
|
| 92 |
+
# print('train gt',gt_path)
|
| 93 |
+
input_path = os.path.join(lq, line.strip())
|
| 94 |
+
# print('train input',input_path)
|
| 95 |
+
paths.append(
|
| 96 |
+
dict([(f'lq_path', input_path),
|
| 97 |
+
(f'gt_path', gt_path)]))
|
| 98 |
+
else:
|
| 99 |
+
gt_dir = basename + '/outdoor/gt'
|
| 100 |
+
lq = basename + '/outdoor/hazy'
|
| 101 |
+
#1917_0.95_0.2.jpg
|
| 102 |
+
# print('performing val dataset organize')
|
| 103 |
+
with open(dataset, 'r') as fin:
|
| 104 |
+
for line in fin:
|
| 105 |
+
gt_path = os.path.join(gt_dir, line.split('_')[0]+ '.png')
|
| 106 |
+
# print('valid gt',gt_path)
|
| 107 |
+
input_path = os.path.join(lq, line.strip())
|
| 108 |
+
# print('valid input',input_path)
|
| 109 |
+
paths.append(
|
| 110 |
+
dict([(f'lq_path', input_path),
|
| 111 |
+
(f'gt_path', gt_path)]))
|
| 112 |
+
self.paths = paths
|
| 113 |
+
# self.paths = [
|
| 114 |
+
# osp.join(self.gt_folder,
|
| 115 |
+
# line.split(' ')[0]) for line in fin
|
| 116 |
+
# ]
|
| 117 |
+
|
| 118 |
+
if self.opt['phase'] == 'train':
|
| 119 |
+
self.geometric_augs = opt['geometric_augs']
|
| 120 |
+
|
| 121 |
+
def __getitem__(self, index):
|
| 122 |
+
if self.file_client is None:
|
| 123 |
+
self.file_client = FileClient(
|
| 124 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 125 |
+
|
| 126 |
+
scale = self.opt['scale']
|
| 127 |
+
index = index % len(self.paths)
|
| 128 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 129 |
+
# image range: [0, 1], float32.
|
| 130 |
+
gt_path = self.paths[index]['gt_path']
|
| 131 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 132 |
+
try:
|
| 133 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 134 |
+
except:
|
| 135 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 136 |
+
|
| 137 |
+
lq_path = self.paths[index]['lq_path']
|
| 138 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 139 |
+
try:
|
| 140 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 141 |
+
except:
|
| 142 |
+
raise Exception("lq path {} not working".format(lq_path))
|
| 143 |
+
|
| 144 |
+
# augmentation for training
|
| 145 |
+
if self.opt['phase'] == 'train':
|
| 146 |
+
gt_size = self.opt['gt_size']
|
| 147 |
+
# padding
|
| 148 |
+
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
|
| 149 |
+
|
| 150 |
+
# random crop
|
| 151 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
|
| 152 |
+
gt_path)
|
| 153 |
+
|
| 154 |
+
# flip, rotation augmentations
|
| 155 |
+
if self.geometric_augs:
|
| 156 |
+
img_gt, img_lq = random_augmentation(img_gt, img_lq)
|
| 157 |
+
elif self.opt['phase'] == 'val':
|
| 158 |
+
# print('entering val processing')
|
| 159 |
+
|
| 160 |
+
#centerCrop for validation
|
| 161 |
+
gt_size = self.opt['gt_size']
|
| 162 |
+
img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
|
| 163 |
+
gt_path)
|
| 164 |
+
elif self.opt['phase'] == 'test':
|
| 165 |
+
#doingNothing
|
| 166 |
+
print('Test on Full Image')
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 171 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
| 172 |
+
bgr2rgb=True,
|
| 173 |
+
float32=True)
|
| 174 |
+
# normalize
|
| 175 |
+
if self.mean is not None or self.std is not None:
|
| 176 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 177 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
'lq': img_lq,
|
| 181 |
+
'gt': img_gt,
|
| 182 |
+
'lq_path': lq_path,
|
| 183 |
+
'gt_path': gt_path
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def __len__(self):
|
| 187 |
+
return len(self.paths)
|
| 188 |
+
|
| 189 |
+
class Dataset_PairedImage_denseHaze(data.Dataset):
|
| 190 |
+
"""Paired image dataset for image restoration.
|
| 191 |
+
|
| 192 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
| 193 |
+
GT image pairs.
|
| 194 |
+
|
| 195 |
+
There are three modes:
|
| 196 |
+
1. 'lmdb': Use lmdb files.
|
| 197 |
+
If opt['io_backend'] == lmdb.
|
| 198 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
| 199 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 200 |
+
3. 'folder': Scan folders to generate paths.
|
| 201 |
+
The rest.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 205 |
+
dataroot_gt (str): Data root path for gt.
|
| 206 |
+
dataroot_lq (str): Data root path for lq.
|
| 207 |
+
meta_info_file (str): Path for meta information file.
|
| 208 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 209 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 210 |
+
template excludes the file extension. Default: '{}'.
|
| 211 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 212 |
+
geometric_augs (bool): Use geometric augmentations.
|
| 213 |
+
|
| 214 |
+
scale (bool): Scale, which will be added automatically.
|
| 215 |
+
phase (str): 'train' or 'val'.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, opt):
|
| 219 |
+
super(Dataset_PairedImage_denseHaze, self).__init__()
|
| 220 |
+
self.opt = opt
|
| 221 |
+
# file client (io backend)
|
| 222 |
+
self.file_client = None
|
| 223 |
+
self.io_backend_opt = opt['io_backend']
|
| 224 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 225 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 226 |
+
|
| 227 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 228 |
+
if 'filename_tmpl' in opt:
|
| 229 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 230 |
+
else:
|
| 231 |
+
self.filename_tmpl = '{}'
|
| 232 |
+
|
| 233 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 234 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 235 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 236 |
+
self.paths = paired_paths_from_lmdb(
|
| 237 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 238 |
+
elif 'meta_info_file' in self.opt and self.opt[
|
| 239 |
+
'meta_info_file'] is not None:
|
| 240 |
+
self.paths = paired_paths_from_meta_info_file(
|
| 241 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 242 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
| 243 |
+
else:
|
| 244 |
+
self.paths = paired_paths_from_folder(
|
| 245 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 246 |
+
self.filename_tmpl)
|
| 247 |
+
|
| 248 |
+
if self.opt['phase'] == 'train':
|
| 249 |
+
self.geometric_augs = opt['geometric_augs']
|
| 250 |
+
|
| 251 |
+
def __getitem__(self, index):
|
| 252 |
+
if self.file_client is None:
|
| 253 |
+
self.file_client = FileClient(
|
| 254 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 255 |
+
|
| 256 |
+
scale = self.opt['scale']
|
| 257 |
+
index = index % len(self.paths)
|
| 258 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 259 |
+
# image range: [0, 1], float32.
|
| 260 |
+
gt_path = self.paths[index]['gt_path']
|
| 261 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 262 |
+
try:
|
| 263 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 264 |
+
except:
|
| 265 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 266 |
+
|
| 267 |
+
lq_path = self.paths[index]['lq_path']
|
| 268 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 269 |
+
try:
|
| 270 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 271 |
+
except:
|
| 272 |
+
raise Exception("lq path {} not working".format(lq_path))
|
| 273 |
+
|
| 274 |
+
# augmentation for training
|
| 275 |
+
if self.opt['phase'] == 'train':
|
| 276 |
+
gt_size = self.opt['gt_size']
|
| 277 |
+
# padding
|
| 278 |
+
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
|
| 279 |
+
|
| 280 |
+
# random crop
|
| 281 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
|
| 282 |
+
gt_path)
|
| 283 |
+
|
| 284 |
+
# flip, rotation augmentations
|
| 285 |
+
if self.geometric_augs:
|
| 286 |
+
img_gt, img_lq = random_augmentation(img_gt, img_lq)
|
| 287 |
+
|
| 288 |
+
elif self.opt['phase'] == 'val':
|
| 289 |
+
# print('entering val processing')
|
| 290 |
+
|
| 291 |
+
#centerCrop for validation
|
| 292 |
+
gt_size = self.opt['gt_size']
|
| 293 |
+
img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
|
| 294 |
+
gt_path)
|
| 295 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 296 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
| 297 |
+
bgr2rgb=True,
|
| 298 |
+
float32=True)
|
| 299 |
+
# normalize
|
| 300 |
+
if self.mean is not None or self.std is not None:
|
| 301 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 302 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 303 |
+
|
| 304 |
+
return {
|
| 305 |
+
'lq': img_lq,
|
| 306 |
+
'gt': img_gt,
|
| 307 |
+
'lq_path': lq_path,
|
| 308 |
+
'gt_path': gt_path
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
def __len__(self):
|
| 312 |
+
return len(self.paths)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class Dataset_PairedImage(data.Dataset):
|
| 316 |
+
"""Paired image dataset for image restoration.
|
| 317 |
+
|
| 318 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
| 319 |
+
GT image pairs.
|
| 320 |
+
|
| 321 |
+
There are three modes:
|
| 322 |
+
1. 'lmdb': Use lmdb files.
|
| 323 |
+
If opt['io_backend'] == lmdb.
|
| 324 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
| 325 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 326 |
+
3. 'folder': Scan folders to generate paths.
|
| 327 |
+
The rest.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 331 |
+
dataroot_gt (str): Data root path for gt.
|
| 332 |
+
dataroot_lq (str): Data root path for lq.
|
| 333 |
+
meta_info_file (str): Path for meta information file.
|
| 334 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 335 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 336 |
+
template excludes the file extension. Default: '{}'.
|
| 337 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 338 |
+
geometric_augs (bool): Use geometric augmentations.
|
| 339 |
+
|
| 340 |
+
scale (bool): Scale, which will be added automatically.
|
| 341 |
+
phase (str): 'train' or 'val'.
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
def __init__(self, opt):
|
| 345 |
+
super(Dataset_PairedImage, self).__init__()
|
| 346 |
+
self.opt = opt
|
| 347 |
+
# file client (io backend)
|
| 348 |
+
self.file_client = None
|
| 349 |
+
self.io_backend_opt = opt['io_backend']
|
| 350 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 351 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 352 |
+
|
| 353 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 354 |
+
if 'filename_tmpl' in opt:
|
| 355 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 356 |
+
else:
|
| 357 |
+
self.filename_tmpl = '{}'
|
| 358 |
+
|
| 359 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 360 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 361 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 362 |
+
self.paths = paired_paths_from_lmdb(
|
| 363 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 364 |
+
elif 'meta_info_file' in self.opt and self.opt[
|
| 365 |
+
'meta_info_file'] is not None:
|
| 366 |
+
self.paths = paired_paths_from_meta_info_file(
|
| 367 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 368 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
| 369 |
+
else:
|
| 370 |
+
self.paths = paired_paths_from_folder(
|
| 371 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 372 |
+
self.filename_tmpl)
|
| 373 |
+
|
| 374 |
+
if self.opt['phase'] == 'train':
|
| 375 |
+
self.geometric_augs = opt['geometric_augs']
|
| 376 |
+
|
| 377 |
+
def __getitem__(self, index):
|
| 378 |
+
if self.file_client is None:
|
| 379 |
+
self.file_client = FileClient(
|
| 380 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 381 |
+
|
| 382 |
+
scale = self.opt['scale']
|
| 383 |
+
index = index % len(self.paths)
|
| 384 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 385 |
+
# image range: [0, 1], float32.
|
| 386 |
+
gt_path = self.paths[index]['gt_path']
|
| 387 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 388 |
+
try:
|
| 389 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 390 |
+
except:
|
| 391 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 392 |
+
|
| 393 |
+
lq_path = self.paths[index]['lq_path']
|
| 394 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 395 |
+
try:
|
| 396 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 397 |
+
except:
|
| 398 |
+
raise Exception("lq path {} not working".format(lq_path))
|
| 399 |
+
|
| 400 |
+
# augmentation for training
|
| 401 |
+
if self.opt['phase'] == 'train':
|
| 402 |
+
gt_size = self.opt['gt_size']
|
| 403 |
+
# padding
|
| 404 |
+
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
|
| 405 |
+
|
| 406 |
+
# random crop
|
| 407 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
|
| 408 |
+
gt_path)
|
| 409 |
+
|
| 410 |
+
# flip, rotation augmentations
|
| 411 |
+
if self.geometric_augs:
|
| 412 |
+
img_gt, img_lq = random_augmentation(img_gt, img_lq)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 416 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
| 417 |
+
bgr2rgb=True,
|
| 418 |
+
float32=True)
|
| 419 |
+
# normalize
|
| 420 |
+
if self.mean is not None or self.std is not None:
|
| 421 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 422 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 423 |
+
|
| 424 |
+
return {
|
| 425 |
+
'lq': img_lq,
|
| 426 |
+
'gt': img_gt,
|
| 427 |
+
'lq_path': lq_path,
|
| 428 |
+
'gt_path': gt_path
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
def __len__(self):
|
| 432 |
+
return len(self.paths)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class Dataset_PairedImage_derainSpad(data.Dataset):
|
| 436 |
+
"""Paired image dataset for image restoration.
|
| 437 |
+
|
| 438 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
| 439 |
+
GT image pairs.
|
| 440 |
+
|
| 441 |
+
There are three modes:
|
| 442 |
+
1. 'lmdb': Use lmdb files.
|
| 443 |
+
If opt['io_backend'] == lmdb.
|
| 444 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
| 445 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 446 |
+
3. 'folder': Scan folders to generate paths.
|
| 447 |
+
The rest.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 451 |
+
dataroot_gt (str): Data root path for gt.
|
| 452 |
+
dataroot_lq (str): Data root path for lq.
|
| 453 |
+
meta_info_file (str): Path for meta information file.
|
| 454 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 455 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 456 |
+
template excludes the file extension. Default: '{}'.
|
| 457 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 458 |
+
geometric_augs (bool): Use geometric augmentations.
|
| 459 |
+
|
| 460 |
+
scale (bool): Scale, which will be added automatically.
|
| 461 |
+
phase (str): 'train' or 'val'.
|
| 462 |
+
"""
|
| 463 |
+
|
| 464 |
+
def __init__(self, opt):
|
| 465 |
+
super(Dataset_PairedImage_derainSpad, self).__init__()
|
| 466 |
+
self.opt = opt
|
| 467 |
+
# file client (io backend)
|
| 468 |
+
self.file_client = None
|
| 469 |
+
self.io_backend_opt = opt['io_backend']
|
| 470 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 471 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 472 |
+
|
| 473 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 474 |
+
if 'filename_tmpl' in opt:
|
| 475 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 476 |
+
else:
|
| 477 |
+
self.filename_tmpl = '{123}'
|
| 478 |
+
|
| 479 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 480 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 481 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 482 |
+
self.paths = paired_paths_from_lmdb(
|
| 483 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 484 |
+
elif 'meta_info_file' in self.opt and self.opt[
|
| 485 |
+
'meta_info_file'] is not None:
|
| 486 |
+
self.paths = paired_paths_from_meta_info_file(
|
| 487 |
+
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 488 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
| 489 |
+
else:
|
| 490 |
+
# self.paths = paired_paths_from_folder(
|
| 491 |
+
# [self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 492 |
+
# self.filename_tmpl)
|
| 493 |
+
basename = '/home/ubuntu/zsh/datasets/derain'
|
| 494 |
+
name = ''
|
| 495 |
+
if self.opt['phase'] == 'train':
|
| 496 |
+
name = 'real_world.txt'
|
| 497 |
+
else:
|
| 498 |
+
name = 'real_test_1000.txt'
|
| 499 |
+
dataset = os.path.join(basename, name)
|
| 500 |
+
paths = []
|
| 501 |
+
with open(dataset, 'r') as fin:
|
| 502 |
+
for line in fin:
|
| 503 |
+
gt_path = os.path.join(basename, line.split(' ')[1][1:-1])
|
| 504 |
+
input_path = os.path.join(basename, line.split(' ')[0][1:])
|
| 505 |
+
paths.append(
|
| 506 |
+
dict([(f'lq_path', input_path),
|
| 507 |
+
(f'gt_path', gt_path)]))
|
| 508 |
+
self.paths = paths
|
| 509 |
+
# self.paths = [
|
| 510 |
+
# osp.join(self.gt_folder,
|
| 511 |
+
# line.split(' ')[0]) for line in fin
|
| 512 |
+
# ]
|
| 513 |
+
|
| 514 |
+
if self.opt['phase'] == 'train':
|
| 515 |
+
self.geometric_augs = opt['geometric_augs']
|
| 516 |
+
|
| 517 |
+
def __getitem__(self, index):
|
| 518 |
+
if self.file_client is None:
|
| 519 |
+
self.file_client = FileClient(
|
| 520 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 521 |
+
|
| 522 |
+
scale = self.opt['scale']
|
| 523 |
+
index = index % len(self.paths)
|
| 524 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 525 |
+
# image range: [0, 1], float32.
|
| 526 |
+
gt_path = self.paths[index]['gt_path']
|
| 527 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 528 |
+
try:
|
| 529 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 530 |
+
except:
|
| 531 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 532 |
+
|
| 533 |
+
lq_path = self.paths[index]['lq_path']
|
| 534 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 535 |
+
try:
|
| 536 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 537 |
+
except:
|
| 538 |
+
raise Exception("lq path {} not working".format(lq_path))
|
| 539 |
+
|
| 540 |
+
# augmentation for training
|
| 541 |
+
if self.opt['phase'] == 'train':
|
| 542 |
+
gt_size = self.opt['gt_size']
|
| 543 |
+
# padding
|
| 544 |
+
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
|
| 545 |
+
|
| 546 |
+
# random crop
|
| 547 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
|
| 548 |
+
gt_path)
|
| 549 |
+
|
| 550 |
+
# flip, rotation augmentations
|
| 551 |
+
if self.geometric_augs:
|
| 552 |
+
img_gt, img_lq = random_augmentation(img_gt, img_lq)
|
| 553 |
+
elif self.opt['phase'] == 'val':
|
| 554 |
+
# print('entering val processing')
|
| 555 |
+
|
| 556 |
+
#centerCrop for validation
|
| 557 |
+
gt_size = self.opt['gt_size']
|
| 558 |
+
img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
|
| 559 |
+
gt_path)
|
| 560 |
+
elif self.opt['phase'] == 'test':
|
| 561 |
+
#doingNothing
|
| 562 |
+
print('Test on Full Image')
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 567 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
| 568 |
+
bgr2rgb=True,
|
| 569 |
+
float32=True)
|
| 570 |
+
# normalize
|
| 571 |
+
if self.mean is not None or self.std is not None:
|
| 572 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 573 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 574 |
+
|
| 575 |
+
return {
|
| 576 |
+
'lq': img_lq,
|
| 577 |
+
'gt': img_gt,
|
| 578 |
+
'lq_path': lq_path,
|
| 579 |
+
'gt_path': gt_path
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
def __len__(self):
|
| 583 |
+
return len(self.paths)
|
| 584 |
+
|
| 585 |
+
class Dataset_GaussianDenoising(data.Dataset):
|
| 586 |
+
"""Paired image dataset for image restoration.
|
| 587 |
+
|
| 588 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
| 589 |
+
GT image pairs.
|
| 590 |
+
|
| 591 |
+
There are three modes:
|
| 592 |
+
1. 'lmdb': Use lmdb files.
|
| 593 |
+
If opt['io_backend'] == lmdb.
|
| 594 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
| 595 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 596 |
+
3. 'folder': Scan folders to generate paths.
|
| 597 |
+
The rest.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 601 |
+
dataroot_gt (str): Data root path for gt.
|
| 602 |
+
meta_info_file (str): Path for meta information file.
|
| 603 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 604 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 605 |
+
use_flip (bool): Use horizontal flips.
|
| 606 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
| 607 |
+
and w for implementation).
|
| 608 |
+
|
| 609 |
+
scale (bool): Scale, which will be added automatically.
|
| 610 |
+
phase (str): 'train' or 'val'.
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
def __init__(self, opt):
|
| 614 |
+
super(Dataset_GaussianDenoising, self).__init__()
|
| 615 |
+
self.opt = opt
|
| 616 |
+
|
| 617 |
+
if self.opt['phase'] == 'train':
|
| 618 |
+
self.sigma_type = opt['sigma_type']
|
| 619 |
+
self.sigma_range = opt['sigma_range']
|
| 620 |
+
assert self.sigma_type in ['constant', 'random', 'choice']
|
| 621 |
+
else:
|
| 622 |
+
self.sigma_test = opt['sigma_test']
|
| 623 |
+
self.in_ch = opt['in_ch']
|
| 624 |
+
|
| 625 |
+
# file client (io backend)
|
| 626 |
+
self.file_client = None
|
| 627 |
+
self.io_backend_opt = opt['io_backend']
|
| 628 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 629 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 630 |
+
|
| 631 |
+
self.gt_folder = opt['dataroot_gt']
|
| 632 |
+
|
| 633 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 634 |
+
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
| 635 |
+
self.io_backend_opt['client_keys'] = ['gt']
|
| 636 |
+
self.paths = paths_from_lmdb(self.gt_folder)
|
| 637 |
+
elif 'meta_info_file' in self.opt:
|
| 638 |
+
with open(self.opt['meta_info_file'], 'r') as fin:
|
| 639 |
+
self.paths = [
|
| 640 |
+
osp.join(self.gt_folder,
|
| 641 |
+
line.split(' ')[0]) for line in fin
|
| 642 |
+
]
|
| 643 |
+
else:
|
| 644 |
+
#self.paths = sorted(list(scandir(self.gt_folder, full_path=True)))
|
| 645 |
+
#self.paths = sorted(list(scandir(self.gt_folder)))
|
| 646 |
+
self.paths = list(scandir(self.gt_folder))
|
| 647 |
+
# self.paths = (list(scandir(self.gt_folder, full_path=True)))
|
| 648 |
+
|
| 649 |
+
if self.opt['phase'] == 'train':
|
| 650 |
+
self.geometric_augs = self.opt['geometric_augs']
|
| 651 |
+
|
| 652 |
+
def __getitem__(self, index):
|
| 653 |
+
if self.file_client is None:
|
| 654 |
+
self.file_client = FileClient(
|
| 655 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 656 |
+
|
| 657 |
+
scale = self.opt['scale']
|
| 658 |
+
index = index % len(self.paths)
|
| 659 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 660 |
+
# image range: [0, 1], float32.
|
| 661 |
+
# gt_path = self.paths[index]['gt_path']
|
| 662 |
+
gt_path = self.paths[index].path
|
| 663 |
+
# gt_path = os.path.join(self.gt_folder,gt_path)
|
| 664 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 665 |
+
|
| 666 |
+
if self.in_ch == 3:
|
| 667 |
+
try:
|
| 668 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 669 |
+
except:
|
| 670 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 671 |
+
|
| 672 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
|
| 673 |
+
else:
|
| 674 |
+
try:
|
| 675 |
+
img_gt = imfrombytes(img_bytes, flag='grayscale', float32=True)
|
| 676 |
+
except:
|
| 677 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 678 |
+
|
| 679 |
+
img_gt = np.expand_dims(img_gt, axis=2)
|
| 680 |
+
img_lq = img_gt.copy()
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
# augmentation for training
|
| 684 |
+
if self.opt['phase'] == 'train':
|
| 685 |
+
gt_size = self.opt['gt_size']
|
| 686 |
+
# padding
|
| 687 |
+
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
|
| 688 |
+
|
| 689 |
+
# random crop
|
| 690 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
|
| 691 |
+
gt_path)
|
| 692 |
+
# flip, rotation
|
| 693 |
+
if self.geometric_augs:
|
| 694 |
+
img_gt, img_lq = random_augmentation(img_gt, img_lq)
|
| 695 |
+
|
| 696 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
| 697 |
+
bgr2rgb=False,
|
| 698 |
+
float32=True)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
if self.sigma_type == 'constant':
|
| 702 |
+
sigma_value = self.sigma_range
|
| 703 |
+
elif self.sigma_type == 'random':
|
| 704 |
+
sigma_value = random.uniform(self.sigma_range[0], self.sigma_range[1])
|
| 705 |
+
elif self.sigma_type == 'choice':
|
| 706 |
+
sigma_value = random.choice(self.sigma_range)
|
| 707 |
+
|
| 708 |
+
noise_level = torch.FloatTensor([sigma_value])/255.0
|
| 709 |
+
# noise_level_map = torch.ones((1, img_lq.size(1), img_lq.size(2))).mul_(noise_level).float()
|
| 710 |
+
noise = torch.randn(img_lq.size()).mul_(noise_level).float()
|
| 711 |
+
img_lq.add_(noise)
|
| 712 |
+
|
| 713 |
+
else:
|
| 714 |
+
#change here to update center
|
| 715 |
+
gt_size = self.opt['gt_size']
|
| 716 |
+
img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
|
| 717 |
+
gt_path)
|
| 718 |
+
|
| 719 |
+
np.random.seed(seed=0)
|
| 720 |
+
img_lq += np.random.normal(0, self.sigma_test/255.0, img_lq.shape)
|
| 721 |
+
# noise_level_map = torch.ones((1, img_lq.shape[0], img_lq.shape[1])).mul_(self.sigma_test/255.0).float()
|
| 722 |
+
|
| 723 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
| 724 |
+
bgr2rgb=False,
|
| 725 |
+
float32=True)
|
| 726 |
+
|
| 727 |
+
return {
|
| 728 |
+
'lq': img_lq,
|
| 729 |
+
'gt': img_gt,
|
| 730 |
+
'lq_path': gt_path,
|
| 731 |
+
'gt_path': gt_path
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
def __len__(self):
|
| 735 |
+
return len(self.paths)
|
| 736 |
+
|
| 737 |
+
class Dataset_DefocusDeblur_DualPixel_16bit(data.Dataset):
|
| 738 |
+
def __init__(self, opt):
|
| 739 |
+
super(Dataset_DefocusDeblur_DualPixel_16bit, self).__init__()
|
| 740 |
+
self.opt = opt
|
| 741 |
+
# file client (io backend)
|
| 742 |
+
self.file_client = None
|
| 743 |
+
self.io_backend_opt = opt['io_backend']
|
| 744 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 745 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 746 |
+
|
| 747 |
+
self.gt_folder, self.lqL_folder, self.lqR_folder = opt['dataroot_gt'], opt['dataroot_lqL'], opt['dataroot_lqR']
|
| 748 |
+
if 'filename_tmpl' in opt:
|
| 749 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 750 |
+
else:
|
| 751 |
+
self.filename_tmpl = '{}'
|
| 752 |
+
|
| 753 |
+
self.paths = paired_DP_paths_from_folder(
|
| 754 |
+
[self.lqL_folder, self.lqR_folder, self.gt_folder], ['lqL', 'lqR', 'gt'],
|
| 755 |
+
self.filename_tmpl)
|
| 756 |
+
|
| 757 |
+
if self.opt['phase'] == 'train':
|
| 758 |
+
self.geometric_augs = self.opt['geometric_augs']
|
| 759 |
+
|
| 760 |
+
def __getitem__(self, index):
|
| 761 |
+
if self.file_client is None:
|
| 762 |
+
self.file_client = FileClient(
|
| 763 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 764 |
+
|
| 765 |
+
scale = self.opt['scale']
|
| 766 |
+
index = index % len(self.paths)
|
| 767 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 768 |
+
# image range: [0, 1], float32.
|
| 769 |
+
gt_path = self.paths[index]['gt_path']
|
| 770 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 771 |
+
try:
|
| 772 |
+
img_gt = imfrombytesDP(img_bytes, float32=True)
|
| 773 |
+
except:
|
| 774 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 775 |
+
|
| 776 |
+
lqL_path = self.paths[index]['lqL_path']
|
| 777 |
+
img_bytes = self.file_client.get(lqL_path, 'lqL')
|
| 778 |
+
try:
|
| 779 |
+
img_lqL = imfrombytesDP(img_bytes, float32=True)
|
| 780 |
+
except:
|
| 781 |
+
raise Exception("lqL path {} not working".format(lqL_path))
|
| 782 |
+
|
| 783 |
+
lqR_path = self.paths[index]['lqR_path']
|
| 784 |
+
img_bytes = self.file_client.get(lqR_path, 'lqR')
|
| 785 |
+
try:
|
| 786 |
+
img_lqR = imfrombytesDP(img_bytes, float32=True)
|
| 787 |
+
except:
|
| 788 |
+
raise Exception("lqR path {} not working".format(lqR_path))
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
# augmentation for training
|
| 792 |
+
if self.opt['phase'] == 'train':
|
| 793 |
+
gt_size = self.opt['gt_size']
|
| 794 |
+
# padding
|
| 795 |
+
img_lqL, img_lqR, img_gt = padding_DP(img_lqL, img_lqR, img_gt, gt_size)
|
| 796 |
+
|
| 797 |
+
# random crop
|
| 798 |
+
img_lqL, img_lqR, img_gt = paired_random_crop_DP(img_lqL, img_lqR, img_gt, gt_size, scale, gt_path)
|
| 799 |
+
|
| 800 |
+
# flip, rotation
|
| 801 |
+
if self.geometric_augs:
|
| 802 |
+
img_lqL, img_lqR, img_gt = random_augmentation(img_lqL, img_lqR, img_gt)
|
| 803 |
+
# TODO: color space transform
|
| 804 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 805 |
+
img_lqL, img_lqR, img_gt = img2tensor([img_lqL, img_lqR, img_gt],
|
| 806 |
+
bgr2rgb=True,
|
| 807 |
+
float32=True)
|
| 808 |
+
# normalize
|
| 809 |
+
if self.mean is not None or self.std is not None:
|
| 810 |
+
normalize(img_lqL, self.mean, self.std, inplace=True)
|
| 811 |
+
normalize(img_lqR, self.mean, self.std, inplace=True)
|
| 812 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 813 |
+
|
| 814 |
+
img_lq = torch.cat([img_lqL, img_lqR], 0)
|
| 815 |
+
|
| 816 |
+
return {
|
| 817 |
+
'lq': img_lq,
|
| 818 |
+
'gt': img_gt,
|
| 819 |
+
'lq_path': lqL_path,
|
| 820 |
+
'gt_path': gt_path
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
def __len__(self):
|
| 824 |
+
return len(self.paths)
|
basicsr/data/prefetch_dataloader.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue as Queue
|
| 2 |
+
import threading
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PrefetchGenerator(threading.Thread):
|
| 8 |
+
"""A general prefetch generator.
|
| 9 |
+
|
| 10 |
+
Ref:
|
| 11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
generator: Python generator.
|
| 15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator, num_prefetch_queue):
|
| 19 |
+
threading.Thread.__init__(self)
|
| 20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
| 21 |
+
self.generator = generator
|
| 22 |
+
self.daemon = True
|
| 23 |
+
self.start()
|
| 24 |
+
|
| 25 |
+
def run(self):
|
| 26 |
+
for item in self.generator:
|
| 27 |
+
self.queue.put(item)
|
| 28 |
+
self.queue.put(None)
|
| 29 |
+
|
| 30 |
+
def __next__(self):
|
| 31 |
+
next_item = self.queue.get()
|
| 32 |
+
if next_item is None:
|
| 33 |
+
raise StopIteration
|
| 34 |
+
return next_item
|
| 35 |
+
|
| 36 |
+
def __iter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PrefetchDataLoader(DataLoader):
|
| 41 |
+
"""Prefetch version of dataloader.
|
| 42 |
+
|
| 43 |
+
Ref:
|
| 44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
| 45 |
+
|
| 46 |
+
TODO:
|
| 47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
| 48 |
+
ddp.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 52 |
+
kwargs (dict): Other arguments for dataloader.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
| 56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
| 57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
| 58 |
+
|
| 59 |
+
def __iter__(self):
|
| 60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CPUPrefetcher():
|
| 64 |
+
"""CPU prefetcher.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
loader: Dataloader.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, loader):
|
| 71 |
+
self.ori_loader = loader
|
| 72 |
+
self.loader = iter(loader)
|
| 73 |
+
|
| 74 |
+
def next(self):
|
| 75 |
+
try:
|
| 76 |
+
return next(self.loader)
|
| 77 |
+
except StopIteration:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def reset(self):
|
| 81 |
+
self.loader = iter(self.ori_loader)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CUDAPrefetcher():
|
| 85 |
+
"""CUDA prefetcher.
|
| 86 |
+
|
| 87 |
+
Ref:
|
| 88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
| 89 |
+
|
| 90 |
+
It may consums more GPU memory.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
loader: Dataloader.
|
| 94 |
+
opt (dict): Options.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, loader, opt):
|
| 98 |
+
self.ori_loader = loader
|
| 99 |
+
self.loader = iter(loader)
|
| 100 |
+
self.opt = opt
|
| 101 |
+
self.stream = torch.cuda.Stream()
|
| 102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 103 |
+
self.preload()
|
| 104 |
+
|
| 105 |
+
def preload(self):
|
| 106 |
+
try:
|
| 107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
| 108 |
+
except StopIteration:
|
| 109 |
+
self.batch = None
|
| 110 |
+
return None
|
| 111 |
+
# put tensors to gpu
|
| 112 |
+
with torch.cuda.stream(self.stream):
|
| 113 |
+
for k, v in self.batch.items():
|
| 114 |
+
if torch.is_tensor(v):
|
| 115 |
+
self.batch[k] = self.batch[k].to(
|
| 116 |
+
device=self.device, non_blocking=True)
|
| 117 |
+
|
| 118 |
+
def next(self):
|
| 119 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
| 120 |
+
batch = self.batch
|
| 121 |
+
self.preload()
|
| 122 |
+
return batch
|
| 123 |
+
|
| 124 |
+
def reset(self):
|
| 125 |
+
self.loader = iter(self.ori_loader)
|
| 126 |
+
self.preload()
|
basicsr/data/reds_dataset.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from torch.utils import data as data
|
| 6 |
+
|
| 7 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 8 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 9 |
+
from basicsr.utils.flow_util import dequantize_flow
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class REDSDataset(data.Dataset):
|
| 13 |
+
"""REDS dataset for training.
|
| 14 |
+
|
| 15 |
+
The keys are generated from a meta info txt file.
|
| 16 |
+
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
| 17 |
+
|
| 18 |
+
Each line contains:
|
| 19 |
+
1. subfolder (clip) name; 2. frame number; 3. image shape, seperated by
|
| 20 |
+
a white space.
|
| 21 |
+
Examples:
|
| 22 |
+
000 100 (720,1280,3)
|
| 23 |
+
001 100 (720,1280,3)
|
| 24 |
+
...
|
| 25 |
+
|
| 26 |
+
Key examples: "000/00000000"
|
| 27 |
+
GT (gt): Ground-Truth;
|
| 28 |
+
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 32 |
+
dataroot_gt (str): Data root path for gt.
|
| 33 |
+
dataroot_lq (str): Data root path for lq.
|
| 34 |
+
dataroot_flow (str, optional): Data root path for flow.
|
| 35 |
+
meta_info_file (str): Path for meta information file.
|
| 36 |
+
val_partition (str): Validation partition types. 'REDS4' or
|
| 37 |
+
'official'.
|
| 38 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 39 |
+
|
| 40 |
+
num_frame (int): Window size for input frames.
|
| 41 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 42 |
+
interval_list (list): Interval list for temporal augmentation.
|
| 43 |
+
random_reverse (bool): Random reverse input frames.
|
| 44 |
+
use_flip (bool): Use horizontal flips.
|
| 45 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
| 46 |
+
and w for implementation).
|
| 47 |
+
|
| 48 |
+
scale (bool): Scale, which will be added automatically.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, opt):
|
| 52 |
+
super(REDSDataset, self).__init__()
|
| 53 |
+
self.opt = opt
|
| 54 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
|
| 55 |
+
opt['dataroot_lq'])
|
| 56 |
+
self.flow_root = Path(
|
| 57 |
+
opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
|
| 58 |
+
assert opt['num_frame'] % 2 == 1, (
|
| 59 |
+
f'num_frame should be odd number, but got {opt["num_frame"]}')
|
| 60 |
+
self.num_frame = opt['num_frame']
|
| 61 |
+
self.num_half_frames = opt['num_frame'] // 2
|
| 62 |
+
|
| 63 |
+
self.keys = []
|
| 64 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 65 |
+
for line in fin:
|
| 66 |
+
folder, frame_num, _ = line.split(' ')
|
| 67 |
+
self.keys.extend(
|
| 68 |
+
[f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
| 69 |
+
|
| 70 |
+
# remove the video clips used in validation
|
| 71 |
+
if opt['val_partition'] == 'REDS4':
|
| 72 |
+
val_partition = ['000', '011', '015', '020']
|
| 73 |
+
elif opt['val_partition'] == 'official':
|
| 74 |
+
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f'Wrong validation partition {opt["val_partition"]}.'
|
| 78 |
+
f"Supported ones are ['official', 'REDS4'].")
|
| 79 |
+
self.keys = [
|
| 80 |
+
v for v in self.keys if v.split('/')[0] not in val_partition
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
# file client (io backend)
|
| 84 |
+
self.file_client = None
|
| 85 |
+
self.io_backend_opt = opt['io_backend']
|
| 86 |
+
self.is_lmdb = False
|
| 87 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 88 |
+
self.is_lmdb = True
|
| 89 |
+
if self.flow_root is not None:
|
| 90 |
+
self.io_backend_opt['db_paths'] = [
|
| 91 |
+
self.lq_root, self.gt_root, self.flow_root
|
| 92 |
+
]
|
| 93 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
| 94 |
+
else:
|
| 95 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
| 96 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 97 |
+
|
| 98 |
+
# temporal augmentation configs
|
| 99 |
+
self.interval_list = opt['interval_list']
|
| 100 |
+
self.random_reverse = opt['random_reverse']
|
| 101 |
+
interval_str = ','.join(str(x) for x in opt['interval_list'])
|
| 102 |
+
logger = get_root_logger()
|
| 103 |
+
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
| 104 |
+
f'random reverse is {self.random_reverse}.')
|
| 105 |
+
|
| 106 |
+
def __getitem__(self, index):
|
| 107 |
+
if self.file_client is None:
|
| 108 |
+
self.file_client = FileClient(
|
| 109 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 110 |
+
|
| 111 |
+
scale = self.opt['scale']
|
| 112 |
+
gt_size = self.opt['gt_size']
|
| 113 |
+
key = self.keys[index]
|
| 114 |
+
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
| 115 |
+
center_frame_idx = int(frame_name)
|
| 116 |
+
|
| 117 |
+
# determine the neighboring frames
|
| 118 |
+
interval = random.choice(self.interval_list)
|
| 119 |
+
|
| 120 |
+
# ensure not exceeding the borders
|
| 121 |
+
start_frame_idx = center_frame_idx - self.num_half_frames * interval
|
| 122 |
+
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
| 123 |
+
# each clip has 100 frames starting from 0 to 99
|
| 124 |
+
while (start_frame_idx < 0) or (end_frame_idx > 99):
|
| 125 |
+
center_frame_idx = random.randint(0, 99)
|
| 126 |
+
start_frame_idx = (
|
| 127 |
+
center_frame_idx - self.num_half_frames * interval)
|
| 128 |
+
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
| 129 |
+
frame_name = f'{center_frame_idx:08d}'
|
| 130 |
+
neighbor_list = list(
|
| 131 |
+
range(center_frame_idx - self.num_half_frames * interval,
|
| 132 |
+
center_frame_idx + self.num_half_frames * interval + 1,
|
| 133 |
+
interval))
|
| 134 |
+
# random reverse
|
| 135 |
+
if self.random_reverse and random.random() < 0.5:
|
| 136 |
+
neighbor_list.reverse()
|
| 137 |
+
|
| 138 |
+
assert len(neighbor_list) == self.num_frame, (
|
| 139 |
+
f'Wrong length of neighbor list: {len(neighbor_list)}')
|
| 140 |
+
|
| 141 |
+
# get the GT frame (as the center frame)
|
| 142 |
+
if self.is_lmdb:
|
| 143 |
+
img_gt_path = f'{clip_name}/{frame_name}'
|
| 144 |
+
else:
|
| 145 |
+
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
|
| 146 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| 147 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 148 |
+
|
| 149 |
+
# get the neighboring LQ frames
|
| 150 |
+
img_lqs = []
|
| 151 |
+
for neighbor in neighbor_list:
|
| 152 |
+
if self.is_lmdb:
|
| 153 |
+
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
| 154 |
+
else:
|
| 155 |
+
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
| 156 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| 157 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 158 |
+
img_lqs.append(img_lq)
|
| 159 |
+
|
| 160 |
+
# get flows
|
| 161 |
+
if self.flow_root is not None:
|
| 162 |
+
img_flows = []
|
| 163 |
+
# read previous flows
|
| 164 |
+
for i in range(self.num_half_frames, 0, -1):
|
| 165 |
+
if self.is_lmdb:
|
| 166 |
+
flow_path = f'{clip_name}/{frame_name}_p{i}'
|
| 167 |
+
else:
|
| 168 |
+
flow_path = (
|
| 169 |
+
self.flow_root / clip_name / f'{frame_name}_p{i}.png')
|
| 170 |
+
img_bytes = self.file_client.get(flow_path, 'flow')
|
| 171 |
+
cat_flow = imfrombytes(
|
| 172 |
+
img_bytes, flag='grayscale',
|
| 173 |
+
float32=False) # uint8, [0, 255]
|
| 174 |
+
dx, dy = np.split(cat_flow, 2, axis=0)
|
| 175 |
+
flow = dequantize_flow(
|
| 176 |
+
dx, dy, max_val=20,
|
| 177 |
+
denorm=False) # we use max_val 20 here.
|
| 178 |
+
img_flows.append(flow)
|
| 179 |
+
# read next flows
|
| 180 |
+
for i in range(1, self.num_half_frames + 1):
|
| 181 |
+
if self.is_lmdb:
|
| 182 |
+
flow_path = f'{clip_name}/{frame_name}_n{i}'
|
| 183 |
+
else:
|
| 184 |
+
flow_path = (
|
| 185 |
+
self.flow_root / clip_name / f'{frame_name}_n{i}.png')
|
| 186 |
+
img_bytes = self.file_client.get(flow_path, 'flow')
|
| 187 |
+
cat_flow = imfrombytes(
|
| 188 |
+
img_bytes, flag='grayscale',
|
| 189 |
+
float32=False) # uint8, [0, 255]
|
| 190 |
+
dx, dy = np.split(cat_flow, 2, axis=0)
|
| 191 |
+
flow = dequantize_flow(
|
| 192 |
+
dx, dy, max_val=20,
|
| 193 |
+
denorm=False) # we use max_val 20 here.
|
| 194 |
+
img_flows.append(flow)
|
| 195 |
+
|
| 196 |
+
# for random crop, here, img_flows and img_lqs have the same
|
| 197 |
+
# spatial size
|
| 198 |
+
img_lqs.extend(img_flows)
|
| 199 |
+
|
| 200 |
+
# randomly crop
|
| 201 |
+
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
|
| 202 |
+
img_gt_path)
|
| 203 |
+
if self.flow_root is not None:
|
| 204 |
+
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.
|
| 205 |
+
num_frame:]
|
| 206 |
+
|
| 207 |
+
# augmentation - flip, rotate
|
| 208 |
+
img_lqs.append(img_gt)
|
| 209 |
+
if self.flow_root is not None:
|
| 210 |
+
img_results, img_flows = augment(img_lqs, self.opt['use_flip'],
|
| 211 |
+
self.opt['use_rot'], img_flows)
|
| 212 |
+
else:
|
| 213 |
+
img_results = augment(img_lqs, self.opt['use_flip'],
|
| 214 |
+
self.opt['use_rot'])
|
| 215 |
+
|
| 216 |
+
img_results = img2tensor(img_results)
|
| 217 |
+
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
| 218 |
+
img_gt = img_results[-1]
|
| 219 |
+
|
| 220 |
+
if self.flow_root is not None:
|
| 221 |
+
img_flows = img2tensor(img_flows)
|
| 222 |
+
# add the zero center flow
|
| 223 |
+
img_flows.insert(self.num_half_frames,
|
| 224 |
+
torch.zeros_like(img_flows[0]))
|
| 225 |
+
img_flows = torch.stack(img_flows, dim=0)
|
| 226 |
+
|
| 227 |
+
# img_lqs: (t, c, h, w)
|
| 228 |
+
# img_flows: (t, 2, h, w)
|
| 229 |
+
# img_gt: (c, h, w)
|
| 230 |
+
# key: str
|
| 231 |
+
if self.flow_root is not None:
|
| 232 |
+
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
|
| 233 |
+
else:
|
| 234 |
+
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
| 235 |
+
|
| 236 |
+
def __len__(self):
|
| 237 |
+
return len(self.keys)
|
basicsr/data/single_image_dataset.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import path as osp
|
| 2 |
+
from torch.utils import data as data
|
| 3 |
+
from torchvision.transforms.functional import normalize
|
| 4 |
+
|
| 5 |
+
from basicsr.data.data_util import paths_from_lmdb
|
| 6 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SingleImageDataset(data.Dataset):
|
| 10 |
+
"""Read only lq images in the test phase.
|
| 11 |
+
|
| 12 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
|
| 13 |
+
|
| 14 |
+
There are two modes:
|
| 15 |
+
1. 'meta_info_file': Use meta information file to generate paths.
|
| 16 |
+
2. 'folder': Scan folders to generate paths.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 20 |
+
dataroot_lq (str): Data root path for lq.
|
| 21 |
+
meta_info_file (str): Path for meta information file.
|
| 22 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, opt):
|
| 26 |
+
super(SingleImageDataset, self).__init__()
|
| 27 |
+
self.opt = opt
|
| 28 |
+
# file client (io backend)
|
| 29 |
+
self.file_client = None
|
| 30 |
+
self.io_backend_opt = opt['io_backend']
|
| 31 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 32 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 33 |
+
self.lq_folder = opt['dataroot_lq']
|
| 34 |
+
|
| 35 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 36 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder]
|
| 37 |
+
self.io_backend_opt['client_keys'] = ['lq']
|
| 38 |
+
self.paths = paths_from_lmdb(self.lq_folder)
|
| 39 |
+
elif 'meta_info_file' in self.opt:
|
| 40 |
+
with open(self.opt['meta_info_file'], 'r') as fin:
|
| 41 |
+
self.paths = [
|
| 42 |
+
osp.join(self.lq_folder,
|
| 43 |
+
line.split(' ')[0]) for line in fin
|
| 44 |
+
]
|
| 45 |
+
else:
|
| 46 |
+
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, index):
|
| 49 |
+
if self.file_client is None:
|
| 50 |
+
self.file_client = FileClient(
|
| 51 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 52 |
+
|
| 53 |
+
# load lq image
|
| 54 |
+
lq_path = self.paths[index]
|
| 55 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 56 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 57 |
+
|
| 58 |
+
# TODO: color space transform
|
| 59 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 60 |
+
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
| 61 |
+
# normalize
|
| 62 |
+
if self.mean is not None or self.std is not None:
|
| 63 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 64 |
+
return {'lq': img_lq, 'lq_path': lq_path}
|
| 65 |
+
|
| 66 |
+
def __len__(self):
|
| 67 |
+
return len(self.paths)
|
basicsr/data/transforms.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
def mod_crop(img, scale):
|
| 7 |
+
"""Mod crop images, used during testing.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
img (ndarray): Input image.
|
| 11 |
+
scale (int): Scale factor.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
ndarray: Result image.
|
| 15 |
+
"""
|
| 16 |
+
img = img.copy()
|
| 17 |
+
if img.ndim in (2, 3):
|
| 18 |
+
h, w = img.shape[0], img.shape[1]
|
| 19 |
+
h_remainder, w_remainder = h % scale, w % scale
|
| 20 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
def paired_random_crop(img_gts, img_lqs, lq_patch_size, scale, gt_path):
|
| 26 |
+
"""Paired random crop.
|
| 27 |
+
|
| 28 |
+
It crops lists of lq and gt images with corresponding locations.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
| 32 |
+
should have the same shape. If the input is an ndarray, it will
|
| 33 |
+
be transformed to a list containing itself.
|
| 34 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 35 |
+
should have the same shape. If the input is an ndarray, it will
|
| 36 |
+
be transformed to a list containing itself.
|
| 37 |
+
lq_patch_size (int): LQ patch size.
|
| 38 |
+
scale (int): Scale factor.
|
| 39 |
+
gt_path (str): Path to ground-truth.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 43 |
+
only have one element, just return ndarray.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
if not isinstance(img_gts, list):
|
| 47 |
+
img_gts = [img_gts]
|
| 48 |
+
if not isinstance(img_lqs, list):
|
| 49 |
+
img_lqs = [img_lqs]
|
| 50 |
+
|
| 51 |
+
h_lq, w_lq, _ = img_lqs[0].shape
|
| 52 |
+
h_gt, w_gt, _ = img_gts[0].shape
|
| 53 |
+
gt_patch_size = int(lq_patch_size * scale)
|
| 54 |
+
|
| 55 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 58 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 59 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 60 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 61 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
| 62 |
+
f'Please remove {gt_path}.')
|
| 63 |
+
|
| 64 |
+
# randomly choose top and left coordinates for lq patch
|
| 65 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
| 66 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
| 67 |
+
|
| 68 |
+
# crop lq patch
|
| 69 |
+
img_lqs = [
|
| 70 |
+
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
| 71 |
+
for v in img_lqs
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# crop corresponding gt patch
|
| 75 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
| 76 |
+
img_gts = [
|
| 77 |
+
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
|
| 78 |
+
for v in img_gts
|
| 79 |
+
]
|
| 80 |
+
if len(img_gts) == 1:
|
| 81 |
+
img_gts = img_gts[0]
|
| 82 |
+
if len(img_lqs) == 1:
|
| 83 |
+
img_lqs = img_lqs[0]
|
| 84 |
+
return img_gts, img_lqs
|
| 85 |
+
|
| 86 |
+
def paired_center_crop(img_gts, img_lqs, lq_patch_size, scale, gt_path):
|
| 87 |
+
"""Paired random crop.
|
| 88 |
+
|
| 89 |
+
It crops lists of lq and gt images with corresponding locations.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
| 93 |
+
should have the same shape. If the input is an ndarray, it will
|
| 94 |
+
be transformed to a list containing itself.
|
| 95 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 96 |
+
should have the same shape. If the input is an ndarray, it will
|
| 97 |
+
be transformed to a list containing itself.
|
| 98 |
+
lq_patch_size (int): LQ patch size.
|
| 99 |
+
scale (int): Scale factor.
|
| 100 |
+
gt_path (str): Path to ground-truth.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 104 |
+
only have one element, just return ndarray.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
if not isinstance(img_gts, list):
|
| 108 |
+
img_gts = [img_gts]
|
| 109 |
+
if not isinstance(img_lqs, list):
|
| 110 |
+
img_lqs = [img_lqs]
|
| 111 |
+
|
| 112 |
+
h_lq, w_lq, _ = img_lqs[0].shape
|
| 113 |
+
h_gt, w_gt, _ = img_gts[0].shape
|
| 114 |
+
gt_patch_size = int(lq_patch_size * scale)
|
| 115 |
+
|
| 116 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 119 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 120 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 121 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 122 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
| 123 |
+
f'Please remove {gt_path}.')
|
| 124 |
+
|
| 125 |
+
# randomly choose top and left coordinates for lq patch
|
| 126 |
+
top = (h_lq - lq_patch_size)//2#random.randint(0, h_lq - lq_patch_size)
|
| 127 |
+
left = (w_lq - lq_patch_size)//2#random.randint(0, w_lq - lq_patch_size)
|
| 128 |
+
|
| 129 |
+
# crop lq patch
|
| 130 |
+
img_lqs = [
|
| 131 |
+
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
| 132 |
+
for v in img_lqs
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# crop corresponding gt patch
|
| 136 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
| 137 |
+
img_gts = [
|
| 138 |
+
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
|
| 139 |
+
for v in img_gts
|
| 140 |
+
]
|
| 141 |
+
if len(img_gts) == 1:
|
| 142 |
+
img_gts = img_gts[0]
|
| 143 |
+
if len(img_lqs) == 1:
|
| 144 |
+
img_lqs = img_lqs[0]
|
| 145 |
+
return img_gts, img_lqs
|
| 146 |
+
|
| 147 |
+
def paired_random_crop_DP(img_lqLs, img_lqRs, img_gts, gt_patch_size, scale, gt_path):
|
| 148 |
+
if not isinstance(img_gts, list):
|
| 149 |
+
img_gts = [img_gts]
|
| 150 |
+
if not isinstance(img_lqLs, list):
|
| 151 |
+
img_lqLs = [img_lqLs]
|
| 152 |
+
if not isinstance(img_lqRs, list):
|
| 153 |
+
img_lqRs = [img_lqRs]
|
| 154 |
+
|
| 155 |
+
h_lq, w_lq, _ = img_lqLs[0].shape
|
| 156 |
+
h_gt, w_gt, _ = img_gts[0].shape
|
| 157 |
+
lq_patch_size = gt_patch_size // scale
|
| 158 |
+
|
| 159 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 162 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 163 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 164 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 165 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
| 166 |
+
f'Please remove {gt_path}.')
|
| 167 |
+
|
| 168 |
+
# randomly choose top and left coordinates for lq patch
|
| 169 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
| 170 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
| 171 |
+
|
| 172 |
+
# crop lq patch
|
| 173 |
+
img_lqLs = [
|
| 174 |
+
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
| 175 |
+
for v in img_lqLs
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
img_lqRs = [
|
| 179 |
+
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
| 180 |
+
for v in img_lqRs
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
# crop corresponding gt patch
|
| 184 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
| 185 |
+
img_gts = [
|
| 186 |
+
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
|
| 187 |
+
for v in img_gts
|
| 188 |
+
]
|
| 189 |
+
if len(img_gts) == 1:
|
| 190 |
+
img_gts = img_gts[0]
|
| 191 |
+
if len(img_lqLs) == 1:
|
| 192 |
+
img_lqLs = img_lqLs[0]
|
| 193 |
+
if len(img_lqRs) == 1:
|
| 194 |
+
img_lqRs = img_lqRs[0]
|
| 195 |
+
return img_lqLs, img_lqRs, img_gts
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
| 199 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
| 200 |
+
|
| 201 |
+
We use vertical flip and transpose for rotation implementation.
|
| 202 |
+
All the images in the list use the same augmentation.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
| 206 |
+
is an ndarray, it will be transformed to a list.
|
| 207 |
+
hflip (bool): Horizontal flip. Default: True.
|
| 208 |
+
rotation (bool): Ratotation. Default: True.
|
| 209 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
| 210 |
+
ndarray, it will be transformed to a list.
|
| 211 |
+
Dimension is (h, w, 2). Default: None.
|
| 212 |
+
return_status (bool): Return the status of flip and rotation.
|
| 213 |
+
Default: False.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
| 217 |
+
results only have one element, just return ndarray.
|
| 218 |
+
|
| 219 |
+
"""
|
| 220 |
+
hflip = hflip and random.random() < 0.5
|
| 221 |
+
vflip = rotation and random.random() < 0.5
|
| 222 |
+
rot90 = rotation and random.random() < 0.5
|
| 223 |
+
|
| 224 |
+
def _augment(img):
|
| 225 |
+
if hflip: # horizontal
|
| 226 |
+
cv2.flip(img, 1, img)
|
| 227 |
+
if vflip: # vertical
|
| 228 |
+
cv2.flip(img, 0, img)
|
| 229 |
+
if rot90:
|
| 230 |
+
img = img.transpose(1, 0, 2)
|
| 231 |
+
return img
|
| 232 |
+
|
| 233 |
+
def _augment_flow(flow):
|
| 234 |
+
if hflip: # horizontal
|
| 235 |
+
cv2.flip(flow, 1, flow)
|
| 236 |
+
flow[:, :, 0] *= -1
|
| 237 |
+
if vflip: # vertical
|
| 238 |
+
cv2.flip(flow, 0, flow)
|
| 239 |
+
flow[:, :, 1] *= -1
|
| 240 |
+
if rot90:
|
| 241 |
+
flow = flow.transpose(1, 0, 2)
|
| 242 |
+
flow = flow[:, :, [1, 0]]
|
| 243 |
+
return flow
|
| 244 |
+
|
| 245 |
+
if not isinstance(imgs, list):
|
| 246 |
+
imgs = [imgs]
|
| 247 |
+
imgs = [_augment(img) for img in imgs]
|
| 248 |
+
if len(imgs) == 1:
|
| 249 |
+
imgs = imgs[0]
|
| 250 |
+
|
| 251 |
+
if flows is not None:
|
| 252 |
+
if not isinstance(flows, list):
|
| 253 |
+
flows = [flows]
|
| 254 |
+
flows = [_augment_flow(flow) for flow in flows]
|
| 255 |
+
if len(flows) == 1:
|
| 256 |
+
flows = flows[0]
|
| 257 |
+
return imgs, flows
|
| 258 |
+
else:
|
| 259 |
+
if return_status:
|
| 260 |
+
return imgs, (hflip, vflip, rot90)
|
| 261 |
+
else:
|
| 262 |
+
return imgs
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
| 266 |
+
"""Rotate image.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
img (ndarray): Image to be rotated.
|
| 270 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
| 271 |
+
counter-clockwise rotation.
|
| 272 |
+
center (tuple[int]): Rotation center. If the center is None,
|
| 273 |
+
initialize it as the center of the image. Default: None.
|
| 274 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
| 275 |
+
"""
|
| 276 |
+
(h, w) = img.shape[:2]
|
| 277 |
+
|
| 278 |
+
if center is None:
|
| 279 |
+
center = (w // 2, h // 2)
|
| 280 |
+
|
| 281 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 282 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
| 283 |
+
return rotated_img
|
| 284 |
+
|
| 285 |
+
def data_augmentation(image, mode):
|
| 286 |
+
"""
|
| 287 |
+
Performs data augmentation of the input image
|
| 288 |
+
Input:
|
| 289 |
+
image: a cv2 (OpenCV) image
|
| 290 |
+
mode: int. Choice of transformation to apply to the image
|
| 291 |
+
0 - no transformation
|
| 292 |
+
1 - flip up and down
|
| 293 |
+
2 - rotate counterwise 90 degree
|
| 294 |
+
3 - rotate 90 degree and flip up and down
|
| 295 |
+
4 - rotate 180 degree
|
| 296 |
+
5 - rotate 180 degree and flip
|
| 297 |
+
6 - rotate 270 degree
|
| 298 |
+
7 - rotate 270 degree and flip
|
| 299 |
+
"""
|
| 300 |
+
if mode == 0:
|
| 301 |
+
# original
|
| 302 |
+
out = image
|
| 303 |
+
elif mode == 1:
|
| 304 |
+
# flip up and down
|
| 305 |
+
out = np.flipud(image)
|
| 306 |
+
elif mode == 2:
|
| 307 |
+
# rotate counterwise 90 degree
|
| 308 |
+
out = np.rot90(image)
|
| 309 |
+
elif mode == 3:
|
| 310 |
+
# rotate 90 degree and flip up and down
|
| 311 |
+
out = np.rot90(image)
|
| 312 |
+
out = np.flipud(out)
|
| 313 |
+
elif mode == 4:
|
| 314 |
+
# rotate 180 degree
|
| 315 |
+
out = np.rot90(image, k=2)
|
| 316 |
+
elif mode == 5:
|
| 317 |
+
# rotate 180 degree and flip
|
| 318 |
+
out = np.rot90(image, k=2)
|
| 319 |
+
out = np.flipud(out)
|
| 320 |
+
elif mode == 6:
|
| 321 |
+
# rotate 270 degree
|
| 322 |
+
out = np.rot90(image, k=3)
|
| 323 |
+
elif mode == 7:
|
| 324 |
+
# rotate 270 degree and flip
|
| 325 |
+
out = np.rot90(image, k=3)
|
| 326 |
+
out = np.flipud(out)
|
| 327 |
+
else:
|
| 328 |
+
raise Exception('Invalid choice of image transformation')
|
| 329 |
+
|
| 330 |
+
return out
|
| 331 |
+
|
| 332 |
+
def random_augmentation(*args):
|
| 333 |
+
out = []
|
| 334 |
+
flag_aug = random.randint(0,7)
|
| 335 |
+
for data in args:
|
| 336 |
+
out.append(data_augmentation(data, flag_aug).copy())
|
| 337 |
+
return out
|
| 338 |
+
|
| 339 |
+
# def paired_random_crop_tip18(img_gts, img_lqs, lq_patch_size, scale, gt_path):
|
| 340 |
+
# """Paired random crop.
|
| 341 |
+
|
| 342 |
+
# It crops lists of lq and gt images with corresponding locations.
|
| 343 |
+
|
| 344 |
+
# Args:
|
| 345 |
+
# img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
| 346 |
+
# should have the same shape. If the input is an ndarray, it will
|
| 347 |
+
# be transformed to a list containing itself.
|
| 348 |
+
# img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 349 |
+
# should have the same shape. If the input is an ndarray, it will
|
| 350 |
+
# be transformed to a list containing itself.
|
| 351 |
+
# lq_patch_size (int): LQ patch size.
|
| 352 |
+
# scale (int): Scale factor.
|
| 353 |
+
# gt_path (str): Path to ground-truth.
|
| 354 |
+
|
| 355 |
+
# Returns:
|
| 356 |
+
# list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 357 |
+
# only have one element, just return ndarray.
|
| 358 |
+
# """
|
| 359 |
+
|
| 360 |
+
# if not isinstance(img_gts, list):
|
| 361 |
+
# img_gts = [img_gts]
|
| 362 |
+
# if not isinstance(img_lqs, list):
|
| 363 |
+
# img_lqs = [img_lqs]
|
| 364 |
+
|
| 365 |
+
# h_lq, w_lq, _ = img_lqs[0].shape
|
| 366 |
+
# h_gt, w_gt, _ = img_gts[0].shape
|
| 367 |
+
# gt_patch_size = int(lq_patch_size * scale)
|
| 368 |
+
|
| 369 |
+
# if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 370 |
+
# raise ValueError(
|
| 371 |
+
# f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 372 |
+
# f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 373 |
+
# if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 374 |
+
# raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 375 |
+
# f'({lq_patch_size}, {lq_patch_size}). '
|
| 376 |
+
# f'Please remove {gt_path}.')
|
| 377 |
+
|
| 378 |
+
# #pre process
|
| 379 |
+
# # w, h = img.size
|
| 380 |
+
# # region = img.crop((1 + int(0.15 * w), 1 + int(0.15 * h), int(0.85 * w), int(0.85 * h)))
|
| 381 |
+
# # region = region.resize((286, 286), Image.BILINEAR)
|
| 382 |
+
# # crop lq patch
|
| 383 |
+
# w = w_lq,h =h_lq
|
| 384 |
+
# img_lqs = [
|
| 385 |
+
# # v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...]
|
| 386 |
+
# for v in img_lqs:
|
| 387 |
+
# # v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...]
|
| 388 |
+
# img = Image.fromarray(v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...])
|
| 389 |
+
# img = img.resize((286, 286), Image.BILINEAR)
|
| 390 |
+
|
| 391 |
+
# ]
|
| 392 |
+
# img_gts = [
|
| 393 |
+
# v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...]
|
| 394 |
+
# for v in img_gts
|
| 395 |
+
# ]
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# # randomly choose top and left coordinates for lq patch
|
| 400 |
+
# top = random.randint(0, h_lq - lq_patch_size)
|
| 401 |
+
# left = random.randint(0, w_lq - lq_patch_size)
|
| 402 |
+
|
| 403 |
+
# # crop lq patch
|
| 404 |
+
# img_lqs = [
|
| 405 |
+
# v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
| 406 |
+
# for v in img_lqs
|
| 407 |
+
# ]
|
| 408 |
+
|
| 409 |
+
# # crop corresponding gt patch
|
| 410 |
+
# top_gt, left_gt = int(top * scale), int(left * scale)
|
| 411 |
+
# img_gts = [
|
| 412 |
+
# v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
|
| 413 |
+
# for v in img_gts
|
| 414 |
+
# ]
|
| 415 |
+
# if len(img_gts) == 1:
|
| 416 |
+
# img_gts = img_gts[0]
|
| 417 |
+
# if len(img_lqs) == 1:
|
| 418 |
+
# img_lqs = img_lqs[0]
|
| 419 |
+
# return img_gts, img_lqs
|
| 420 |
+
|
| 421 |
+
# def paired_center_crop_tip18(img_gts, img_lqs, lq_patch_size, scale, gt_path):
|
| 422 |
+
# """Paired random crop.
|
| 423 |
+
|
| 424 |
+
# It crops lists of lq and gt images with corresponding locations.
|
| 425 |
+
|
| 426 |
+
# Args:
|
| 427 |
+
# img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
| 428 |
+
# should have the same shape. If the input is an ndarray, it will
|
| 429 |
+
# be transformed to a list containing itself.
|
| 430 |
+
# img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 431 |
+
# should have the same shape. If the input is an ndarray, it will
|
| 432 |
+
# be transformed to a list containing itself.
|
| 433 |
+
# lq_patch_size (int): LQ patch size.
|
| 434 |
+
# scale (int): Scale factor.
|
| 435 |
+
# gt_path (str): Path to ground-truth.
|
| 436 |
+
|
| 437 |
+
# Returns:
|
| 438 |
+
# list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 439 |
+
# only have one element, just return ndarray.
|
| 440 |
+
# """
|
| 441 |
+
|
| 442 |
+
# if not isinstance(img_gts, list):
|
| 443 |
+
# img_gts = [img_gts]
|
| 444 |
+
# if not isinstance(img_lqs, list):
|
| 445 |
+
# img_lqs = [img_lqs]
|
| 446 |
+
|
| 447 |
+
# h_lq, w_lq, _ = img_lqs[0].shape
|
| 448 |
+
# h_gt, w_gt, _ = img_gts[0].shape
|
| 449 |
+
# gt_patch_size = int(lq_patch_size * scale)
|
| 450 |
+
|
| 451 |
+
# if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 452 |
+
# raise ValueError(
|
| 453 |
+
# f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 454 |
+
# f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 455 |
+
# if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 456 |
+
# raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 457 |
+
# f'({lq_patch_size}, {lq_patch_size}). '
|
| 458 |
+
# f'Please remove {gt_path}.')
|
| 459 |
+
|
| 460 |
+
# # randomly choose top and left coordinates for lq patch
|
| 461 |
+
# top = (h_lq - lq_patch_size)//2#random.randint(0, h_lq - lq_patch_size)
|
| 462 |
+
# left = (w_lq - lq_patch_size)//2#random.randint(0, w_lq - lq_patch_size)
|
| 463 |
+
|
| 464 |
+
# # crop lq patch
|
| 465 |
+
# img_lqs = [
|
| 466 |
+
# v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
| 467 |
+
# for v in img_lqs
|
| 468 |
+
# ]
|
| 469 |
+
|
| 470 |
+
# # crop corresponding gt patch
|
| 471 |
+
# top_gt, left_gt = int(top * scale), int(left * scale)
|
| 472 |
+
# img_gts = [
|
| 473 |
+
# v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
|
| 474 |
+
# for v in img_gts
|
| 475 |
+
# ]
|
| 476 |
+
# if len(img_gts) == 1:
|
| 477 |
+
# img_gts = img_gts[0]
|
| 478 |
+
# if len(img_lqs) == 1:
|
| 479 |
+
# img_lqs = img_lqs[0]
|
| 480 |
+
# return img_gts, img_lqs
|
basicsr/data/video_test_dataset.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import torch
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from torch.utils import data as data
|
| 5 |
+
|
| 6 |
+
from basicsr.data.data_util import (duf_downsample, generate_frame_indices,
|
| 7 |
+
read_img_seq)
|
| 8 |
+
from basicsr.utils import get_root_logger, scandir
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VideoTestDataset(data.Dataset):
|
| 12 |
+
"""Video test dataset.
|
| 13 |
+
|
| 14 |
+
Supported datasets: Vid4, REDS4, REDSofficial.
|
| 15 |
+
More generally, it supports testing dataset with following structures:
|
| 16 |
+
|
| 17 |
+
dataroot
|
| 18 |
+
├── subfolder1
|
| 19 |
+
├── frame000
|
| 20 |
+
├── frame001
|
| 21 |
+
├── ...
|
| 22 |
+
├── subfolder1
|
| 23 |
+
├── frame000
|
| 24 |
+
├── frame001
|
| 25 |
+
├── ...
|
| 26 |
+
├── ...
|
| 27 |
+
|
| 28 |
+
For testing datasets, there is no need to prepare LMDB files.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 32 |
+
dataroot_gt (str): Data root path for gt.
|
| 33 |
+
dataroot_lq (str): Data root path for lq.
|
| 34 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 35 |
+
cache_data (bool): Whether to cache testing datasets.
|
| 36 |
+
name (str): Dataset name.
|
| 37 |
+
meta_info_file (str): The path to the file storing the list of test
|
| 38 |
+
folders. If not provided, all the folders in the dataroot will
|
| 39 |
+
be used.
|
| 40 |
+
num_frame (int): Window size for input frames.
|
| 41 |
+
padding (str): Padding mode.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, opt):
|
| 45 |
+
super(VideoTestDataset, self).__init__()
|
| 46 |
+
self.opt = opt
|
| 47 |
+
self.cache_data = opt['cache_data']
|
| 48 |
+
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
| 49 |
+
self.data_info = {
|
| 50 |
+
'lq_path': [],
|
| 51 |
+
'gt_path': [],
|
| 52 |
+
'folder': [],
|
| 53 |
+
'idx': [],
|
| 54 |
+
'border': []
|
| 55 |
+
}
|
| 56 |
+
# file client (io backend)
|
| 57 |
+
self.file_client = None
|
| 58 |
+
self.io_backend_opt = opt['io_backend']
|
| 59 |
+
assert self.io_backend_opt[
|
| 60 |
+
'type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
| 61 |
+
|
| 62 |
+
logger = get_root_logger()
|
| 63 |
+
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
| 64 |
+
self.imgs_lq, self.imgs_gt = {}, {}
|
| 65 |
+
if 'meta_info_file' in opt:
|
| 66 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 67 |
+
subfolders = [line.split(' ')[0] for line in fin]
|
| 68 |
+
subfolders_lq = [
|
| 69 |
+
osp.join(self.lq_root, key) for key in subfolders
|
| 70 |
+
]
|
| 71 |
+
subfolders_gt = [
|
| 72 |
+
osp.join(self.gt_root, key) for key in subfolders
|
| 73 |
+
]
|
| 74 |
+
else:
|
| 75 |
+
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
|
| 76 |
+
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
|
| 77 |
+
|
| 78 |
+
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
|
| 79 |
+
for subfolder_lq, subfolder_gt in zip(subfolders_lq,
|
| 80 |
+
subfolders_gt):
|
| 81 |
+
# get frame list for lq and gt
|
| 82 |
+
subfolder_name = osp.basename(subfolder_lq)
|
| 83 |
+
img_paths_lq = sorted(
|
| 84 |
+
list(scandir(subfolder_lq, full_path=True)))
|
| 85 |
+
img_paths_gt = sorted(
|
| 86 |
+
list(scandir(subfolder_gt, full_path=True)))
|
| 87 |
+
|
| 88 |
+
max_idx = len(img_paths_lq)
|
| 89 |
+
assert max_idx == len(img_paths_gt), (
|
| 90 |
+
f'Different number of images in lq ({max_idx})'
|
| 91 |
+
f' and gt folders ({len(img_paths_gt)})')
|
| 92 |
+
|
| 93 |
+
self.data_info['lq_path'].extend(img_paths_lq)
|
| 94 |
+
self.data_info['gt_path'].extend(img_paths_gt)
|
| 95 |
+
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
| 96 |
+
for i in range(max_idx):
|
| 97 |
+
self.data_info['idx'].append(f'{i}/{max_idx}')
|
| 98 |
+
border_l = [0] * max_idx
|
| 99 |
+
for i in range(self.opt['num_frame'] // 2):
|
| 100 |
+
border_l[i] = 1
|
| 101 |
+
border_l[max_idx - i - 1] = 1
|
| 102 |
+
self.data_info['border'].extend(border_l)
|
| 103 |
+
|
| 104 |
+
# cache data or save the frame list
|
| 105 |
+
if self.cache_data:
|
| 106 |
+
logger.info(
|
| 107 |
+
f'Cache {subfolder_name} for VideoTestDataset...')
|
| 108 |
+
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
|
| 109 |
+
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
|
| 110 |
+
else:
|
| 111 |
+
self.imgs_lq[subfolder_name] = img_paths_lq
|
| 112 |
+
self.imgs_gt[subfolder_name] = img_paths_gt
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f'Non-supported video test dataset: {type(opt["name"])}')
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, index):
|
| 118 |
+
folder = self.data_info['folder'][index]
|
| 119 |
+
idx, max_idx = self.data_info['idx'][index].split('/')
|
| 120 |
+
idx, max_idx = int(idx), int(max_idx)
|
| 121 |
+
border = self.data_info['border'][index]
|
| 122 |
+
lq_path = self.data_info['lq_path'][index]
|
| 123 |
+
|
| 124 |
+
select_idx = generate_frame_indices(
|
| 125 |
+
idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
| 126 |
+
|
| 127 |
+
if self.cache_data:
|
| 128 |
+
imgs_lq = self.imgs_lq[folder].index_select(
|
| 129 |
+
0, torch.LongTensor(select_idx))
|
| 130 |
+
img_gt = self.imgs_gt[folder][idx]
|
| 131 |
+
else:
|
| 132 |
+
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
| 133 |
+
imgs_lq = read_img_seq(img_paths_lq)
|
| 134 |
+
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
|
| 135 |
+
img_gt.squeeze_(0)
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
'lq': imgs_lq, # (t, c, h, w)
|
| 139 |
+
'gt': img_gt, # (c, h, w)
|
| 140 |
+
'folder': folder, # folder name
|
| 141 |
+
'idx': self.data_info['idx'][index], # e.g., 0/99
|
| 142 |
+
'border': border, # 1 for border, 0 for non-border
|
| 143 |
+
'lq_path': lq_path # center frame
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def __len__(self):
|
| 147 |
+
return len(self.data_info['gt_path'])
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class VideoTestVimeo90KDataset(data.Dataset):
|
| 151 |
+
"""Video test dataset for Vimeo90k-Test dataset.
|
| 152 |
+
|
| 153 |
+
It only keeps the center frame for testing.
|
| 154 |
+
For testing datasets, there is no need to prepare LMDB files.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 158 |
+
dataroot_gt (str): Data root path for gt.
|
| 159 |
+
dataroot_lq (str): Data root path for lq.
|
| 160 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 161 |
+
cache_data (bool): Whether to cache testing datasets.
|
| 162 |
+
name (str): Dataset name.
|
| 163 |
+
meta_info_file (str): The path to the file storing the list of test
|
| 164 |
+
folders. If not provided, all the folders in the dataroot will
|
| 165 |
+
be used.
|
| 166 |
+
num_frame (int): Window size for input frames.
|
| 167 |
+
padding (str): Padding mode.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, opt):
|
| 171 |
+
super(VideoTestVimeo90KDataset, self).__init__()
|
| 172 |
+
self.opt = opt
|
| 173 |
+
self.cache_data = opt['cache_data']
|
| 174 |
+
if self.cache_data:
|
| 175 |
+
raise NotImplementedError(
|
| 176 |
+
'cache_data in Vimeo90K-Test dataset is not implemented.')
|
| 177 |
+
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
| 178 |
+
self.data_info = {
|
| 179 |
+
'lq_path': [],
|
| 180 |
+
'gt_path': [],
|
| 181 |
+
'folder': [],
|
| 182 |
+
'idx': [],
|
| 183 |
+
'border': []
|
| 184 |
+
}
|
| 185 |
+
neighbor_list = [
|
| 186 |
+
i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
# file client (io backend)
|
| 190 |
+
self.file_client = None
|
| 191 |
+
self.io_backend_opt = opt['io_backend']
|
| 192 |
+
assert self.io_backend_opt[
|
| 193 |
+
'type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
| 194 |
+
|
| 195 |
+
logger = get_root_logger()
|
| 196 |
+
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
| 197 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 198 |
+
subfolders = [line.split(' ')[0] for line in fin]
|
| 199 |
+
for idx, subfolder in enumerate(subfolders):
|
| 200 |
+
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
|
| 201 |
+
self.data_info['gt_path'].append(gt_path)
|
| 202 |
+
lq_paths = [
|
| 203 |
+
osp.join(self.lq_root, subfolder, f'im{i}.png')
|
| 204 |
+
for i in neighbor_list
|
| 205 |
+
]
|
| 206 |
+
self.data_info['lq_path'].append(lq_paths)
|
| 207 |
+
self.data_info['folder'].append('vimeo90k')
|
| 208 |
+
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
|
| 209 |
+
self.data_info['border'].append(0)
|
| 210 |
+
|
| 211 |
+
def __getitem__(self, index):
|
| 212 |
+
lq_path = self.data_info['lq_path'][index]
|
| 213 |
+
gt_path = self.data_info['gt_path'][index]
|
| 214 |
+
imgs_lq = read_img_seq(lq_path)
|
| 215 |
+
img_gt = read_img_seq([gt_path])
|
| 216 |
+
img_gt.squeeze_(0)
|
| 217 |
+
|
| 218 |
+
return {
|
| 219 |
+
'lq': imgs_lq, # (t, c, h, w)
|
| 220 |
+
'gt': img_gt, # (c, h, w)
|
| 221 |
+
'folder': self.data_info['folder'][index], # folder name
|
| 222 |
+
'idx': self.data_info['idx'][index], # e.g., 0/843
|
| 223 |
+
'border': self.data_info['border'][index], # 0 for non-border
|
| 224 |
+
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def __len__(self):
|
| 228 |
+
return len(self.data_info['gt_path'])
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class VideoTestDUFDataset(VideoTestDataset):
|
| 232 |
+
""" Video test dataset for DUF dataset.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
opt (dict): Config for train dataset.
|
| 236 |
+
Most of keys are the same as VideoTestDataset.
|
| 237 |
+
It has the follwing extra keys:
|
| 238 |
+
|
| 239 |
+
use_duf_downsampling (bool): Whether to use duf downsampling to
|
| 240 |
+
generate low-resolution frames.
|
| 241 |
+
scale (bool): Scale, which will be added automatically.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __getitem__(self, index):
|
| 245 |
+
folder = self.data_info['folder'][index]
|
| 246 |
+
idx, max_idx = self.data_info['idx'][index].split('/')
|
| 247 |
+
idx, max_idx = int(idx), int(max_idx)
|
| 248 |
+
border = self.data_info['border'][index]
|
| 249 |
+
lq_path = self.data_info['lq_path'][index]
|
| 250 |
+
|
| 251 |
+
select_idx = generate_frame_indices(
|
| 252 |
+
idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
| 253 |
+
|
| 254 |
+
if self.cache_data:
|
| 255 |
+
if self.opt['use_duf_downsampling']:
|
| 256 |
+
# read imgs_gt to generate low-resolution frames
|
| 257 |
+
imgs_lq = self.imgs_gt[folder].index_select(
|
| 258 |
+
0, torch.LongTensor(select_idx))
|
| 259 |
+
imgs_lq = duf_downsample(
|
| 260 |
+
imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
| 261 |
+
else:
|
| 262 |
+
imgs_lq = self.imgs_lq[folder].index_select(
|
| 263 |
+
0, torch.LongTensor(select_idx))
|
| 264 |
+
img_gt = self.imgs_gt[folder][idx]
|
| 265 |
+
else:
|
| 266 |
+
if self.opt['use_duf_downsampling']:
|
| 267 |
+
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
|
| 268 |
+
# read imgs_gt to generate low-resolution frames
|
| 269 |
+
imgs_lq = read_img_seq(
|
| 270 |
+
img_paths_lq,
|
| 271 |
+
require_mod_crop=True,
|
| 272 |
+
scale=self.opt['scale'])
|
| 273 |
+
imgs_lq = duf_downsample(
|
| 274 |
+
imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
| 275 |
+
else:
|
| 276 |
+
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
| 277 |
+
imgs_lq = read_img_seq(img_paths_lq)
|
| 278 |
+
img_gt = read_img_seq([self.imgs_gt[folder][idx]],
|
| 279 |
+
require_mod_crop=True,
|
| 280 |
+
scale=self.opt['scale'])
|
| 281 |
+
img_gt.squeeze_(0)
|
| 282 |
+
|
| 283 |
+
return {
|
| 284 |
+
'lq': imgs_lq, # (t, c, h, w)
|
| 285 |
+
'gt': img_gt, # (c, h, w)
|
| 286 |
+
'folder': folder, # folder name
|
| 287 |
+
'idx': self.data_info['idx'][index], # e.g., 0/99
|
| 288 |
+
'border': border, # 1 for border, 0 for non-border
|
| 289 |
+
'lq_path': lq_path # center frame
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class VideoRecurrentTestDataset(VideoTestDataset):
|
| 294 |
+
"""Video test dataset for recurrent architectures, which takes LR video
|
| 295 |
+
frames as input and output corresponding HR video frames.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
Same as VideoTestDataset.
|
| 299 |
+
Unused opt:
|
| 300 |
+
padding (str): Padding mode.
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(self, opt):
|
| 305 |
+
super(VideoRecurrentTestDataset, self).__init__(opt)
|
| 306 |
+
# Find unique folder strings
|
| 307 |
+
self.folders = sorted(list(set(self.data_info['folder'])))
|
| 308 |
+
|
| 309 |
+
def __getitem__(self, index):
|
| 310 |
+
folder = self.folders[index]
|
| 311 |
+
|
| 312 |
+
if self.cache_data:
|
| 313 |
+
imgs_lq = self.imgs_lq[folder]
|
| 314 |
+
imgs_gt = self.imgs_gt[folder]
|
| 315 |
+
else:
|
| 316 |
+
raise NotImplementedError('Without cache_data is not implemented.')
|
| 317 |
+
|
| 318 |
+
return {
|
| 319 |
+
'lq': imgs_lq,
|
| 320 |
+
'gt': imgs_gt,
|
| 321 |
+
'folder': folder,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
def __len__(self):
|
| 325 |
+
return len(self.folders)
|
basicsr/data/vimeo90k_dataset.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from torch.utils import data as data
|
| 5 |
+
|
| 6 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 7 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Vimeo90KDataset(data.Dataset):
|
| 11 |
+
"""Vimeo90K dataset for training.
|
| 12 |
+
|
| 13 |
+
The keys are generated from a meta info txt file.
|
| 14 |
+
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
|
| 15 |
+
|
| 16 |
+
Each line contains:
|
| 17 |
+
1. clip name; 2. frame number; 3. image shape, seperated by a white space.
|
| 18 |
+
Examples:
|
| 19 |
+
00001/0001 7 (256,448,3)
|
| 20 |
+
00001/0002 7 (256,448,3)
|
| 21 |
+
|
| 22 |
+
Key examples: "00001/0001"
|
| 23 |
+
GT (gt): Ground-Truth;
|
| 24 |
+
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
| 25 |
+
|
| 26 |
+
The neighboring frame list for different num_frame:
|
| 27 |
+
num_frame | frame list
|
| 28 |
+
1 | 4
|
| 29 |
+
3 | 3,4,5
|
| 30 |
+
5 | 2,3,4,5,6
|
| 31 |
+
7 | 1,2,3,4,5,6,7
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
| 35 |
+
dataroot_gt (str): Data root path for gt.
|
| 36 |
+
dataroot_lq (str): Data root path for lq.
|
| 37 |
+
meta_info_file (str): Path for meta information file.
|
| 38 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 39 |
+
|
| 40 |
+
num_frame (int): Window size for input frames.
|
| 41 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 42 |
+
random_reverse (bool): Random reverse input frames.
|
| 43 |
+
use_flip (bool): Use horizontal flips.
|
| 44 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
| 45 |
+
and w for implementation).
|
| 46 |
+
|
| 47 |
+
scale (bool): Scale, which will be added automatically.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, opt):
|
| 51 |
+
super(Vimeo90KDataset, self).__init__()
|
| 52 |
+
self.opt = opt
|
| 53 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
|
| 54 |
+
opt['dataroot_lq'])
|
| 55 |
+
|
| 56 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
| 57 |
+
self.keys = [line.split(' ')[0] for line in fin]
|
| 58 |
+
|
| 59 |
+
# file client (io backend)
|
| 60 |
+
self.file_client = None
|
| 61 |
+
self.io_backend_opt = opt['io_backend']
|
| 62 |
+
self.is_lmdb = False
|
| 63 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 64 |
+
self.is_lmdb = True
|
| 65 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
| 66 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 67 |
+
|
| 68 |
+
# indices of input images
|
| 69 |
+
self.neighbor_list = [
|
| 70 |
+
i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# temporal augmentation configs
|
| 74 |
+
self.random_reverse = opt['random_reverse']
|
| 75 |
+
logger = get_root_logger()
|
| 76 |
+
logger.info(f'Random reverse is {self.random_reverse}.')
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, index):
|
| 79 |
+
if self.file_client is None:
|
| 80 |
+
self.file_client = FileClient(
|
| 81 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 82 |
+
|
| 83 |
+
# random reverse
|
| 84 |
+
if self.random_reverse and random.random() < 0.5:
|
| 85 |
+
self.neighbor_list.reverse()
|
| 86 |
+
|
| 87 |
+
scale = self.opt['scale']
|
| 88 |
+
gt_size = self.opt['gt_size']
|
| 89 |
+
key = self.keys[index]
|
| 90 |
+
clip, seq = key.split('/') # key example: 00001/0001
|
| 91 |
+
|
| 92 |
+
# get the GT frame (im4.png)
|
| 93 |
+
if self.is_lmdb:
|
| 94 |
+
img_gt_path = f'{key}/im4'
|
| 95 |
+
else:
|
| 96 |
+
img_gt_path = self.gt_root / clip / seq / 'im4.png'
|
| 97 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
| 98 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 99 |
+
|
| 100 |
+
# get the neighboring LQ frames
|
| 101 |
+
img_lqs = []
|
| 102 |
+
for neighbor in self.neighbor_list:
|
| 103 |
+
if self.is_lmdb:
|
| 104 |
+
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
| 105 |
+
else:
|
| 106 |
+
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
| 107 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
| 108 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 109 |
+
img_lqs.append(img_lq)
|
| 110 |
+
|
| 111 |
+
# randomly crop
|
| 112 |
+
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
|
| 113 |
+
img_gt_path)
|
| 114 |
+
|
| 115 |
+
# augmentation - flip, rotate
|
| 116 |
+
img_lqs.append(img_gt)
|
| 117 |
+
img_results = augment(img_lqs, self.opt['use_flip'],
|
| 118 |
+
self.opt['use_rot'])
|
| 119 |
+
|
| 120 |
+
img_results = img2tensor(img_results)
|
| 121 |
+
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
| 122 |
+
img_gt = img_results[-1]
|
| 123 |
+
|
| 124 |
+
# img_lqs: (t, c, h, w)
|
| 125 |
+
# img_gt: (c, h, w)
|
| 126 |
+
# key: str
|
| 127 |
+
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
| 128 |
+
|
| 129 |
+
def __len__(self):
|
| 130 |
+
return len(self.keys)
|
basicsr/metrics/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .niqe import calculate_niqe
|
| 2 |
+
from .psnr_ssim import calculate_psnr, calculate_ssim
|
| 3 |
+
|
| 4 |
+
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
|
basicsr/metrics/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (311 Bytes). View file
|
|
|
basicsr/metrics/__pycache__/metric_util.cpython-37.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
basicsr/metrics/__pycache__/niqe.cpython-37.pyc
ADDED
|
Binary file (6.46 kB). View file
|
|
|
basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc
ADDED
|
Binary file (7.67 kB). View file
|
|
|
basicsr/metrics/fid.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from scipy import linalg
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from basicsr.models.archs.inception import InceptionV3
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_patched_inception_v3(device='cuda',
|
| 11 |
+
resize_input=True,
|
| 12 |
+
normalize_input=False):
|
| 13 |
+
# we may not resize the input, but in [rosinality/stylegan2-pytorch] it
|
| 14 |
+
# does resize the input.
|
| 15 |
+
inception = InceptionV3([3],
|
| 16 |
+
resize_input=resize_input,
|
| 17 |
+
normalize_input=normalize_input)
|
| 18 |
+
inception = nn.DataParallel(inception).eval().to(device)
|
| 19 |
+
return inception
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def extract_inception_features(data_generator,
|
| 24 |
+
inception,
|
| 25 |
+
len_generator=None,
|
| 26 |
+
device='cuda'):
|
| 27 |
+
"""Extract inception features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data_generator (generator): A data generator.
|
| 31 |
+
inception (nn.Module): Inception model.
|
| 32 |
+
len_generator (int): Length of the data_generator to show the
|
| 33 |
+
progressbar. Default: None.
|
| 34 |
+
device (str): Device. Default: cuda.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tensor: Extracted features.
|
| 38 |
+
"""
|
| 39 |
+
if len_generator is not None:
|
| 40 |
+
pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
|
| 41 |
+
else:
|
| 42 |
+
pbar = None
|
| 43 |
+
features = []
|
| 44 |
+
|
| 45 |
+
for data in data_generator:
|
| 46 |
+
if pbar:
|
| 47 |
+
pbar.update(1)
|
| 48 |
+
data = data.to(device)
|
| 49 |
+
feature = inception(data)[0].view(data.shape[0], -1)
|
| 50 |
+
features.append(feature.to('cpu'))
|
| 51 |
+
if pbar:
|
| 52 |
+
pbar.close()
|
| 53 |
+
features = torch.cat(features, 0)
|
| 54 |
+
return features
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 58 |
+
"""Numpy implementation of the Frechet Distance.
|
| 59 |
+
|
| 60 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 61 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 62 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 63 |
+
Stable version by Dougal J. Sutherland.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
mu1 (np.array): The sample mean over activations.
|
| 67 |
+
sigma1 (np.array): The covariance matrix over activations for
|
| 68 |
+
generated samples.
|
| 69 |
+
mu2 (np.array): The sample mean over activations, precalculated on an
|
| 70 |
+
representative data set.
|
| 71 |
+
sigma2 (np.array): The covariance matrix over activations,
|
| 72 |
+
precalculated on an representative data set.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
float: The Frechet Distance.
|
| 76 |
+
"""
|
| 77 |
+
assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
|
| 78 |
+
assert sigma1.shape == sigma2.shape, (
|
| 79 |
+
'Two covariances have different dimensions')
|
| 80 |
+
|
| 81 |
+
cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
|
| 82 |
+
|
| 83 |
+
# Product might be almost singular
|
| 84 |
+
if not np.isfinite(cov_sqrt).all():
|
| 85 |
+
print('Product of cov matrices is singular. Adding {eps} to diagonal '
|
| 86 |
+
'of cov estimates')
|
| 87 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 88 |
+
cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
|
| 89 |
+
|
| 90 |
+
# Numerical error might give slight imaginary component
|
| 91 |
+
if np.iscomplexobj(cov_sqrt):
|
| 92 |
+
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
|
| 93 |
+
m = np.max(np.abs(cov_sqrt.imag))
|
| 94 |
+
raise ValueError(f'Imaginary component {m}')
|
| 95 |
+
cov_sqrt = cov_sqrt.real
|
| 96 |
+
|
| 97 |
+
mean_diff = mu1 - mu2
|
| 98 |
+
mean_norm = mean_diff @ mean_diff
|
| 99 |
+
trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
|
| 100 |
+
fid = mean_norm + trace
|
| 101 |
+
|
| 102 |
+
return fid
|
basicsr/metrics/metric_util.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from basicsr.utils.matlab_functions import bgr2ycbcr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def reorder_image(img, input_order='HWC'):
|
| 7 |
+
"""Reorder images to 'HWC' order.
|
| 8 |
+
|
| 9 |
+
If the input_order is (h, w), return (h, w, 1);
|
| 10 |
+
If the input_order is (c, h, w), return (h, w, c);
|
| 11 |
+
If the input_order is (h, w, c), return as it is.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
img (ndarray): Input image.
|
| 15 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 16 |
+
If the input image shape is (h, w), input_order will not have
|
| 17 |
+
effects. Default: 'HWC'.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
ndarray: reordered image.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
if input_order not in ['HWC', 'CHW']:
|
| 24 |
+
raise ValueError(
|
| 25 |
+
f'Wrong input_order {input_order}. Supported input_orders are '
|
| 26 |
+
"'HWC' and 'CHW'")
|
| 27 |
+
if len(img.shape) == 2:
|
| 28 |
+
img = img[..., None]
|
| 29 |
+
if input_order == 'CHW':
|
| 30 |
+
img = img.transpose(1, 2, 0)
|
| 31 |
+
return img
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_y_channel(img):
|
| 35 |
+
"""Change to Y channel of YCbCr.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
img (ndarray): Images with range [0, 255].
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
| 42 |
+
"""
|
| 43 |
+
img = img.astype(np.float32) / 255.
|
| 44 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
| 45 |
+
img = bgr2ycbcr(img, y_only=True)
|
| 46 |
+
img = img[..., None]
|
| 47 |
+
return img * 255.
|
basicsr/metrics/niqe.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.ndimage.filters import convolve
|
| 5 |
+
from scipy.special import gamma
|
| 6 |
+
|
| 7 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def estimate_aggd_param(block):
|
| 11 |
+
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
block (ndarray): 2D Image block.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
|
| 18 |
+
distribution (Estimating the parames in Equation 7 in the paper).
|
| 19 |
+
"""
|
| 20 |
+
block = block.flatten()
|
| 21 |
+
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
|
| 22 |
+
gam_reciprocal = np.reciprocal(gam)
|
| 23 |
+
r_gam = np.square(gamma(gam_reciprocal * 2)) / (
|
| 24 |
+
gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
|
| 25 |
+
|
| 26 |
+
left_std = np.sqrt(np.mean(block[block < 0]**2))
|
| 27 |
+
right_std = np.sqrt(np.mean(block[block > 0]**2))
|
| 28 |
+
gammahat = left_std / right_std
|
| 29 |
+
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
|
| 30 |
+
rhatnorm = (rhat * (gammahat**3 + 1) *
|
| 31 |
+
(gammahat + 1)) / ((gammahat**2 + 1)**2)
|
| 32 |
+
array_position = np.argmin((r_gam - rhatnorm)**2)
|
| 33 |
+
|
| 34 |
+
alpha = gam[array_position]
|
| 35 |
+
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
| 36 |
+
beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
| 37 |
+
return (alpha, beta_l, beta_r)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_feature(block):
|
| 41 |
+
"""Compute features.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
block (ndarray): 2D Image block.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
list: Features with length of 18.
|
| 48 |
+
"""
|
| 49 |
+
feat = []
|
| 50 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block)
|
| 51 |
+
feat.extend([alpha, (beta_l + beta_r) / 2])
|
| 52 |
+
|
| 53 |
+
# distortions disturb the fairly regular structure of natural images.
|
| 54 |
+
# This deviation can be captured by analyzing the sample distribution of
|
| 55 |
+
# the products of pairs of adjacent coefficients computed along
|
| 56 |
+
# horizontal, vertical and diagonal orientations.
|
| 57 |
+
shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
|
| 58 |
+
for i in range(len(shifts)):
|
| 59 |
+
shifted_block = np.roll(block, shifts[i], axis=(0, 1))
|
| 60 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
|
| 61 |
+
# Eq. 8
|
| 62 |
+
mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
|
| 63 |
+
feat.extend([alpha, mean, beta_l, beta_r])
|
| 64 |
+
return feat
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def niqe(img,
|
| 68 |
+
mu_pris_param,
|
| 69 |
+
cov_pris_param,
|
| 70 |
+
gaussian_window,
|
| 71 |
+
block_size_h=96,
|
| 72 |
+
block_size_w=96):
|
| 73 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
| 74 |
+
|
| 75 |
+
Ref: Making a "Completely Blind" Image Quality Analyzer.
|
| 76 |
+
This implementation could produce almost the same results as the official
|
| 77 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
| 78 |
+
|
| 79 |
+
Note that we do not include block overlap height and width, since they are
|
| 80 |
+
always 0 in the official implementation.
|
| 81 |
+
|
| 82 |
+
For good performance, it is advisable by the official implemtation to
|
| 83 |
+
divide the distorted image in to the same size patched as used for the
|
| 84 |
+
construction of multivariate Gaussian model.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
img (ndarray): Input image whose quality needs to be computed. The
|
| 88 |
+
image must be a gray or Y (of YCbCr) image with shape (h, w).
|
| 89 |
+
Range [0, 255] with float type.
|
| 90 |
+
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
|
| 91 |
+
model calculated on the pristine dataset.
|
| 92 |
+
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
|
| 93 |
+
Gaussian model calculated on the pristine dataset.
|
| 94 |
+
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
|
| 95 |
+
image.
|
| 96 |
+
block_size_h (int): Height of the blocks in to which image is divided.
|
| 97 |
+
Default: 96 (the official recommended value).
|
| 98 |
+
block_size_w (int): Width of the blocks in to which image is divided.
|
| 99 |
+
Default: 96 (the official recommended value).
|
| 100 |
+
"""
|
| 101 |
+
assert img.ndim == 2, (
|
| 102 |
+
'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
|
| 103 |
+
# crop image
|
| 104 |
+
h, w = img.shape
|
| 105 |
+
num_block_h = math.floor(h / block_size_h)
|
| 106 |
+
num_block_w = math.floor(w / block_size_w)
|
| 107 |
+
img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
|
| 108 |
+
|
| 109 |
+
distparam = [] # dist param is actually the multiscale features
|
| 110 |
+
for scale in (1, 2): # perform on two scales (1, 2)
|
| 111 |
+
mu = convolve(img, gaussian_window, mode='nearest')
|
| 112 |
+
sigma = np.sqrt(
|
| 113 |
+
np.abs(
|
| 114 |
+
convolve(np.square(img), gaussian_window, mode='nearest') -
|
| 115 |
+
np.square(mu)))
|
| 116 |
+
# normalize, as in Eq. 1 in the paper
|
| 117 |
+
img_nomalized = (img - mu) / (sigma + 1)
|
| 118 |
+
|
| 119 |
+
feat = []
|
| 120 |
+
for idx_w in range(num_block_w):
|
| 121 |
+
for idx_h in range(num_block_h):
|
| 122 |
+
# process ecah block
|
| 123 |
+
block = img_nomalized[idx_h * block_size_h //
|
| 124 |
+
scale:(idx_h + 1) * block_size_h //
|
| 125 |
+
scale, idx_w * block_size_w //
|
| 126 |
+
scale:(idx_w + 1) * block_size_w //
|
| 127 |
+
scale]
|
| 128 |
+
feat.append(compute_feature(block))
|
| 129 |
+
|
| 130 |
+
distparam.append(np.array(feat))
|
| 131 |
+
# TODO: matlab bicubic downsample with anti-aliasing
|
| 132 |
+
# for simplicity, now we use opencv instead, which will result in
|
| 133 |
+
# a slight difference.
|
| 134 |
+
if scale == 1:
|
| 135 |
+
h, w = img.shape
|
| 136 |
+
img = cv2.resize(
|
| 137 |
+
img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
|
| 138 |
+
img = img * 255.
|
| 139 |
+
|
| 140 |
+
distparam = np.concatenate(distparam, axis=1)
|
| 141 |
+
|
| 142 |
+
# fit a MVG (multivariate Gaussian) model to distorted patch features
|
| 143 |
+
mu_distparam = np.nanmean(distparam, axis=0)
|
| 144 |
+
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
|
| 145 |
+
distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
|
| 146 |
+
cov_distparam = np.cov(distparam_no_nan, rowvar=False)
|
| 147 |
+
|
| 148 |
+
# compute niqe quality, Eq. 10 in the paper
|
| 149 |
+
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
|
| 150 |
+
quality = np.matmul(
|
| 151 |
+
np.matmul((mu_pris_param - mu_distparam), invcov_param),
|
| 152 |
+
np.transpose((mu_pris_param - mu_distparam)))
|
| 153 |
+
quality = np.sqrt(quality)
|
| 154 |
+
|
| 155 |
+
return quality
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
|
| 159 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
| 160 |
+
|
| 161 |
+
Ref: Making a "Completely Blind" Image Quality Analyzer.
|
| 162 |
+
This implementation could produce almost the same results as the official
|
| 163 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
| 164 |
+
|
| 165 |
+
We use the official params estimated from the pristine dataset.
|
| 166 |
+
We use the recommended block size (96, 96) without overlaps.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
img (ndarray): Input image whose quality needs to be computed.
|
| 170 |
+
The input image must be in range [0, 255] with float/int type.
|
| 171 |
+
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
|
| 172 |
+
If the input order is 'HWC' or 'CHW', it will be converted to gray
|
| 173 |
+
or Y (of YCbCr) image according to the ``convert_to`` argument.
|
| 174 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 175 |
+
pixels are not involved in the metric calculation.
|
| 176 |
+
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
|
| 177 |
+
Default: 'HWC'.
|
| 178 |
+
convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
|
| 179 |
+
Default: 'y'.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
float: NIQE result.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# we use the official params estimated from the pristine dataset.
|
| 186 |
+
niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
|
| 187 |
+
mu_pris_param = niqe_pris_params['mu_pris_param']
|
| 188 |
+
cov_pris_param = niqe_pris_params['cov_pris_param']
|
| 189 |
+
gaussian_window = niqe_pris_params['gaussian_window']
|
| 190 |
+
|
| 191 |
+
img = img.astype(np.float32)
|
| 192 |
+
if input_order != 'HW':
|
| 193 |
+
img = reorder_image(img, input_order=input_order)
|
| 194 |
+
if convert_to == 'y':
|
| 195 |
+
img = to_y_channel(img)
|
| 196 |
+
elif convert_to == 'gray':
|
| 197 |
+
img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
|
| 198 |
+
img = np.squeeze(img)
|
| 199 |
+
|
| 200 |
+
if crop_border != 0:
|
| 201 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border]
|
| 202 |
+
|
| 203 |
+
niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
|
| 204 |
+
|
| 205 |
+
return niqe_result
|
basicsr/metrics/niqe_pris_params.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
|
| 3 |
+
size 11850
|
basicsr/metrics/psnr_ssim.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
| 5 |
+
import skimage.metrics
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def calculate_psnr(img1,
|
| 10 |
+
img2,
|
| 11 |
+
crop_border,
|
| 12 |
+
input_order='HWC',
|
| 13 |
+
test_y_channel=False):
|
| 14 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
| 15 |
+
|
| 16 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
|
| 20 |
+
img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
|
| 21 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 22 |
+
pixels are not involved in the PSNR calculation.
|
| 23 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 24 |
+
Default: 'HWC'.
|
| 25 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
float: psnr result.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
assert img1.shape == img2.shape, (
|
| 32 |
+
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 33 |
+
if input_order not in ['HWC', 'CHW']:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f'Wrong input_order {input_order}. Supported input_orders are '
|
| 36 |
+
'"HWC" and "CHW"')
|
| 37 |
+
if type(img1) == torch.Tensor:
|
| 38 |
+
if len(img1.shape) == 4:
|
| 39 |
+
img1 = img1.squeeze(0)
|
| 40 |
+
img1 = img1.detach().cpu().numpy().transpose(1,2,0)
|
| 41 |
+
if type(img2) == torch.Tensor:
|
| 42 |
+
if len(img2.shape) == 4:
|
| 43 |
+
img2 = img2.squeeze(0)
|
| 44 |
+
img2 = img2.detach().cpu().numpy().transpose(1,2,0)
|
| 45 |
+
|
| 46 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 47 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 48 |
+
img1 = img1.astype(np.float64)
|
| 49 |
+
img2 = img2.astype(np.float64)
|
| 50 |
+
|
| 51 |
+
if crop_border != 0:
|
| 52 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 53 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 54 |
+
|
| 55 |
+
if test_y_channel:
|
| 56 |
+
img1 = to_y_channel(img1)
|
| 57 |
+
img2 = to_y_channel(img2)
|
| 58 |
+
|
| 59 |
+
mse = np.mean((img1 - img2)**2)
|
| 60 |
+
if mse == 0:
|
| 61 |
+
return float('inf')
|
| 62 |
+
max_value = 1. if img1.max() <= 1 else 255.
|
| 63 |
+
return 20. * np.log10(max_value / np.sqrt(mse))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _ssim(img1, img2):
|
| 67 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 68 |
+
|
| 69 |
+
It is called by func:`calculate_ssim`.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 73 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
float: ssim result.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
C1 = (0.01 * 255)**2
|
| 80 |
+
C2 = (0.03 * 255)**2
|
| 81 |
+
|
| 82 |
+
img1 = img1.astype(np.float64)
|
| 83 |
+
img2 = img2.astype(np.float64)
|
| 84 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 85 |
+
window = np.outer(kernel, kernel.transpose())
|
| 86 |
+
|
| 87 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
| 88 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 89 |
+
mu1_sq = mu1**2
|
| 90 |
+
mu2_sq = mu2**2
|
| 91 |
+
mu1_mu2 = mu1 * mu2
|
| 92 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 93 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 94 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 95 |
+
|
| 96 |
+
ssim_map = ((2 * mu1_mu2 + C1) *
|
| 97 |
+
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 98 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 99 |
+
return ssim_map.mean()
|
| 100 |
+
|
| 101 |
+
def prepare_for_ssim(img, k):
|
| 102 |
+
import torch
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
|
| 105 |
+
conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
|
| 106 |
+
conv.weight.requires_grad = False
|
| 107 |
+
conv.weight[:, :, :, :] = 1. / (k * k)
|
| 108 |
+
|
| 109 |
+
img = conv(img)
|
| 110 |
+
|
| 111 |
+
img = img.squeeze(0).squeeze(0)
|
| 112 |
+
img = img[0::k, 0::k]
|
| 113 |
+
return img.detach().cpu().numpy()
|
| 114 |
+
|
| 115 |
+
def prepare_for_ssim_rgb(img, k):
|
| 116 |
+
import torch
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
img = torch.from_numpy(img).float() #HxWx3
|
| 119 |
+
|
| 120 |
+
conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
|
| 121 |
+
conv.weight.requires_grad = False
|
| 122 |
+
conv.weight[:, :, :, :] = 1. / (k * k)
|
| 123 |
+
|
| 124 |
+
new_img = []
|
| 125 |
+
|
| 126 |
+
for i in range(3):
|
| 127 |
+
new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
|
| 128 |
+
|
| 129 |
+
return torch.stack(new_img, dim=2).detach().cpu().numpy()
|
| 130 |
+
|
| 131 |
+
def _3d_gaussian_calculator(img, conv3d):
|
| 132 |
+
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
def _generate_3d_gaussian_kernel():
|
| 136 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 137 |
+
window = np.outer(kernel, kernel.transpose())
|
| 138 |
+
kernel_3 = cv2.getGaussianKernel(11, 1.5)
|
| 139 |
+
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
|
| 140 |
+
conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
|
| 141 |
+
conv3d.weight.requires_grad = False
|
| 142 |
+
conv3d.weight[0, 0, :, :, :] = kernel
|
| 143 |
+
return conv3d
|
| 144 |
+
|
| 145 |
+
def _ssim_3d(img1, img2, max_value):
|
| 146 |
+
assert len(img1.shape) == 3 and len(img2.shape) == 3
|
| 147 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 148 |
+
|
| 149 |
+
It is called by func:`calculate_ssim`.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
|
| 153 |
+
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
float: ssim result.
|
| 157 |
+
"""
|
| 158 |
+
C1 = (0.01 * max_value) ** 2
|
| 159 |
+
C2 = (0.03 * max_value) ** 2
|
| 160 |
+
img1 = img1.astype(np.float64)
|
| 161 |
+
img2 = img2.astype(np.float64)
|
| 162 |
+
|
| 163 |
+
kernel = _generate_3d_gaussian_kernel().cuda()
|
| 164 |
+
|
| 165 |
+
img1 = torch.tensor(img1).float().cuda()
|
| 166 |
+
img2 = torch.tensor(img2).float().cuda()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
mu1 = _3d_gaussian_calculator(img1, kernel)
|
| 170 |
+
mu2 = _3d_gaussian_calculator(img2, kernel)
|
| 171 |
+
|
| 172 |
+
mu1_sq = mu1 ** 2
|
| 173 |
+
mu2_sq = mu2 ** 2
|
| 174 |
+
mu1_mu2 = mu1 * mu2
|
| 175 |
+
sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
|
| 176 |
+
sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
|
| 177 |
+
sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
|
| 178 |
+
|
| 179 |
+
ssim_map = ((2 * mu1_mu2 + C1) *
|
| 180 |
+
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 181 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 182 |
+
return float(ssim_map.mean())
|
| 183 |
+
|
| 184 |
+
def _ssim_cly(img1, img2):
|
| 185 |
+
assert len(img1.shape) == 2 and len(img2.shape) == 2
|
| 186 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 187 |
+
|
| 188 |
+
It is called by func:`calculate_ssim`.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 192 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
float: ssim result.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
C1 = (0.01 * 255)**2
|
| 199 |
+
C2 = (0.03 * 255)**2
|
| 200 |
+
img1 = img1.astype(np.float64)
|
| 201 |
+
img2 = img2.astype(np.float64)
|
| 202 |
+
|
| 203 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 204 |
+
# print(kernel)
|
| 205 |
+
window = np.outer(kernel, kernel.transpose())
|
| 206 |
+
|
| 207 |
+
bt = cv2.BORDER_REPLICATE
|
| 208 |
+
|
| 209 |
+
mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
|
| 210 |
+
mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
|
| 211 |
+
|
| 212 |
+
mu1_sq = mu1**2
|
| 213 |
+
mu2_sq = mu2**2
|
| 214 |
+
mu1_mu2 = mu1 * mu2
|
| 215 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
|
| 216 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
|
| 217 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
|
| 218 |
+
|
| 219 |
+
ssim_map = ((2 * mu1_mu2 + C1) *
|
| 220 |
+
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 221 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 222 |
+
return ssim_map.mean()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def calculate_ssim(img1,
|
| 226 |
+
img2,
|
| 227 |
+
crop_border,
|
| 228 |
+
input_order='HWC',
|
| 229 |
+
test_y_channel=False):
|
| 230 |
+
"""Calculate SSIM (structural similarity).
|
| 231 |
+
|
| 232 |
+
Ref:
|
| 233 |
+
Image quality assessment: From error visibility to structural similarity
|
| 234 |
+
|
| 235 |
+
The results are the same as that of the official released MATLAB code in
|
| 236 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
| 237 |
+
|
| 238 |
+
For three-channel images, SSIM is calculated for each channel and then
|
| 239 |
+
averaged.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
img1 (ndarray): Images with range [0, 255].
|
| 243 |
+
img2 (ndarray): Images with range [0, 255].
|
| 244 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 245 |
+
pixels are not involved in the SSIM calculation.
|
| 246 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 247 |
+
Default: 'HWC'.
|
| 248 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
float: ssim result.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
assert img1.shape == img2.shape, (
|
| 255 |
+
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 256 |
+
if input_order not in ['HWC', 'CHW']:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f'Wrong input_order {input_order}. Supported input_orders are '
|
| 259 |
+
'"HWC" and "CHW"')
|
| 260 |
+
|
| 261 |
+
if type(img1) == torch.Tensor:
|
| 262 |
+
if len(img1.shape) == 4:
|
| 263 |
+
img1 = img1.squeeze(0)
|
| 264 |
+
img1 = img1.detach().cpu().numpy().transpose(1,2,0)
|
| 265 |
+
if type(img2) == torch.Tensor:
|
| 266 |
+
if len(img2.shape) == 4:
|
| 267 |
+
img2 = img2.squeeze(0)
|
| 268 |
+
img2 = img2.detach().cpu().numpy().transpose(1,2,0)
|
| 269 |
+
|
| 270 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 271 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 272 |
+
|
| 273 |
+
img1 = img1.astype(np.float64)
|
| 274 |
+
img2 = img2.astype(np.float64)
|
| 275 |
+
|
| 276 |
+
if crop_border != 0:
|
| 277 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 278 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 279 |
+
|
| 280 |
+
if test_y_channel:
|
| 281 |
+
img1 = to_y_channel(img1)
|
| 282 |
+
img2 = to_y_channel(img2)
|
| 283 |
+
return _ssim_cly(img1[..., 0], img2[..., 0])
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
ssims = []
|
| 287 |
+
# ssims_before = []
|
| 288 |
+
|
| 289 |
+
# skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
|
| 290 |
+
# print('.._skimage',
|
| 291 |
+
# skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
|
| 292 |
+
max_value = 1 if img1.max() <= 1 else 255
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
final_ssim = _ssim_3d(img1, img2, max_value)
|
| 295 |
+
ssims.append(final_ssim)
|
| 296 |
+
|
| 297 |
+
# for i in range(img1.shape[2]):
|
| 298 |
+
# ssims_before.append(_ssim(img1, img2))
|
| 299 |
+
|
| 300 |
+
# print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
|
| 301 |
+
# ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
|
| 302 |
+
|
| 303 |
+
return np.array(ssims).mean()
|
basicsr/models/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
basicsr/models/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from os import path as osp
|
| 3 |
+
|
| 4 |
+
from basicsr.utils import get_root_logger, scandir
|
| 5 |
+
|
| 6 |
+
# automatically scan and import model modules
|
| 7 |
+
# scan all the files under the 'models' folder and collect files ending with
|
| 8 |
+
# '_model.py'
|
| 9 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 10 |
+
model_filenames = [
|
| 11 |
+
osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
|
| 12 |
+
if v.endswith('_model.py')
|
| 13 |
+
]
|
| 14 |
+
# import all the model modules
|
| 15 |
+
_model_modules = [
|
| 16 |
+
importlib.import_module(f'basicsr.models.{file_name}')
|
| 17 |
+
for file_name in model_filenames
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_model(opt):
|
| 22 |
+
"""Create model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
opt (dict): Configuration. It constains:
|
| 26 |
+
model_type (str): Model type.
|
| 27 |
+
"""
|
| 28 |
+
model_type = opt['model_type']
|
| 29 |
+
|
| 30 |
+
# dynamic instantiation
|
| 31 |
+
for module in _model_modules:
|
| 32 |
+
model_cls = getattr(module, model_type, None)
|
| 33 |
+
if model_cls is not None:
|
| 34 |
+
break
|
| 35 |
+
if model_cls is None:
|
| 36 |
+
raise ValueError(f'Model {model_type} is not found.')
|
| 37 |
+
|
| 38 |
+
model = model_cls(opt)
|
| 39 |
+
|
| 40 |
+
logger = get_root_logger()
|
| 41 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
| 42 |
+
return model
|
basicsr/models/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
basicsr/models/__pycache__/base_model.cpython-37.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc
ADDED
|
Binary file (9.52 kB). View file
|
|
|
basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc
ADDED
|
Binary file (8.91 kB). View file
|
|
|
basicsr/models/archs/FPro_arch.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from pdb import set_trace as stx
|
| 6 |
+
import numbers
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
##########################################################################
|
| 11 |
+
## Layer Norm
|
| 12 |
+
|
| 13 |
+
def to_3d(x):
|
| 14 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 15 |
+
|
| 16 |
+
def to_4d(x,h,w):
|
| 17 |
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
| 18 |
+
|
| 19 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 20 |
+
def __init__(self, normalized_shape):
|
| 21 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 22 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 23 |
+
normalized_shape = (normalized_shape,)
|
| 24 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 25 |
+
|
| 26 |
+
assert len(normalized_shape) == 1
|
| 27 |
+
|
| 28 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 29 |
+
self.normalized_shape = normalized_shape
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 33 |
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
| 34 |
+
|
| 35 |
+
class WithBias_LayerNorm(nn.Module):
|
| 36 |
+
def __init__(self, normalized_shape):
|
| 37 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 38 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 39 |
+
normalized_shape = (normalized_shape,)
|
| 40 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 41 |
+
|
| 42 |
+
assert len(normalized_shape) == 1
|
| 43 |
+
|
| 44 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 45 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 46 |
+
self.normalized_shape = normalized_shape
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
mu = x.mean(-1, keepdim=True)
|
| 50 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 51 |
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class LayerNorm(nn.Module):
|
| 55 |
+
def __init__(self, dim, LayerNorm_type):
|
| 56 |
+
super(LayerNorm, self).__init__()
|
| 57 |
+
if LayerNorm_type =='BiasFree':
|
| 58 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 59 |
+
else:
|
| 60 |
+
self.body = WithBias_LayerNorm(dim)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
h, w = x.shape[-2:]
|
| 64 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
##########################################################################
|
| 69 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
| 70 |
+
class FeedForward(nn.Module):
|
| 71 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 72 |
+
super(FeedForward, self).__init__()
|
| 73 |
+
|
| 74 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 75 |
+
|
| 76 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 77 |
+
|
| 78 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
| 79 |
+
|
| 80 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
x = self.project_in(x)
|
| 84 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 85 |
+
x = F.gelu(x1) * x2
|
| 86 |
+
x = self.project_out(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
##########################################################################
|
| 92 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
| 93 |
+
class Attention(nn.Module):
|
| 94 |
+
def __init__(self, dim, num_heads, bias):
|
| 95 |
+
super(Attention, self).__init__()
|
| 96 |
+
self.num_heads = num_heads
|
| 97 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
| 98 |
+
|
| 99 |
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
| 100 |
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
| 101 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
b,c,h,w = x.shape
|
| 107 |
+
|
| 108 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
| 109 |
+
q,k,v = qkv.chunk(3, dim=1)
|
| 110 |
+
|
| 111 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 112 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 113 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 114 |
+
|
| 115 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
| 116 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
| 117 |
+
|
| 118 |
+
attn = (q @ k.transpose(-2, -1).contiguous()) * self.temperature
|
| 119 |
+
attn = attn.softmax(dim=-1)
|
| 120 |
+
|
| 121 |
+
out = (attn @ v)
|
| 122 |
+
|
| 123 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
| 124 |
+
|
| 125 |
+
out = self.project_out(out)
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
##########################################################################
|
| 131 |
+
class TransformerBlock(nn.Module):
|
| 132 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, isAtt):
|
| 133 |
+
super(TransformerBlock, self).__init__()
|
| 134 |
+
self.isAtt = isAtt
|
| 135 |
+
if self.isAtt:
|
| 136 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
| 137 |
+
self.attn = Attention(dim, num_heads, bias)
|
| 138 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
| 139 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
if self.isAtt:
|
| 143 |
+
x = x + self.attn(self.norm1(x))
|
| 144 |
+
x = x + self.ffn(self.norm2(x))
|
| 145 |
+
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
##########################################################################
|
| 151 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 152 |
+
class OverlapPatchEmbed(nn.Module):
|
| 153 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 154 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 155 |
+
|
| 156 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
x = self.proj(x)
|
| 160 |
+
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
########### window operation#############
|
| 164 |
+
def window_partition(x, win_size, dilation_rate=1):
|
| 165 |
+
B, H, W, C = x.shape
|
| 166 |
+
if dilation_rate !=1:
|
| 167 |
+
x = x.permute(0,3,1,2) # B, C, H, W
|
| 168 |
+
assert type(dilation_rate) is int, 'dilation_rate should be a int'
|
| 169 |
+
x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
|
| 170 |
+
windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
|
| 171 |
+
windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C
|
| 172 |
+
else:
|
| 173 |
+
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
|
| 174 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
|
| 175 |
+
return windows
|
| 176 |
+
|
| 177 |
+
def window_reverse(windows, win_size, H, W, dilation_rate=1):
|
| 178 |
+
# B' ,Wh ,Ww ,C
|
| 179 |
+
B = int(windows.shape[0] / (H * W / win_size / win_size))
|
| 180 |
+
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
|
| 181 |
+
if dilation_rate !=1:
|
| 182 |
+
x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
|
| 183 |
+
x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size)
|
| 184 |
+
else:
|
| 185 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
class lowFrequencyPromptFusion(nn.Module):
|
| 189 |
+
def __init__(self, dim, dim_bak, num_heads,win_size=8, bias=False):
|
| 190 |
+
super(lowFrequencyPromptFusion, self).__init__()
|
| 191 |
+
self.num_heads = num_heads
|
| 192 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
| 193 |
+
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 194 |
+
self.ap_kv = nn.AdaptiveAvgPool2d(1)
|
| 195 |
+
self.kv = nn.Conv2d(dim_bak, dim * 2, kernel_size=1, bias=bias)
|
| 196 |
+
|
| 197 |
+
self.project_out = nn.Conv2d( dim, dim, kernel_size=1, bias=bias)
|
| 198 |
+
|
| 199 |
+
def forward(self, feature, prompt_feature):
|
| 200 |
+
b, c1,h,w = feature.shape
|
| 201 |
+
_, c2,_,_ = prompt_feature.shape
|
| 202 |
+
|
| 203 |
+
query = self.q(feature).reshape(b, h * w, self.num_heads, c1 // self.num_heads).permute(0, 2, 1, 3).contiguous()
|
| 204 |
+
|
| 205 |
+
prompt_feature = self.ap_kv(prompt_feature)#.reshape(b, c2, -1).permute(0, 2, 1)
|
| 206 |
+
key_value = self.kv(prompt_feature).reshape(b, 2*c1, -1).permute(0, 2, 1).contiguous().reshape(b, -1, 2, self.num_heads, c1 // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
|
| 207 |
+
key, value = key_value[0], key_value[1]
|
| 208 |
+
|
| 209 |
+
attn = (query @ key.transpose(-2, -1).contiguous()) * self.temperature
|
| 210 |
+
attn = attn.softmax(dim=-1)
|
| 211 |
+
|
| 212 |
+
out = (attn @ value)
|
| 213 |
+
|
| 214 |
+
out = rearrange(out, 'b head (h w) c -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
| 215 |
+
out = self.project_out(out)
|
| 216 |
+
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
class LinearProjection(nn.Module):
|
| 220 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True, isQuery = True):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.isQuery =isQuery
|
| 223 |
+
inner_dim = dim_head * heads
|
| 224 |
+
self.heads = heads
|
| 225 |
+
if self.isQuery:
|
| 226 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = bias)
|
| 227 |
+
else:
|
| 228 |
+
self.to_kv = nn.Linear(dim, 2*inner_dim, bias = bias)
|
| 229 |
+
self.dim = dim
|
| 230 |
+
self.inner_dim = inner_dim
|
| 231 |
+
|
| 232 |
+
def forward(self, x, attn_kv=None):
|
| 233 |
+
B_, N, C = x.shape
|
| 234 |
+
if attn_kv is not None:
|
| 235 |
+
attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1)
|
| 236 |
+
else:
|
| 237 |
+
attn_kv = x
|
| 238 |
+
N_kv = attn_kv.size(1)
|
| 239 |
+
if self.isQuery:
|
| 240 |
+
q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
|
| 241 |
+
q = q[0]
|
| 242 |
+
return q
|
| 243 |
+
else:
|
| 244 |
+
C = self.inner_dim
|
| 245 |
+
kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
|
| 246 |
+
k, v = kv[0], kv[1]
|
| 247 |
+
return k,v
|
| 248 |
+
|
| 249 |
+
class highFrequencyPromptFusion(nn.Module):
|
| 250 |
+
def __init__(self, dim, dim_bak,win_size, num_heads, qkv_bias=True, qk_scale=None, bias=False):
|
| 251 |
+
super(highFrequencyPromptFusion, self).__init__()
|
| 252 |
+
self.num_heads = num_heads
|
| 253 |
+
self.win_size = win_size # Wh, Ww
|
| 254 |
+
head_dim = dim // num_heads
|
| 255 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 256 |
+
|
| 257 |
+
self.to_q = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias,isQuery=True)
|
| 258 |
+
self.to_kv = LinearProjection(dim_bak,num_heads,dim//num_heads,bias=qkv_bias,isQuery=False)
|
| 259 |
+
|
| 260 |
+
self.kv_dwconv = nn.Conv2d(dim_bak , dim_bak, kernel_size=3, stride=1, padding=1, groups=dim_bak, bias=bias)
|
| 261 |
+
|
| 262 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 263 |
+
|
| 264 |
+
self.project_out = nn.Linear(dim, dim)
|
| 265 |
+
|
| 266 |
+
def forward(self, query_feature, key_value_feature):
|
| 267 |
+
|
| 268 |
+
b,c,h,w = query_feature.shape
|
| 269 |
+
_,c_2,_,_ = key_value_feature.shape
|
| 270 |
+
|
| 271 |
+
key_value_feature = self.kv_dwconv(key_value_feature)
|
| 272 |
+
|
| 273 |
+
# partition windows
|
| 274 |
+
query_feature = rearrange(query_feature, ' b c1 h w -> b h w c1 ', h=h, w=w)
|
| 275 |
+
query_feature_windows = window_partition(query_feature, self.win_size) # nW*B, win_size, win_size, C N*C->C
|
| 276 |
+
query_feature_windows = query_feature_windows.view(-1, self.win_size * self.win_size, c) # nW*B, win_size*win_size, C
|
| 277 |
+
|
| 278 |
+
key_value_feature = rearrange(key_value_feature, ' b c2 h w -> b h w c2 ', h=h, w=w)
|
| 279 |
+
key_value_feature_windows = window_partition(key_value_feature, self.win_size) # nW*B, win_size, win_size, C N*C->C
|
| 280 |
+
key_value_feature_windows = key_value_feature_windows.view(-1, self.win_size * self.win_size, c_2) # nW*B, win_size*win_size, C
|
| 281 |
+
|
| 282 |
+
B_, N, C = query_feature_windows.shape
|
| 283 |
+
|
| 284 |
+
query = self.to_q(query_feature_windows)
|
| 285 |
+
query = query * self.scale
|
| 286 |
+
|
| 287 |
+
key,value = self.to_kv(key_value_feature_windows)
|
| 288 |
+
attn = (query @ key.transpose(-2, -1).contiguous())
|
| 289 |
+
attn = attn.softmax(dim=-1)
|
| 290 |
+
|
| 291 |
+
out = (attn @ value).transpose(1, 2).contiguous().reshape(B_, N, C)
|
| 292 |
+
|
| 293 |
+
out = self.project_out(out)
|
| 294 |
+
|
| 295 |
+
# merge windows
|
| 296 |
+
attn_windows = out.view(-1, self.win_size, self.win_size, C)
|
| 297 |
+
attn_windows = window_reverse(attn_windows, self.win_size, h, w) # B H' W' C
|
| 298 |
+
return rearrange(attn_windows, 'b h w c -> b c h w', h=h, w=w)
|
| 299 |
+
|
| 300 |
+
##########################################################################
|
| 301 |
+
## channel dynamic filters
|
| 302 |
+
class dynamic_filter_channel(nn.Module):
|
| 303 |
+
def __init__(self, inchannels, kernel_size=3, stride=1, group=8):
|
| 304 |
+
super(dynamic_filter_channel, self).__init__()
|
| 305 |
+
self.stride = stride
|
| 306 |
+
self.kernel_size = kernel_size
|
| 307 |
+
self.group = group
|
| 308 |
+
|
| 309 |
+
self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
|
| 310 |
+
self.conv_gate = nn.Conv2d(group*kernel_size**2, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
|
| 311 |
+
self.act_gate = nn.Sigmoid()
|
| 312 |
+
self.bn = nn.BatchNorm2d(group*kernel_size**2)
|
| 313 |
+
self.act = nn.Softmax(dim=-2)
|
| 314 |
+
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
|
| 315 |
+
|
| 316 |
+
self.pad = nn.ReflectionPad2d(kernel_size//2)
|
| 317 |
+
|
| 318 |
+
self.ap_1 = nn.AdaptiveAvgPool2d((1, 1))
|
| 319 |
+
#self.ap_2 = nn.AdaptiveMaxPool2d((1, 1))
|
| 320 |
+
|
| 321 |
+
def forward(self, x):
|
| 322 |
+
identity_input = x
|
| 323 |
+
low_filter1 = self.ap_1(x)
|
| 324 |
+
#low_filter2 = self.ap_2(x)
|
| 325 |
+
low_filter = self.conv(low_filter1)
|
| 326 |
+
low_filter = low_filter * self.act_gate(self.conv_gate(low_filter))
|
| 327 |
+
low_filter = self.bn(low_filter)
|
| 328 |
+
|
| 329 |
+
n, c, h, w = x.shape
|
| 330 |
+
x = F.unfold(self.pad(x), kernel_size=self.kernel_size).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w)
|
| 331 |
+
|
| 332 |
+
n,c1,p,q = low_filter.shape
|
| 333 |
+
low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2)
|
| 334 |
+
|
| 335 |
+
low_filter = self.act(low_filter)
|
| 336 |
+
# print('low_filter size',low_filter.shape)
|
| 337 |
+
# print('low_filter n,c1,p,q',n,c1,p,q)
|
| 338 |
+
|
| 339 |
+
low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w)
|
| 340 |
+
|
| 341 |
+
out_high = identity_input - low_part
|
| 342 |
+
return low_part, out_high
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class frequenctSpecificPromptGenetator(nn.Module):
|
| 346 |
+
def __init__(self, dim=3,h=128,w=65, flag_highF=True):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.flag_highF = flag_highF
|
| 349 |
+
k_size = 3
|
| 350 |
+
if flag_highF:
|
| 351 |
+
w = (w - 1) * 2
|
| 352 |
+
self.w = w
|
| 353 |
+
self.h = h
|
| 354 |
+
self.weight = nn.Parameter(torch.randn(1,dim, h, w, dtype=torch.float32) * 0.02)
|
| 355 |
+
self.body = nn.Sequential(nn.Conv2d(dim, dim, (1,k_size), padding=(0, k_size//2), groups=dim),
|
| 356 |
+
nn.Conv2d(dim, dim, (k_size,1), padding=(k_size//2, 0), groups=dim),
|
| 357 |
+
nn.GELU())
|
| 358 |
+
else:
|
| 359 |
+
self.complex_weight = nn.Parameter(torch.randn(1,dim, h, w, 2, dtype=torch.float32) * 0.02)
|
| 360 |
+
self.body = nn.Sequential(nn.Conv2d(2*dim,2*dim,kernel_size=1,stride=1),
|
| 361 |
+
nn.GELU(),
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def forward(self, ffm, H, W):
|
| 366 |
+
if self.flag_highF:
|
| 367 |
+
ffm = F.interpolate(ffm, size=(H, W), mode='bilinear')
|
| 368 |
+
y_att = self.body(ffm)
|
| 369 |
+
|
| 370 |
+
y_f = y_att * ffm
|
| 371 |
+
y = y_f * self.weight
|
| 372 |
+
|
| 373 |
+
else:
|
| 374 |
+
ffm = F.interpolate(ffm, size=(H, W), mode='bicubic')
|
| 375 |
+
y = torch.fft.rfft2(ffm.to(torch.float32).cuda())
|
| 376 |
+
y_imag = y.imag
|
| 377 |
+
y_real = y.real
|
| 378 |
+
y_f = torch.cat([y_real, y_imag], dim=1)
|
| 379 |
+
weight = torch.complex(self.complex_weight[..., 0],self.complex_weight[..., 1])
|
| 380 |
+
y_att = self.body(y_f)
|
| 381 |
+
y_f = y_f * y_att
|
| 382 |
+
y_real, y_imag = torch.chunk(y_f, 2, dim=1)
|
| 383 |
+
y = torch.complex(y_real, y_imag)
|
| 384 |
+
y = y * weight
|
| 385 |
+
y = torch.fft.irfft2(y, s=(H, W))
|
| 386 |
+
|
| 387 |
+
return y
|
| 388 |
+
|
| 389 |
+
##########################################################################
|
| 390 |
+
## PromptModule
|
| 391 |
+
class PromptModule(nn.Module):
|
| 392 |
+
def __init__(self, basic_dim=32, dim=32, input_resolution=128):
|
| 393 |
+
super().__init__()
|
| 394 |
+
h = input_resolution
|
| 395 |
+
w = input_resolution//2 +1
|
| 396 |
+
self.simple_Fusion = nn.Conv2d(2*dim,dim,kernel_size=1,stride=1)
|
| 397 |
+
|
| 398 |
+
self.FSPG_high = frequenctSpecificPromptGenetator(basic_dim,h,w, flag_highF=True)
|
| 399 |
+
self.FSPG_low = frequenctSpecificPromptGenetator(basic_dim,h,w, flag_highF=False)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
self.modulator_hi = highFrequencyPromptFusion(dim, basic_dim, win_size=8, num_heads=2, bias=False)
|
| 403 |
+
self.modulator_lo = lowFrequencyPromptFusion(dim, basic_dim, win_size=8, num_heads=2, bias=False)
|
| 404 |
+
def forward(self, low_part, out_high , x):
|
| 405 |
+
b,c,h,w = x.shape
|
| 406 |
+
|
| 407 |
+
y_h = self.FSPG_high(out_high, h, w)
|
| 408 |
+
y_l = self.FSPG_low(low_part, h, w)
|
| 409 |
+
|
| 410 |
+
y_h = self.modulator_hi(x,y_h)
|
| 411 |
+
y_l = self.modulator_lo(x,y_l)
|
| 412 |
+
|
| 413 |
+
x = self.simple_Fusion(torch.cat([y_h,y_l], dim=1))
|
| 414 |
+
|
| 415 |
+
return x
|
| 416 |
+
|
| 417 |
+
## PromptModule
|
| 418 |
+
class splitFrequencyModule(nn.Module):
|
| 419 |
+
def __init__(self, basic_dim=32, dim=32, input_resolution=128):
|
| 420 |
+
super().__init__()
|
| 421 |
+
|
| 422 |
+
self.dyna_channel = dynamic_filter_channel(inchannels=basic_dim)
|
| 423 |
+
def forward(self, F_low ):
|
| 424 |
+
_,c_basic,h_ori, w_ori = F_low.shape
|
| 425 |
+
|
| 426 |
+
low_part, out_high = self.dyna_channel(F_low)
|
| 427 |
+
|
| 428 |
+
return low_part, out_high
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
##########################################################################
|
| 432 |
+
## Resizing modules
|
| 433 |
+
class Downsample(nn.Module):
|
| 434 |
+
def __init__(self, n_feat):
|
| 435 |
+
super(Downsample, self).__init__()
|
| 436 |
+
|
| 437 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 438 |
+
nn.PixelUnshuffle(2))
|
| 439 |
+
|
| 440 |
+
def forward(self, x):
|
| 441 |
+
return self.body(x)
|
| 442 |
+
|
| 443 |
+
class Upsample(nn.Module):
|
| 444 |
+
def __init__(self, n_feat):
|
| 445 |
+
super(Upsample, self).__init__()
|
| 446 |
+
|
| 447 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 448 |
+
nn.PixelShuffle(2))
|
| 449 |
+
|
| 450 |
+
def forward(self, x):
|
| 451 |
+
return self.body(x)
|
| 452 |
+
|
| 453 |
+
##########################################################################
|
| 454 |
+
##---------- FPro -----------------------
|
| 455 |
+
class FPro(nn.Module):
|
| 456 |
+
def __init__(self,
|
| 457 |
+
inp_channels=3,
|
| 458 |
+
out_channels=3,
|
| 459 |
+
dim = 48,
|
| 460 |
+
num_blocks = [4,6,6,8],
|
| 461 |
+
num_refinement_blocks = 4,
|
| 462 |
+
heads = [1,2,4,8],
|
| 463 |
+
ffn_expansion_factor = 2.66,
|
| 464 |
+
bias = False,
|
| 465 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 466 |
+
dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
| 467 |
+
):
|
| 468 |
+
|
| 469 |
+
super(FPro, self).__init__()
|
| 470 |
+
|
| 471 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 472 |
+
|
| 473 |
+
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=False) for i in range(num_blocks[0])])
|
| 474 |
+
|
| 475 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
| 476 |
+
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=False) for i in range(num_blocks[1])])
|
| 477 |
+
|
| 478 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
| 479 |
+
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=False) for i in range(num_blocks[2])])
|
| 480 |
+
|
| 481 |
+
self.splitFre =splitFrequencyModule(basic_dim= dim,dim=int(dim*2**2),input_resolution=32)
|
| 482 |
+
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_blocks[2])])
|
| 483 |
+
self.prompt_d3 = PromptModule(basic_dim= dim,dim=int(dim*2**2),input_resolution=64)
|
| 484 |
+
|
| 485 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
| 486 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
| 487 |
+
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_blocks[1])])
|
| 488 |
+
self.prompt_d2 = PromptModule(basic_dim= dim,dim=int(dim*2**1),input_resolution=128)
|
| 489 |
+
|
| 490 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
| 491 |
+
|
| 492 |
+
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_blocks[0])])
|
| 493 |
+
self.prompt_d1 = PromptModule(basic_dim= dim,dim=int(dim*2**1),input_resolution=256)
|
| 494 |
+
|
| 495 |
+
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_refinement_blocks)])
|
| 496 |
+
self.prompt_r = PromptModule(basic_dim= dim,dim=int(dim*2**1),input_resolution=256)
|
| 497 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
| 498 |
+
self.dual_pixel_task = dual_pixel_task
|
| 499 |
+
if self.dual_pixel_task:
|
| 500 |
+
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
|
| 501 |
+
###########################
|
| 502 |
+
|
| 503 |
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 504 |
+
|
| 505 |
+
def forward(self, inp_img):
|
| 506 |
+
|
| 507 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
| 508 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 509 |
+
|
| 510 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 511 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 512 |
+
|
| 513 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 514 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
out_dec_level3 = self.decoder_level3(out_enc_level3)
|
| 518 |
+
low_part, out_high = self.splitFre(inp_enc_level1)
|
| 519 |
+
out_dec_level3 = self.prompt_d3(low_part, out_high,out_dec_level3) + out_dec_level3
|
| 520 |
+
|
| 521 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
| 522 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
| 523 |
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
| 524 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
| 525 |
+
out_dec_level2 = self.prompt_d2(low_part, out_high,out_dec_level2) + out_dec_level2
|
| 526 |
+
|
| 527 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
| 528 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
| 529 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
| 530 |
+
out_dec_level1 = self.prompt_d1(low_part, out_high,out_dec_level1) + out_dec_level1
|
| 531 |
+
|
| 532 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
| 533 |
+
out_dec_level1 = self.prompt_r(low_part, out_high,out_dec_level1) + out_dec_level1
|
| 534 |
+
|
| 535 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
| 536 |
+
if self.dual_pixel_task:
|
| 537 |
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
| 538 |
+
out_dec_level1 = self.output(out_dec_level1)
|
| 539 |
+
###########################
|
| 540 |
+
else:
|
| 541 |
+
out_dec_level1 = self.output(out_dec_level1) + inp_img
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
return out_dec_level1
|
| 545 |
+
|
basicsr/models/archs/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from os import path as osp
|
| 3 |
+
|
| 4 |
+
from basicsr.utils import scandir
|
| 5 |
+
|
| 6 |
+
# automatically scan and import arch modules
|
| 7 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
| 8 |
+
# '_arch.py'
|
| 9 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
| 10 |
+
arch_filenames = [
|
| 11 |
+
osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
|
| 12 |
+
if v.endswith('_arch.py')
|
| 13 |
+
]
|
| 14 |
+
# import all the arch modules
|
| 15 |
+
_arch_modules = [
|
| 16 |
+
importlib.import_module(f'basicsr.models.archs.{file_name}')
|
| 17 |
+
for file_name in arch_filenames
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dynamic_instantiation(modules, cls_type, opt):
|
| 22 |
+
"""Dynamically instantiate class.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
modules (list[importlib modules]): List of modules from importlib
|
| 26 |
+
files.
|
| 27 |
+
cls_type (str): Class type.
|
| 28 |
+
opt (dict): Class initialization kwargs.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
class: Instantiated class.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
for module in modules:
|
| 35 |
+
cls_ = getattr(module, cls_type, None)
|
| 36 |
+
if cls_ is not None:
|
| 37 |
+
break
|
| 38 |
+
if cls_ is None:
|
| 39 |
+
raise ValueError(f'{cls_type} is not found.')
|
| 40 |
+
return cls_(**opt)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def define_network(opt):
|
| 44 |
+
network_type = opt.pop('type')
|
| 45 |
+
net = dynamic_instantiation(_arch_modules, network_type, opt)
|
| 46 |
+
return net
|
basicsr/models/archs/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc
ADDED
|
Binary file (7.17 kB). View file
|
|
|
basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc
ADDED
|
Binary file (6.01 kB). View file
|
|
|
basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc
ADDED
|
Binary file (6.42 kB). View file
|
|
|
basicsr/models/archs/arch_util.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torch.nn import init as init
|
| 6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 7 |
+
|
| 8 |
+
from basicsr.utils import get_root_logger
|
| 9 |
+
|
| 10 |
+
# try:
|
| 11 |
+
# from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
|
| 12 |
+
# modulated_deform_conv)
|
| 13 |
+
# except ImportError:
|
| 14 |
+
# # print('Cannot import dcn. Ignore this warning if dcn is not used. '
|
| 15 |
+
# # 'Otherwise install BasicSR with compiling dcn.')
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
| 20 |
+
"""Initialize network weights.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
| 24 |
+
scale (float): Scale initialized weights, especially for residual
|
| 25 |
+
blocks. Default: 1.
|
| 26 |
+
bias_fill (float): The value to fill bias. Default: 0
|
| 27 |
+
kwargs (dict): Other arguments for initialization function.
|
| 28 |
+
"""
|
| 29 |
+
if not isinstance(module_list, list):
|
| 30 |
+
module_list = [module_list]
|
| 31 |
+
for module in module_list:
|
| 32 |
+
for m in module.modules():
|
| 33 |
+
if isinstance(m, nn.Conv2d):
|
| 34 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 35 |
+
m.weight.data *= scale
|
| 36 |
+
if m.bias is not None:
|
| 37 |
+
m.bias.data.fill_(bias_fill)
|
| 38 |
+
elif isinstance(m, nn.Linear):
|
| 39 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 40 |
+
m.weight.data *= scale
|
| 41 |
+
if m.bias is not None:
|
| 42 |
+
m.bias.data.fill_(bias_fill)
|
| 43 |
+
elif isinstance(m, _BatchNorm):
|
| 44 |
+
init.constant_(m.weight, 1)
|
| 45 |
+
if m.bias is not None:
|
| 46 |
+
m.bias.data.fill_(bias_fill)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
| 50 |
+
"""Make layers by stacking the same blocks.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
basic_block (nn.module): nn.module class for basic block.
|
| 54 |
+
num_basic_block (int): number of blocks.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
| 58 |
+
"""
|
| 59 |
+
layers = []
|
| 60 |
+
for _ in range(num_basic_block):
|
| 61 |
+
layers.append(basic_block(**kwarg))
|
| 62 |
+
return nn.Sequential(*layers)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ResidualBlockNoBN(nn.Module):
|
| 66 |
+
"""Residual block without BN.
|
| 67 |
+
|
| 68 |
+
It has a style of:
|
| 69 |
+
---Conv-ReLU-Conv-+-
|
| 70 |
+
|________________|
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
num_feat (int): Channel number of intermediate features.
|
| 74 |
+
Default: 64.
|
| 75 |
+
res_scale (float): Residual scale. Default: 1.
|
| 76 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
| 77 |
+
otherwise, use default_init_weights. Default: False.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
| 81 |
+
super(ResidualBlockNoBN, self).__init__()
|
| 82 |
+
self.res_scale = res_scale
|
| 83 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 84 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 85 |
+
self.relu = nn.ReLU(inplace=True)
|
| 86 |
+
|
| 87 |
+
if not pytorch_init:
|
| 88 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
identity = x
|
| 92 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
| 93 |
+
return identity + out * self.res_scale
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Upsample(nn.Sequential):
|
| 97 |
+
"""Upsample module.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
| 101 |
+
num_feat (int): Channel number of intermediate features.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, scale, num_feat):
|
| 105 |
+
m = []
|
| 106 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
| 107 |
+
for _ in range(int(math.log(scale, 2))):
|
| 108 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
| 109 |
+
m.append(nn.PixelShuffle(2))
|
| 110 |
+
elif scale == 3:
|
| 111 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
| 112 |
+
m.append(nn.PixelShuffle(3))
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f'scale {scale} is not supported. '
|
| 115 |
+
'Supported scales: 2^n and 3.')
|
| 116 |
+
super(Upsample, self).__init__(*m)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def flow_warp(x,
|
| 120 |
+
flow,
|
| 121 |
+
interp_mode='bilinear',
|
| 122 |
+
padding_mode='zeros',
|
| 123 |
+
align_corners=True):
|
| 124 |
+
"""Warp an image or feature map with optical flow.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
| 128 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
| 129 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
| 130 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
| 131 |
+
Default: 'zeros'.
|
| 132 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
| 133 |
+
align_corners=True. After pytorch 1.3, the default value is
|
| 134 |
+
align_corners=False. Here, we use the True as default.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Tensor: Warped image or feature map.
|
| 138 |
+
"""
|
| 139 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
| 140 |
+
_, _, h, w = x.size()
|
| 141 |
+
# create mesh grid
|
| 142 |
+
grid_y, grid_x = torch.meshgrid(
|
| 143 |
+
torch.arange(0, h).type_as(x),
|
| 144 |
+
torch.arange(0, w).type_as(x))
|
| 145 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
| 146 |
+
grid.requires_grad = False
|
| 147 |
+
|
| 148 |
+
vgrid = grid + flow
|
| 149 |
+
# scale grid to [-1,1]
|
| 150 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
| 151 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
| 152 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
| 153 |
+
output = F.grid_sample(
|
| 154 |
+
x,
|
| 155 |
+
vgrid_scaled,
|
| 156 |
+
mode=interp_mode,
|
| 157 |
+
padding_mode=padding_mode,
|
| 158 |
+
align_corners=align_corners)
|
| 159 |
+
|
| 160 |
+
# TODO, what if align_corners=False
|
| 161 |
+
return output
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def resize_flow(flow,
|
| 165 |
+
size_type,
|
| 166 |
+
sizes,
|
| 167 |
+
interp_mode='bilinear',
|
| 168 |
+
align_corners=False):
|
| 169 |
+
"""Resize a flow according to ratio or shape.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
| 173 |
+
size_type (str): 'ratio' or 'shape'.
|
| 174 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
| 175 |
+
shape.
|
| 176 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
| 177 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
| 178 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
| 179 |
+
ratio > 1.0).
|
| 180 |
+
2) The order of output_size should be [out_h, out_w].
|
| 181 |
+
interp_mode (str): The mode of interpolation for resizing.
|
| 182 |
+
Default: 'bilinear'.
|
| 183 |
+
align_corners (bool): Whether align corners. Default: False.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Tensor: Resized flow.
|
| 187 |
+
"""
|
| 188 |
+
_, _, flow_h, flow_w = flow.size()
|
| 189 |
+
if size_type == 'ratio':
|
| 190 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
| 191 |
+
elif size_type == 'shape':
|
| 192 |
+
output_h, output_w = sizes[0], sizes[1]
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f'Size type should be ratio or shape, but got type {size_type}.')
|
| 196 |
+
|
| 197 |
+
input_flow = flow.clone()
|
| 198 |
+
ratio_h = output_h / flow_h
|
| 199 |
+
ratio_w = output_w / flow_w
|
| 200 |
+
input_flow[:, 0, :, :] *= ratio_w
|
| 201 |
+
input_flow[:, 1, :, :] *= ratio_h
|
| 202 |
+
resized_flow = F.interpolate(
|
| 203 |
+
input=input_flow,
|
| 204 |
+
size=(output_h, output_w),
|
| 205 |
+
mode=interp_mode,
|
| 206 |
+
align_corners=align_corners)
|
| 207 |
+
return resized_flow
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# TODO: may write a cpp file
|
| 211 |
+
def pixel_unshuffle(x, scale):
|
| 212 |
+
""" Pixel unshuffle.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
| 216 |
+
scale (int): Downsample ratio.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Tensor: the pixel unshuffled feature.
|
| 220 |
+
"""
|
| 221 |
+
b, c, hh, hw = x.size()
|
| 222 |
+
out_channel = c * (scale**2)
|
| 223 |
+
assert hh % scale == 0 and hw % scale == 0
|
| 224 |
+
h = hh // scale
|
| 225 |
+
w = hw // scale
|
| 226 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
| 227 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# class DCNv2Pack(ModulatedDeformConvPack):
|
| 231 |
+
# """Modulated deformable conv for deformable alignment.
|
| 232 |
+
#
|
| 233 |
+
# Different from the official DCNv2Pack, which generates offsets and masks
|
| 234 |
+
# from the preceding features, this DCNv2Pack takes another different
|
| 235 |
+
# features to generate offsets and masks.
|
| 236 |
+
#
|
| 237 |
+
# Ref:
|
| 238 |
+
# Delving Deep into Deformable Alignment in Video Super-Resolution.
|
| 239 |
+
# """
|
| 240 |
+
#
|
| 241 |
+
# def forward(self, x, feat):
|
| 242 |
+
# out = self.conv_offset(feat)
|
| 243 |
+
# o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 244 |
+
# offset = torch.cat((o1, o2), dim=1)
|
| 245 |
+
# mask = torch.sigmoid(mask)
|
| 246 |
+
#
|
| 247 |
+
# offset_absmean = torch.mean(torch.abs(offset))
|
| 248 |
+
# if offset_absmean > 50:
|
| 249 |
+
# logger = get_root_logger()
|
| 250 |
+
# logger.warning(
|
| 251 |
+
# f'Offset abs mean is {offset_absmean}, larger than 50.')
|
| 252 |
+
#
|
| 253 |
+
# return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
|
| 254 |
+
# self.stride, self.padding, self.dilation,
|
| 255 |
+
# self.groups, self.deformable_groups)
|
basicsr/models/base_model.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
| 7 |
+
|
| 8 |
+
from basicsr.models import lr_scheduler as lr_scheduler
|
| 9 |
+
from basicsr.utils.dist_util import master_only
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger('basicsr')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseModel():
|
| 15 |
+
"""Base model."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
self.opt = opt
|
| 19 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 20 |
+
self.is_train = opt['is_train']
|
| 21 |
+
self.schedulers = []
|
| 22 |
+
self.optimizers = []
|
| 23 |
+
|
| 24 |
+
def feed_data(self, data):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def optimize_parameters(self):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def get_current_visuals(self):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def save(self, epoch, current_iter):
|
| 34 |
+
"""Save networks and training state."""
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True):
|
| 38 |
+
"""Validation function.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
dataloader (torch.utils.data.DataLoader): Validation dataloader.
|
| 42 |
+
current_iter (int): Current iteration.
|
| 43 |
+
tb_logger (tensorboard logger): Tensorboard logger.
|
| 44 |
+
save_img (bool): Whether to save images. Default: False.
|
| 45 |
+
rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True
|
| 46 |
+
use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True
|
| 47 |
+
"""
|
| 48 |
+
if self.opt['dist']:
|
| 49 |
+
return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
|
| 50 |
+
else:
|
| 51 |
+
return self.nondist_validation(dataloader, current_iter, tb_logger,
|
| 52 |
+
save_img, rgb2bgr, use_image)
|
| 53 |
+
|
| 54 |
+
def model_ema(self, decay=0.999):
|
| 55 |
+
net_g = self.get_bare_model(self.net_g)
|
| 56 |
+
|
| 57 |
+
net_g_params = dict(net_g.named_parameters())
|
| 58 |
+
net_g_ema_params = dict(self.net_g_ema.named_parameters())
|
| 59 |
+
|
| 60 |
+
for k in net_g_ema_params.keys():
|
| 61 |
+
net_g_ema_params[k].data.mul_(decay).add_(
|
| 62 |
+
net_g_params[k].data, alpha=1 - decay)
|
| 63 |
+
|
| 64 |
+
def get_current_log(self):
|
| 65 |
+
return self.log_dict
|
| 66 |
+
|
| 67 |
+
def model_to_device(self, net):
|
| 68 |
+
"""Model to device. It also warps models with DistributedDataParallel
|
| 69 |
+
or DataParallel.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
net (nn.Module)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
net = net.to(self.device)
|
| 76 |
+
if self.opt['dist']:
|
| 77 |
+
find_unused_parameters = self.opt.get('find_unused_parameters',
|
| 78 |
+
False)
|
| 79 |
+
net = DistributedDataParallel(
|
| 80 |
+
net,
|
| 81 |
+
device_ids=[torch.cuda.current_device()],
|
| 82 |
+
find_unused_parameters=find_unused_parameters)
|
| 83 |
+
elif self.opt['num_gpu'] > 1:
|
| 84 |
+
net = DataParallel(net)
|
| 85 |
+
return net
|
| 86 |
+
|
| 87 |
+
def setup_schedulers(self):
|
| 88 |
+
"""Set up schedulers."""
|
| 89 |
+
train_opt = self.opt['train']
|
| 90 |
+
scheduler_type = train_opt['scheduler'].pop('type')
|
| 91 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
| 92 |
+
for optimizer in self.optimizers:
|
| 93 |
+
self.schedulers.append(
|
| 94 |
+
lr_scheduler.MultiStepRestartLR(optimizer,
|
| 95 |
+
**train_opt['scheduler']))
|
| 96 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
| 97 |
+
for optimizer in self.optimizers:
|
| 98 |
+
self.schedulers.append(
|
| 99 |
+
lr_scheduler.CosineAnnealingRestartLR(
|
| 100 |
+
optimizer, **train_opt['scheduler']))
|
| 101 |
+
elif scheduler_type == 'CosineAnnealingWarmupRestarts':
|
| 102 |
+
for optimizer in self.optimizers:
|
| 103 |
+
self.schedulers.append(
|
| 104 |
+
lr_scheduler.CosineAnnealingWarmupRestarts(
|
| 105 |
+
optimizer, **train_opt['scheduler']))
|
| 106 |
+
elif scheduler_type == 'CosineAnnealingRestartCyclicLR':
|
| 107 |
+
for optimizer in self.optimizers:
|
| 108 |
+
self.schedulers.append(
|
| 109 |
+
lr_scheduler.CosineAnnealingRestartCyclicLR(
|
| 110 |
+
optimizer, **train_opt['scheduler']))
|
| 111 |
+
elif scheduler_type == 'TrueCosineAnnealingLR':
|
| 112 |
+
print('..', 'cosineannealingLR')
|
| 113 |
+
for optimizer in self.optimizers:
|
| 114 |
+
self.schedulers.append(
|
| 115 |
+
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler']))
|
| 116 |
+
elif scheduler_type == 'CosineAnnealingLRWithRestart':
|
| 117 |
+
print('..', 'CosineAnnealingLR_With_Restart')
|
| 118 |
+
for optimizer in self.optimizers:
|
| 119 |
+
self.schedulers.append(
|
| 120 |
+
lr_scheduler.CosineAnnealingLRWithRestart(optimizer, **train_opt['scheduler']))
|
| 121 |
+
elif scheduler_type == 'LinearLR':
|
| 122 |
+
for optimizer in self.optimizers:
|
| 123 |
+
self.schedulers.append(
|
| 124 |
+
lr_scheduler.LinearLR(
|
| 125 |
+
optimizer, train_opt['total_iter']))
|
| 126 |
+
elif scheduler_type == 'VibrateLR':
|
| 127 |
+
for optimizer in self.optimizers:
|
| 128 |
+
self.schedulers.append(
|
| 129 |
+
lr_scheduler.VibrateLR(
|
| 130 |
+
optimizer, train_opt['total_iter']))
|
| 131 |
+
else:
|
| 132 |
+
raise NotImplementedError(
|
| 133 |
+
f'Scheduler {scheduler_type} is not implemented yet.')
|
| 134 |
+
|
| 135 |
+
def get_bare_model(self, net):
|
| 136 |
+
"""Get bare model, especially under wrapping with
|
| 137 |
+
DistributedDataParallel or DataParallel.
|
| 138 |
+
"""
|
| 139 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 140 |
+
net = net.module
|
| 141 |
+
return net
|
| 142 |
+
|
| 143 |
+
@master_only
|
| 144 |
+
def print_network(self, net):
|
| 145 |
+
"""Print the str and parameter number of a network.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
net (nn.Module)
|
| 149 |
+
"""
|
| 150 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 151 |
+
net_cls_str = (f'{net.__class__.__name__} - '
|
| 152 |
+
f'{net.module.__class__.__name__}')
|
| 153 |
+
else:
|
| 154 |
+
net_cls_str = f'{net.__class__.__name__}'
|
| 155 |
+
|
| 156 |
+
net = self.get_bare_model(net)
|
| 157 |
+
net_str = str(net)
|
| 158 |
+
net_params = sum(map(lambda x: x.numel(), net.parameters()))
|
| 159 |
+
|
| 160 |
+
logger.info(
|
| 161 |
+
f'Network: {net_cls_str}, with parameters: {net_params:,d}')
|
| 162 |
+
logger.info(net_str)
|
| 163 |
+
|
| 164 |
+
def _set_lr(self, lr_groups_l):
|
| 165 |
+
"""Set learning rate for warmup.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
lr_groups_l (list): List for lr_groups, each for an optimizer.
|
| 169 |
+
"""
|
| 170 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
| 171 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
| 172 |
+
param_group['lr'] = lr
|
| 173 |
+
|
| 174 |
+
def _get_init_lr(self):
|
| 175 |
+
"""Get the initial lr, which is set by the scheduler.
|
| 176 |
+
"""
|
| 177 |
+
init_lr_groups_l = []
|
| 178 |
+
for optimizer in self.optimizers:
|
| 179 |
+
init_lr_groups_l.append(
|
| 180 |
+
[v['initial_lr'] for v in optimizer.param_groups])
|
| 181 |
+
return init_lr_groups_l
|
| 182 |
+
|
| 183 |
+
def update_learning_rate(self, current_iter, warmup_iter=-1):
|
| 184 |
+
"""Update learning rate.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
current_iter (int): Current iteration.
|
| 188 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 189 |
+
Default: -1.
|
| 190 |
+
"""
|
| 191 |
+
if current_iter > 1:
|
| 192 |
+
for scheduler in self.schedulers:
|
| 193 |
+
scheduler.step()
|
| 194 |
+
# set up warm-up learning rate
|
| 195 |
+
if current_iter < warmup_iter:
|
| 196 |
+
# get initial lr for each group
|
| 197 |
+
init_lr_g_l = self._get_init_lr()
|
| 198 |
+
# modify warming-up learning rates
|
| 199 |
+
# currently only support linearly warm up
|
| 200 |
+
warm_up_lr_l = []
|
| 201 |
+
for init_lr_g in init_lr_g_l:
|
| 202 |
+
warm_up_lr_l.append(
|
| 203 |
+
[v / warmup_iter * current_iter for v in init_lr_g])
|
| 204 |
+
# set learning rate
|
| 205 |
+
self._set_lr(warm_up_lr_l)
|
| 206 |
+
|
| 207 |
+
def get_current_learning_rate(self):
|
| 208 |
+
return [
|
| 209 |
+
param_group['lr']
|
| 210 |
+
for param_group in self.optimizers[0].param_groups
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
@master_only
|
| 214 |
+
def save_network(self, net, net_label, current_iter, param_key='params'):
|
| 215 |
+
"""Save networks.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
net (nn.Module | list[nn.Module]): Network(s) to be saved.
|
| 219 |
+
net_label (str): Network label.
|
| 220 |
+
current_iter (int): Current iter number.
|
| 221 |
+
param_key (str | list[str]): The parameter key(s) to save network.
|
| 222 |
+
Default: 'params'.
|
| 223 |
+
"""
|
| 224 |
+
if current_iter == -1:
|
| 225 |
+
current_iter = 'latest'
|
| 226 |
+
save_filename = f'{net_label}_{current_iter}.pth'
|
| 227 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
| 228 |
+
|
| 229 |
+
net = net if isinstance(net, list) else [net]
|
| 230 |
+
param_key = param_key if isinstance(param_key, list) else [param_key]
|
| 231 |
+
assert len(net) == len(
|
| 232 |
+
param_key), 'The lengths of net and param_key should be the same.'
|
| 233 |
+
|
| 234 |
+
save_dict = {}
|
| 235 |
+
for net_, param_key_ in zip(net, param_key):
|
| 236 |
+
net_ = self.get_bare_model(net_)
|
| 237 |
+
state_dict = net_.state_dict()
|
| 238 |
+
for key, param in state_dict.items():
|
| 239 |
+
if key.startswith('module.'): # remove unnecessary 'module.'
|
| 240 |
+
key = key[7:]
|
| 241 |
+
state_dict[key] = param.cpu()
|
| 242 |
+
save_dict[param_key_] = state_dict
|
| 243 |
+
|
| 244 |
+
torch.save(save_dict, save_path)
|
| 245 |
+
|
| 246 |
+
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
|
| 247 |
+
"""Print keys with differnet name or different size when loading models.
|
| 248 |
+
|
| 249 |
+
1. Print keys with differnet names.
|
| 250 |
+
2. If strict=False, print the same key but with different tensor size.
|
| 251 |
+
It also ignore these keys with different sizes (not load).
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
crt_net (torch model): Current network.
|
| 255 |
+
load_net (dict): Loaded network.
|
| 256 |
+
strict (bool): Whether strictly loaded. Default: True.
|
| 257 |
+
"""
|
| 258 |
+
crt_net = self.get_bare_model(crt_net)
|
| 259 |
+
crt_net = crt_net.state_dict()
|
| 260 |
+
crt_net_keys = set(crt_net.keys())
|
| 261 |
+
load_net_keys = set(load_net.keys())
|
| 262 |
+
|
| 263 |
+
if crt_net_keys != load_net_keys:
|
| 264 |
+
logger.warning('Current net - loaded net:')
|
| 265 |
+
for v in sorted(list(crt_net_keys - load_net_keys)):
|
| 266 |
+
logger.warning(f' {v}')
|
| 267 |
+
logger.warning('Loaded net - current net:')
|
| 268 |
+
for v in sorted(list(load_net_keys - crt_net_keys)):
|
| 269 |
+
logger.warning(f' {v}')
|
| 270 |
+
|
| 271 |
+
# check the size for the same keys
|
| 272 |
+
if not strict:
|
| 273 |
+
common_keys = crt_net_keys & load_net_keys
|
| 274 |
+
for k in common_keys:
|
| 275 |
+
if crt_net[k].size() != load_net[k].size():
|
| 276 |
+
logger.warning(
|
| 277 |
+
f'Size different, ignore [{k}]: crt_net: '
|
| 278 |
+
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
|
| 279 |
+
load_net[k + '.ignore'] = load_net.pop(k)
|
| 280 |
+
|
| 281 |
+
def load_network(self, net, load_path, strict=True, param_key='params'):
|
| 282 |
+
"""Load network.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
load_path (str): The path of networks to be loaded.
|
| 286 |
+
net (nn.Module): Network.
|
| 287 |
+
strict (bool): Whether strictly loaded.
|
| 288 |
+
param_key (str): The parameter key of loaded network. If set to
|
| 289 |
+
None, use the root 'path'.
|
| 290 |
+
Default: 'params'.
|
| 291 |
+
"""
|
| 292 |
+
net = self.get_bare_model(net)
|
| 293 |
+
logger.info(
|
| 294 |
+
f'Loading {net.__class__.__name__} model from {load_path}.')
|
| 295 |
+
load_net = torch.load(
|
| 296 |
+
load_path, map_location=lambda storage, loc: storage)
|
| 297 |
+
if param_key is not None:
|
| 298 |
+
if param_key not in load_net and 'params' in load_net:
|
| 299 |
+
param_key = 'params'
|
| 300 |
+
logger.info('Loading: params_ema does not exist, use params.')
|
| 301 |
+
load_net = load_net[param_key]
|
| 302 |
+
print(' load net keys', load_net.keys)
|
| 303 |
+
# remove unnecessary 'module.'
|
| 304 |
+
for k, v in deepcopy(load_net).items():
|
| 305 |
+
if k.startswith('module.'):
|
| 306 |
+
load_net[k[7:]] = v
|
| 307 |
+
load_net.pop(k)
|
| 308 |
+
self._print_different_keys_loading(net, load_net, strict)
|
| 309 |
+
net.load_state_dict(load_net, strict=strict)
|
| 310 |
+
|
| 311 |
+
@master_only
|
| 312 |
+
def save_training_state(self, epoch, current_iter):
|
| 313 |
+
"""Save training states during training, which will be used for
|
| 314 |
+
resuming.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
epoch (int): Current epoch.
|
| 318 |
+
current_iter (int): Current iteration.
|
| 319 |
+
"""
|
| 320 |
+
if current_iter != -1:
|
| 321 |
+
state = {
|
| 322 |
+
'epoch': epoch,
|
| 323 |
+
'iter': current_iter,
|
| 324 |
+
'optimizers': [],
|
| 325 |
+
'schedulers': []
|
| 326 |
+
}
|
| 327 |
+
for o in self.optimizers:
|
| 328 |
+
state['optimizers'].append(o.state_dict())
|
| 329 |
+
for s in self.schedulers:
|
| 330 |
+
state['schedulers'].append(s.state_dict())
|
| 331 |
+
save_filename = f'{current_iter}.state'
|
| 332 |
+
save_path = os.path.join(self.opt['path']['training_states'],
|
| 333 |
+
save_filename)
|
| 334 |
+
torch.save(state, save_path)
|
| 335 |
+
|
| 336 |
+
def resume_training(self, resume_state):
|
| 337 |
+
"""Reload the optimizers and schedulers for resumed training.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
resume_state (dict): Resume state.
|
| 341 |
+
"""
|
| 342 |
+
resume_optimizers = resume_state['optimizers']
|
| 343 |
+
resume_schedulers = resume_state['schedulers']
|
| 344 |
+
assert len(resume_optimizers) == len(
|
| 345 |
+
self.optimizers), 'Wrong lengths of optimizers'
|
| 346 |
+
assert len(resume_schedulers) == len(
|
| 347 |
+
self.schedulers), 'Wrong lengths of schedulers'
|
| 348 |
+
for i, o in enumerate(resume_optimizers):
|
| 349 |
+
self.optimizers[i].load_state_dict(o)
|
| 350 |
+
for i, s in enumerate(resume_schedulers):
|
| 351 |
+
self.schedulers[i].load_state_dict(s)
|
| 352 |
+
|
| 353 |
+
def reduce_loss_dict(self, loss_dict):
|
| 354 |
+
"""reduce loss dict.
|
| 355 |
+
|
| 356 |
+
In distributed training, it averages the losses among different GPUs .
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
loss_dict (OrderedDict): Loss dict.
|
| 360 |
+
"""
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
if self.opt['dist']:
|
| 363 |
+
keys = []
|
| 364 |
+
losses = []
|
| 365 |
+
for name, value in loss_dict.items():
|
| 366 |
+
keys.append(name)
|
| 367 |
+
losses.append(value)
|
| 368 |
+
losses = torch.stack(losses, 0)
|
| 369 |
+
torch.distributed.reduce(losses, dst=0)
|
| 370 |
+
if self.opt['rank'] == 0:
|
| 371 |
+
losses /= self.opt['world_size']
|
| 372 |
+
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
| 373 |
+
|
| 374 |
+
log_dict = OrderedDict()
|
| 375 |
+
for name, value in loss_dict.items():
|
| 376 |
+
log_dict[name] = value.mean().item()
|
| 377 |
+
|
| 378 |
+
return log_dict
|
basicsr/models/image_restoration_model.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import torch
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from os import path as osp
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from basicsr.models.archs import define_network
|
| 9 |
+
from basicsr.models.base_model import BaseModel
|
| 10 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 11 |
+
|
| 12 |
+
loss_module = importlib.import_module('basicsr.models.losses')
|
| 13 |
+
metric_module = importlib.import_module('basicsr.metrics')
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import numpy as np
|
| 18 |
+
import cv2
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from functools import partial
|
| 21 |
+
|
| 22 |
+
class Mixing_Augment:
|
| 23 |
+
def __init__(self, mixup_beta, use_identity, device):
|
| 24 |
+
self.dist = torch.distributions.beta.Beta(torch.tensor([mixup_beta]), torch.tensor([mixup_beta]))
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
self.use_identity = use_identity
|
| 28 |
+
|
| 29 |
+
self.augments = [self.mixup]
|
| 30 |
+
|
| 31 |
+
def mixup(self, target, input_):
|
| 32 |
+
lam = self.dist.rsample((1,1)).item()
|
| 33 |
+
|
| 34 |
+
r_index = torch.randperm(target.size(0)).to(self.device)
|
| 35 |
+
|
| 36 |
+
target = lam * target + (1-lam) * target[r_index, :]
|
| 37 |
+
input_ = lam * input_ + (1-lam) * input_[r_index, :]
|
| 38 |
+
|
| 39 |
+
return target, input_
|
| 40 |
+
|
| 41 |
+
def __call__(self, target, input_):
|
| 42 |
+
if self.use_identity:
|
| 43 |
+
augment = random.randint(0, len(self.augments))
|
| 44 |
+
if augment < len(self.augments):
|
| 45 |
+
target, input_ = self.augments[augment](target, input_)
|
| 46 |
+
else:
|
| 47 |
+
augment = random.randint(0, len(self.augments)-1)
|
| 48 |
+
target, input_ = self.augments[augment](target, input_)
|
| 49 |
+
return target, input_
|
| 50 |
+
|
| 51 |
+
class ImageCleanModel(BaseModel):
|
| 52 |
+
"""Base Deblur model for single image deblur."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, opt):
|
| 55 |
+
super(ImageCleanModel, self).__init__(opt)
|
| 56 |
+
|
| 57 |
+
# define network
|
| 58 |
+
|
| 59 |
+
self.mixing_flag = self.opt['train']['mixing_augs'].get('mixup', False)
|
| 60 |
+
if self.mixing_flag:
|
| 61 |
+
mixup_beta = self.opt['train']['mixing_augs'].get('mixup_beta', 1.2)
|
| 62 |
+
use_identity = self.opt['train']['mixing_augs'].get('use_identity', False)
|
| 63 |
+
self.mixing_augmentation = Mixing_Augment(mixup_beta, use_identity, self.device)
|
| 64 |
+
|
| 65 |
+
self.net_g = define_network(deepcopy(opt['network_g']))
|
| 66 |
+
self.net_g = self.model_to_device(self.net_g)
|
| 67 |
+
self.print_network(self.net_g)
|
| 68 |
+
|
| 69 |
+
# load pretrained models
|
| 70 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 71 |
+
if load_path is not None:
|
| 72 |
+
self.load_network(self.net_g, load_path,
|
| 73 |
+
self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params'))
|
| 74 |
+
|
| 75 |
+
if self.is_train:
|
| 76 |
+
self.init_training_settings()
|
| 77 |
+
|
| 78 |
+
def init_training_settings(self):
|
| 79 |
+
self.net_g.train()
|
| 80 |
+
train_opt = self.opt['train']
|
| 81 |
+
|
| 82 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 83 |
+
if self.ema_decay > 0:
|
| 84 |
+
logger = get_root_logger()
|
| 85 |
+
logger.info(
|
| 86 |
+
f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 87 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 88 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 89 |
+
# There is no need to wrap with DistributedDataParallel
|
| 90 |
+
self.net_g_ema = define_network(self.opt['network_g']).to(
|
| 91 |
+
self.device)
|
| 92 |
+
# load pretrained model
|
| 93 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 94 |
+
if load_path is not None:
|
| 95 |
+
self.load_network(self.net_g_ema, load_path,
|
| 96 |
+
self.opt['path'].get('strict_load_g',
|
| 97 |
+
True), 'params_ema')
|
| 98 |
+
else:
|
| 99 |
+
self.model_ema(0) # copy net_g weight
|
| 100 |
+
self.net_g_ema.eval()
|
| 101 |
+
|
| 102 |
+
# define losses
|
| 103 |
+
if train_opt.get('pixel_opt'):
|
| 104 |
+
pixel_type = train_opt['pixel_opt'].pop('type')
|
| 105 |
+
cri_pix_cls = getattr(loss_module, pixel_type)
|
| 106 |
+
self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
|
| 107 |
+
self.device)
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError('pixel loss are None.')
|
| 110 |
+
|
| 111 |
+
if train_opt.get('fft_loss_opt'):
|
| 112 |
+
fft_type = train_opt['fft_loss_opt'].pop('type')
|
| 113 |
+
cri_fft_cls = getattr(loss_module, fft_type)
|
| 114 |
+
self.cri_fft = cri_fft_cls(**train_opt['fft_loss_opt']).to(
|
| 115 |
+
self.device)
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
self.cri_fft = None
|
| 119 |
+
|
| 120 |
+
# set up optimizers and schedulers
|
| 121 |
+
self.setup_optimizers()
|
| 122 |
+
self.setup_schedulers()
|
| 123 |
+
|
| 124 |
+
def setup_optimizers(self):
|
| 125 |
+
train_opt = self.opt['train']
|
| 126 |
+
optim_params = []
|
| 127 |
+
|
| 128 |
+
for k, v in self.net_g.named_parameters():
|
| 129 |
+
if v.requires_grad:
|
| 130 |
+
optim_params.append(v)
|
| 131 |
+
else:
|
| 132 |
+
logger = get_root_logger()
|
| 133 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 134 |
+
|
| 135 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 136 |
+
if optim_type == 'Adam':
|
| 137 |
+
self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g'])
|
| 138 |
+
elif optim_type == 'AdamW':
|
| 139 |
+
self.optimizer_g = torch.optim.AdamW(optim_params, **train_opt['optim_g'])
|
| 140 |
+
else:
|
| 141 |
+
raise NotImplementedError(
|
| 142 |
+
f'optimizer {optim_type} is not supperted yet.')
|
| 143 |
+
self.optimizers.append(self.optimizer_g)
|
| 144 |
+
|
| 145 |
+
def feed_train_data(self, data):
|
| 146 |
+
self.lq = data['lq'].to(self.device)
|
| 147 |
+
if 'gt' in data:
|
| 148 |
+
self.gt = data['gt'].to(self.device)
|
| 149 |
+
|
| 150 |
+
if self.mixing_flag:
|
| 151 |
+
self.gt, self.lq = self.mixing_augmentation(self.gt, self.lq)
|
| 152 |
+
|
| 153 |
+
def feed_data(self, data):
|
| 154 |
+
self.lq = data['lq'].to(self.device)
|
| 155 |
+
if 'gt' in data:
|
| 156 |
+
self.gt = data['gt'].to(self.device)
|
| 157 |
+
|
| 158 |
+
def optimize_parameters(self, current_iter):
|
| 159 |
+
self.optimizer_g.zero_grad()
|
| 160 |
+
preds = self.net_g(self.lq)
|
| 161 |
+
if not isinstance(preds, list):
|
| 162 |
+
preds = [preds]
|
| 163 |
+
|
| 164 |
+
self.output = preds[-1]
|
| 165 |
+
|
| 166 |
+
# loss_dict = OrderedDict()
|
| 167 |
+
# # pixel loss
|
| 168 |
+
# l_pix = 0.
|
| 169 |
+
# for pred in preds:
|
| 170 |
+
# l_pix += self.cri_pix(pred, self.gt)
|
| 171 |
+
|
| 172 |
+
# loss_dict['l_pix'] = l_pix
|
| 173 |
+
|
| 174 |
+
# l_pix.backward()
|
| 175 |
+
l_total = 0
|
| 176 |
+
loss_dict = OrderedDict()
|
| 177 |
+
# pixel loss
|
| 178 |
+
if self.cri_pix:
|
| 179 |
+
l_pix = 0.
|
| 180 |
+
for pred in preds:
|
| 181 |
+
l_pix += self.cri_pix(pred, self.gt)
|
| 182 |
+
|
| 183 |
+
# print('l pix ... ', l_pix)
|
| 184 |
+
l_total += l_pix
|
| 185 |
+
loss_dict['l_pix'] = l_pix
|
| 186 |
+
|
| 187 |
+
# fft loss
|
| 188 |
+
if self.cri_fft:
|
| 189 |
+
l_fft = self.cri_fft(preds[-1], self.gt)
|
| 190 |
+
l_total += l_fft
|
| 191 |
+
loss_dict['l_fft'] = l_fft
|
| 192 |
+
|
| 193 |
+
l_total = l_total + 0. * sum(p.sum() for p in self.net_g.parameters())
|
| 194 |
+
|
| 195 |
+
l_total = l_total
|
| 196 |
+
|
| 197 |
+
l_total.backward()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if self.opt['train']['use_grad_clip']:
|
| 201 |
+
torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01)
|
| 202 |
+
self.optimizer_g.step()
|
| 203 |
+
|
| 204 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 205 |
+
|
| 206 |
+
if self.ema_decay > 0:
|
| 207 |
+
self.model_ema(decay=self.ema_decay)
|
| 208 |
+
|
| 209 |
+
def pad_test(self, window_size):
|
| 210 |
+
scale = self.opt.get('scale', 1)
|
| 211 |
+
mod_pad_h, mod_pad_w = 0, 0
|
| 212 |
+
_, _, h, w = self.lq.size()
|
| 213 |
+
if h % window_size != 0:
|
| 214 |
+
mod_pad_h = window_size - h % window_size
|
| 215 |
+
if w % window_size != 0:
|
| 216 |
+
mod_pad_w = window_size - w % window_size
|
| 217 |
+
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
| 218 |
+
self.nonpad_test(img)
|
| 219 |
+
_, _, h, w = self.output.size()
|
| 220 |
+
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
|
| 221 |
+
|
| 222 |
+
def nonpad_test(self, img=None):
|
| 223 |
+
if img is None:
|
| 224 |
+
img = self.lq
|
| 225 |
+
if hasattr(self, 'net_g_ema'):
|
| 226 |
+
self.net_g_ema.eval()
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
pred = self.net_g_ema(img)
|
| 229 |
+
if isinstance(pred, list):
|
| 230 |
+
pred = pred[-1]
|
| 231 |
+
self.output = pred
|
| 232 |
+
else:
|
| 233 |
+
self.net_g.eval()
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
pred = self.net_g(img)
|
| 236 |
+
if isinstance(pred, list):
|
| 237 |
+
pred = pred[-1]
|
| 238 |
+
self.output = pred
|
| 239 |
+
self.net_g.train()
|
| 240 |
+
|
| 241 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image):
|
| 242 |
+
if os.environ['LOCAL_RANK'] == '0':
|
| 243 |
+
return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
|
| 244 |
+
else:
|
| 245 |
+
return 0.
|
| 246 |
+
|
| 247 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger,
|
| 248 |
+
save_img, rgb2bgr, use_image):
|
| 249 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 250 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 251 |
+
if with_metrics:
|
| 252 |
+
self.metric_results = {
|
| 253 |
+
metric: 0
|
| 254 |
+
for metric in self.opt['val']['metrics'].keys()
|
| 255 |
+
}
|
| 256 |
+
# pbar = tqdm(total=len(dataloader), unit='image')
|
| 257 |
+
|
| 258 |
+
window_size = self.opt['val'].get('window_size', 0)
|
| 259 |
+
|
| 260 |
+
if window_size:
|
| 261 |
+
test = partial(self.pad_test, window_size)
|
| 262 |
+
else:
|
| 263 |
+
test = self.nonpad_test
|
| 264 |
+
|
| 265 |
+
cnt = 0
|
| 266 |
+
|
| 267 |
+
for idx, val_data in enumerate(dataloader):
|
| 268 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 269 |
+
|
| 270 |
+
self.feed_data(val_data)
|
| 271 |
+
test()
|
| 272 |
+
|
| 273 |
+
visuals = self.get_current_visuals()
|
| 274 |
+
sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
|
| 275 |
+
if 'gt' in visuals:
|
| 276 |
+
gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
|
| 277 |
+
del self.gt
|
| 278 |
+
|
| 279 |
+
# tentative for out of GPU memory
|
| 280 |
+
del self.lq
|
| 281 |
+
del self.output
|
| 282 |
+
torch.cuda.empty_cache()
|
| 283 |
+
|
| 284 |
+
if save_img:
|
| 285 |
+
|
| 286 |
+
if self.opt['is_train']:
|
| 287 |
+
|
| 288 |
+
save_img_path = osp.join(self.opt['path']['visualization'],
|
| 289 |
+
img_name,
|
| 290 |
+
f'{img_name}_{current_iter}.png')
|
| 291 |
+
|
| 292 |
+
save_gt_img_path = osp.join(self.opt['path']['visualization'],
|
| 293 |
+
img_name,
|
| 294 |
+
f'{img_name}_{current_iter}_gt.png')
|
| 295 |
+
else:
|
| 296 |
+
|
| 297 |
+
save_img_path = osp.join(
|
| 298 |
+
self.opt['path']['visualization'], dataset_name,
|
| 299 |
+
f'{img_name}.png')
|
| 300 |
+
save_gt_img_path = osp.join(
|
| 301 |
+
self.opt['path']['visualization'], dataset_name,
|
| 302 |
+
f'{img_name}_gt.png')
|
| 303 |
+
|
| 304 |
+
imwrite(sr_img, save_img_path)
|
| 305 |
+
imwrite(gt_img, save_gt_img_path)
|
| 306 |
+
|
| 307 |
+
if with_metrics:
|
| 308 |
+
# calculate metrics
|
| 309 |
+
opt_metric = deepcopy(self.opt['val']['metrics'])
|
| 310 |
+
if use_image:
|
| 311 |
+
for name, opt_ in opt_metric.items():
|
| 312 |
+
metric_type = opt_.pop('type')
|
| 313 |
+
self.metric_results[name] += getattr(
|
| 314 |
+
metric_module, metric_type)(sr_img, gt_img, **opt_)
|
| 315 |
+
else:
|
| 316 |
+
for name, opt_ in opt_metric.items():
|
| 317 |
+
metric_type = opt_.pop('type')
|
| 318 |
+
self.metric_results[name] += getattr(
|
| 319 |
+
metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_)
|
| 320 |
+
|
| 321 |
+
cnt += 1
|
| 322 |
+
|
| 323 |
+
current_metric = 0.
|
| 324 |
+
if with_metrics:
|
| 325 |
+
for metric in self.metric_results.keys():
|
| 326 |
+
self.metric_results[metric] /= cnt
|
| 327 |
+
current_metric = self.metric_results[metric]
|
| 328 |
+
|
| 329 |
+
self._log_validation_metric_values(current_iter, dataset_name,
|
| 330 |
+
tb_logger)
|
| 331 |
+
return current_metric
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _log_validation_metric_values(self, current_iter, dataset_name,
|
| 335 |
+
tb_logger):
|
| 336 |
+
log_str = f'Validation {dataset_name},\t'
|
| 337 |
+
for metric, value in self.metric_results.items():
|
| 338 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
| 339 |
+
logger = get_root_logger()
|
| 340 |
+
logger.info(log_str)
|
| 341 |
+
if tb_logger:
|
| 342 |
+
for metric, value in self.metric_results.items():
|
| 343 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 344 |
+
|
| 345 |
+
def get_current_visuals(self):
|
| 346 |
+
out_dict = OrderedDict()
|
| 347 |
+
out_dict['lq'] = self.lq.detach().cpu()
|
| 348 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 349 |
+
if hasattr(self, 'gt'):
|
| 350 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 351 |
+
return out_dict
|
| 352 |
+
|
| 353 |
+
def save(self, epoch, current_iter):
|
| 354 |
+
if self.ema_decay > 0:
|
| 355 |
+
self.save_network([self.net_g, self.net_g_ema],
|
| 356 |
+
'net_g',
|
| 357 |
+
current_iter,
|
| 358 |
+
param_key=['params', 'params_ema'])
|
| 359 |
+
else:
|
| 360 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 361 |
+
self.save_training_state(epoch, current_iter)
|