yssszzzzzzzzy commited on
Commit
8e79984
·
1 Parent(s): dac2323

Initial commit of FPro dehazing model

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. basicsr/.DS_Store +0 -0
  2. basicsr/__pycache__/version.cpython-37.pyc +0 -0
  3. basicsr/data/.DS_Store +0 -0
  4. basicsr/data/__init__.py +126 -0
  5. basicsr/data/__pycache__/__init__.cpython-37.pyc +0 -0
  6. basicsr/data/__pycache__/data_sampler.cpython-37.pyc +0 -0
  7. basicsr/data/__pycache__/data_util.cpython-37.pyc +0 -0
  8. basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc +0 -0
  9. basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc +0 -0
  10. basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc +0 -0
  11. basicsr/data/__pycache__/reds_dataset.cpython-37.pyc +0 -0
  12. basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc +0 -0
  13. basicsr/data/__pycache__/transforms.cpython-37.pyc +0 -0
  14. basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc +0 -0
  15. basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc +0 -0
  16. basicsr/data/data_sampler.py +49 -0
  17. basicsr/data/data_util.py +388 -0
  18. basicsr/data/ffhq_dataset.py +65 -0
  19. basicsr/data/paired_image_dataset.py +824 -0
  20. basicsr/data/prefetch_dataloader.py +126 -0
  21. basicsr/data/reds_dataset.py +237 -0
  22. basicsr/data/single_image_dataset.py +67 -0
  23. basicsr/data/transforms.py +480 -0
  24. basicsr/data/video_test_dataset.py +325 -0
  25. basicsr/data/vimeo90k_dataset.py +130 -0
  26. basicsr/metrics/__init__.py +4 -0
  27. basicsr/metrics/__pycache__/__init__.cpython-37.pyc +0 -0
  28. basicsr/metrics/__pycache__/metric_util.cpython-37.pyc +0 -0
  29. basicsr/metrics/__pycache__/niqe.cpython-37.pyc +0 -0
  30. basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc +0 -0
  31. basicsr/metrics/fid.py +102 -0
  32. basicsr/metrics/metric_util.py +47 -0
  33. basicsr/metrics/niqe.py +205 -0
  34. basicsr/metrics/niqe_pris_params.npz +3 -0
  35. basicsr/metrics/psnr_ssim.py +303 -0
  36. basicsr/models/.DS_Store +0 -0
  37. basicsr/models/__init__.py +42 -0
  38. basicsr/models/__pycache__/__init__.cpython-37.pyc +0 -0
  39. basicsr/models/__pycache__/base_model.cpython-37.pyc +0 -0
  40. basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc +0 -0
  41. basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc +0 -0
  42. basicsr/models/archs/FPro_arch.py +545 -0
  43. basicsr/models/archs/__init__.py +46 -0
  44. basicsr/models/archs/__pycache__/__init__.cpython-37.pyc +0 -0
  45. basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc +0 -0
  46. basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc +0 -0
  47. basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc +0 -0
  48. basicsr/models/archs/arch_util.py +255 -0
  49. basicsr/models/base_model.py +378 -0
  50. basicsr/models/image_restoration_model.py +361 -0
basicsr/.DS_Store ADDED
Binary file (10.2 kB). View file
 
basicsr/__pycache__/version.cpython-37.pyc ADDED
Binary file (244 Bytes). View file
 
basicsr/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
basicsr/data/__init__.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from functools import partial
7
+ from os import path as osp
8
+
9
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10
+ from basicsr.utils import get_root_logger, scandir
11
+ from basicsr.utils.dist_util import get_dist_info
12
+
13
+ __all__ = ['create_dataset', 'create_dataloader']
14
+
15
+ # automatically scan and import dataset modules
16
+ # scan all the files under the data folder with '_dataset' in file names
17
+ data_folder = osp.dirname(osp.abspath(__file__))
18
+ dataset_filenames = [
19
+ osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
20
+ if v.endswith('_dataset.py')
21
+ ]
22
+ # import all the dataset modules
23
+ _dataset_modules = [
24
+ importlib.import_module(f'basicsr.data.{file_name}')
25
+ for file_name in dataset_filenames
26
+ ]
27
+
28
+
29
+ def create_dataset(dataset_opt):
30
+ """Create dataset.
31
+
32
+ Args:
33
+ dataset_opt (dict): Configuration for dataset. It constains:
34
+ name (str): Dataset name.
35
+ type (str): Dataset type.
36
+ """
37
+ dataset_type = dataset_opt['type']
38
+
39
+ # dynamic instantiation
40
+ for module in _dataset_modules:
41
+ dataset_cls = getattr(module, dataset_type, None)
42
+ if dataset_cls is not None:
43
+ break
44
+ if dataset_cls is None:
45
+ raise ValueError(f'Dataset {dataset_type} is not found.')
46
+
47
+ dataset = dataset_cls(dataset_opt)
48
+
49
+ logger = get_root_logger()
50
+ logger.info(
51
+ f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
52
+ 'is created.')
53
+ return dataset
54
+
55
+
56
+ def create_dataloader(dataset,
57
+ dataset_opt,
58
+ num_gpu=1,
59
+ dist=False,
60
+ sampler=None,
61
+ seed=None):
62
+ """Create dataloader.
63
+
64
+ Args:
65
+ dataset (torch.utils.data.Dataset): Dataset.
66
+ dataset_opt (dict): Dataset options. It contains the following keys:
67
+ phase (str): 'train' or 'val'.
68
+ num_worker_per_gpu (int): Number of workers for each GPU.
69
+ batch_size_per_gpu (int): Training batch size for each GPU.
70
+ num_gpu (int): Number of GPUs. Used only in the train phase.
71
+ Default: 1.
72
+ dist (bool): Whether in distributed training. Used only in the train
73
+ phase. Default: False.
74
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
75
+ seed (int | None): Seed. Default: None
76
+ """
77
+ phase = dataset_opt['phase']
78
+ rank, _ = get_dist_info()
79
+ if phase == 'train':
80
+ if dist: # distributed training
81
+ batch_size = dataset_opt['batch_size_per_gpu']
82
+ num_workers = dataset_opt['num_worker_per_gpu']
83
+ else: # non-distributed training
84
+ multiplier = 1 if num_gpu == 0 else num_gpu
85
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
86
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
87
+ dataloader_args = dict(
88
+ dataset=dataset,
89
+ batch_size=batch_size,
90
+ shuffle=False,
91
+ num_workers=num_workers,
92
+ sampler=sampler,
93
+ drop_last=True)
94
+ if sampler is None:
95
+ dataloader_args['shuffle'] = True
96
+ dataloader_args['worker_init_fn'] = partial(
97
+ worker_init_fn, num_workers=num_workers, rank=rank,
98
+ seed=seed) if seed is not None else None
99
+ elif phase in ['val', 'test']: # validation
100
+ dataloader_args = dict(
101
+ dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
102
+ else:
103
+ raise ValueError(f'Wrong dataset phase: {phase}. '
104
+ "Supported ones are 'train', 'val' and 'test'.")
105
+
106
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
107
+
108
+ prefetch_mode = dataset_opt.get('prefetch_mode')
109
+ if prefetch_mode == 'cpu': # CPUPrefetcher
110
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
111
+ logger = get_root_logger()
112
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: '
113
+ f'num_prefetch_queue = {num_prefetch_queue}')
114
+ return PrefetchDataLoader(
115
+ num_prefetch_queue=num_prefetch_queue, **dataloader_args)
116
+ else:
117
+ # prefetch_mode=None: Normal dataloader
118
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
119
+ return torch.utils.data.DataLoader(**dataloader_args)
120
+
121
+
122
+ def worker_init_fn(worker_id, num_workers, rank, seed):
123
+ # Set the worker seed to num_workers * rank + worker_id + seed
124
+ worker_seed = num_workers * rank + worker_id + seed
125
+ np.random.seed(worker_seed)
126
+ random.seed(worker_seed)
basicsr/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (3.53 kB). View file
 
basicsr/data/__pycache__/data_sampler.cpython-37.pyc ADDED
Binary file (2.14 kB). View file
 
basicsr/data/__pycache__/data_util.cpython-37.pyc ADDED
Binary file (13 kB). View file
 
basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc ADDED
Binary file (2.54 kB). View file
 
basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc ADDED
Binary file (16.3 kB). View file
 
basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc ADDED
Binary file (4.29 kB). View file
 
basicsr/data/__pycache__/reds_dataset.cpython-37.pyc ADDED
Binary file (6.44 kB). View file
 
basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc ADDED
Binary file (2.61 kB). View file
 
basicsr/data/__pycache__/transforms.cpython-37.pyc ADDED
Binary file (9.85 kB). View file
 
basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc ADDED
Binary file (10.7 kB). View file
 
basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc ADDED
Binary file (4.16 kB). View file
 
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(
27
+ len(self.dataset) * ratio / self.num_replicas)
28
+ self.total_size = self.num_samples * self.num_replicas
29
+
30
+ def __iter__(self):
31
+ # deterministically shuffle based on epoch
32
+ g = torch.Generator()
33
+ g.manual_seed(self.epoch)
34
+ indices = torch.randperm(self.total_size, generator=g).tolist()
35
+
36
+ dataset_size = len(self.dataset)
37
+ indices = [v % dataset_size for v in indices]
38
+
39
+ # subsample
40
+ indices = indices[self.rank:self.total_size:self.num_replicas]
41
+ assert len(indices) == self.num_samples
42
+
43
+ return iter(indices)
44
+
45
+ def __len__(self):
46
+ return self.num_samples
47
+
48
+ def set_epoch(self, epoch):
49
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ cv2.setNumThreads(1)
3
+ import numpy as np
4
+ import torch
5
+ from os import path as osp
6
+ from torch.nn import functional as F
7
+
8
+ from basicsr.data.transforms import mod_crop
9
+ from basicsr.utils import img2tensor, scandir
10
+
11
+
12
+ def read_img_seq(path, require_mod_crop=False, scale=1):
13
+ """Read a sequence of images from a given folder path.
14
+
15
+ Args:
16
+ path (list[str] | str): List of image paths or image folder path.
17
+ require_mod_crop (bool): Require mod crop for each image.
18
+ Default: False.
19
+ scale (int): Scale factor for mod_crop. Default: 1.
20
+
21
+ Returns:
22
+ Tensor: size (t, c, h, w), RGB, [0, 1].
23
+ """
24
+ if isinstance(path, list):
25
+ img_paths = path
26
+ else:
27
+ img_paths = sorted(list(scandir(path, full_path=True)))
28
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
29
+ if require_mod_crop:
30
+ imgs = [mod_crop(img, scale) for img in imgs]
31
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
32
+ imgs = torch.stack(imgs, dim=0)
33
+ return imgs
34
+
35
+
36
+ def generate_frame_indices(crt_idx,
37
+ max_frame_num,
38
+ num_frames,
39
+ padding='reflection'):
40
+ """Generate an index list for reading `num_frames` frames from a sequence
41
+ of images.
42
+
43
+ Args:
44
+ crt_idx (int): Current center index.
45
+ max_frame_num (int): Max number of the sequence of images (from 1).
46
+ num_frames (int): Reading num_frames frames.
47
+ padding (str): Padding mode, one of
48
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
49
+ Examples: current_idx = 0, num_frames = 5
50
+ The generated frame indices under different padding mode:
51
+ replicate: [0, 0, 0, 1, 2]
52
+ reflection: [2, 1, 0, 1, 2]
53
+ reflection_circle: [4, 3, 0, 1, 2]
54
+ circle: [3, 4, 0, 1, 2]
55
+
56
+ Returns:
57
+ list[int]: A list of indices.
58
+ """
59
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
60
+ assert padding in ('replicate', 'reflection', 'reflection_circle',
61
+ 'circle'), f'Wrong padding mode: {padding}.'
62
+
63
+ max_frame_num = max_frame_num - 1 # start from 0
64
+ num_pad = num_frames // 2
65
+
66
+ indices = []
67
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
68
+ if i < 0:
69
+ if padding == 'replicate':
70
+ pad_idx = 0
71
+ elif padding == 'reflection':
72
+ pad_idx = -i
73
+ elif padding == 'reflection_circle':
74
+ pad_idx = crt_idx + num_pad - i
75
+ else:
76
+ pad_idx = num_frames + i
77
+ elif i > max_frame_num:
78
+ if padding == 'replicate':
79
+ pad_idx = max_frame_num
80
+ elif padding == 'reflection':
81
+ pad_idx = max_frame_num * 2 - i
82
+ elif padding == 'reflection_circle':
83
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
84
+ else:
85
+ pad_idx = i - num_frames
86
+ else:
87
+ pad_idx = i
88
+ indices.append(pad_idx)
89
+ return indices
90
+
91
+
92
+ def paired_paths_from_lmdb(folders, keys):
93
+ """Generate paired paths from lmdb files.
94
+
95
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
96
+
97
+ lq.lmdb
98
+ ├── data.mdb
99
+ ├── lock.mdb
100
+ ├── meta_info.txt
101
+
102
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
103
+ https://lmdb.readthedocs.io/en/release/ for more details.
104
+
105
+ The meta_info.txt is a specified txt file to record the meta information
106
+ of our datasets. It will be automatically created when preparing
107
+ datasets by our provided dataset tools.
108
+ Each line in the txt file records
109
+ 1)image name (with extension),
110
+ 2)image shape,
111
+ 3)compression level, separated by a white space.
112
+ Example: `baboon.png (120,125,3) 1`
113
+
114
+ We use the image name without extension as the lmdb key.
115
+ Note that we use the same key for the corresponding lq and gt images.
116
+
117
+ Args:
118
+ folders (list[str]): A list of folder path. The order of list should
119
+ be [input_folder, gt_folder].
120
+ keys (list[str]): A list of keys identifying folders. The order should
121
+ be in consistent with folders, e.g., ['lq', 'gt'].
122
+ Note that this key is different from lmdb keys.
123
+
124
+ Returns:
125
+ list[str]: Returned path list.
126
+ """
127
+ assert len(folders) == 2, (
128
+ 'The len of folders should be 2 with [input_folder, gt_folder]. '
129
+ f'But got {len(folders)}')
130
+ assert len(keys) == 2, (
131
+ 'The len of keys should be 2 with [input_key, gt_key]. '
132
+ f'But got {len(keys)}')
133
+ input_folder, gt_folder = folders
134
+ input_key, gt_key = keys
135
+
136
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
137
+ raise ValueError(
138
+ f'{input_key} folder and {gt_key} folder should both in lmdb '
139
+ f'formats. But received {input_key}: {input_folder}; '
140
+ f'{gt_key}: {gt_folder}')
141
+ # ensure that the two meta_info files are the same
142
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
143
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
144
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
145
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
146
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
147
+ raise ValueError(
148
+ f'Keys in {input_key}_folder and {gt_key}_folder are different.')
149
+ else:
150
+ paths = []
151
+ for lmdb_key in sorted(input_lmdb_keys):
152
+ paths.append(
153
+ dict([(f'{input_key}_path', lmdb_key),
154
+ (f'{gt_key}_path', lmdb_key)]))
155
+ return paths
156
+
157
+
158
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file,
159
+ filename_tmpl):
160
+ """Generate paired paths from an meta information file.
161
+
162
+ Each line in the meta information file contains the image names and
163
+ image shape (usually for gt), separated by a white space.
164
+
165
+ Example of an meta information file:
166
+ ```
167
+ 0001_s001.png (480,480,3)
168
+ 0001_s002.png (480,480,3)
169
+ ```
170
+
171
+ Args:
172
+ folders (list[str]): A list of folder path. The order of list should
173
+ be [input_folder, gt_folder].
174
+ keys (list[str]): A list of keys identifying folders. The order should
175
+ be in consistent with folders, e.g., ['lq', 'gt'].
176
+ meta_info_file (str): Path to the meta information file.
177
+ filename_tmpl (str): Template for each filename. Note that the
178
+ template excludes the file extension. Usually the filename_tmpl is
179
+ for files in the input folder.
180
+
181
+ Returns:
182
+ list[str]: Returned path list.
183
+ """
184
+ assert len(folders) == 2, (
185
+ 'The len of folders should be 2 with [input_folder, gt_folder]. '
186
+ f'But got {len(folders)}')
187
+ assert len(keys) == 2, (
188
+ 'The len of keys should be 2 with [input_key, gt_key]. '
189
+ f'But got {len(keys)}')
190
+ input_folder, gt_folder = folders
191
+ input_key, gt_key = keys
192
+
193
+ with open(meta_info_file, 'r') as fin:
194
+ gt_names = [line.split(' ')[0] for line in fin]
195
+
196
+ paths = []
197
+ for gt_name in gt_names:
198
+ basename, ext = osp.splitext(osp.basename(gt_name))
199
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
200
+ input_path = osp.join(input_folder, input_name)
201
+ gt_path = osp.join(gt_folder, gt_name)
202
+ paths.append(
203
+ dict([(f'{input_key}_path', input_path),
204
+ (f'{gt_key}_path', gt_path)]))
205
+ return paths
206
+
207
+
208
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
209
+ """Generate paired paths from folders.
210
+
211
+ Args:
212
+ folders (list[str]): A list of folder path. The order of list should
213
+ be [input_folder, gt_folder].
214
+ keys (list[str]): A list of keys identifying folders. The order should
215
+ be in consistent with folders, e.g., ['lq', 'gt'].
216
+ filename_tmpl (str): Template for each filename. Note that the
217
+ template excludes the file extension. Usually the filename_tmpl is
218
+ for files in the input folder.
219
+
220
+ Returns:
221
+ list[str]: Returned path list.
222
+ """
223
+ assert len(folders) == 2, (
224
+ 'The len of folders should be 2 with [input_folder, gt_folder]. '
225
+ f'But got {len(folders)}')
226
+ assert len(keys) == 2, (
227
+ 'The len of keys should be 2 with [input_key, gt_key]. '
228
+ f'But got {len(keys)}')
229
+ input_folder, gt_folder = folders
230
+ input_key, gt_key = keys
231
+
232
+ input_paths = list(scandir(input_folder))
233
+ gt_paths = list(scandir(gt_folder))
234
+ assert len(input_paths) == len(gt_paths), (
235
+ f'{input_key} and {gt_key} datasets have different number of images: '
236
+ f'{len(input_paths)}, {len(gt_paths)}.')
237
+ paths = []
238
+ for idx in range(len(gt_paths)):
239
+ gt_path = gt_paths[idx]
240
+ basename, ext = osp.splitext(osp.basename(gt_path))
241
+ input_path = input_paths[idx]
242
+ basename_input, ext_input = osp.splitext(osp.basename(input_path))
243
+ input_name = f'{filename_tmpl.format(basename)}{ext_input}'
244
+ input_path = osp.join(input_folder, input_name)
245
+ assert input_name in input_paths, (f'{input_name} is not in '
246
+ f'{input_key}_paths.')
247
+ gt_path = osp.join(gt_folder, gt_path)
248
+ paths.append(
249
+ dict([(f'{input_key}_path', input_path),
250
+ (f'{gt_key}_path', gt_path)]))
251
+ return paths
252
+
253
+ def paired_DP_paths_from_folder(folders, keys, filename_tmpl):
254
+ """Generate paired paths from folders.
255
+
256
+ Args:
257
+ folders (list[str]): A list of folder path. The order of list should
258
+ be [input_folder, gt_folder].
259
+ keys (list[str]): A list of keys identifying folders. The order should
260
+ be in consistent with folders, e.g., ['lq', 'gt'].
261
+ filename_tmpl (str): Template for each filename. Note that the
262
+ template excludes the file extension. Usually the filename_tmpl is
263
+ for files in the input folder.
264
+
265
+ Returns:
266
+ list[str]: Returned path list.
267
+ """
268
+ assert len(folders) == 3, (
269
+ 'The len of folders should be 3 with [inputL_folder, inputR_folder, gt_folder]. '
270
+ f'But got {len(folders)}')
271
+ assert len(keys) == 3, (
272
+ 'The len of keys should be 2 with [inputL_key, inputR_key, gt_key]. '
273
+ f'But got {len(keys)}')
274
+ inputL_folder, inputR_folder, gt_folder = folders
275
+ inputL_key, inputR_key, gt_key = keys
276
+
277
+ inputL_paths = list(scandir(inputL_folder))
278
+ inputR_paths = list(scandir(inputR_folder))
279
+ gt_paths = list(scandir(gt_folder))
280
+ assert len(inputL_paths) == len(inputR_paths) == len(gt_paths), (
281
+ f'{inputL_key} and {inputR_key} and {gt_key} datasets have different number of images: '
282
+ f'{len(inputL_paths)}, {len(inputR_paths)}, {len(gt_paths)}.')
283
+ paths = []
284
+ for idx in range(len(gt_paths)):
285
+ gt_path = gt_paths[idx]
286
+ basename, ext = osp.splitext(osp.basename(gt_path))
287
+ inputL_path = inputL_paths[idx]
288
+ basename_input, ext_input = osp.splitext(osp.basename(inputL_path))
289
+ inputL_name = f'{filename_tmpl.format(basename)}{ext_input}'
290
+ inputL_path = osp.join(inputL_folder, inputL_name)
291
+ assert inputL_name in inputL_paths, (f'{inputL_name} is not in '
292
+ f'{inputL_key}_paths.')
293
+ inputR_path = inputR_paths[idx]
294
+ basename_input, ext_input = osp.splitext(osp.basename(inputR_path))
295
+ inputR_name = f'{filename_tmpl.format(basename)}{ext_input}'
296
+ inputR_path = osp.join(inputR_folder, inputR_name)
297
+ assert inputR_name in inputR_paths, (f'{inputR_name} is not in '
298
+ f'{inputR_key}_paths.')
299
+ gt_path = osp.join(gt_folder, gt_path)
300
+ paths.append(
301
+ dict([(f'{inputL_key}_path', inputL_path),
302
+ (f'{inputR_key}_path', inputR_path),
303
+ (f'{gt_key}_path', gt_path)]))
304
+ return paths
305
+
306
+
307
+ def paths_from_folder(folder):
308
+ """Generate paths from folder.
309
+
310
+ Args:
311
+ folder (str): Folder path.
312
+
313
+ Returns:
314
+ list[str]: Returned path list.
315
+ """
316
+
317
+ paths = list(scandir(folder))
318
+ paths = [osp.join(folder, path) for path in paths]
319
+ return paths
320
+
321
+
322
+ def paths_from_lmdb(folder):
323
+ """Generate paths from lmdb.
324
+
325
+ Args:
326
+ folder (str): Folder path.
327
+
328
+ Returns:
329
+ list[str]: Returned path list.
330
+ """
331
+ if not folder.endswith('.lmdb'):
332
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
333
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
334
+ paths = [line.split('.')[0] for line in fin]
335
+ return paths
336
+
337
+
338
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
339
+ """Generate Gaussian kernel used in `duf_downsample`.
340
+
341
+ Args:
342
+ kernel_size (int): Kernel size. Default: 13.
343
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
344
+
345
+ Returns:
346
+ np.array: The Gaussian kernel.
347
+ """
348
+ from scipy.ndimage import filters as filters
349
+ kernel = np.zeros((kernel_size, kernel_size))
350
+ # set element at the middle to one, a dirac delta
351
+ kernel[kernel_size // 2, kernel_size // 2] = 1
352
+ # gaussian-smooth the dirac, resulting in a gaussian filter
353
+ return filters.gaussian_filter(kernel, sigma)
354
+
355
+
356
+ def duf_downsample(x, kernel_size=13, scale=4):
357
+ """Downsamping with Gaussian kernel used in the DUF official code.
358
+
359
+ Args:
360
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
361
+ kernel_size (int): Kernel size. Default: 13.
362
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
363
+ Default: 4.
364
+
365
+ Returns:
366
+ Tensor: DUF downsampled frames.
367
+ """
368
+ assert scale in (2, 3,
369
+ 4), f'Only support scale (2, 3, 4), but got {scale}.'
370
+
371
+ squeeze_flag = False
372
+ if x.ndim == 4:
373
+ squeeze_flag = True
374
+ x = x.unsqueeze(0)
375
+ b, t, c, h, w = x.size()
376
+ x = x.view(-1, 1, h, w)
377
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
378
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
379
+
380
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
381
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(
382
+ 0).unsqueeze(0)
383
+ x = F.conv2d(x, gaussian_filter, stride=scale)
384
+ x = x[:, :, 2:-2, 2:-2]
385
+ x = x.view(b, t, c, x.size(2), x.size(3))
386
+ if squeeze_flag:
387
+ x = x.squeeze(0)
388
+ return x
basicsr/data/ffhq_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ from torch.utils import data as data
3
+ from torchvision.transforms.functional import normalize
4
+
5
+ from basicsr.data.transforms import augment
6
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
7
+
8
+
9
+ class FFHQDataset(data.Dataset):
10
+ """FFHQ dataset for StyleGAN.
11
+
12
+ Args:
13
+ opt (dict): Config for train datasets. It contains the following keys:
14
+ dataroot_gt (str): Data root path for gt.
15
+ io_backend (dict): IO backend type and other kwarg.
16
+ mean (list | tuple): Image mean.
17
+ std (list | tuple): Image std.
18
+ use_hflip (bool): Whether to horizontally flip.
19
+
20
+ """
21
+
22
+ def __init__(self, opt):
23
+ super(FFHQDataset, self).__init__()
24
+ self.opt = opt
25
+ # file client (io backend)
26
+ self.file_client = None
27
+ self.io_backend_opt = opt['io_backend']
28
+
29
+ self.gt_folder = opt['dataroot_gt']
30
+ self.mean = opt['mean']
31
+ self.std = opt['std']
32
+
33
+ if self.io_backend_opt['type'] == 'lmdb':
34
+ self.io_backend_opt['db_paths'] = self.gt_folder
35
+ if not self.gt_folder.endswith('.lmdb'):
36
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "
37
+ f'but received {self.gt_folder}')
38
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
39
+ self.paths = [line.split('.')[0] for line in fin]
40
+ else:
41
+ # FFHQ has 70000 images in total
42
+ self.paths = [
43
+ osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
44
+ ]
45
+
46
+ def __getitem__(self, index):
47
+ if self.file_client is None:
48
+ self.file_client = FileClient(
49
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
50
+
51
+ # load gt image
52
+ gt_path = self.paths[index]
53
+ img_bytes = self.file_client.get(gt_path)
54
+ img_gt = imfrombytes(img_bytes, float32=True)
55
+
56
+ # random horizontal flip
57
+ img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
58
+ # BGR to RGB, HWC to CHW, numpy to tensor
59
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
60
+ # normalize
61
+ normalize(img_gt, self.mean, self.std, inplace=True)
62
+ return {'gt': img_gt, 'gt_path': gt_path}
63
+
64
+ def __len__(self):
65
+ return len(self.paths)
basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from basicsr.data.data_util import (paired_paths_from_folder,
5
+ paired_DP_paths_from_folder,
6
+ paired_paths_from_lmdb,
7
+ paired_paths_from_meta_info_file)
8
+ from basicsr.data.transforms import augment, paired_random_crop, paired_random_crop_DP, random_augmentation, paired_center_crop
9
+ from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, padding_DP, imfrombytesDP
10
+
11
+ import random
12
+ import numpy as np
13
+ import torch
14
+ import cv2
15
+
16
+ import os
17
+ from scandir import scandir
18
+
19
+ class Dataset_PairedImage_dehazeSOT(data.Dataset):
20
+ """Paired image dataset for image restoration.
21
+
22
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
23
+ GT image pairs.
24
+
25
+ There are three modes:
26
+ 1. 'lmdb': Use lmdb files.
27
+ If opt['io_backend'] == lmdb.
28
+ 2. 'meta_info_file': Use meta information file to generate paths.
29
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
30
+ 3. 'folder': Scan folders to generate paths.
31
+ The rest.
32
+
33
+ Args:
34
+ opt (dict): Config for train datasets. It contains the following keys:
35
+ dataroot_gt (str): Data root path for gt.
36
+ dataroot_lq (str): Data root path for lq.
37
+ meta_info_file (str): Path for meta information file.
38
+ io_backend (dict): IO backend type and other kwarg.
39
+ filename_tmpl (str): Template for each filename. Note that the
40
+ template excludes the file extension. Default: '{}'.
41
+ gt_size (int): Cropped patched size for gt patches.
42
+ geometric_augs (bool): Use geometric augmentations.
43
+
44
+ scale (bool): Scale, which will be added automatically.
45
+ phase (str): 'train' or 'val'.
46
+ """
47
+
48
+ def __init__(self, opt):
49
+ super(Dataset_PairedImage_dehazeSOT, self).__init__()
50
+ self.opt = opt
51
+ # file client (io backend)
52
+ self.file_client = None
53
+ self.io_backend_opt = opt['io_backend']
54
+ self.mean = opt['mean'] if 'mean' in opt else None
55
+ self.std = opt['std'] if 'std' in opt else None
56
+
57
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
58
+ if 'filename_tmpl' in opt:
59
+ self.filename_tmpl = opt['filename_tmpl']
60
+ else:
61
+ self.filename_tmpl = '{123}'
62
+
63
+ if self.io_backend_opt['type'] == 'lmdb':
64
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
65
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
66
+ self.paths = paired_paths_from_lmdb(
67
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'])
68
+ elif 'meta_info_file' in self.opt and self.opt[
69
+ 'meta_info_file'] is not None:
70
+ self.paths = paired_paths_from_meta_info_file(
71
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'],
72
+ self.opt['meta_info_file'], self.filename_tmpl)
73
+ else:
74
+ # self.paths = paired_paths_from_folder(
75
+ # [self.lq_folder, self.gt_folder], ['lq', 'gt'],
76
+ # self.filename_tmpl)
77
+ basename = '/mnt/sda/zsh/dataset/haze/promptIR'
78
+ name = ''
79
+ if self.opt['phase'] == 'train':
80
+ name = 'hazy_outside.txt'
81
+ else:
82
+ name = 'haze_test.txt'
83
+ dataset = os.path.join(basename, name)
84
+ paths = []
85
+ if self.opt['phase'] == 'train':
86
+ gt_dir = basename + '/Dehaze/original'
87
+ lq = basename + '/Dehaze'
88
+ with open(dataset, 'r') as fin:
89
+ #synthetic/part4/8961_0.95_0.08.jpg
90
+ for line in fin:
91
+ gt_path = os.path.join(gt_dir, line.split('/')[-1].split('_')[0]+ '.jpg')
92
+ # print('train gt',gt_path)
93
+ input_path = os.path.join(lq, line.strip())
94
+ # print('train input',input_path)
95
+ paths.append(
96
+ dict([(f'lq_path', input_path),
97
+ (f'gt_path', gt_path)]))
98
+ else:
99
+ gt_dir = basename + '/outdoor/gt'
100
+ lq = basename + '/outdoor/hazy'
101
+ #1917_0.95_0.2.jpg
102
+ # print('performing val dataset organize')
103
+ with open(dataset, 'r') as fin:
104
+ for line in fin:
105
+ gt_path = os.path.join(gt_dir, line.split('_')[0]+ '.png')
106
+ # print('valid gt',gt_path)
107
+ input_path = os.path.join(lq, line.strip())
108
+ # print('valid input',input_path)
109
+ paths.append(
110
+ dict([(f'lq_path', input_path),
111
+ (f'gt_path', gt_path)]))
112
+ self.paths = paths
113
+ # self.paths = [
114
+ # osp.join(self.gt_folder,
115
+ # line.split(' ')[0]) for line in fin
116
+ # ]
117
+
118
+ if self.opt['phase'] == 'train':
119
+ self.geometric_augs = opt['geometric_augs']
120
+
121
+ def __getitem__(self, index):
122
+ if self.file_client is None:
123
+ self.file_client = FileClient(
124
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
125
+
126
+ scale = self.opt['scale']
127
+ index = index % len(self.paths)
128
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
129
+ # image range: [0, 1], float32.
130
+ gt_path = self.paths[index]['gt_path']
131
+ img_bytes = self.file_client.get(gt_path, 'gt')
132
+ try:
133
+ img_gt = imfrombytes(img_bytes, float32=True)
134
+ except:
135
+ raise Exception("gt path {} not working".format(gt_path))
136
+
137
+ lq_path = self.paths[index]['lq_path']
138
+ img_bytes = self.file_client.get(lq_path, 'lq')
139
+ try:
140
+ img_lq = imfrombytes(img_bytes, float32=True)
141
+ except:
142
+ raise Exception("lq path {} not working".format(lq_path))
143
+
144
+ # augmentation for training
145
+ if self.opt['phase'] == 'train':
146
+ gt_size = self.opt['gt_size']
147
+ # padding
148
+ img_gt, img_lq = padding(img_gt, img_lq, gt_size)
149
+
150
+ # random crop
151
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
152
+ gt_path)
153
+
154
+ # flip, rotation augmentations
155
+ if self.geometric_augs:
156
+ img_gt, img_lq = random_augmentation(img_gt, img_lq)
157
+ elif self.opt['phase'] == 'val':
158
+ # print('entering val processing')
159
+
160
+ #centerCrop for validation
161
+ gt_size = self.opt['gt_size']
162
+ img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
163
+ gt_path)
164
+ elif self.opt['phase'] == 'test':
165
+ #doingNothing
166
+ print('Test on Full Image')
167
+
168
+
169
+
170
+ # BGR to RGB, HWC to CHW, numpy to tensor
171
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
172
+ bgr2rgb=True,
173
+ float32=True)
174
+ # normalize
175
+ if self.mean is not None or self.std is not None:
176
+ normalize(img_lq, self.mean, self.std, inplace=True)
177
+ normalize(img_gt, self.mean, self.std, inplace=True)
178
+
179
+ return {
180
+ 'lq': img_lq,
181
+ 'gt': img_gt,
182
+ 'lq_path': lq_path,
183
+ 'gt_path': gt_path
184
+ }
185
+
186
+ def __len__(self):
187
+ return len(self.paths)
188
+
189
+ class Dataset_PairedImage_denseHaze(data.Dataset):
190
+ """Paired image dataset for image restoration.
191
+
192
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
193
+ GT image pairs.
194
+
195
+ There are three modes:
196
+ 1. 'lmdb': Use lmdb files.
197
+ If opt['io_backend'] == lmdb.
198
+ 2. 'meta_info_file': Use meta information file to generate paths.
199
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
200
+ 3. 'folder': Scan folders to generate paths.
201
+ The rest.
202
+
203
+ Args:
204
+ opt (dict): Config for train datasets. It contains the following keys:
205
+ dataroot_gt (str): Data root path for gt.
206
+ dataroot_lq (str): Data root path for lq.
207
+ meta_info_file (str): Path for meta information file.
208
+ io_backend (dict): IO backend type and other kwarg.
209
+ filename_tmpl (str): Template for each filename. Note that the
210
+ template excludes the file extension. Default: '{}'.
211
+ gt_size (int): Cropped patched size for gt patches.
212
+ geometric_augs (bool): Use geometric augmentations.
213
+
214
+ scale (bool): Scale, which will be added automatically.
215
+ phase (str): 'train' or 'val'.
216
+ """
217
+
218
+ def __init__(self, opt):
219
+ super(Dataset_PairedImage_denseHaze, self).__init__()
220
+ self.opt = opt
221
+ # file client (io backend)
222
+ self.file_client = None
223
+ self.io_backend_opt = opt['io_backend']
224
+ self.mean = opt['mean'] if 'mean' in opt else None
225
+ self.std = opt['std'] if 'std' in opt else None
226
+
227
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
228
+ if 'filename_tmpl' in opt:
229
+ self.filename_tmpl = opt['filename_tmpl']
230
+ else:
231
+ self.filename_tmpl = '{}'
232
+
233
+ if self.io_backend_opt['type'] == 'lmdb':
234
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
235
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
236
+ self.paths = paired_paths_from_lmdb(
237
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'])
238
+ elif 'meta_info_file' in self.opt and self.opt[
239
+ 'meta_info_file'] is not None:
240
+ self.paths = paired_paths_from_meta_info_file(
241
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'],
242
+ self.opt['meta_info_file'], self.filename_tmpl)
243
+ else:
244
+ self.paths = paired_paths_from_folder(
245
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'],
246
+ self.filename_tmpl)
247
+
248
+ if self.opt['phase'] == 'train':
249
+ self.geometric_augs = opt['geometric_augs']
250
+
251
+ def __getitem__(self, index):
252
+ if self.file_client is None:
253
+ self.file_client = FileClient(
254
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
255
+
256
+ scale = self.opt['scale']
257
+ index = index % len(self.paths)
258
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
259
+ # image range: [0, 1], float32.
260
+ gt_path = self.paths[index]['gt_path']
261
+ img_bytes = self.file_client.get(gt_path, 'gt')
262
+ try:
263
+ img_gt = imfrombytes(img_bytes, float32=True)
264
+ except:
265
+ raise Exception("gt path {} not working".format(gt_path))
266
+
267
+ lq_path = self.paths[index]['lq_path']
268
+ img_bytes = self.file_client.get(lq_path, 'lq')
269
+ try:
270
+ img_lq = imfrombytes(img_bytes, float32=True)
271
+ except:
272
+ raise Exception("lq path {} not working".format(lq_path))
273
+
274
+ # augmentation for training
275
+ if self.opt['phase'] == 'train':
276
+ gt_size = self.opt['gt_size']
277
+ # padding
278
+ img_gt, img_lq = padding(img_gt, img_lq, gt_size)
279
+
280
+ # random crop
281
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
282
+ gt_path)
283
+
284
+ # flip, rotation augmentations
285
+ if self.geometric_augs:
286
+ img_gt, img_lq = random_augmentation(img_gt, img_lq)
287
+
288
+ elif self.opt['phase'] == 'val':
289
+ # print('entering val processing')
290
+
291
+ #centerCrop for validation
292
+ gt_size = self.opt['gt_size']
293
+ img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
294
+ gt_path)
295
+ # BGR to RGB, HWC to CHW, numpy to tensor
296
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
297
+ bgr2rgb=True,
298
+ float32=True)
299
+ # normalize
300
+ if self.mean is not None or self.std is not None:
301
+ normalize(img_lq, self.mean, self.std, inplace=True)
302
+ normalize(img_gt, self.mean, self.std, inplace=True)
303
+
304
+ return {
305
+ 'lq': img_lq,
306
+ 'gt': img_gt,
307
+ 'lq_path': lq_path,
308
+ 'gt_path': gt_path
309
+ }
310
+
311
+ def __len__(self):
312
+ return len(self.paths)
313
+
314
+
315
+ class Dataset_PairedImage(data.Dataset):
316
+ """Paired image dataset for image restoration.
317
+
318
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
319
+ GT image pairs.
320
+
321
+ There are three modes:
322
+ 1. 'lmdb': Use lmdb files.
323
+ If opt['io_backend'] == lmdb.
324
+ 2. 'meta_info_file': Use meta information file to generate paths.
325
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
326
+ 3. 'folder': Scan folders to generate paths.
327
+ The rest.
328
+
329
+ Args:
330
+ opt (dict): Config for train datasets. It contains the following keys:
331
+ dataroot_gt (str): Data root path for gt.
332
+ dataroot_lq (str): Data root path for lq.
333
+ meta_info_file (str): Path for meta information file.
334
+ io_backend (dict): IO backend type and other kwarg.
335
+ filename_tmpl (str): Template for each filename. Note that the
336
+ template excludes the file extension. Default: '{}'.
337
+ gt_size (int): Cropped patched size for gt patches.
338
+ geometric_augs (bool): Use geometric augmentations.
339
+
340
+ scale (bool): Scale, which will be added automatically.
341
+ phase (str): 'train' or 'val'.
342
+ """
343
+
344
+ def __init__(self, opt):
345
+ super(Dataset_PairedImage, self).__init__()
346
+ self.opt = opt
347
+ # file client (io backend)
348
+ self.file_client = None
349
+ self.io_backend_opt = opt['io_backend']
350
+ self.mean = opt['mean'] if 'mean' in opt else None
351
+ self.std = opt['std'] if 'std' in opt else None
352
+
353
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
354
+ if 'filename_tmpl' in opt:
355
+ self.filename_tmpl = opt['filename_tmpl']
356
+ else:
357
+ self.filename_tmpl = '{}'
358
+
359
+ if self.io_backend_opt['type'] == 'lmdb':
360
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
361
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
362
+ self.paths = paired_paths_from_lmdb(
363
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'])
364
+ elif 'meta_info_file' in self.opt and self.opt[
365
+ 'meta_info_file'] is not None:
366
+ self.paths = paired_paths_from_meta_info_file(
367
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'],
368
+ self.opt['meta_info_file'], self.filename_tmpl)
369
+ else:
370
+ self.paths = paired_paths_from_folder(
371
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'],
372
+ self.filename_tmpl)
373
+
374
+ if self.opt['phase'] == 'train':
375
+ self.geometric_augs = opt['geometric_augs']
376
+
377
+ def __getitem__(self, index):
378
+ if self.file_client is None:
379
+ self.file_client = FileClient(
380
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
381
+
382
+ scale = self.opt['scale']
383
+ index = index % len(self.paths)
384
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
385
+ # image range: [0, 1], float32.
386
+ gt_path = self.paths[index]['gt_path']
387
+ img_bytes = self.file_client.get(gt_path, 'gt')
388
+ try:
389
+ img_gt = imfrombytes(img_bytes, float32=True)
390
+ except:
391
+ raise Exception("gt path {} not working".format(gt_path))
392
+
393
+ lq_path = self.paths[index]['lq_path']
394
+ img_bytes = self.file_client.get(lq_path, 'lq')
395
+ try:
396
+ img_lq = imfrombytes(img_bytes, float32=True)
397
+ except:
398
+ raise Exception("lq path {} not working".format(lq_path))
399
+
400
+ # augmentation for training
401
+ if self.opt['phase'] == 'train':
402
+ gt_size = self.opt['gt_size']
403
+ # padding
404
+ img_gt, img_lq = padding(img_gt, img_lq, gt_size)
405
+
406
+ # random crop
407
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
408
+ gt_path)
409
+
410
+ # flip, rotation augmentations
411
+ if self.geometric_augs:
412
+ img_gt, img_lq = random_augmentation(img_gt, img_lq)
413
+
414
+
415
+ # BGR to RGB, HWC to CHW, numpy to tensor
416
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
417
+ bgr2rgb=True,
418
+ float32=True)
419
+ # normalize
420
+ if self.mean is not None or self.std is not None:
421
+ normalize(img_lq, self.mean, self.std, inplace=True)
422
+ normalize(img_gt, self.mean, self.std, inplace=True)
423
+
424
+ return {
425
+ 'lq': img_lq,
426
+ 'gt': img_gt,
427
+ 'lq_path': lq_path,
428
+ 'gt_path': gt_path
429
+ }
430
+
431
+ def __len__(self):
432
+ return len(self.paths)
433
+
434
+
435
+ class Dataset_PairedImage_derainSpad(data.Dataset):
436
+ """Paired image dataset for image restoration.
437
+
438
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
439
+ GT image pairs.
440
+
441
+ There are three modes:
442
+ 1. 'lmdb': Use lmdb files.
443
+ If opt['io_backend'] == lmdb.
444
+ 2. 'meta_info_file': Use meta information file to generate paths.
445
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
446
+ 3. 'folder': Scan folders to generate paths.
447
+ The rest.
448
+
449
+ Args:
450
+ opt (dict): Config for train datasets. It contains the following keys:
451
+ dataroot_gt (str): Data root path for gt.
452
+ dataroot_lq (str): Data root path for lq.
453
+ meta_info_file (str): Path for meta information file.
454
+ io_backend (dict): IO backend type and other kwarg.
455
+ filename_tmpl (str): Template for each filename. Note that the
456
+ template excludes the file extension. Default: '{}'.
457
+ gt_size (int): Cropped patched size for gt patches.
458
+ geometric_augs (bool): Use geometric augmentations.
459
+
460
+ scale (bool): Scale, which will be added automatically.
461
+ phase (str): 'train' or 'val'.
462
+ """
463
+
464
+ def __init__(self, opt):
465
+ super(Dataset_PairedImage_derainSpad, self).__init__()
466
+ self.opt = opt
467
+ # file client (io backend)
468
+ self.file_client = None
469
+ self.io_backend_opt = opt['io_backend']
470
+ self.mean = opt['mean'] if 'mean' in opt else None
471
+ self.std = opt['std'] if 'std' in opt else None
472
+
473
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
474
+ if 'filename_tmpl' in opt:
475
+ self.filename_tmpl = opt['filename_tmpl']
476
+ else:
477
+ self.filename_tmpl = '{123}'
478
+
479
+ if self.io_backend_opt['type'] == 'lmdb':
480
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
481
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
482
+ self.paths = paired_paths_from_lmdb(
483
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'])
484
+ elif 'meta_info_file' in self.opt and self.opt[
485
+ 'meta_info_file'] is not None:
486
+ self.paths = paired_paths_from_meta_info_file(
487
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'],
488
+ self.opt['meta_info_file'], self.filename_tmpl)
489
+ else:
490
+ # self.paths = paired_paths_from_folder(
491
+ # [self.lq_folder, self.gt_folder], ['lq', 'gt'],
492
+ # self.filename_tmpl)
493
+ basename = '/home/ubuntu/zsh/datasets/derain'
494
+ name = ''
495
+ if self.opt['phase'] == 'train':
496
+ name = 'real_world.txt'
497
+ else:
498
+ name = 'real_test_1000.txt'
499
+ dataset = os.path.join(basename, name)
500
+ paths = []
501
+ with open(dataset, 'r') as fin:
502
+ for line in fin:
503
+ gt_path = os.path.join(basename, line.split(' ')[1][1:-1])
504
+ input_path = os.path.join(basename, line.split(' ')[0][1:])
505
+ paths.append(
506
+ dict([(f'lq_path', input_path),
507
+ (f'gt_path', gt_path)]))
508
+ self.paths = paths
509
+ # self.paths = [
510
+ # osp.join(self.gt_folder,
511
+ # line.split(' ')[0]) for line in fin
512
+ # ]
513
+
514
+ if self.opt['phase'] == 'train':
515
+ self.geometric_augs = opt['geometric_augs']
516
+
517
+ def __getitem__(self, index):
518
+ if self.file_client is None:
519
+ self.file_client = FileClient(
520
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
521
+
522
+ scale = self.opt['scale']
523
+ index = index % len(self.paths)
524
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
525
+ # image range: [0, 1], float32.
526
+ gt_path = self.paths[index]['gt_path']
527
+ img_bytes = self.file_client.get(gt_path, 'gt')
528
+ try:
529
+ img_gt = imfrombytes(img_bytes, float32=True)
530
+ except:
531
+ raise Exception("gt path {} not working".format(gt_path))
532
+
533
+ lq_path = self.paths[index]['lq_path']
534
+ img_bytes = self.file_client.get(lq_path, 'lq')
535
+ try:
536
+ img_lq = imfrombytes(img_bytes, float32=True)
537
+ except:
538
+ raise Exception("lq path {} not working".format(lq_path))
539
+
540
+ # augmentation for training
541
+ if self.opt['phase'] == 'train':
542
+ gt_size = self.opt['gt_size']
543
+ # padding
544
+ img_gt, img_lq = padding(img_gt, img_lq, gt_size)
545
+
546
+ # random crop
547
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
548
+ gt_path)
549
+
550
+ # flip, rotation augmentations
551
+ if self.geometric_augs:
552
+ img_gt, img_lq = random_augmentation(img_gt, img_lq)
553
+ elif self.opt['phase'] == 'val':
554
+ # print('entering val processing')
555
+
556
+ #centerCrop for validation
557
+ gt_size = self.opt['gt_size']
558
+ img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
559
+ gt_path)
560
+ elif self.opt['phase'] == 'test':
561
+ #doingNothing
562
+ print('Test on Full Image')
563
+
564
+
565
+
566
+ # BGR to RGB, HWC to CHW, numpy to tensor
567
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
568
+ bgr2rgb=True,
569
+ float32=True)
570
+ # normalize
571
+ if self.mean is not None or self.std is not None:
572
+ normalize(img_lq, self.mean, self.std, inplace=True)
573
+ normalize(img_gt, self.mean, self.std, inplace=True)
574
+
575
+ return {
576
+ 'lq': img_lq,
577
+ 'gt': img_gt,
578
+ 'lq_path': lq_path,
579
+ 'gt_path': gt_path
580
+ }
581
+
582
+ def __len__(self):
583
+ return len(self.paths)
584
+
585
+ class Dataset_GaussianDenoising(data.Dataset):
586
+ """Paired image dataset for image restoration.
587
+
588
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
589
+ GT image pairs.
590
+
591
+ There are three modes:
592
+ 1. 'lmdb': Use lmdb files.
593
+ If opt['io_backend'] == lmdb.
594
+ 2. 'meta_info_file': Use meta information file to generate paths.
595
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
596
+ 3. 'folder': Scan folders to generate paths.
597
+ The rest.
598
+
599
+ Args:
600
+ opt (dict): Config for train datasets. It contains the following keys:
601
+ dataroot_gt (str): Data root path for gt.
602
+ meta_info_file (str): Path for meta information file.
603
+ io_backend (dict): IO backend type and other kwarg.
604
+ gt_size (int): Cropped patched size for gt patches.
605
+ use_flip (bool): Use horizontal flips.
606
+ use_rot (bool): Use rotation (use vertical flip and transposing h
607
+ and w for implementation).
608
+
609
+ scale (bool): Scale, which will be added automatically.
610
+ phase (str): 'train' or 'val'.
611
+ """
612
+
613
+ def __init__(self, opt):
614
+ super(Dataset_GaussianDenoising, self).__init__()
615
+ self.opt = opt
616
+
617
+ if self.opt['phase'] == 'train':
618
+ self.sigma_type = opt['sigma_type']
619
+ self.sigma_range = opt['sigma_range']
620
+ assert self.sigma_type in ['constant', 'random', 'choice']
621
+ else:
622
+ self.sigma_test = opt['sigma_test']
623
+ self.in_ch = opt['in_ch']
624
+
625
+ # file client (io backend)
626
+ self.file_client = None
627
+ self.io_backend_opt = opt['io_backend']
628
+ self.mean = opt['mean'] if 'mean' in opt else None
629
+ self.std = opt['std'] if 'std' in opt else None
630
+
631
+ self.gt_folder = opt['dataroot_gt']
632
+
633
+ if self.io_backend_opt['type'] == 'lmdb':
634
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
635
+ self.io_backend_opt['client_keys'] = ['gt']
636
+ self.paths = paths_from_lmdb(self.gt_folder)
637
+ elif 'meta_info_file' in self.opt:
638
+ with open(self.opt['meta_info_file'], 'r') as fin:
639
+ self.paths = [
640
+ osp.join(self.gt_folder,
641
+ line.split(' ')[0]) for line in fin
642
+ ]
643
+ else:
644
+ #self.paths = sorted(list(scandir(self.gt_folder, full_path=True)))
645
+ #self.paths = sorted(list(scandir(self.gt_folder)))
646
+ self.paths = list(scandir(self.gt_folder))
647
+ # self.paths = (list(scandir(self.gt_folder, full_path=True)))
648
+
649
+ if self.opt['phase'] == 'train':
650
+ self.geometric_augs = self.opt['geometric_augs']
651
+
652
+ def __getitem__(self, index):
653
+ if self.file_client is None:
654
+ self.file_client = FileClient(
655
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
656
+
657
+ scale = self.opt['scale']
658
+ index = index % len(self.paths)
659
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
660
+ # image range: [0, 1], float32.
661
+ # gt_path = self.paths[index]['gt_path']
662
+ gt_path = self.paths[index].path
663
+ # gt_path = os.path.join(self.gt_folder,gt_path)
664
+ img_bytes = self.file_client.get(gt_path, 'gt')
665
+
666
+ if self.in_ch == 3:
667
+ try:
668
+ img_gt = imfrombytes(img_bytes, float32=True)
669
+ except:
670
+ raise Exception("gt path {} not working".format(gt_path))
671
+
672
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
673
+ else:
674
+ try:
675
+ img_gt = imfrombytes(img_bytes, flag='grayscale', float32=True)
676
+ except:
677
+ raise Exception("gt path {} not working".format(gt_path))
678
+
679
+ img_gt = np.expand_dims(img_gt, axis=2)
680
+ img_lq = img_gt.copy()
681
+
682
+
683
+ # augmentation for training
684
+ if self.opt['phase'] == 'train':
685
+ gt_size = self.opt['gt_size']
686
+ # padding
687
+ img_gt, img_lq = padding(img_gt, img_lq, gt_size)
688
+
689
+ # random crop
690
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
691
+ gt_path)
692
+ # flip, rotation
693
+ if self.geometric_augs:
694
+ img_gt, img_lq = random_augmentation(img_gt, img_lq)
695
+
696
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
697
+ bgr2rgb=False,
698
+ float32=True)
699
+
700
+
701
+ if self.sigma_type == 'constant':
702
+ sigma_value = self.sigma_range
703
+ elif self.sigma_type == 'random':
704
+ sigma_value = random.uniform(self.sigma_range[0], self.sigma_range[1])
705
+ elif self.sigma_type == 'choice':
706
+ sigma_value = random.choice(self.sigma_range)
707
+
708
+ noise_level = torch.FloatTensor([sigma_value])/255.0
709
+ # noise_level_map = torch.ones((1, img_lq.size(1), img_lq.size(2))).mul_(noise_level).float()
710
+ noise = torch.randn(img_lq.size()).mul_(noise_level).float()
711
+ img_lq.add_(noise)
712
+
713
+ else:
714
+ #change here to update center
715
+ gt_size = self.opt['gt_size']
716
+ img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale,
717
+ gt_path)
718
+
719
+ np.random.seed(seed=0)
720
+ img_lq += np.random.normal(0, self.sigma_test/255.0, img_lq.shape)
721
+ # noise_level_map = torch.ones((1, img_lq.shape[0], img_lq.shape[1])).mul_(self.sigma_test/255.0).float()
722
+
723
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
724
+ bgr2rgb=False,
725
+ float32=True)
726
+
727
+ return {
728
+ 'lq': img_lq,
729
+ 'gt': img_gt,
730
+ 'lq_path': gt_path,
731
+ 'gt_path': gt_path
732
+ }
733
+
734
+ def __len__(self):
735
+ return len(self.paths)
736
+
737
+ class Dataset_DefocusDeblur_DualPixel_16bit(data.Dataset):
738
+ def __init__(self, opt):
739
+ super(Dataset_DefocusDeblur_DualPixel_16bit, self).__init__()
740
+ self.opt = opt
741
+ # file client (io backend)
742
+ self.file_client = None
743
+ self.io_backend_opt = opt['io_backend']
744
+ self.mean = opt['mean'] if 'mean' in opt else None
745
+ self.std = opt['std'] if 'std' in opt else None
746
+
747
+ self.gt_folder, self.lqL_folder, self.lqR_folder = opt['dataroot_gt'], opt['dataroot_lqL'], opt['dataroot_lqR']
748
+ if 'filename_tmpl' in opt:
749
+ self.filename_tmpl = opt['filename_tmpl']
750
+ else:
751
+ self.filename_tmpl = '{}'
752
+
753
+ self.paths = paired_DP_paths_from_folder(
754
+ [self.lqL_folder, self.lqR_folder, self.gt_folder], ['lqL', 'lqR', 'gt'],
755
+ self.filename_tmpl)
756
+
757
+ if self.opt['phase'] == 'train':
758
+ self.geometric_augs = self.opt['geometric_augs']
759
+
760
+ def __getitem__(self, index):
761
+ if self.file_client is None:
762
+ self.file_client = FileClient(
763
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
764
+
765
+ scale = self.opt['scale']
766
+ index = index % len(self.paths)
767
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
768
+ # image range: [0, 1], float32.
769
+ gt_path = self.paths[index]['gt_path']
770
+ img_bytes = self.file_client.get(gt_path, 'gt')
771
+ try:
772
+ img_gt = imfrombytesDP(img_bytes, float32=True)
773
+ except:
774
+ raise Exception("gt path {} not working".format(gt_path))
775
+
776
+ lqL_path = self.paths[index]['lqL_path']
777
+ img_bytes = self.file_client.get(lqL_path, 'lqL')
778
+ try:
779
+ img_lqL = imfrombytesDP(img_bytes, float32=True)
780
+ except:
781
+ raise Exception("lqL path {} not working".format(lqL_path))
782
+
783
+ lqR_path = self.paths[index]['lqR_path']
784
+ img_bytes = self.file_client.get(lqR_path, 'lqR')
785
+ try:
786
+ img_lqR = imfrombytesDP(img_bytes, float32=True)
787
+ except:
788
+ raise Exception("lqR path {} not working".format(lqR_path))
789
+
790
+
791
+ # augmentation for training
792
+ if self.opt['phase'] == 'train':
793
+ gt_size = self.opt['gt_size']
794
+ # padding
795
+ img_lqL, img_lqR, img_gt = padding_DP(img_lqL, img_lqR, img_gt, gt_size)
796
+
797
+ # random crop
798
+ img_lqL, img_lqR, img_gt = paired_random_crop_DP(img_lqL, img_lqR, img_gt, gt_size, scale, gt_path)
799
+
800
+ # flip, rotation
801
+ if self.geometric_augs:
802
+ img_lqL, img_lqR, img_gt = random_augmentation(img_lqL, img_lqR, img_gt)
803
+ # TODO: color space transform
804
+ # BGR to RGB, HWC to CHW, numpy to tensor
805
+ img_lqL, img_lqR, img_gt = img2tensor([img_lqL, img_lqR, img_gt],
806
+ bgr2rgb=True,
807
+ float32=True)
808
+ # normalize
809
+ if self.mean is not None or self.std is not None:
810
+ normalize(img_lqL, self.mean, self.std, inplace=True)
811
+ normalize(img_lqR, self.mean, self.std, inplace=True)
812
+ normalize(img_gt, self.mean, self.std, inplace=True)
813
+
814
+ img_lq = torch.cat([img_lqL, img_lqR], 0)
815
+
816
+ return {
817
+ 'lq': img_lq,
818
+ 'gt': img_gt,
819
+ 'lq_path': lqL_path,
820
+ 'gt_path': gt_path
821
+ }
822
+
823
+ def __len__(self):
824
+ return len(self.paths)
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(
116
+ device=self.device, non_blocking=True)
117
+
118
+ def next(self):
119
+ torch.cuda.current_stream().wait_stream(self.stream)
120
+ batch = self.batch
121
+ self.preload()
122
+ return batch
123
+
124
+ def reset(self):
125
+ self.loader = iter(self.ori_loader)
126
+ self.preload()
basicsr/data/reds_dataset.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from pathlib import Path
5
+ from torch.utils import data as data
6
+
7
+ from basicsr.data.transforms import augment, paired_random_crop
8
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
+ from basicsr.utils.flow_util import dequantize_flow
10
+
11
+
12
+ class REDSDataset(data.Dataset):
13
+ """REDS dataset for training.
14
+
15
+ The keys are generated from a meta info txt file.
16
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
17
+
18
+ Each line contains:
19
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, seperated by
20
+ a white space.
21
+ Examples:
22
+ 000 100 (720,1280,3)
23
+ 001 100 (720,1280,3)
24
+ ...
25
+
26
+ Key examples: "000/00000000"
27
+ GT (gt): Ground-Truth;
28
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
29
+
30
+ Args:
31
+ opt (dict): Config for train dataset. It contains the following keys:
32
+ dataroot_gt (str): Data root path for gt.
33
+ dataroot_lq (str): Data root path for lq.
34
+ dataroot_flow (str, optional): Data root path for flow.
35
+ meta_info_file (str): Path for meta information file.
36
+ val_partition (str): Validation partition types. 'REDS4' or
37
+ 'official'.
38
+ io_backend (dict): IO backend type and other kwarg.
39
+
40
+ num_frame (int): Window size for input frames.
41
+ gt_size (int): Cropped patched size for gt patches.
42
+ interval_list (list): Interval list for temporal augmentation.
43
+ random_reverse (bool): Random reverse input frames.
44
+ use_flip (bool): Use horizontal flips.
45
+ use_rot (bool): Use rotation (use vertical flip and transposing h
46
+ and w for implementation).
47
+
48
+ scale (bool): Scale, which will be added automatically.
49
+ """
50
+
51
+ def __init__(self, opt):
52
+ super(REDSDataset, self).__init__()
53
+ self.opt = opt
54
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
55
+ opt['dataroot_lq'])
56
+ self.flow_root = Path(
57
+ opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
58
+ assert opt['num_frame'] % 2 == 1, (
59
+ f'num_frame should be odd number, but got {opt["num_frame"]}')
60
+ self.num_frame = opt['num_frame']
61
+ self.num_half_frames = opt['num_frame'] // 2
62
+
63
+ self.keys = []
64
+ with open(opt['meta_info_file'], 'r') as fin:
65
+ for line in fin:
66
+ folder, frame_num, _ = line.split(' ')
67
+ self.keys.extend(
68
+ [f'{folder}/{i:08d}' for i in range(int(frame_num))])
69
+
70
+ # remove the video clips used in validation
71
+ if opt['val_partition'] == 'REDS4':
72
+ val_partition = ['000', '011', '015', '020']
73
+ elif opt['val_partition'] == 'official':
74
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
75
+ else:
76
+ raise ValueError(
77
+ f'Wrong validation partition {opt["val_partition"]}.'
78
+ f"Supported ones are ['official', 'REDS4'].")
79
+ self.keys = [
80
+ v for v in self.keys if v.split('/')[0] not in val_partition
81
+ ]
82
+
83
+ # file client (io backend)
84
+ self.file_client = None
85
+ self.io_backend_opt = opt['io_backend']
86
+ self.is_lmdb = False
87
+ if self.io_backend_opt['type'] == 'lmdb':
88
+ self.is_lmdb = True
89
+ if self.flow_root is not None:
90
+ self.io_backend_opt['db_paths'] = [
91
+ self.lq_root, self.gt_root, self.flow_root
92
+ ]
93
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
94
+ else:
95
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
96
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
97
+
98
+ # temporal augmentation configs
99
+ self.interval_list = opt['interval_list']
100
+ self.random_reverse = opt['random_reverse']
101
+ interval_str = ','.join(str(x) for x in opt['interval_list'])
102
+ logger = get_root_logger()
103
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
104
+ f'random reverse is {self.random_reverse}.')
105
+
106
+ def __getitem__(self, index):
107
+ if self.file_client is None:
108
+ self.file_client = FileClient(
109
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
110
+
111
+ scale = self.opt['scale']
112
+ gt_size = self.opt['gt_size']
113
+ key = self.keys[index]
114
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
115
+ center_frame_idx = int(frame_name)
116
+
117
+ # determine the neighboring frames
118
+ interval = random.choice(self.interval_list)
119
+
120
+ # ensure not exceeding the borders
121
+ start_frame_idx = center_frame_idx - self.num_half_frames * interval
122
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
123
+ # each clip has 100 frames starting from 0 to 99
124
+ while (start_frame_idx < 0) or (end_frame_idx > 99):
125
+ center_frame_idx = random.randint(0, 99)
126
+ start_frame_idx = (
127
+ center_frame_idx - self.num_half_frames * interval)
128
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
129
+ frame_name = f'{center_frame_idx:08d}'
130
+ neighbor_list = list(
131
+ range(center_frame_idx - self.num_half_frames * interval,
132
+ center_frame_idx + self.num_half_frames * interval + 1,
133
+ interval))
134
+ # random reverse
135
+ if self.random_reverse and random.random() < 0.5:
136
+ neighbor_list.reverse()
137
+
138
+ assert len(neighbor_list) == self.num_frame, (
139
+ f'Wrong length of neighbor list: {len(neighbor_list)}')
140
+
141
+ # get the GT frame (as the center frame)
142
+ if self.is_lmdb:
143
+ img_gt_path = f'{clip_name}/{frame_name}'
144
+ else:
145
+ img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
146
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
147
+ img_gt = imfrombytes(img_bytes, float32=True)
148
+
149
+ # get the neighboring LQ frames
150
+ img_lqs = []
151
+ for neighbor in neighbor_list:
152
+ if self.is_lmdb:
153
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
154
+ else:
155
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
156
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
157
+ img_lq = imfrombytes(img_bytes, float32=True)
158
+ img_lqs.append(img_lq)
159
+
160
+ # get flows
161
+ if self.flow_root is not None:
162
+ img_flows = []
163
+ # read previous flows
164
+ for i in range(self.num_half_frames, 0, -1):
165
+ if self.is_lmdb:
166
+ flow_path = f'{clip_name}/{frame_name}_p{i}'
167
+ else:
168
+ flow_path = (
169
+ self.flow_root / clip_name / f'{frame_name}_p{i}.png')
170
+ img_bytes = self.file_client.get(flow_path, 'flow')
171
+ cat_flow = imfrombytes(
172
+ img_bytes, flag='grayscale',
173
+ float32=False) # uint8, [0, 255]
174
+ dx, dy = np.split(cat_flow, 2, axis=0)
175
+ flow = dequantize_flow(
176
+ dx, dy, max_val=20,
177
+ denorm=False) # we use max_val 20 here.
178
+ img_flows.append(flow)
179
+ # read next flows
180
+ for i in range(1, self.num_half_frames + 1):
181
+ if self.is_lmdb:
182
+ flow_path = f'{clip_name}/{frame_name}_n{i}'
183
+ else:
184
+ flow_path = (
185
+ self.flow_root / clip_name / f'{frame_name}_n{i}.png')
186
+ img_bytes = self.file_client.get(flow_path, 'flow')
187
+ cat_flow = imfrombytes(
188
+ img_bytes, flag='grayscale',
189
+ float32=False) # uint8, [0, 255]
190
+ dx, dy = np.split(cat_flow, 2, axis=0)
191
+ flow = dequantize_flow(
192
+ dx, dy, max_val=20,
193
+ denorm=False) # we use max_val 20 here.
194
+ img_flows.append(flow)
195
+
196
+ # for random crop, here, img_flows and img_lqs have the same
197
+ # spatial size
198
+ img_lqs.extend(img_flows)
199
+
200
+ # randomly crop
201
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
202
+ img_gt_path)
203
+ if self.flow_root is not None:
204
+ img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.
205
+ num_frame:]
206
+
207
+ # augmentation - flip, rotate
208
+ img_lqs.append(img_gt)
209
+ if self.flow_root is not None:
210
+ img_results, img_flows = augment(img_lqs, self.opt['use_flip'],
211
+ self.opt['use_rot'], img_flows)
212
+ else:
213
+ img_results = augment(img_lqs, self.opt['use_flip'],
214
+ self.opt['use_rot'])
215
+
216
+ img_results = img2tensor(img_results)
217
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
218
+ img_gt = img_results[-1]
219
+
220
+ if self.flow_root is not None:
221
+ img_flows = img2tensor(img_flows)
222
+ # add the zero center flow
223
+ img_flows.insert(self.num_half_frames,
224
+ torch.zeros_like(img_flows[0]))
225
+ img_flows = torch.stack(img_flows, dim=0)
226
+
227
+ # img_lqs: (t, c, h, w)
228
+ # img_flows: (t, 2, h, w)
229
+ # img_gt: (c, h, w)
230
+ # key: str
231
+ if self.flow_root is not None:
232
+ return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
233
+ else:
234
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
235
+
236
+ def __len__(self):
237
+ return len(self.keys)
basicsr/data/single_image_dataset.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ from torch.utils import data as data
3
+ from torchvision.transforms.functional import normalize
4
+
5
+ from basicsr.data.data_util import paths_from_lmdb
6
+ from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7
+
8
+
9
+ class SingleImageDataset(data.Dataset):
10
+ """Read only lq images in the test phase.
11
+
12
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13
+
14
+ There are two modes:
15
+ 1. 'meta_info_file': Use meta information file to generate paths.
16
+ 2. 'folder': Scan folders to generate paths.
17
+
18
+ Args:
19
+ opt (dict): Config for train datasets. It contains the following keys:
20
+ dataroot_lq (str): Data root path for lq.
21
+ meta_info_file (str): Path for meta information file.
22
+ io_backend (dict): IO backend type and other kwarg.
23
+ """
24
+
25
+ def __init__(self, opt):
26
+ super(SingleImageDataset, self).__init__()
27
+ self.opt = opt
28
+ # file client (io backend)
29
+ self.file_client = None
30
+ self.io_backend_opt = opt['io_backend']
31
+ self.mean = opt['mean'] if 'mean' in opt else None
32
+ self.std = opt['std'] if 'std' in opt else None
33
+ self.lq_folder = opt['dataroot_lq']
34
+
35
+ if self.io_backend_opt['type'] == 'lmdb':
36
+ self.io_backend_opt['db_paths'] = [self.lq_folder]
37
+ self.io_backend_opt['client_keys'] = ['lq']
38
+ self.paths = paths_from_lmdb(self.lq_folder)
39
+ elif 'meta_info_file' in self.opt:
40
+ with open(self.opt['meta_info_file'], 'r') as fin:
41
+ self.paths = [
42
+ osp.join(self.lq_folder,
43
+ line.split(' ')[0]) for line in fin
44
+ ]
45
+ else:
46
+ self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47
+
48
+ def __getitem__(self, index):
49
+ if self.file_client is None:
50
+ self.file_client = FileClient(
51
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
52
+
53
+ # load lq image
54
+ lq_path = self.paths[index]
55
+ img_bytes = self.file_client.get(lq_path, 'lq')
56
+ img_lq = imfrombytes(img_bytes, float32=True)
57
+
58
+ # TODO: color space transform
59
+ # BGR to RGB, HWC to CHW, numpy to tensor
60
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61
+ # normalize
62
+ if self.mean is not None or self.std is not None:
63
+ normalize(img_lq, self.mean, self.std, inplace=True)
64
+ return {'lq': img_lq, 'lq_path': lq_path}
65
+
66
+ def __len__(self):
67
+ return len(self.paths)
basicsr/data/transforms.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ def mod_crop(img, scale):
7
+ """Mod crop images, used during testing.
8
+
9
+ Args:
10
+ img (ndarray): Input image.
11
+ scale (int): Scale factor.
12
+
13
+ Returns:
14
+ ndarray: Result image.
15
+ """
16
+ img = img.copy()
17
+ if img.ndim in (2, 3):
18
+ h, w = img.shape[0], img.shape[1]
19
+ h_remainder, w_remainder = h % scale, w % scale
20
+ img = img[:h - h_remainder, :w - w_remainder, ...]
21
+ else:
22
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
23
+ return img
24
+
25
+ def paired_random_crop(img_gts, img_lqs, lq_patch_size, scale, gt_path):
26
+ """Paired random crop.
27
+
28
+ It crops lists of lq and gt images with corresponding locations.
29
+
30
+ Args:
31
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
32
+ should have the same shape. If the input is an ndarray, it will
33
+ be transformed to a list containing itself.
34
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
35
+ should have the same shape. If the input is an ndarray, it will
36
+ be transformed to a list containing itself.
37
+ lq_patch_size (int): LQ patch size.
38
+ scale (int): Scale factor.
39
+ gt_path (str): Path to ground-truth.
40
+
41
+ Returns:
42
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
43
+ only have one element, just return ndarray.
44
+ """
45
+
46
+ if not isinstance(img_gts, list):
47
+ img_gts = [img_gts]
48
+ if not isinstance(img_lqs, list):
49
+ img_lqs = [img_lqs]
50
+
51
+ h_lq, w_lq, _ = img_lqs[0].shape
52
+ h_gt, w_gt, _ = img_gts[0].shape
53
+ gt_patch_size = int(lq_patch_size * scale)
54
+
55
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
56
+ raise ValueError(
57
+ f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
58
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
59
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
60
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
61
+ f'({lq_patch_size}, {lq_patch_size}). '
62
+ f'Please remove {gt_path}.')
63
+
64
+ # randomly choose top and left coordinates for lq patch
65
+ top = random.randint(0, h_lq - lq_patch_size)
66
+ left = random.randint(0, w_lq - lq_patch_size)
67
+
68
+ # crop lq patch
69
+ img_lqs = [
70
+ v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
71
+ for v in img_lqs
72
+ ]
73
+
74
+ # crop corresponding gt patch
75
+ top_gt, left_gt = int(top * scale), int(left * scale)
76
+ img_gts = [
77
+ v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
78
+ for v in img_gts
79
+ ]
80
+ if len(img_gts) == 1:
81
+ img_gts = img_gts[0]
82
+ if len(img_lqs) == 1:
83
+ img_lqs = img_lqs[0]
84
+ return img_gts, img_lqs
85
+
86
+ def paired_center_crop(img_gts, img_lqs, lq_patch_size, scale, gt_path):
87
+ """Paired random crop.
88
+
89
+ It crops lists of lq and gt images with corresponding locations.
90
+
91
+ Args:
92
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
93
+ should have the same shape. If the input is an ndarray, it will
94
+ be transformed to a list containing itself.
95
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
96
+ should have the same shape. If the input is an ndarray, it will
97
+ be transformed to a list containing itself.
98
+ lq_patch_size (int): LQ patch size.
99
+ scale (int): Scale factor.
100
+ gt_path (str): Path to ground-truth.
101
+
102
+ Returns:
103
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
104
+ only have one element, just return ndarray.
105
+ """
106
+
107
+ if not isinstance(img_gts, list):
108
+ img_gts = [img_gts]
109
+ if not isinstance(img_lqs, list):
110
+ img_lqs = [img_lqs]
111
+
112
+ h_lq, w_lq, _ = img_lqs[0].shape
113
+ h_gt, w_gt, _ = img_gts[0].shape
114
+ gt_patch_size = int(lq_patch_size * scale)
115
+
116
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
117
+ raise ValueError(
118
+ f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
119
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
120
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
121
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
122
+ f'({lq_patch_size}, {lq_patch_size}). '
123
+ f'Please remove {gt_path}.')
124
+
125
+ # randomly choose top and left coordinates for lq patch
126
+ top = (h_lq - lq_patch_size)//2#random.randint(0, h_lq - lq_patch_size)
127
+ left = (w_lq - lq_patch_size)//2#random.randint(0, w_lq - lq_patch_size)
128
+
129
+ # crop lq patch
130
+ img_lqs = [
131
+ v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
132
+ for v in img_lqs
133
+ ]
134
+
135
+ # crop corresponding gt patch
136
+ top_gt, left_gt = int(top * scale), int(left * scale)
137
+ img_gts = [
138
+ v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
139
+ for v in img_gts
140
+ ]
141
+ if len(img_gts) == 1:
142
+ img_gts = img_gts[0]
143
+ if len(img_lqs) == 1:
144
+ img_lqs = img_lqs[0]
145
+ return img_gts, img_lqs
146
+
147
+ def paired_random_crop_DP(img_lqLs, img_lqRs, img_gts, gt_patch_size, scale, gt_path):
148
+ if not isinstance(img_gts, list):
149
+ img_gts = [img_gts]
150
+ if not isinstance(img_lqLs, list):
151
+ img_lqLs = [img_lqLs]
152
+ if not isinstance(img_lqRs, list):
153
+ img_lqRs = [img_lqRs]
154
+
155
+ h_lq, w_lq, _ = img_lqLs[0].shape
156
+ h_gt, w_gt, _ = img_gts[0].shape
157
+ lq_patch_size = gt_patch_size // scale
158
+
159
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
160
+ raise ValueError(
161
+ f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
162
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
163
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
164
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
165
+ f'({lq_patch_size}, {lq_patch_size}). '
166
+ f'Please remove {gt_path}.')
167
+
168
+ # randomly choose top and left coordinates for lq patch
169
+ top = random.randint(0, h_lq - lq_patch_size)
170
+ left = random.randint(0, w_lq - lq_patch_size)
171
+
172
+ # crop lq patch
173
+ img_lqLs = [
174
+ v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
175
+ for v in img_lqLs
176
+ ]
177
+
178
+ img_lqRs = [
179
+ v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
180
+ for v in img_lqRs
181
+ ]
182
+
183
+ # crop corresponding gt patch
184
+ top_gt, left_gt = int(top * scale), int(left * scale)
185
+ img_gts = [
186
+ v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
187
+ for v in img_gts
188
+ ]
189
+ if len(img_gts) == 1:
190
+ img_gts = img_gts[0]
191
+ if len(img_lqLs) == 1:
192
+ img_lqLs = img_lqLs[0]
193
+ if len(img_lqRs) == 1:
194
+ img_lqRs = img_lqRs[0]
195
+ return img_lqLs, img_lqRs, img_gts
196
+
197
+
198
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
199
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
200
+
201
+ We use vertical flip and transpose for rotation implementation.
202
+ All the images in the list use the same augmentation.
203
+
204
+ Args:
205
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
206
+ is an ndarray, it will be transformed to a list.
207
+ hflip (bool): Horizontal flip. Default: True.
208
+ rotation (bool): Ratotation. Default: True.
209
+ flows (list[ndarray]: Flows to be augmented. If the input is an
210
+ ndarray, it will be transformed to a list.
211
+ Dimension is (h, w, 2). Default: None.
212
+ return_status (bool): Return the status of flip and rotation.
213
+ Default: False.
214
+
215
+ Returns:
216
+ list[ndarray] | ndarray: Augmented images and flows. If returned
217
+ results only have one element, just return ndarray.
218
+
219
+ """
220
+ hflip = hflip and random.random() < 0.5
221
+ vflip = rotation and random.random() < 0.5
222
+ rot90 = rotation and random.random() < 0.5
223
+
224
+ def _augment(img):
225
+ if hflip: # horizontal
226
+ cv2.flip(img, 1, img)
227
+ if vflip: # vertical
228
+ cv2.flip(img, 0, img)
229
+ if rot90:
230
+ img = img.transpose(1, 0, 2)
231
+ return img
232
+
233
+ def _augment_flow(flow):
234
+ if hflip: # horizontal
235
+ cv2.flip(flow, 1, flow)
236
+ flow[:, :, 0] *= -1
237
+ if vflip: # vertical
238
+ cv2.flip(flow, 0, flow)
239
+ flow[:, :, 1] *= -1
240
+ if rot90:
241
+ flow = flow.transpose(1, 0, 2)
242
+ flow = flow[:, :, [1, 0]]
243
+ return flow
244
+
245
+ if not isinstance(imgs, list):
246
+ imgs = [imgs]
247
+ imgs = [_augment(img) for img in imgs]
248
+ if len(imgs) == 1:
249
+ imgs = imgs[0]
250
+
251
+ if flows is not None:
252
+ if not isinstance(flows, list):
253
+ flows = [flows]
254
+ flows = [_augment_flow(flow) for flow in flows]
255
+ if len(flows) == 1:
256
+ flows = flows[0]
257
+ return imgs, flows
258
+ else:
259
+ if return_status:
260
+ return imgs, (hflip, vflip, rot90)
261
+ else:
262
+ return imgs
263
+
264
+
265
+ def img_rotate(img, angle, center=None, scale=1.0):
266
+ """Rotate image.
267
+
268
+ Args:
269
+ img (ndarray): Image to be rotated.
270
+ angle (float): Rotation angle in degrees. Positive values mean
271
+ counter-clockwise rotation.
272
+ center (tuple[int]): Rotation center. If the center is None,
273
+ initialize it as the center of the image. Default: None.
274
+ scale (float): Isotropic scale factor. Default: 1.0.
275
+ """
276
+ (h, w) = img.shape[:2]
277
+
278
+ if center is None:
279
+ center = (w // 2, h // 2)
280
+
281
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
282
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
283
+ return rotated_img
284
+
285
+ def data_augmentation(image, mode):
286
+ """
287
+ Performs data augmentation of the input image
288
+ Input:
289
+ image: a cv2 (OpenCV) image
290
+ mode: int. Choice of transformation to apply to the image
291
+ 0 - no transformation
292
+ 1 - flip up and down
293
+ 2 - rotate counterwise 90 degree
294
+ 3 - rotate 90 degree and flip up and down
295
+ 4 - rotate 180 degree
296
+ 5 - rotate 180 degree and flip
297
+ 6 - rotate 270 degree
298
+ 7 - rotate 270 degree and flip
299
+ """
300
+ if mode == 0:
301
+ # original
302
+ out = image
303
+ elif mode == 1:
304
+ # flip up and down
305
+ out = np.flipud(image)
306
+ elif mode == 2:
307
+ # rotate counterwise 90 degree
308
+ out = np.rot90(image)
309
+ elif mode == 3:
310
+ # rotate 90 degree and flip up and down
311
+ out = np.rot90(image)
312
+ out = np.flipud(out)
313
+ elif mode == 4:
314
+ # rotate 180 degree
315
+ out = np.rot90(image, k=2)
316
+ elif mode == 5:
317
+ # rotate 180 degree and flip
318
+ out = np.rot90(image, k=2)
319
+ out = np.flipud(out)
320
+ elif mode == 6:
321
+ # rotate 270 degree
322
+ out = np.rot90(image, k=3)
323
+ elif mode == 7:
324
+ # rotate 270 degree and flip
325
+ out = np.rot90(image, k=3)
326
+ out = np.flipud(out)
327
+ else:
328
+ raise Exception('Invalid choice of image transformation')
329
+
330
+ return out
331
+
332
+ def random_augmentation(*args):
333
+ out = []
334
+ flag_aug = random.randint(0,7)
335
+ for data in args:
336
+ out.append(data_augmentation(data, flag_aug).copy())
337
+ return out
338
+
339
+ # def paired_random_crop_tip18(img_gts, img_lqs, lq_patch_size, scale, gt_path):
340
+ # """Paired random crop.
341
+
342
+ # It crops lists of lq and gt images with corresponding locations.
343
+
344
+ # Args:
345
+ # img_gts (list[ndarray] | ndarray): GT images. Note that all images
346
+ # should have the same shape. If the input is an ndarray, it will
347
+ # be transformed to a list containing itself.
348
+ # img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
349
+ # should have the same shape. If the input is an ndarray, it will
350
+ # be transformed to a list containing itself.
351
+ # lq_patch_size (int): LQ patch size.
352
+ # scale (int): Scale factor.
353
+ # gt_path (str): Path to ground-truth.
354
+
355
+ # Returns:
356
+ # list[ndarray] | ndarray: GT images and LQ images. If returned results
357
+ # only have one element, just return ndarray.
358
+ # """
359
+
360
+ # if not isinstance(img_gts, list):
361
+ # img_gts = [img_gts]
362
+ # if not isinstance(img_lqs, list):
363
+ # img_lqs = [img_lqs]
364
+
365
+ # h_lq, w_lq, _ = img_lqs[0].shape
366
+ # h_gt, w_gt, _ = img_gts[0].shape
367
+ # gt_patch_size = int(lq_patch_size * scale)
368
+
369
+ # if h_gt != h_lq * scale or w_gt != w_lq * scale:
370
+ # raise ValueError(
371
+ # f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
372
+ # f'multiplication of LQ ({h_lq}, {w_lq}).')
373
+ # if h_lq < lq_patch_size or w_lq < lq_patch_size:
374
+ # raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
375
+ # f'({lq_patch_size}, {lq_patch_size}). '
376
+ # f'Please remove {gt_path}.')
377
+
378
+ # #pre process
379
+ # # w, h = img.size
380
+ # # region = img.crop((1 + int(0.15 * w), 1 + int(0.15 * h), int(0.85 * w), int(0.85 * h)))
381
+ # # region = region.resize((286, 286), Image.BILINEAR)
382
+ # # crop lq patch
383
+ # w = w_lq,h =h_lq
384
+ # img_lqs = [
385
+ # # v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...]
386
+ # for v in img_lqs:
387
+ # # v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...]
388
+ # img = Image.fromarray(v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...])
389
+ # img = img.resize((286, 286), Image.BILINEAR)
390
+
391
+ # ]
392
+ # img_gts = [
393
+ # v[(1 + int(0.15 * h)):int(0.85 * h), (1 + int(0.15 * w)):int(0.85 * w), ...]
394
+ # for v in img_gts
395
+ # ]
396
+
397
+
398
+
399
+ # # randomly choose top and left coordinates for lq patch
400
+ # top = random.randint(0, h_lq - lq_patch_size)
401
+ # left = random.randint(0, w_lq - lq_patch_size)
402
+
403
+ # # crop lq patch
404
+ # img_lqs = [
405
+ # v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
406
+ # for v in img_lqs
407
+ # ]
408
+
409
+ # # crop corresponding gt patch
410
+ # top_gt, left_gt = int(top * scale), int(left * scale)
411
+ # img_gts = [
412
+ # v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
413
+ # for v in img_gts
414
+ # ]
415
+ # if len(img_gts) == 1:
416
+ # img_gts = img_gts[0]
417
+ # if len(img_lqs) == 1:
418
+ # img_lqs = img_lqs[0]
419
+ # return img_gts, img_lqs
420
+
421
+ # def paired_center_crop_tip18(img_gts, img_lqs, lq_patch_size, scale, gt_path):
422
+ # """Paired random crop.
423
+
424
+ # It crops lists of lq and gt images with corresponding locations.
425
+
426
+ # Args:
427
+ # img_gts (list[ndarray] | ndarray): GT images. Note that all images
428
+ # should have the same shape. If the input is an ndarray, it will
429
+ # be transformed to a list containing itself.
430
+ # img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
431
+ # should have the same shape. If the input is an ndarray, it will
432
+ # be transformed to a list containing itself.
433
+ # lq_patch_size (int): LQ patch size.
434
+ # scale (int): Scale factor.
435
+ # gt_path (str): Path to ground-truth.
436
+
437
+ # Returns:
438
+ # list[ndarray] | ndarray: GT images and LQ images. If returned results
439
+ # only have one element, just return ndarray.
440
+ # """
441
+
442
+ # if not isinstance(img_gts, list):
443
+ # img_gts = [img_gts]
444
+ # if not isinstance(img_lqs, list):
445
+ # img_lqs = [img_lqs]
446
+
447
+ # h_lq, w_lq, _ = img_lqs[0].shape
448
+ # h_gt, w_gt, _ = img_gts[0].shape
449
+ # gt_patch_size = int(lq_patch_size * scale)
450
+
451
+ # if h_gt != h_lq * scale or w_gt != w_lq * scale:
452
+ # raise ValueError(
453
+ # f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
454
+ # f'multiplication of LQ ({h_lq}, {w_lq}).')
455
+ # if h_lq < lq_patch_size or w_lq < lq_patch_size:
456
+ # raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
457
+ # f'({lq_patch_size}, {lq_patch_size}). '
458
+ # f'Please remove {gt_path}.')
459
+
460
+ # # randomly choose top and left coordinates for lq patch
461
+ # top = (h_lq - lq_patch_size)//2#random.randint(0, h_lq - lq_patch_size)
462
+ # left = (w_lq - lq_patch_size)//2#random.randint(0, w_lq - lq_patch_size)
463
+
464
+ # # crop lq patch
465
+ # img_lqs = [
466
+ # v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
467
+ # for v in img_lqs
468
+ # ]
469
+
470
+ # # crop corresponding gt patch
471
+ # top_gt, left_gt = int(top * scale), int(left * scale)
472
+ # img_gts = [
473
+ # v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
474
+ # for v in img_gts
475
+ # ]
476
+ # if len(img_gts) == 1:
477
+ # img_gts = img_gts[0]
478
+ # if len(img_lqs) == 1:
479
+ # img_lqs = img_lqs[0]
480
+ # return img_gts, img_lqs
basicsr/data/video_test_dataset.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import torch
3
+ from os import path as osp
4
+ from torch.utils import data as data
5
+
6
+ from basicsr.data.data_util import (duf_downsample, generate_frame_indices,
7
+ read_img_seq)
8
+ from basicsr.utils import get_root_logger, scandir
9
+
10
+
11
+ class VideoTestDataset(data.Dataset):
12
+ """Video test dataset.
13
+
14
+ Supported datasets: Vid4, REDS4, REDSofficial.
15
+ More generally, it supports testing dataset with following structures:
16
+
17
+ dataroot
18
+ ├── subfolder1
19
+ ├── frame000
20
+ ├── frame001
21
+ ├── ...
22
+ ├── subfolder1
23
+ ├── frame000
24
+ ├── frame001
25
+ ├── ...
26
+ ├── ...
27
+
28
+ For testing datasets, there is no need to prepare LMDB files.
29
+
30
+ Args:
31
+ opt (dict): Config for train dataset. It contains the following keys:
32
+ dataroot_gt (str): Data root path for gt.
33
+ dataroot_lq (str): Data root path for lq.
34
+ io_backend (dict): IO backend type and other kwarg.
35
+ cache_data (bool): Whether to cache testing datasets.
36
+ name (str): Dataset name.
37
+ meta_info_file (str): The path to the file storing the list of test
38
+ folders. If not provided, all the folders in the dataroot will
39
+ be used.
40
+ num_frame (int): Window size for input frames.
41
+ padding (str): Padding mode.
42
+ """
43
+
44
+ def __init__(self, opt):
45
+ super(VideoTestDataset, self).__init__()
46
+ self.opt = opt
47
+ self.cache_data = opt['cache_data']
48
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
49
+ self.data_info = {
50
+ 'lq_path': [],
51
+ 'gt_path': [],
52
+ 'folder': [],
53
+ 'idx': [],
54
+ 'border': []
55
+ }
56
+ # file client (io backend)
57
+ self.file_client = None
58
+ self.io_backend_opt = opt['io_backend']
59
+ assert self.io_backend_opt[
60
+ 'type'] != 'lmdb', 'No need to use lmdb during validation/test.'
61
+
62
+ logger = get_root_logger()
63
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
64
+ self.imgs_lq, self.imgs_gt = {}, {}
65
+ if 'meta_info_file' in opt:
66
+ with open(opt['meta_info_file'], 'r') as fin:
67
+ subfolders = [line.split(' ')[0] for line in fin]
68
+ subfolders_lq = [
69
+ osp.join(self.lq_root, key) for key in subfolders
70
+ ]
71
+ subfolders_gt = [
72
+ osp.join(self.gt_root, key) for key in subfolders
73
+ ]
74
+ else:
75
+ subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
76
+ subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
77
+
78
+ if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
79
+ for subfolder_lq, subfolder_gt in zip(subfolders_lq,
80
+ subfolders_gt):
81
+ # get frame list for lq and gt
82
+ subfolder_name = osp.basename(subfolder_lq)
83
+ img_paths_lq = sorted(
84
+ list(scandir(subfolder_lq, full_path=True)))
85
+ img_paths_gt = sorted(
86
+ list(scandir(subfolder_gt, full_path=True)))
87
+
88
+ max_idx = len(img_paths_lq)
89
+ assert max_idx == len(img_paths_gt), (
90
+ f'Different number of images in lq ({max_idx})'
91
+ f' and gt folders ({len(img_paths_gt)})')
92
+
93
+ self.data_info['lq_path'].extend(img_paths_lq)
94
+ self.data_info['gt_path'].extend(img_paths_gt)
95
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
96
+ for i in range(max_idx):
97
+ self.data_info['idx'].append(f'{i}/{max_idx}')
98
+ border_l = [0] * max_idx
99
+ for i in range(self.opt['num_frame'] // 2):
100
+ border_l[i] = 1
101
+ border_l[max_idx - i - 1] = 1
102
+ self.data_info['border'].extend(border_l)
103
+
104
+ # cache data or save the frame list
105
+ if self.cache_data:
106
+ logger.info(
107
+ f'Cache {subfolder_name} for VideoTestDataset...')
108
+ self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
109
+ self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
110
+ else:
111
+ self.imgs_lq[subfolder_name] = img_paths_lq
112
+ self.imgs_gt[subfolder_name] = img_paths_gt
113
+ else:
114
+ raise ValueError(
115
+ f'Non-supported video test dataset: {type(opt["name"])}')
116
+
117
+ def __getitem__(self, index):
118
+ folder = self.data_info['folder'][index]
119
+ idx, max_idx = self.data_info['idx'][index].split('/')
120
+ idx, max_idx = int(idx), int(max_idx)
121
+ border = self.data_info['border'][index]
122
+ lq_path = self.data_info['lq_path'][index]
123
+
124
+ select_idx = generate_frame_indices(
125
+ idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
126
+
127
+ if self.cache_data:
128
+ imgs_lq = self.imgs_lq[folder].index_select(
129
+ 0, torch.LongTensor(select_idx))
130
+ img_gt = self.imgs_gt[folder][idx]
131
+ else:
132
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
133
+ imgs_lq = read_img_seq(img_paths_lq)
134
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]])
135
+ img_gt.squeeze_(0)
136
+
137
+ return {
138
+ 'lq': imgs_lq, # (t, c, h, w)
139
+ 'gt': img_gt, # (c, h, w)
140
+ 'folder': folder, # folder name
141
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
142
+ 'border': border, # 1 for border, 0 for non-border
143
+ 'lq_path': lq_path # center frame
144
+ }
145
+
146
+ def __len__(self):
147
+ return len(self.data_info['gt_path'])
148
+
149
+
150
+ class VideoTestVimeo90KDataset(data.Dataset):
151
+ """Video test dataset for Vimeo90k-Test dataset.
152
+
153
+ It only keeps the center frame for testing.
154
+ For testing datasets, there is no need to prepare LMDB files.
155
+
156
+ Args:
157
+ opt (dict): Config for train dataset. It contains the following keys:
158
+ dataroot_gt (str): Data root path for gt.
159
+ dataroot_lq (str): Data root path for lq.
160
+ io_backend (dict): IO backend type and other kwarg.
161
+ cache_data (bool): Whether to cache testing datasets.
162
+ name (str): Dataset name.
163
+ meta_info_file (str): The path to the file storing the list of test
164
+ folders. If not provided, all the folders in the dataroot will
165
+ be used.
166
+ num_frame (int): Window size for input frames.
167
+ padding (str): Padding mode.
168
+ """
169
+
170
+ def __init__(self, opt):
171
+ super(VideoTestVimeo90KDataset, self).__init__()
172
+ self.opt = opt
173
+ self.cache_data = opt['cache_data']
174
+ if self.cache_data:
175
+ raise NotImplementedError(
176
+ 'cache_data in Vimeo90K-Test dataset is not implemented.')
177
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
178
+ self.data_info = {
179
+ 'lq_path': [],
180
+ 'gt_path': [],
181
+ 'folder': [],
182
+ 'idx': [],
183
+ 'border': []
184
+ }
185
+ neighbor_list = [
186
+ i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
187
+ ]
188
+
189
+ # file client (io backend)
190
+ self.file_client = None
191
+ self.io_backend_opt = opt['io_backend']
192
+ assert self.io_backend_opt[
193
+ 'type'] != 'lmdb', 'No need to use lmdb during validation/test.'
194
+
195
+ logger = get_root_logger()
196
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
197
+ with open(opt['meta_info_file'], 'r') as fin:
198
+ subfolders = [line.split(' ')[0] for line in fin]
199
+ for idx, subfolder in enumerate(subfolders):
200
+ gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
201
+ self.data_info['gt_path'].append(gt_path)
202
+ lq_paths = [
203
+ osp.join(self.lq_root, subfolder, f'im{i}.png')
204
+ for i in neighbor_list
205
+ ]
206
+ self.data_info['lq_path'].append(lq_paths)
207
+ self.data_info['folder'].append('vimeo90k')
208
+ self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
209
+ self.data_info['border'].append(0)
210
+
211
+ def __getitem__(self, index):
212
+ lq_path = self.data_info['lq_path'][index]
213
+ gt_path = self.data_info['gt_path'][index]
214
+ imgs_lq = read_img_seq(lq_path)
215
+ img_gt = read_img_seq([gt_path])
216
+ img_gt.squeeze_(0)
217
+
218
+ return {
219
+ 'lq': imgs_lq, # (t, c, h, w)
220
+ 'gt': img_gt, # (c, h, w)
221
+ 'folder': self.data_info['folder'][index], # folder name
222
+ 'idx': self.data_info['idx'][index], # e.g., 0/843
223
+ 'border': self.data_info['border'][index], # 0 for non-border
224
+ 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
225
+ }
226
+
227
+ def __len__(self):
228
+ return len(self.data_info['gt_path'])
229
+
230
+
231
+ class VideoTestDUFDataset(VideoTestDataset):
232
+ """ Video test dataset for DUF dataset.
233
+
234
+ Args:
235
+ opt (dict): Config for train dataset.
236
+ Most of keys are the same as VideoTestDataset.
237
+ It has the follwing extra keys:
238
+
239
+ use_duf_downsampling (bool): Whether to use duf downsampling to
240
+ generate low-resolution frames.
241
+ scale (bool): Scale, which will be added automatically.
242
+ """
243
+
244
+ def __getitem__(self, index):
245
+ folder = self.data_info['folder'][index]
246
+ idx, max_idx = self.data_info['idx'][index].split('/')
247
+ idx, max_idx = int(idx), int(max_idx)
248
+ border = self.data_info['border'][index]
249
+ lq_path = self.data_info['lq_path'][index]
250
+
251
+ select_idx = generate_frame_indices(
252
+ idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
253
+
254
+ if self.cache_data:
255
+ if self.opt['use_duf_downsampling']:
256
+ # read imgs_gt to generate low-resolution frames
257
+ imgs_lq = self.imgs_gt[folder].index_select(
258
+ 0, torch.LongTensor(select_idx))
259
+ imgs_lq = duf_downsample(
260
+ imgs_lq, kernel_size=13, scale=self.opt['scale'])
261
+ else:
262
+ imgs_lq = self.imgs_lq[folder].index_select(
263
+ 0, torch.LongTensor(select_idx))
264
+ img_gt = self.imgs_gt[folder][idx]
265
+ else:
266
+ if self.opt['use_duf_downsampling']:
267
+ img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
268
+ # read imgs_gt to generate low-resolution frames
269
+ imgs_lq = read_img_seq(
270
+ img_paths_lq,
271
+ require_mod_crop=True,
272
+ scale=self.opt['scale'])
273
+ imgs_lq = duf_downsample(
274
+ imgs_lq, kernel_size=13, scale=self.opt['scale'])
275
+ else:
276
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
277
+ imgs_lq = read_img_seq(img_paths_lq)
278
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]],
279
+ require_mod_crop=True,
280
+ scale=self.opt['scale'])
281
+ img_gt.squeeze_(0)
282
+
283
+ return {
284
+ 'lq': imgs_lq, # (t, c, h, w)
285
+ 'gt': img_gt, # (c, h, w)
286
+ 'folder': folder, # folder name
287
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
288
+ 'border': border, # 1 for border, 0 for non-border
289
+ 'lq_path': lq_path # center frame
290
+ }
291
+
292
+
293
+ class VideoRecurrentTestDataset(VideoTestDataset):
294
+ """Video test dataset for recurrent architectures, which takes LR video
295
+ frames as input and output corresponding HR video frames.
296
+
297
+ Args:
298
+ Same as VideoTestDataset.
299
+ Unused opt:
300
+ padding (str): Padding mode.
301
+
302
+ """
303
+
304
+ def __init__(self, opt):
305
+ super(VideoRecurrentTestDataset, self).__init__(opt)
306
+ # Find unique folder strings
307
+ self.folders = sorted(list(set(self.data_info['folder'])))
308
+
309
+ def __getitem__(self, index):
310
+ folder = self.folders[index]
311
+
312
+ if self.cache_data:
313
+ imgs_lq = self.imgs_lq[folder]
314
+ imgs_gt = self.imgs_gt[folder]
315
+ else:
316
+ raise NotImplementedError('Without cache_data is not implemented.')
317
+
318
+ return {
319
+ 'lq': imgs_lq,
320
+ 'gt': imgs_gt,
321
+ 'folder': folder,
322
+ }
323
+
324
+ def __len__(self):
325
+ return len(self.folders)
basicsr/data/vimeo90k_dataset.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from pathlib import Path
4
+ from torch.utils import data as data
5
+
6
+ from basicsr.data.transforms import augment, paired_random_crop
7
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8
+
9
+
10
+ class Vimeo90KDataset(data.Dataset):
11
+ """Vimeo90K dataset for training.
12
+
13
+ The keys are generated from a meta info txt file.
14
+ basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
15
+
16
+ Each line contains:
17
+ 1. clip name; 2. frame number; 3. image shape, seperated by a white space.
18
+ Examples:
19
+ 00001/0001 7 (256,448,3)
20
+ 00001/0002 7 (256,448,3)
21
+
22
+ Key examples: "00001/0001"
23
+ GT (gt): Ground-Truth;
24
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
25
+
26
+ The neighboring frame list for different num_frame:
27
+ num_frame | frame list
28
+ 1 | 4
29
+ 3 | 3,4,5
30
+ 5 | 2,3,4,5,6
31
+ 7 | 1,2,3,4,5,6,7
32
+
33
+ Args:
34
+ opt (dict): Config for train dataset. It contains the following keys:
35
+ dataroot_gt (str): Data root path for gt.
36
+ dataroot_lq (str): Data root path for lq.
37
+ meta_info_file (str): Path for meta information file.
38
+ io_backend (dict): IO backend type and other kwarg.
39
+
40
+ num_frame (int): Window size for input frames.
41
+ gt_size (int): Cropped patched size for gt patches.
42
+ random_reverse (bool): Random reverse input frames.
43
+ use_flip (bool): Use horizontal flips.
44
+ use_rot (bool): Use rotation (use vertical flip and transposing h
45
+ and w for implementation).
46
+
47
+ scale (bool): Scale, which will be added automatically.
48
+ """
49
+
50
+ def __init__(self, opt):
51
+ super(Vimeo90KDataset, self).__init__()
52
+ self.opt = opt
53
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
54
+ opt['dataroot_lq'])
55
+
56
+ with open(opt['meta_info_file'], 'r') as fin:
57
+ self.keys = [line.split(' ')[0] for line in fin]
58
+
59
+ # file client (io backend)
60
+ self.file_client = None
61
+ self.io_backend_opt = opt['io_backend']
62
+ self.is_lmdb = False
63
+ if self.io_backend_opt['type'] == 'lmdb':
64
+ self.is_lmdb = True
65
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
66
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
67
+
68
+ # indices of input images
69
+ self.neighbor_list = [
70
+ i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
71
+ ]
72
+
73
+ # temporal augmentation configs
74
+ self.random_reverse = opt['random_reverse']
75
+ logger = get_root_logger()
76
+ logger.info(f'Random reverse is {self.random_reverse}.')
77
+
78
+ def __getitem__(self, index):
79
+ if self.file_client is None:
80
+ self.file_client = FileClient(
81
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
82
+
83
+ # random reverse
84
+ if self.random_reverse and random.random() < 0.5:
85
+ self.neighbor_list.reverse()
86
+
87
+ scale = self.opt['scale']
88
+ gt_size = self.opt['gt_size']
89
+ key = self.keys[index]
90
+ clip, seq = key.split('/') # key example: 00001/0001
91
+
92
+ # get the GT frame (im4.png)
93
+ if self.is_lmdb:
94
+ img_gt_path = f'{key}/im4'
95
+ else:
96
+ img_gt_path = self.gt_root / clip / seq / 'im4.png'
97
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
98
+ img_gt = imfrombytes(img_bytes, float32=True)
99
+
100
+ # get the neighboring LQ frames
101
+ img_lqs = []
102
+ for neighbor in self.neighbor_list:
103
+ if self.is_lmdb:
104
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
105
+ else:
106
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
107
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
108
+ img_lq = imfrombytes(img_bytes, float32=True)
109
+ img_lqs.append(img_lq)
110
+
111
+ # randomly crop
112
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
113
+ img_gt_path)
114
+
115
+ # augmentation - flip, rotate
116
+ img_lqs.append(img_gt)
117
+ img_results = augment(img_lqs, self.opt['use_flip'],
118
+ self.opt['use_rot'])
119
+
120
+ img_results = img2tensor(img_results)
121
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
122
+ img_gt = img_results[-1]
123
+
124
+ # img_lqs: (t, c, h, w)
125
+ # img_gt: (c, h, w)
126
+ # key: str
127
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
128
+
129
+ def __len__(self):
130
+ return len(self.keys)
basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .niqe import calculate_niqe
2
+ from .psnr_ssim import calculate_psnr, calculate_ssim
3
+
4
+ __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
basicsr/metrics/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (311 Bytes). View file
 
basicsr/metrics/__pycache__/metric_util.cpython-37.pyc ADDED
Binary file (1.5 kB). View file
 
basicsr/metrics/__pycache__/niqe.cpython-37.pyc ADDED
Binary file (6.46 kB). View file
 
basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc ADDED
Binary file (7.67 kB). View file
 
basicsr/metrics/fid.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from scipy import linalg
5
+ from tqdm import tqdm
6
+
7
+ from basicsr.models.archs.inception import InceptionV3
8
+
9
+
10
+ def load_patched_inception_v3(device='cuda',
11
+ resize_input=True,
12
+ normalize_input=False):
13
+ # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
14
+ # does resize the input.
15
+ inception = InceptionV3([3],
16
+ resize_input=resize_input,
17
+ normalize_input=normalize_input)
18
+ inception = nn.DataParallel(inception).eval().to(device)
19
+ return inception
20
+
21
+
22
+ @torch.no_grad()
23
+ def extract_inception_features(data_generator,
24
+ inception,
25
+ len_generator=None,
26
+ device='cuda'):
27
+ """Extract inception features.
28
+
29
+ Args:
30
+ data_generator (generator): A data generator.
31
+ inception (nn.Module): Inception model.
32
+ len_generator (int): Length of the data_generator to show the
33
+ progressbar. Default: None.
34
+ device (str): Device. Default: cuda.
35
+
36
+ Returns:
37
+ Tensor: Extracted features.
38
+ """
39
+ if len_generator is not None:
40
+ pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
41
+ else:
42
+ pbar = None
43
+ features = []
44
+
45
+ for data in data_generator:
46
+ if pbar:
47
+ pbar.update(1)
48
+ data = data.to(device)
49
+ feature = inception(data)[0].view(data.shape[0], -1)
50
+ features.append(feature.to('cpu'))
51
+ if pbar:
52
+ pbar.close()
53
+ features = torch.cat(features, 0)
54
+ return features
55
+
56
+
57
+ def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
58
+ """Numpy implementation of the Frechet Distance.
59
+
60
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
61
+ and X_2 ~ N(mu_2, C_2) is
62
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
63
+ Stable version by Dougal J. Sutherland.
64
+
65
+ Args:
66
+ mu1 (np.array): The sample mean over activations.
67
+ sigma1 (np.array): The covariance matrix over activations for
68
+ generated samples.
69
+ mu2 (np.array): The sample mean over activations, precalculated on an
70
+ representative data set.
71
+ sigma2 (np.array): The covariance matrix over activations,
72
+ precalculated on an representative data set.
73
+
74
+ Returns:
75
+ float: The Frechet Distance.
76
+ """
77
+ assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
78
+ assert sigma1.shape == sigma2.shape, (
79
+ 'Two covariances have different dimensions')
80
+
81
+ cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
82
+
83
+ # Product might be almost singular
84
+ if not np.isfinite(cov_sqrt).all():
85
+ print('Product of cov matrices is singular. Adding {eps} to diagonal '
86
+ 'of cov estimates')
87
+ offset = np.eye(sigma1.shape[0]) * eps
88
+ cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
89
+
90
+ # Numerical error might give slight imaginary component
91
+ if np.iscomplexobj(cov_sqrt):
92
+ if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
93
+ m = np.max(np.abs(cov_sqrt.imag))
94
+ raise ValueError(f'Imaginary component {m}')
95
+ cov_sqrt = cov_sqrt.real
96
+
97
+ mean_diff = mu1 - mu2
98
+ mean_norm = mean_diff @ mean_diff
99
+ trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
100
+ fid = mean_norm + trace
101
+
102
+ return fid
basicsr/metrics/metric_util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from basicsr.utils.matlab_functions import bgr2ycbcr
4
+
5
+
6
+ def reorder_image(img, input_order='HWC'):
7
+ """Reorder images to 'HWC' order.
8
+
9
+ If the input_order is (h, w), return (h, w, 1);
10
+ If the input_order is (c, h, w), return (h, w, c);
11
+ If the input_order is (h, w, c), return as it is.
12
+
13
+ Args:
14
+ img (ndarray): Input image.
15
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
+ If the input image shape is (h, w), input_order will not have
17
+ effects. Default: 'HWC'.
18
+
19
+ Returns:
20
+ ndarray: reordered image.
21
+ """
22
+
23
+ if input_order not in ['HWC', 'CHW']:
24
+ raise ValueError(
25
+ f'Wrong input_order {input_order}. Supported input_orders are '
26
+ "'HWC' and 'CHW'")
27
+ if len(img.shape) == 2:
28
+ img = img[..., None]
29
+ if input_order == 'CHW':
30
+ img = img.transpose(1, 2, 0)
31
+ return img
32
+
33
+
34
+ def to_y_channel(img):
35
+ """Change to Y channel of YCbCr.
36
+
37
+ Args:
38
+ img (ndarray): Images with range [0, 255].
39
+
40
+ Returns:
41
+ (ndarray): Images with range [0, 255] (float type) without round.
42
+ """
43
+ img = img.astype(np.float32) / 255.
44
+ if img.ndim == 3 and img.shape[2] == 3:
45
+ img = bgr2ycbcr(img, y_only=True)
46
+ img = img[..., None]
47
+ return img * 255.
basicsr/metrics/niqe.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ from scipy.ndimage.filters import convolve
5
+ from scipy.special import gamma
6
+
7
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
8
+
9
+
10
+ def estimate_aggd_param(block):
11
+ """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
12
+
13
+ Args:
14
+ block (ndarray): 2D Image block.
15
+
16
+ Returns:
17
+ tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
18
+ distribution (Estimating the parames in Equation 7 in the paper).
19
+ """
20
+ block = block.flatten()
21
+ gam = np.arange(0.2, 10.001, 0.001) # len = 9801
22
+ gam_reciprocal = np.reciprocal(gam)
23
+ r_gam = np.square(gamma(gam_reciprocal * 2)) / (
24
+ gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
25
+
26
+ left_std = np.sqrt(np.mean(block[block < 0]**2))
27
+ right_std = np.sqrt(np.mean(block[block > 0]**2))
28
+ gammahat = left_std / right_std
29
+ rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
30
+ rhatnorm = (rhat * (gammahat**3 + 1) *
31
+ (gammahat + 1)) / ((gammahat**2 + 1)**2)
32
+ array_position = np.argmin((r_gam - rhatnorm)**2)
33
+
34
+ alpha = gam[array_position]
35
+ beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
36
+ beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
37
+ return (alpha, beta_l, beta_r)
38
+
39
+
40
+ def compute_feature(block):
41
+ """Compute features.
42
+
43
+ Args:
44
+ block (ndarray): 2D Image block.
45
+
46
+ Returns:
47
+ list: Features with length of 18.
48
+ """
49
+ feat = []
50
+ alpha, beta_l, beta_r = estimate_aggd_param(block)
51
+ feat.extend([alpha, (beta_l + beta_r) / 2])
52
+
53
+ # distortions disturb the fairly regular structure of natural images.
54
+ # This deviation can be captured by analyzing the sample distribution of
55
+ # the products of pairs of adjacent coefficients computed along
56
+ # horizontal, vertical and diagonal orientations.
57
+ shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
58
+ for i in range(len(shifts)):
59
+ shifted_block = np.roll(block, shifts[i], axis=(0, 1))
60
+ alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
61
+ # Eq. 8
62
+ mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
63
+ feat.extend([alpha, mean, beta_l, beta_r])
64
+ return feat
65
+
66
+
67
+ def niqe(img,
68
+ mu_pris_param,
69
+ cov_pris_param,
70
+ gaussian_window,
71
+ block_size_h=96,
72
+ block_size_w=96):
73
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
74
+
75
+ Ref: Making a "Completely Blind" Image Quality Analyzer.
76
+ This implementation could produce almost the same results as the official
77
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
78
+
79
+ Note that we do not include block overlap height and width, since they are
80
+ always 0 in the official implementation.
81
+
82
+ For good performance, it is advisable by the official implemtation to
83
+ divide the distorted image in to the same size patched as used for the
84
+ construction of multivariate Gaussian model.
85
+
86
+ Args:
87
+ img (ndarray): Input image whose quality needs to be computed. The
88
+ image must be a gray or Y (of YCbCr) image with shape (h, w).
89
+ Range [0, 255] with float type.
90
+ mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
91
+ model calculated on the pristine dataset.
92
+ cov_pris_param (ndarray): Covariance of a pre-defined multivariate
93
+ Gaussian model calculated on the pristine dataset.
94
+ gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
95
+ image.
96
+ block_size_h (int): Height of the blocks in to which image is divided.
97
+ Default: 96 (the official recommended value).
98
+ block_size_w (int): Width of the blocks in to which image is divided.
99
+ Default: 96 (the official recommended value).
100
+ """
101
+ assert img.ndim == 2, (
102
+ 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
103
+ # crop image
104
+ h, w = img.shape
105
+ num_block_h = math.floor(h / block_size_h)
106
+ num_block_w = math.floor(w / block_size_w)
107
+ img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
108
+
109
+ distparam = [] # dist param is actually the multiscale features
110
+ for scale in (1, 2): # perform on two scales (1, 2)
111
+ mu = convolve(img, gaussian_window, mode='nearest')
112
+ sigma = np.sqrt(
113
+ np.abs(
114
+ convolve(np.square(img), gaussian_window, mode='nearest') -
115
+ np.square(mu)))
116
+ # normalize, as in Eq. 1 in the paper
117
+ img_nomalized = (img - mu) / (sigma + 1)
118
+
119
+ feat = []
120
+ for idx_w in range(num_block_w):
121
+ for idx_h in range(num_block_h):
122
+ # process ecah block
123
+ block = img_nomalized[idx_h * block_size_h //
124
+ scale:(idx_h + 1) * block_size_h //
125
+ scale, idx_w * block_size_w //
126
+ scale:(idx_w + 1) * block_size_w //
127
+ scale]
128
+ feat.append(compute_feature(block))
129
+
130
+ distparam.append(np.array(feat))
131
+ # TODO: matlab bicubic downsample with anti-aliasing
132
+ # for simplicity, now we use opencv instead, which will result in
133
+ # a slight difference.
134
+ if scale == 1:
135
+ h, w = img.shape
136
+ img = cv2.resize(
137
+ img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
138
+ img = img * 255.
139
+
140
+ distparam = np.concatenate(distparam, axis=1)
141
+
142
+ # fit a MVG (multivariate Gaussian) model to distorted patch features
143
+ mu_distparam = np.nanmean(distparam, axis=0)
144
+ # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
145
+ distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
146
+ cov_distparam = np.cov(distparam_no_nan, rowvar=False)
147
+
148
+ # compute niqe quality, Eq. 10 in the paper
149
+ invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
150
+ quality = np.matmul(
151
+ np.matmul((mu_pris_param - mu_distparam), invcov_param),
152
+ np.transpose((mu_pris_param - mu_distparam)))
153
+ quality = np.sqrt(quality)
154
+
155
+ return quality
156
+
157
+
158
+ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
159
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
160
+
161
+ Ref: Making a "Completely Blind" Image Quality Analyzer.
162
+ This implementation could produce almost the same results as the official
163
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
164
+
165
+ We use the official params estimated from the pristine dataset.
166
+ We use the recommended block size (96, 96) without overlaps.
167
+
168
+ Args:
169
+ img (ndarray): Input image whose quality needs to be computed.
170
+ The input image must be in range [0, 255] with float/int type.
171
+ The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
172
+ If the input order is 'HWC' or 'CHW', it will be converted to gray
173
+ or Y (of YCbCr) image according to the ``convert_to`` argument.
174
+ crop_border (int): Cropped pixels in each edge of an image. These
175
+ pixels are not involved in the metric calculation.
176
+ input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
177
+ Default: 'HWC'.
178
+ convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
179
+ Default: 'y'.
180
+
181
+ Returns:
182
+ float: NIQE result.
183
+ """
184
+
185
+ # we use the official params estimated from the pristine dataset.
186
+ niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
187
+ mu_pris_param = niqe_pris_params['mu_pris_param']
188
+ cov_pris_param = niqe_pris_params['cov_pris_param']
189
+ gaussian_window = niqe_pris_params['gaussian_window']
190
+
191
+ img = img.astype(np.float32)
192
+ if input_order != 'HW':
193
+ img = reorder_image(img, input_order=input_order)
194
+ if convert_to == 'y':
195
+ img = to_y_channel(img)
196
+ elif convert_to == 'gray':
197
+ img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
198
+ img = np.squeeze(img)
199
+
200
+ if crop_border != 0:
201
+ img = img[crop_border:-crop_border, crop_border:-crop_border]
202
+
203
+ niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
204
+
205
+ return niqe_result
basicsr/metrics/niqe_pris_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
3
+ size 11850
basicsr/metrics/psnr_ssim.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
5
+ import skimage.metrics
6
+ import torch
7
+
8
+
9
+ def calculate_psnr(img1,
10
+ img2,
11
+ crop_border,
12
+ input_order='HWC',
13
+ test_y_channel=False):
14
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
15
+
16
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
17
+
18
+ Args:
19
+ img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
20
+ img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
21
+ crop_border (int): Cropped pixels in each edge of an image. These
22
+ pixels are not involved in the PSNR calculation.
23
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
24
+ Default: 'HWC'.
25
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
26
+
27
+ Returns:
28
+ float: psnr result.
29
+ """
30
+
31
+ assert img1.shape == img2.shape, (
32
+ f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
33
+ if input_order not in ['HWC', 'CHW']:
34
+ raise ValueError(
35
+ f'Wrong input_order {input_order}. Supported input_orders are '
36
+ '"HWC" and "CHW"')
37
+ if type(img1) == torch.Tensor:
38
+ if len(img1.shape) == 4:
39
+ img1 = img1.squeeze(0)
40
+ img1 = img1.detach().cpu().numpy().transpose(1,2,0)
41
+ if type(img2) == torch.Tensor:
42
+ if len(img2.shape) == 4:
43
+ img2 = img2.squeeze(0)
44
+ img2 = img2.detach().cpu().numpy().transpose(1,2,0)
45
+
46
+ img1 = reorder_image(img1, input_order=input_order)
47
+ img2 = reorder_image(img2, input_order=input_order)
48
+ img1 = img1.astype(np.float64)
49
+ img2 = img2.astype(np.float64)
50
+
51
+ if crop_border != 0:
52
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
53
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
54
+
55
+ if test_y_channel:
56
+ img1 = to_y_channel(img1)
57
+ img2 = to_y_channel(img2)
58
+
59
+ mse = np.mean((img1 - img2)**2)
60
+ if mse == 0:
61
+ return float('inf')
62
+ max_value = 1. if img1.max() <= 1 else 255.
63
+ return 20. * np.log10(max_value / np.sqrt(mse))
64
+
65
+
66
+ def _ssim(img1, img2):
67
+ """Calculate SSIM (structural similarity) for one channel images.
68
+
69
+ It is called by func:`calculate_ssim`.
70
+
71
+ Args:
72
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
73
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
74
+
75
+ Returns:
76
+ float: ssim result.
77
+ """
78
+
79
+ C1 = (0.01 * 255)**2
80
+ C2 = (0.03 * 255)**2
81
+
82
+ img1 = img1.astype(np.float64)
83
+ img2 = img2.astype(np.float64)
84
+ kernel = cv2.getGaussianKernel(11, 1.5)
85
+ window = np.outer(kernel, kernel.transpose())
86
+
87
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
88
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
89
+ mu1_sq = mu1**2
90
+ mu2_sq = mu2**2
91
+ mu1_mu2 = mu1 * mu2
92
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
93
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
94
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
95
+
96
+ ssim_map = ((2 * mu1_mu2 + C1) *
97
+ (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
98
+ (sigma1_sq + sigma2_sq + C2))
99
+ return ssim_map.mean()
100
+
101
+ def prepare_for_ssim(img, k):
102
+ import torch
103
+ with torch.no_grad():
104
+ img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
105
+ conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
106
+ conv.weight.requires_grad = False
107
+ conv.weight[:, :, :, :] = 1. / (k * k)
108
+
109
+ img = conv(img)
110
+
111
+ img = img.squeeze(0).squeeze(0)
112
+ img = img[0::k, 0::k]
113
+ return img.detach().cpu().numpy()
114
+
115
+ def prepare_for_ssim_rgb(img, k):
116
+ import torch
117
+ with torch.no_grad():
118
+ img = torch.from_numpy(img).float() #HxWx3
119
+
120
+ conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
121
+ conv.weight.requires_grad = False
122
+ conv.weight[:, :, :, :] = 1. / (k * k)
123
+
124
+ new_img = []
125
+
126
+ for i in range(3):
127
+ new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
128
+
129
+ return torch.stack(new_img, dim=2).detach().cpu().numpy()
130
+
131
+ def _3d_gaussian_calculator(img, conv3d):
132
+ out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
133
+ return out
134
+
135
+ def _generate_3d_gaussian_kernel():
136
+ kernel = cv2.getGaussianKernel(11, 1.5)
137
+ window = np.outer(kernel, kernel.transpose())
138
+ kernel_3 = cv2.getGaussianKernel(11, 1.5)
139
+ kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
140
+ conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
141
+ conv3d.weight.requires_grad = False
142
+ conv3d.weight[0, 0, :, :, :] = kernel
143
+ return conv3d
144
+
145
+ def _ssim_3d(img1, img2, max_value):
146
+ assert len(img1.shape) == 3 and len(img2.shape) == 3
147
+ """Calculate SSIM (structural similarity) for one channel images.
148
+
149
+ It is called by func:`calculate_ssim`.
150
+
151
+ Args:
152
+ img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
153
+ img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
154
+
155
+ Returns:
156
+ float: ssim result.
157
+ """
158
+ C1 = (0.01 * max_value) ** 2
159
+ C2 = (0.03 * max_value) ** 2
160
+ img1 = img1.astype(np.float64)
161
+ img2 = img2.astype(np.float64)
162
+
163
+ kernel = _generate_3d_gaussian_kernel().cuda()
164
+
165
+ img1 = torch.tensor(img1).float().cuda()
166
+ img2 = torch.tensor(img2).float().cuda()
167
+
168
+
169
+ mu1 = _3d_gaussian_calculator(img1, kernel)
170
+ mu2 = _3d_gaussian_calculator(img2, kernel)
171
+
172
+ mu1_sq = mu1 ** 2
173
+ mu2_sq = mu2 ** 2
174
+ mu1_mu2 = mu1 * mu2
175
+ sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
176
+ sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
177
+ sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
178
+
179
+ ssim_map = ((2 * mu1_mu2 + C1) *
180
+ (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
181
+ (sigma1_sq + sigma2_sq + C2))
182
+ return float(ssim_map.mean())
183
+
184
+ def _ssim_cly(img1, img2):
185
+ assert len(img1.shape) == 2 and len(img2.shape) == 2
186
+ """Calculate SSIM (structural similarity) for one channel images.
187
+
188
+ It is called by func:`calculate_ssim`.
189
+
190
+ Args:
191
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
192
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
193
+
194
+ Returns:
195
+ float: ssim result.
196
+ """
197
+
198
+ C1 = (0.01 * 255)**2
199
+ C2 = (0.03 * 255)**2
200
+ img1 = img1.astype(np.float64)
201
+ img2 = img2.astype(np.float64)
202
+
203
+ kernel = cv2.getGaussianKernel(11, 1.5)
204
+ # print(kernel)
205
+ window = np.outer(kernel, kernel.transpose())
206
+
207
+ bt = cv2.BORDER_REPLICATE
208
+
209
+ mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
210
+ mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
211
+
212
+ mu1_sq = mu1**2
213
+ mu2_sq = mu2**2
214
+ mu1_mu2 = mu1 * mu2
215
+ sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
216
+ sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
217
+ sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
218
+
219
+ ssim_map = ((2 * mu1_mu2 + C1) *
220
+ (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
221
+ (sigma1_sq + sigma2_sq + C2))
222
+ return ssim_map.mean()
223
+
224
+
225
+ def calculate_ssim(img1,
226
+ img2,
227
+ crop_border,
228
+ input_order='HWC',
229
+ test_y_channel=False):
230
+ """Calculate SSIM (structural similarity).
231
+
232
+ Ref:
233
+ Image quality assessment: From error visibility to structural similarity
234
+
235
+ The results are the same as that of the official released MATLAB code in
236
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
237
+
238
+ For three-channel images, SSIM is calculated for each channel and then
239
+ averaged.
240
+
241
+ Args:
242
+ img1 (ndarray): Images with range [0, 255].
243
+ img2 (ndarray): Images with range [0, 255].
244
+ crop_border (int): Cropped pixels in each edge of an image. These
245
+ pixels are not involved in the SSIM calculation.
246
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
247
+ Default: 'HWC'.
248
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
249
+
250
+ Returns:
251
+ float: ssim result.
252
+ """
253
+
254
+ assert img1.shape == img2.shape, (
255
+ f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
256
+ if input_order not in ['HWC', 'CHW']:
257
+ raise ValueError(
258
+ f'Wrong input_order {input_order}. Supported input_orders are '
259
+ '"HWC" and "CHW"')
260
+
261
+ if type(img1) == torch.Tensor:
262
+ if len(img1.shape) == 4:
263
+ img1 = img1.squeeze(0)
264
+ img1 = img1.detach().cpu().numpy().transpose(1,2,0)
265
+ if type(img2) == torch.Tensor:
266
+ if len(img2.shape) == 4:
267
+ img2 = img2.squeeze(0)
268
+ img2 = img2.detach().cpu().numpy().transpose(1,2,0)
269
+
270
+ img1 = reorder_image(img1, input_order=input_order)
271
+ img2 = reorder_image(img2, input_order=input_order)
272
+
273
+ img1 = img1.astype(np.float64)
274
+ img2 = img2.astype(np.float64)
275
+
276
+ if crop_border != 0:
277
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
278
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
279
+
280
+ if test_y_channel:
281
+ img1 = to_y_channel(img1)
282
+ img2 = to_y_channel(img2)
283
+ return _ssim_cly(img1[..., 0], img2[..., 0])
284
+
285
+
286
+ ssims = []
287
+ # ssims_before = []
288
+
289
+ # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
290
+ # print('.._skimage',
291
+ # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
292
+ max_value = 1 if img1.max() <= 1 else 255
293
+ with torch.no_grad():
294
+ final_ssim = _ssim_3d(img1, img2, max_value)
295
+ ssims.append(final_ssim)
296
+
297
+ # for i in range(img1.shape[2]):
298
+ # ssims_before.append(_ssim(img1, img2))
299
+
300
+ # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
301
+ # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
302
+
303
+ return np.array(ssims).mean()
basicsr/models/.DS_Store ADDED
Binary file (8.2 kB). View file
 
basicsr/models/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from os import path as osp
3
+
4
+ from basicsr.utils import get_root_logger, scandir
5
+
6
+ # automatically scan and import model modules
7
+ # scan all the files under the 'models' folder and collect files ending with
8
+ # '_model.py'
9
+ model_folder = osp.dirname(osp.abspath(__file__))
10
+ model_filenames = [
11
+ osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12
+ if v.endswith('_model.py')
13
+ ]
14
+ # import all the model modules
15
+ _model_modules = [
16
+ importlib.import_module(f'basicsr.models.{file_name}')
17
+ for file_name in model_filenames
18
+ ]
19
+
20
+
21
+ def create_model(opt):
22
+ """Create model.
23
+
24
+ Args:
25
+ opt (dict): Configuration. It constains:
26
+ model_type (str): Model type.
27
+ """
28
+ model_type = opt['model_type']
29
+
30
+ # dynamic instantiation
31
+ for module in _model_modules:
32
+ model_cls = getattr(module, model_type, None)
33
+ if model_cls is not None:
34
+ break
35
+ if model_cls is None:
36
+ raise ValueError(f'Model {model_type} is not found.')
37
+
38
+ model = model_cls(opt)
39
+
40
+ logger = get_root_logger()
41
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
42
+ return model
basicsr/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.24 kB). View file
 
basicsr/models/__pycache__/base_model.cpython-37.pyc ADDED
Binary file (12.9 kB). View file
 
basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc ADDED
Binary file (9.52 kB). View file
 
basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc ADDED
Binary file (8.91 kB). View file
 
basicsr/models/archs/FPro_arch.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from pdb import set_trace as stx
6
+ import numbers
7
+
8
+ from einops import rearrange
9
+
10
+ ##########################################################################
11
+ ## Layer Norm
12
+
13
+ def to_3d(x):
14
+ return rearrange(x, 'b c h w -> b (h w) c')
15
+
16
+ def to_4d(x,h,w):
17
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
18
+
19
+ class BiasFree_LayerNorm(nn.Module):
20
+ def __init__(self, normalized_shape):
21
+ super(BiasFree_LayerNorm, self).__init__()
22
+ if isinstance(normalized_shape, numbers.Integral):
23
+ normalized_shape = (normalized_shape,)
24
+ normalized_shape = torch.Size(normalized_shape)
25
+
26
+ assert len(normalized_shape) == 1
27
+
28
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
29
+ self.normalized_shape = normalized_shape
30
+
31
+ def forward(self, x):
32
+ sigma = x.var(-1, keepdim=True, unbiased=False)
33
+ return x / torch.sqrt(sigma+1e-5) * self.weight
34
+
35
+ class WithBias_LayerNorm(nn.Module):
36
+ def __init__(self, normalized_shape):
37
+ super(WithBias_LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ normalized_shape = (normalized_shape,)
40
+ normalized_shape = torch.Size(normalized_shape)
41
+
42
+ assert len(normalized_shape) == 1
43
+
44
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
45
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
46
+ self.normalized_shape = normalized_shape
47
+
48
+ def forward(self, x):
49
+ mu = x.mean(-1, keepdim=True)
50
+ sigma = x.var(-1, keepdim=True, unbiased=False)
51
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
52
+
53
+
54
+ class LayerNorm(nn.Module):
55
+ def __init__(self, dim, LayerNorm_type):
56
+ super(LayerNorm, self).__init__()
57
+ if LayerNorm_type =='BiasFree':
58
+ self.body = BiasFree_LayerNorm(dim)
59
+ else:
60
+ self.body = WithBias_LayerNorm(dim)
61
+
62
+ def forward(self, x):
63
+ h, w = x.shape[-2:]
64
+ return to_4d(self.body(to_3d(x)), h, w)
65
+
66
+
67
+
68
+ ##########################################################################
69
+ ## Gated-Dconv Feed-Forward Network (GDFN)
70
+ class FeedForward(nn.Module):
71
+ def __init__(self, dim, ffn_expansion_factor, bias):
72
+ super(FeedForward, self).__init__()
73
+
74
+ hidden_features = int(dim*ffn_expansion_factor)
75
+
76
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
77
+
78
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
79
+
80
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
81
+
82
+ def forward(self, x):
83
+ x = self.project_in(x)
84
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
85
+ x = F.gelu(x1) * x2
86
+ x = self.project_out(x)
87
+ return x
88
+
89
+
90
+
91
+ ##########################################################################
92
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
93
+ class Attention(nn.Module):
94
+ def __init__(self, dim, num_heads, bias):
95
+ super(Attention, self).__init__()
96
+ self.num_heads = num_heads
97
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
98
+
99
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
100
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
101
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
102
+
103
+
104
+
105
+ def forward(self, x):
106
+ b,c,h,w = x.shape
107
+
108
+ qkv = self.qkv_dwconv(self.qkv(x))
109
+ q,k,v = qkv.chunk(3, dim=1)
110
+
111
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
112
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
113
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
114
+
115
+ q = torch.nn.functional.normalize(q, dim=-1)
116
+ k = torch.nn.functional.normalize(k, dim=-1)
117
+
118
+ attn = (q @ k.transpose(-2, -1).contiguous()) * self.temperature
119
+ attn = attn.softmax(dim=-1)
120
+
121
+ out = (attn @ v)
122
+
123
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
124
+
125
+ out = self.project_out(out)
126
+ return out
127
+
128
+
129
+
130
+ ##########################################################################
131
+ class TransformerBlock(nn.Module):
132
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, isAtt):
133
+ super(TransformerBlock, self).__init__()
134
+ self.isAtt = isAtt
135
+ if self.isAtt:
136
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
137
+ self.attn = Attention(dim, num_heads, bias)
138
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
139
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
140
+
141
+ def forward(self, x):
142
+ if self.isAtt:
143
+ x = x + self.attn(self.norm1(x))
144
+ x = x + self.ffn(self.norm2(x))
145
+
146
+ return x
147
+
148
+
149
+
150
+ ##########################################################################
151
+ ## Overlapped image patch embedding with 3x3 Conv
152
+ class OverlapPatchEmbed(nn.Module):
153
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
154
+ super(OverlapPatchEmbed, self).__init__()
155
+
156
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
157
+
158
+ def forward(self, x):
159
+ x = self.proj(x)
160
+
161
+ return x
162
+
163
+ ########### window operation#############
164
+ def window_partition(x, win_size, dilation_rate=1):
165
+ B, H, W, C = x.shape
166
+ if dilation_rate !=1:
167
+ x = x.permute(0,3,1,2) # B, C, H, W
168
+ assert type(dilation_rate) is int, 'dilation_rate should be a int'
169
+ x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
170
+ windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
171
+ windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C
172
+ else:
173
+ x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
174
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
175
+ return windows
176
+
177
+ def window_reverse(windows, win_size, H, W, dilation_rate=1):
178
+ # B' ,Wh ,Ww ,C
179
+ B = int(windows.shape[0] / (H * W / win_size / win_size))
180
+ x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
181
+ if dilation_rate !=1:
182
+ x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
183
+ x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size)
184
+ else:
185
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
186
+ return x
187
+
188
+ class lowFrequencyPromptFusion(nn.Module):
189
+ def __init__(self, dim, dim_bak, num_heads,win_size=8, bias=False):
190
+ super(lowFrequencyPromptFusion, self).__init__()
191
+ self.num_heads = num_heads
192
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
193
+ self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
194
+ self.ap_kv = nn.AdaptiveAvgPool2d(1)
195
+ self.kv = nn.Conv2d(dim_bak, dim * 2, kernel_size=1, bias=bias)
196
+
197
+ self.project_out = nn.Conv2d( dim, dim, kernel_size=1, bias=bias)
198
+
199
+ def forward(self, feature, prompt_feature):
200
+ b, c1,h,w = feature.shape
201
+ _, c2,_,_ = prompt_feature.shape
202
+
203
+ query = self.q(feature).reshape(b, h * w, self.num_heads, c1 // self.num_heads).permute(0, 2, 1, 3).contiguous()
204
+
205
+ prompt_feature = self.ap_kv(prompt_feature)#.reshape(b, c2, -1).permute(0, 2, 1)
206
+ key_value = self.kv(prompt_feature).reshape(b, 2*c1, -1).permute(0, 2, 1).contiguous().reshape(b, -1, 2, self.num_heads, c1 // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
207
+ key, value = key_value[0], key_value[1]
208
+
209
+ attn = (query @ key.transpose(-2, -1).contiguous()) * self.temperature
210
+ attn = attn.softmax(dim=-1)
211
+
212
+ out = (attn @ value)
213
+
214
+ out = rearrange(out, 'b head (h w) c -> b (head c) h w', head=self.num_heads, h=h, w=w)
215
+ out = self.project_out(out)
216
+
217
+ return out
218
+
219
+ class LinearProjection(nn.Module):
220
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True, isQuery = True):
221
+ super().__init__()
222
+ self.isQuery =isQuery
223
+ inner_dim = dim_head * heads
224
+ self.heads = heads
225
+ if self.isQuery:
226
+ self.to_q = nn.Linear(dim, inner_dim, bias = bias)
227
+ else:
228
+ self.to_kv = nn.Linear(dim, 2*inner_dim, bias = bias)
229
+ self.dim = dim
230
+ self.inner_dim = inner_dim
231
+
232
+ def forward(self, x, attn_kv=None):
233
+ B_, N, C = x.shape
234
+ if attn_kv is not None:
235
+ attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1)
236
+ else:
237
+ attn_kv = x
238
+ N_kv = attn_kv.size(1)
239
+ if self.isQuery:
240
+ q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
241
+ q = q[0]
242
+ return q
243
+ else:
244
+ C = self.inner_dim
245
+ kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
246
+ k, v = kv[0], kv[1]
247
+ return k,v
248
+
249
+ class highFrequencyPromptFusion(nn.Module):
250
+ def __init__(self, dim, dim_bak,win_size, num_heads, qkv_bias=True, qk_scale=None, bias=False):
251
+ super(highFrequencyPromptFusion, self).__init__()
252
+ self.num_heads = num_heads
253
+ self.win_size = win_size # Wh, Ww
254
+ head_dim = dim // num_heads
255
+ self.scale = qk_scale or head_dim ** -0.5
256
+
257
+ self.to_q = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias,isQuery=True)
258
+ self.to_kv = LinearProjection(dim_bak,num_heads,dim//num_heads,bias=qkv_bias,isQuery=False)
259
+
260
+ self.kv_dwconv = nn.Conv2d(dim_bak , dim_bak, kernel_size=3, stride=1, padding=1, groups=dim_bak, bias=bias)
261
+
262
+ self.softmax = nn.Softmax(dim=-1)
263
+
264
+ self.project_out = nn.Linear(dim, dim)
265
+
266
+ def forward(self, query_feature, key_value_feature):
267
+
268
+ b,c,h,w = query_feature.shape
269
+ _,c_2,_,_ = key_value_feature.shape
270
+
271
+ key_value_feature = self.kv_dwconv(key_value_feature)
272
+
273
+ # partition windows
274
+ query_feature = rearrange(query_feature, ' b c1 h w -> b h w c1 ', h=h, w=w)
275
+ query_feature_windows = window_partition(query_feature, self.win_size) # nW*B, win_size, win_size, C N*C->C
276
+ query_feature_windows = query_feature_windows.view(-1, self.win_size * self.win_size, c) # nW*B, win_size*win_size, C
277
+
278
+ key_value_feature = rearrange(key_value_feature, ' b c2 h w -> b h w c2 ', h=h, w=w)
279
+ key_value_feature_windows = window_partition(key_value_feature, self.win_size) # nW*B, win_size, win_size, C N*C->C
280
+ key_value_feature_windows = key_value_feature_windows.view(-1, self.win_size * self.win_size, c_2) # nW*B, win_size*win_size, C
281
+
282
+ B_, N, C = query_feature_windows.shape
283
+
284
+ query = self.to_q(query_feature_windows)
285
+ query = query * self.scale
286
+
287
+ key,value = self.to_kv(key_value_feature_windows)
288
+ attn = (query @ key.transpose(-2, -1).contiguous())
289
+ attn = attn.softmax(dim=-1)
290
+
291
+ out = (attn @ value).transpose(1, 2).contiguous().reshape(B_, N, C)
292
+
293
+ out = self.project_out(out)
294
+
295
+ # merge windows
296
+ attn_windows = out.view(-1, self.win_size, self.win_size, C)
297
+ attn_windows = window_reverse(attn_windows, self.win_size, h, w) # B H' W' C
298
+ return rearrange(attn_windows, 'b h w c -> b c h w', h=h, w=w)
299
+
300
+ ##########################################################################
301
+ ## channel dynamic filters
302
+ class dynamic_filter_channel(nn.Module):
303
+ def __init__(self, inchannels, kernel_size=3, stride=1, group=8):
304
+ super(dynamic_filter_channel, self).__init__()
305
+ self.stride = stride
306
+ self.kernel_size = kernel_size
307
+ self.group = group
308
+
309
+ self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
310
+ self.conv_gate = nn.Conv2d(group*kernel_size**2, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
311
+ self.act_gate = nn.Sigmoid()
312
+ self.bn = nn.BatchNorm2d(group*kernel_size**2)
313
+ self.act = nn.Softmax(dim=-2)
314
+ nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
315
+
316
+ self.pad = nn.ReflectionPad2d(kernel_size//2)
317
+
318
+ self.ap_1 = nn.AdaptiveAvgPool2d((1, 1))
319
+ #self.ap_2 = nn.AdaptiveMaxPool2d((1, 1))
320
+
321
+ def forward(self, x):
322
+ identity_input = x
323
+ low_filter1 = self.ap_1(x)
324
+ #low_filter2 = self.ap_2(x)
325
+ low_filter = self.conv(low_filter1)
326
+ low_filter = low_filter * self.act_gate(self.conv_gate(low_filter))
327
+ low_filter = self.bn(low_filter)
328
+
329
+ n, c, h, w = x.shape
330
+ x = F.unfold(self.pad(x), kernel_size=self.kernel_size).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w)
331
+
332
+ n,c1,p,q = low_filter.shape
333
+ low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2)
334
+
335
+ low_filter = self.act(low_filter)
336
+ # print('low_filter size',low_filter.shape)
337
+ # print('low_filter n,c1,p,q',n,c1,p,q)
338
+
339
+ low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w)
340
+
341
+ out_high = identity_input - low_part
342
+ return low_part, out_high
343
+
344
+
345
+ class frequenctSpecificPromptGenetator(nn.Module):
346
+ def __init__(self, dim=3,h=128,w=65, flag_highF=True):
347
+ super().__init__()
348
+ self.flag_highF = flag_highF
349
+ k_size = 3
350
+ if flag_highF:
351
+ w = (w - 1) * 2
352
+ self.w = w
353
+ self.h = h
354
+ self.weight = nn.Parameter(torch.randn(1,dim, h, w, dtype=torch.float32) * 0.02)
355
+ self.body = nn.Sequential(nn.Conv2d(dim, dim, (1,k_size), padding=(0, k_size//2), groups=dim),
356
+ nn.Conv2d(dim, dim, (k_size,1), padding=(k_size//2, 0), groups=dim),
357
+ nn.GELU())
358
+ else:
359
+ self.complex_weight = nn.Parameter(torch.randn(1,dim, h, w, 2, dtype=torch.float32) * 0.02)
360
+ self.body = nn.Sequential(nn.Conv2d(2*dim,2*dim,kernel_size=1,stride=1),
361
+ nn.GELU(),
362
+ )
363
+
364
+
365
+ def forward(self, ffm, H, W):
366
+ if self.flag_highF:
367
+ ffm = F.interpolate(ffm, size=(H, W), mode='bilinear')
368
+ y_att = self.body(ffm)
369
+
370
+ y_f = y_att * ffm
371
+ y = y_f * self.weight
372
+
373
+ else:
374
+ ffm = F.interpolate(ffm, size=(H, W), mode='bicubic')
375
+ y = torch.fft.rfft2(ffm.to(torch.float32).cuda())
376
+ y_imag = y.imag
377
+ y_real = y.real
378
+ y_f = torch.cat([y_real, y_imag], dim=1)
379
+ weight = torch.complex(self.complex_weight[..., 0],self.complex_weight[..., 1])
380
+ y_att = self.body(y_f)
381
+ y_f = y_f * y_att
382
+ y_real, y_imag = torch.chunk(y_f, 2, dim=1)
383
+ y = torch.complex(y_real, y_imag)
384
+ y = y * weight
385
+ y = torch.fft.irfft2(y, s=(H, W))
386
+
387
+ return y
388
+
389
+ ##########################################################################
390
+ ## PromptModule
391
+ class PromptModule(nn.Module):
392
+ def __init__(self, basic_dim=32, dim=32, input_resolution=128):
393
+ super().__init__()
394
+ h = input_resolution
395
+ w = input_resolution//2 +1
396
+ self.simple_Fusion = nn.Conv2d(2*dim,dim,kernel_size=1,stride=1)
397
+
398
+ self.FSPG_high = frequenctSpecificPromptGenetator(basic_dim,h,w, flag_highF=True)
399
+ self.FSPG_low = frequenctSpecificPromptGenetator(basic_dim,h,w, flag_highF=False)
400
+
401
+
402
+ self.modulator_hi = highFrequencyPromptFusion(dim, basic_dim, win_size=8, num_heads=2, bias=False)
403
+ self.modulator_lo = lowFrequencyPromptFusion(dim, basic_dim, win_size=8, num_heads=2, bias=False)
404
+ def forward(self, low_part, out_high , x):
405
+ b,c,h,w = x.shape
406
+
407
+ y_h = self.FSPG_high(out_high, h, w)
408
+ y_l = self.FSPG_low(low_part, h, w)
409
+
410
+ y_h = self.modulator_hi(x,y_h)
411
+ y_l = self.modulator_lo(x,y_l)
412
+
413
+ x = self.simple_Fusion(torch.cat([y_h,y_l], dim=1))
414
+
415
+ return x
416
+
417
+ ## PromptModule
418
+ class splitFrequencyModule(nn.Module):
419
+ def __init__(self, basic_dim=32, dim=32, input_resolution=128):
420
+ super().__init__()
421
+
422
+ self.dyna_channel = dynamic_filter_channel(inchannels=basic_dim)
423
+ def forward(self, F_low ):
424
+ _,c_basic,h_ori, w_ori = F_low.shape
425
+
426
+ low_part, out_high = self.dyna_channel(F_low)
427
+
428
+ return low_part, out_high
429
+
430
+
431
+ ##########################################################################
432
+ ## Resizing modules
433
+ class Downsample(nn.Module):
434
+ def __init__(self, n_feat):
435
+ super(Downsample, self).__init__()
436
+
437
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
438
+ nn.PixelUnshuffle(2))
439
+
440
+ def forward(self, x):
441
+ return self.body(x)
442
+
443
+ class Upsample(nn.Module):
444
+ def __init__(self, n_feat):
445
+ super(Upsample, self).__init__()
446
+
447
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
448
+ nn.PixelShuffle(2))
449
+
450
+ def forward(self, x):
451
+ return self.body(x)
452
+
453
+ ##########################################################################
454
+ ##---------- FPro -----------------------
455
+ class FPro(nn.Module):
456
+ def __init__(self,
457
+ inp_channels=3,
458
+ out_channels=3,
459
+ dim = 48,
460
+ num_blocks = [4,6,6,8],
461
+ num_refinement_blocks = 4,
462
+ heads = [1,2,4,8],
463
+ ffn_expansion_factor = 2.66,
464
+ bias = False,
465
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
466
+ dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
467
+ ):
468
+
469
+ super(FPro, self).__init__()
470
+
471
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
472
+
473
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=False) for i in range(num_blocks[0])])
474
+
475
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
476
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=False) for i in range(num_blocks[1])])
477
+
478
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
479
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=False) for i in range(num_blocks[2])])
480
+
481
+ self.splitFre =splitFrequencyModule(basic_dim= dim,dim=int(dim*2**2),input_resolution=32)
482
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_blocks[2])])
483
+ self.prompt_d3 = PromptModule(basic_dim= dim,dim=int(dim*2**2),input_resolution=64)
484
+
485
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
486
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
487
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_blocks[1])])
488
+ self.prompt_d2 = PromptModule(basic_dim= dim,dim=int(dim*2**1),input_resolution=128)
489
+
490
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
491
+
492
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_blocks[0])])
493
+ self.prompt_d1 = PromptModule(basic_dim= dim,dim=int(dim*2**1),input_resolution=256)
494
+
495
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, isAtt=True) for i in range(num_refinement_blocks)])
496
+ self.prompt_r = PromptModule(basic_dim= dim,dim=int(dim*2**1),input_resolution=256)
497
+ #### For Dual-Pixel Defocus Deblurring Task ####
498
+ self.dual_pixel_task = dual_pixel_task
499
+ if self.dual_pixel_task:
500
+ self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
501
+ ###########################
502
+
503
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
504
+
505
+ def forward(self, inp_img):
506
+
507
+ inp_enc_level1 = self.patch_embed(inp_img)
508
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
509
+
510
+ inp_enc_level2 = self.down1_2(out_enc_level1)
511
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
512
+
513
+ inp_enc_level3 = self.down2_3(out_enc_level2)
514
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
515
+
516
+
517
+ out_dec_level3 = self.decoder_level3(out_enc_level3)
518
+ low_part, out_high = self.splitFre(inp_enc_level1)
519
+ out_dec_level3 = self.prompt_d3(low_part, out_high,out_dec_level3) + out_dec_level3
520
+
521
+ inp_dec_level2 = self.up3_2(out_dec_level3)
522
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
523
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
524
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
525
+ out_dec_level2 = self.prompt_d2(low_part, out_high,out_dec_level2) + out_dec_level2
526
+
527
+ inp_dec_level1 = self.up2_1(out_dec_level2)
528
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
529
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
530
+ out_dec_level1 = self.prompt_d1(low_part, out_high,out_dec_level1) + out_dec_level1
531
+
532
+ out_dec_level1 = self.refinement(out_dec_level1)
533
+ out_dec_level1 = self.prompt_r(low_part, out_high,out_dec_level1) + out_dec_level1
534
+
535
+ #### For Dual-Pixel Defocus Deblurring Task ####
536
+ if self.dual_pixel_task:
537
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
538
+ out_dec_level1 = self.output(out_dec_level1)
539
+ ###########################
540
+ else:
541
+ out_dec_level1 = self.output(out_dec_level1) + inp_img
542
+
543
+
544
+ return out_dec_level1
545
+
basicsr/models/archs/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from os import path as osp
3
+
4
+ from basicsr.utils import scandir
5
+
6
+ # automatically scan and import arch modules
7
+ # scan all the files under the 'archs' folder and collect files ending with
8
+ # '_arch.py'
9
+ arch_folder = osp.dirname(osp.abspath(__file__))
10
+ arch_filenames = [
11
+ osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12
+ if v.endswith('_arch.py')
13
+ ]
14
+ # import all the arch modules
15
+ _arch_modules = [
16
+ importlib.import_module(f'basicsr.models.archs.{file_name}')
17
+ for file_name in arch_filenames
18
+ ]
19
+
20
+
21
+ def dynamic_instantiation(modules, cls_type, opt):
22
+ """Dynamically instantiate class.
23
+
24
+ Args:
25
+ modules (list[importlib modules]): List of modules from importlib
26
+ files.
27
+ cls_type (str): Class type.
28
+ opt (dict): Class initialization kwargs.
29
+
30
+ Returns:
31
+ class: Instantiated class.
32
+ """
33
+
34
+ for module in modules:
35
+ cls_ = getattr(module, cls_type, None)
36
+ if cls_ is not None:
37
+ break
38
+ if cls_ is None:
39
+ raise ValueError(f'{cls_type} is not found.')
40
+ return cls_(**opt)
41
+
42
+
43
+ def define_network(opt):
44
+ network_type = opt.pop('type')
45
+ net = dynamic_instantiation(_arch_modules, network_type, opt)
46
+ return net
basicsr/models/archs/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc ADDED
Binary file (7.17 kB). View file
 
basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc ADDED
Binary file (6.01 kB). View file
 
basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc ADDED
Binary file (6.42 kB). View file
 
basicsr/models/archs/arch_util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from basicsr.utils import get_root_logger
9
+
10
+ # try:
11
+ # from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
12
+ # modulated_deform_conv)
13
+ # except ImportError:
14
+ # # print('Cannot import dcn. Ignore this warning if dcn is not used. '
15
+ # # 'Otherwise install BasicSR with compiling dcn.')
16
+ #
17
+
18
+ @torch.no_grad()
19
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
20
+ """Initialize network weights.
21
+
22
+ Args:
23
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
24
+ scale (float): Scale initialized weights, especially for residual
25
+ blocks. Default: 1.
26
+ bias_fill (float): The value to fill bias. Default: 0
27
+ kwargs (dict): Other arguments for initialization function.
28
+ """
29
+ if not isinstance(module_list, list):
30
+ module_list = [module_list]
31
+ for module in module_list:
32
+ for m in module.modules():
33
+ if isinstance(m, nn.Conv2d):
34
+ init.kaiming_normal_(m.weight, **kwargs)
35
+ m.weight.data *= scale
36
+ if m.bias is not None:
37
+ m.bias.data.fill_(bias_fill)
38
+ elif isinstance(m, nn.Linear):
39
+ init.kaiming_normal_(m.weight, **kwargs)
40
+ m.weight.data *= scale
41
+ if m.bias is not None:
42
+ m.bias.data.fill_(bias_fill)
43
+ elif isinstance(m, _BatchNorm):
44
+ init.constant_(m.weight, 1)
45
+ if m.bias is not None:
46
+ m.bias.data.fill_(bias_fill)
47
+
48
+
49
+ def make_layer(basic_block, num_basic_block, **kwarg):
50
+ """Make layers by stacking the same blocks.
51
+
52
+ Args:
53
+ basic_block (nn.module): nn.module class for basic block.
54
+ num_basic_block (int): number of blocks.
55
+
56
+ Returns:
57
+ nn.Sequential: Stacked blocks in nn.Sequential.
58
+ """
59
+ layers = []
60
+ for _ in range(num_basic_block):
61
+ layers.append(basic_block(**kwarg))
62
+ return nn.Sequential(*layers)
63
+
64
+
65
+ class ResidualBlockNoBN(nn.Module):
66
+ """Residual block without BN.
67
+
68
+ It has a style of:
69
+ ---Conv-ReLU-Conv-+-
70
+ |________________|
71
+
72
+ Args:
73
+ num_feat (int): Channel number of intermediate features.
74
+ Default: 64.
75
+ res_scale (float): Residual scale. Default: 1.
76
+ pytorch_init (bool): If set to True, use pytorch default init,
77
+ otherwise, use default_init_weights. Default: False.
78
+ """
79
+
80
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
81
+ super(ResidualBlockNoBN, self).__init__()
82
+ self.res_scale = res_scale
83
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
85
+ self.relu = nn.ReLU(inplace=True)
86
+
87
+ if not pytorch_init:
88
+ default_init_weights([self.conv1, self.conv2], 0.1)
89
+
90
+ def forward(self, x):
91
+ identity = x
92
+ out = self.conv2(self.relu(self.conv1(x)))
93
+ return identity + out * self.res_scale
94
+
95
+
96
+ class Upsample(nn.Sequential):
97
+ """Upsample module.
98
+
99
+ Args:
100
+ scale (int): Scale factor. Supported scales: 2^n and 3.
101
+ num_feat (int): Channel number of intermediate features.
102
+ """
103
+
104
+ def __init__(self, scale, num_feat):
105
+ m = []
106
+ if (scale & (scale - 1)) == 0: # scale = 2^n
107
+ for _ in range(int(math.log(scale, 2))):
108
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
109
+ m.append(nn.PixelShuffle(2))
110
+ elif scale == 3:
111
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
112
+ m.append(nn.PixelShuffle(3))
113
+ else:
114
+ raise ValueError(f'scale {scale} is not supported. '
115
+ 'Supported scales: 2^n and 3.')
116
+ super(Upsample, self).__init__(*m)
117
+
118
+
119
+ def flow_warp(x,
120
+ flow,
121
+ interp_mode='bilinear',
122
+ padding_mode='zeros',
123
+ align_corners=True):
124
+ """Warp an image or feature map with optical flow.
125
+
126
+ Args:
127
+ x (Tensor): Tensor with size (n, c, h, w).
128
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
129
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
130
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
131
+ Default: 'zeros'.
132
+ align_corners (bool): Before pytorch 1.3, the default value is
133
+ align_corners=True. After pytorch 1.3, the default value is
134
+ align_corners=False. Here, we use the True as default.
135
+
136
+ Returns:
137
+ Tensor: Warped image or feature map.
138
+ """
139
+ assert x.size()[-2:] == flow.size()[1:3]
140
+ _, _, h, w = x.size()
141
+ # create mesh grid
142
+ grid_y, grid_x = torch.meshgrid(
143
+ torch.arange(0, h).type_as(x),
144
+ torch.arange(0, w).type_as(x))
145
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
146
+ grid.requires_grad = False
147
+
148
+ vgrid = grid + flow
149
+ # scale grid to [-1,1]
150
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
151
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
152
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
153
+ output = F.grid_sample(
154
+ x,
155
+ vgrid_scaled,
156
+ mode=interp_mode,
157
+ padding_mode=padding_mode,
158
+ align_corners=align_corners)
159
+
160
+ # TODO, what if align_corners=False
161
+ return output
162
+
163
+
164
+ def resize_flow(flow,
165
+ size_type,
166
+ sizes,
167
+ interp_mode='bilinear',
168
+ align_corners=False):
169
+ """Resize a flow according to ratio or shape.
170
+
171
+ Args:
172
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
173
+ size_type (str): 'ratio' or 'shape'.
174
+ sizes (list[int | float]): the ratio for resizing or the final output
175
+ shape.
176
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
177
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
178
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
179
+ ratio > 1.0).
180
+ 2) The order of output_size should be [out_h, out_w].
181
+ interp_mode (str): The mode of interpolation for resizing.
182
+ Default: 'bilinear'.
183
+ align_corners (bool): Whether align corners. Default: False.
184
+
185
+ Returns:
186
+ Tensor: Resized flow.
187
+ """
188
+ _, _, flow_h, flow_w = flow.size()
189
+ if size_type == 'ratio':
190
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
191
+ elif size_type == 'shape':
192
+ output_h, output_w = sizes[0], sizes[1]
193
+ else:
194
+ raise ValueError(
195
+ f'Size type should be ratio or shape, but got type {size_type}.')
196
+
197
+ input_flow = flow.clone()
198
+ ratio_h = output_h / flow_h
199
+ ratio_w = output_w / flow_w
200
+ input_flow[:, 0, :, :] *= ratio_w
201
+ input_flow[:, 1, :, :] *= ratio_h
202
+ resized_flow = F.interpolate(
203
+ input=input_flow,
204
+ size=(output_h, output_w),
205
+ mode=interp_mode,
206
+ align_corners=align_corners)
207
+ return resized_flow
208
+
209
+
210
+ # TODO: may write a cpp file
211
+ def pixel_unshuffle(x, scale):
212
+ """ Pixel unshuffle.
213
+
214
+ Args:
215
+ x (Tensor): Input feature with shape (b, c, hh, hw).
216
+ scale (int): Downsample ratio.
217
+
218
+ Returns:
219
+ Tensor: the pixel unshuffled feature.
220
+ """
221
+ b, c, hh, hw = x.size()
222
+ out_channel = c * (scale**2)
223
+ assert hh % scale == 0 and hw % scale == 0
224
+ h = hh // scale
225
+ w = hw // scale
226
+ x_view = x.view(b, c, h, scale, w, scale)
227
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
228
+
229
+
230
+ # class DCNv2Pack(ModulatedDeformConvPack):
231
+ # """Modulated deformable conv for deformable alignment.
232
+ #
233
+ # Different from the official DCNv2Pack, which generates offsets and masks
234
+ # from the preceding features, this DCNv2Pack takes another different
235
+ # features to generate offsets and masks.
236
+ #
237
+ # Ref:
238
+ # Delving Deep into Deformable Alignment in Video Super-Resolution.
239
+ # """
240
+ #
241
+ # def forward(self, x, feat):
242
+ # out = self.conv_offset(feat)
243
+ # o1, o2, mask = torch.chunk(out, 3, dim=1)
244
+ # offset = torch.cat((o1, o2), dim=1)
245
+ # mask = torch.sigmoid(mask)
246
+ #
247
+ # offset_absmean = torch.mean(torch.abs(offset))
248
+ # if offset_absmean > 50:
249
+ # logger = get_root_logger()
250
+ # logger.warning(
251
+ # f'Offset abs mean is {offset_absmean}, larger than 50.')
252
+ #
253
+ # return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
254
+ # self.stride, self.padding, self.dilation,
255
+ # self.groups, self.deformable_groups)
basicsr/models/base_model.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
7
+
8
+ from basicsr.models import lr_scheduler as lr_scheduler
9
+ from basicsr.utils.dist_util import master_only
10
+
11
+ logger = logging.getLogger('basicsr')
12
+
13
+
14
+ class BaseModel():
15
+ """Base model."""
16
+
17
+ def __init__(self, opt):
18
+ self.opt = opt
19
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
20
+ self.is_train = opt['is_train']
21
+ self.schedulers = []
22
+ self.optimizers = []
23
+
24
+ def feed_data(self, data):
25
+ pass
26
+
27
+ def optimize_parameters(self):
28
+ pass
29
+
30
+ def get_current_visuals(self):
31
+ pass
32
+
33
+ def save(self, epoch, current_iter):
34
+ """Save networks and training state."""
35
+ pass
36
+
37
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True):
38
+ """Validation function.
39
+
40
+ Args:
41
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
42
+ current_iter (int): Current iteration.
43
+ tb_logger (tensorboard logger): Tensorboard logger.
44
+ save_img (bool): Whether to save images. Default: False.
45
+ rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True
46
+ use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True
47
+ """
48
+ if self.opt['dist']:
49
+ return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
50
+ else:
51
+ return self.nondist_validation(dataloader, current_iter, tb_logger,
52
+ save_img, rgb2bgr, use_image)
53
+
54
+ def model_ema(self, decay=0.999):
55
+ net_g = self.get_bare_model(self.net_g)
56
+
57
+ net_g_params = dict(net_g.named_parameters())
58
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
59
+
60
+ for k in net_g_ema_params.keys():
61
+ net_g_ema_params[k].data.mul_(decay).add_(
62
+ net_g_params[k].data, alpha=1 - decay)
63
+
64
+ def get_current_log(self):
65
+ return self.log_dict
66
+
67
+ def model_to_device(self, net):
68
+ """Model to device. It also warps models with DistributedDataParallel
69
+ or DataParallel.
70
+
71
+ Args:
72
+ net (nn.Module)
73
+ """
74
+
75
+ net = net.to(self.device)
76
+ if self.opt['dist']:
77
+ find_unused_parameters = self.opt.get('find_unused_parameters',
78
+ False)
79
+ net = DistributedDataParallel(
80
+ net,
81
+ device_ids=[torch.cuda.current_device()],
82
+ find_unused_parameters=find_unused_parameters)
83
+ elif self.opt['num_gpu'] > 1:
84
+ net = DataParallel(net)
85
+ return net
86
+
87
+ def setup_schedulers(self):
88
+ """Set up schedulers."""
89
+ train_opt = self.opt['train']
90
+ scheduler_type = train_opt['scheduler'].pop('type')
91
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
92
+ for optimizer in self.optimizers:
93
+ self.schedulers.append(
94
+ lr_scheduler.MultiStepRestartLR(optimizer,
95
+ **train_opt['scheduler']))
96
+ elif scheduler_type == 'CosineAnnealingRestartLR':
97
+ for optimizer in self.optimizers:
98
+ self.schedulers.append(
99
+ lr_scheduler.CosineAnnealingRestartLR(
100
+ optimizer, **train_opt['scheduler']))
101
+ elif scheduler_type == 'CosineAnnealingWarmupRestarts':
102
+ for optimizer in self.optimizers:
103
+ self.schedulers.append(
104
+ lr_scheduler.CosineAnnealingWarmupRestarts(
105
+ optimizer, **train_opt['scheduler']))
106
+ elif scheduler_type == 'CosineAnnealingRestartCyclicLR':
107
+ for optimizer in self.optimizers:
108
+ self.schedulers.append(
109
+ lr_scheduler.CosineAnnealingRestartCyclicLR(
110
+ optimizer, **train_opt['scheduler']))
111
+ elif scheduler_type == 'TrueCosineAnnealingLR':
112
+ print('..', 'cosineannealingLR')
113
+ for optimizer in self.optimizers:
114
+ self.schedulers.append(
115
+ torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler']))
116
+ elif scheduler_type == 'CosineAnnealingLRWithRestart':
117
+ print('..', 'CosineAnnealingLR_With_Restart')
118
+ for optimizer in self.optimizers:
119
+ self.schedulers.append(
120
+ lr_scheduler.CosineAnnealingLRWithRestart(optimizer, **train_opt['scheduler']))
121
+ elif scheduler_type == 'LinearLR':
122
+ for optimizer in self.optimizers:
123
+ self.schedulers.append(
124
+ lr_scheduler.LinearLR(
125
+ optimizer, train_opt['total_iter']))
126
+ elif scheduler_type == 'VibrateLR':
127
+ for optimizer in self.optimizers:
128
+ self.schedulers.append(
129
+ lr_scheduler.VibrateLR(
130
+ optimizer, train_opt['total_iter']))
131
+ else:
132
+ raise NotImplementedError(
133
+ f'Scheduler {scheduler_type} is not implemented yet.')
134
+
135
+ def get_bare_model(self, net):
136
+ """Get bare model, especially under wrapping with
137
+ DistributedDataParallel or DataParallel.
138
+ """
139
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
140
+ net = net.module
141
+ return net
142
+
143
+ @master_only
144
+ def print_network(self, net):
145
+ """Print the str and parameter number of a network.
146
+
147
+ Args:
148
+ net (nn.Module)
149
+ """
150
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
151
+ net_cls_str = (f'{net.__class__.__name__} - '
152
+ f'{net.module.__class__.__name__}')
153
+ else:
154
+ net_cls_str = f'{net.__class__.__name__}'
155
+
156
+ net = self.get_bare_model(net)
157
+ net_str = str(net)
158
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
159
+
160
+ logger.info(
161
+ f'Network: {net_cls_str}, with parameters: {net_params:,d}')
162
+ logger.info(net_str)
163
+
164
+ def _set_lr(self, lr_groups_l):
165
+ """Set learning rate for warmup.
166
+
167
+ Args:
168
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
169
+ """
170
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
171
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
172
+ param_group['lr'] = lr
173
+
174
+ def _get_init_lr(self):
175
+ """Get the initial lr, which is set by the scheduler.
176
+ """
177
+ init_lr_groups_l = []
178
+ for optimizer in self.optimizers:
179
+ init_lr_groups_l.append(
180
+ [v['initial_lr'] for v in optimizer.param_groups])
181
+ return init_lr_groups_l
182
+
183
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
184
+ """Update learning rate.
185
+
186
+ Args:
187
+ current_iter (int): Current iteration.
188
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
189
+ Default: -1.
190
+ """
191
+ if current_iter > 1:
192
+ for scheduler in self.schedulers:
193
+ scheduler.step()
194
+ # set up warm-up learning rate
195
+ if current_iter < warmup_iter:
196
+ # get initial lr for each group
197
+ init_lr_g_l = self._get_init_lr()
198
+ # modify warming-up learning rates
199
+ # currently only support linearly warm up
200
+ warm_up_lr_l = []
201
+ for init_lr_g in init_lr_g_l:
202
+ warm_up_lr_l.append(
203
+ [v / warmup_iter * current_iter for v in init_lr_g])
204
+ # set learning rate
205
+ self._set_lr(warm_up_lr_l)
206
+
207
+ def get_current_learning_rate(self):
208
+ return [
209
+ param_group['lr']
210
+ for param_group in self.optimizers[0].param_groups
211
+ ]
212
+
213
+ @master_only
214
+ def save_network(self, net, net_label, current_iter, param_key='params'):
215
+ """Save networks.
216
+
217
+ Args:
218
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
219
+ net_label (str): Network label.
220
+ current_iter (int): Current iter number.
221
+ param_key (str | list[str]): The parameter key(s) to save network.
222
+ Default: 'params'.
223
+ """
224
+ if current_iter == -1:
225
+ current_iter = 'latest'
226
+ save_filename = f'{net_label}_{current_iter}.pth'
227
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
228
+
229
+ net = net if isinstance(net, list) else [net]
230
+ param_key = param_key if isinstance(param_key, list) else [param_key]
231
+ assert len(net) == len(
232
+ param_key), 'The lengths of net and param_key should be the same.'
233
+
234
+ save_dict = {}
235
+ for net_, param_key_ in zip(net, param_key):
236
+ net_ = self.get_bare_model(net_)
237
+ state_dict = net_.state_dict()
238
+ for key, param in state_dict.items():
239
+ if key.startswith('module.'): # remove unnecessary 'module.'
240
+ key = key[7:]
241
+ state_dict[key] = param.cpu()
242
+ save_dict[param_key_] = state_dict
243
+
244
+ torch.save(save_dict, save_path)
245
+
246
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
247
+ """Print keys with differnet name or different size when loading models.
248
+
249
+ 1. Print keys with differnet names.
250
+ 2. If strict=False, print the same key but with different tensor size.
251
+ It also ignore these keys with different sizes (not load).
252
+
253
+ Args:
254
+ crt_net (torch model): Current network.
255
+ load_net (dict): Loaded network.
256
+ strict (bool): Whether strictly loaded. Default: True.
257
+ """
258
+ crt_net = self.get_bare_model(crt_net)
259
+ crt_net = crt_net.state_dict()
260
+ crt_net_keys = set(crt_net.keys())
261
+ load_net_keys = set(load_net.keys())
262
+
263
+ if crt_net_keys != load_net_keys:
264
+ logger.warning('Current net - loaded net:')
265
+ for v in sorted(list(crt_net_keys - load_net_keys)):
266
+ logger.warning(f' {v}')
267
+ logger.warning('Loaded net - current net:')
268
+ for v in sorted(list(load_net_keys - crt_net_keys)):
269
+ logger.warning(f' {v}')
270
+
271
+ # check the size for the same keys
272
+ if not strict:
273
+ common_keys = crt_net_keys & load_net_keys
274
+ for k in common_keys:
275
+ if crt_net[k].size() != load_net[k].size():
276
+ logger.warning(
277
+ f'Size different, ignore [{k}]: crt_net: '
278
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
279
+ load_net[k + '.ignore'] = load_net.pop(k)
280
+
281
+ def load_network(self, net, load_path, strict=True, param_key='params'):
282
+ """Load network.
283
+
284
+ Args:
285
+ load_path (str): The path of networks to be loaded.
286
+ net (nn.Module): Network.
287
+ strict (bool): Whether strictly loaded.
288
+ param_key (str): The parameter key of loaded network. If set to
289
+ None, use the root 'path'.
290
+ Default: 'params'.
291
+ """
292
+ net = self.get_bare_model(net)
293
+ logger.info(
294
+ f'Loading {net.__class__.__name__} model from {load_path}.')
295
+ load_net = torch.load(
296
+ load_path, map_location=lambda storage, loc: storage)
297
+ if param_key is not None:
298
+ if param_key not in load_net and 'params' in load_net:
299
+ param_key = 'params'
300
+ logger.info('Loading: params_ema does not exist, use params.')
301
+ load_net = load_net[param_key]
302
+ print(' load net keys', load_net.keys)
303
+ # remove unnecessary 'module.'
304
+ for k, v in deepcopy(load_net).items():
305
+ if k.startswith('module.'):
306
+ load_net[k[7:]] = v
307
+ load_net.pop(k)
308
+ self._print_different_keys_loading(net, load_net, strict)
309
+ net.load_state_dict(load_net, strict=strict)
310
+
311
+ @master_only
312
+ def save_training_state(self, epoch, current_iter):
313
+ """Save training states during training, which will be used for
314
+ resuming.
315
+
316
+ Args:
317
+ epoch (int): Current epoch.
318
+ current_iter (int): Current iteration.
319
+ """
320
+ if current_iter != -1:
321
+ state = {
322
+ 'epoch': epoch,
323
+ 'iter': current_iter,
324
+ 'optimizers': [],
325
+ 'schedulers': []
326
+ }
327
+ for o in self.optimizers:
328
+ state['optimizers'].append(o.state_dict())
329
+ for s in self.schedulers:
330
+ state['schedulers'].append(s.state_dict())
331
+ save_filename = f'{current_iter}.state'
332
+ save_path = os.path.join(self.opt['path']['training_states'],
333
+ save_filename)
334
+ torch.save(state, save_path)
335
+
336
+ def resume_training(self, resume_state):
337
+ """Reload the optimizers and schedulers for resumed training.
338
+
339
+ Args:
340
+ resume_state (dict): Resume state.
341
+ """
342
+ resume_optimizers = resume_state['optimizers']
343
+ resume_schedulers = resume_state['schedulers']
344
+ assert len(resume_optimizers) == len(
345
+ self.optimizers), 'Wrong lengths of optimizers'
346
+ assert len(resume_schedulers) == len(
347
+ self.schedulers), 'Wrong lengths of schedulers'
348
+ for i, o in enumerate(resume_optimizers):
349
+ self.optimizers[i].load_state_dict(o)
350
+ for i, s in enumerate(resume_schedulers):
351
+ self.schedulers[i].load_state_dict(s)
352
+
353
+ def reduce_loss_dict(self, loss_dict):
354
+ """reduce loss dict.
355
+
356
+ In distributed training, it averages the losses among different GPUs .
357
+
358
+ Args:
359
+ loss_dict (OrderedDict): Loss dict.
360
+ """
361
+ with torch.no_grad():
362
+ if self.opt['dist']:
363
+ keys = []
364
+ losses = []
365
+ for name, value in loss_dict.items():
366
+ keys.append(name)
367
+ losses.append(value)
368
+ losses = torch.stack(losses, 0)
369
+ torch.distributed.reduce(losses, dst=0)
370
+ if self.opt['rank'] == 0:
371
+ losses /= self.opt['world_size']
372
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
373
+
374
+ log_dict = OrderedDict()
375
+ for name, value in loss_dict.items():
376
+ log_dict[name] = value.mean().item()
377
+
378
+ return log_dict
basicsr/models/image_restoration_model.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ from collections import OrderedDict
4
+ from copy import deepcopy
5
+ from os import path as osp
6
+ from tqdm import tqdm
7
+
8
+ from basicsr.models.archs import define_network
9
+ from basicsr.models.base_model import BaseModel
10
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
11
+
12
+ loss_module = importlib.import_module('basicsr.models.losses')
13
+ metric_module = importlib.import_module('basicsr.metrics')
14
+
15
+ import os
16
+ import random
17
+ import numpy as np
18
+ import cv2
19
+ import torch.nn.functional as F
20
+ from functools import partial
21
+
22
+ class Mixing_Augment:
23
+ def __init__(self, mixup_beta, use_identity, device):
24
+ self.dist = torch.distributions.beta.Beta(torch.tensor([mixup_beta]), torch.tensor([mixup_beta]))
25
+ self.device = device
26
+
27
+ self.use_identity = use_identity
28
+
29
+ self.augments = [self.mixup]
30
+
31
+ def mixup(self, target, input_):
32
+ lam = self.dist.rsample((1,1)).item()
33
+
34
+ r_index = torch.randperm(target.size(0)).to(self.device)
35
+
36
+ target = lam * target + (1-lam) * target[r_index, :]
37
+ input_ = lam * input_ + (1-lam) * input_[r_index, :]
38
+
39
+ return target, input_
40
+
41
+ def __call__(self, target, input_):
42
+ if self.use_identity:
43
+ augment = random.randint(0, len(self.augments))
44
+ if augment < len(self.augments):
45
+ target, input_ = self.augments[augment](target, input_)
46
+ else:
47
+ augment = random.randint(0, len(self.augments)-1)
48
+ target, input_ = self.augments[augment](target, input_)
49
+ return target, input_
50
+
51
+ class ImageCleanModel(BaseModel):
52
+ """Base Deblur model for single image deblur."""
53
+
54
+ def __init__(self, opt):
55
+ super(ImageCleanModel, self).__init__(opt)
56
+
57
+ # define network
58
+
59
+ self.mixing_flag = self.opt['train']['mixing_augs'].get('mixup', False)
60
+ if self.mixing_flag:
61
+ mixup_beta = self.opt['train']['mixing_augs'].get('mixup_beta', 1.2)
62
+ use_identity = self.opt['train']['mixing_augs'].get('use_identity', False)
63
+ self.mixing_augmentation = Mixing_Augment(mixup_beta, use_identity, self.device)
64
+
65
+ self.net_g = define_network(deepcopy(opt['network_g']))
66
+ self.net_g = self.model_to_device(self.net_g)
67
+ self.print_network(self.net_g)
68
+
69
+ # load pretrained models
70
+ load_path = self.opt['path'].get('pretrain_network_g', None)
71
+ if load_path is not None:
72
+ self.load_network(self.net_g, load_path,
73
+ self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params'))
74
+
75
+ if self.is_train:
76
+ self.init_training_settings()
77
+
78
+ def init_training_settings(self):
79
+ self.net_g.train()
80
+ train_opt = self.opt['train']
81
+
82
+ self.ema_decay = train_opt.get('ema_decay', 0)
83
+ if self.ema_decay > 0:
84
+ logger = get_root_logger()
85
+ logger.info(
86
+ f'Use Exponential Moving Average with decay: {self.ema_decay}')
87
+ # define network net_g with Exponential Moving Average (EMA)
88
+ # net_g_ema is used only for testing on one GPU and saving
89
+ # There is no need to wrap with DistributedDataParallel
90
+ self.net_g_ema = define_network(self.opt['network_g']).to(
91
+ self.device)
92
+ # load pretrained model
93
+ load_path = self.opt['path'].get('pretrain_network_g', None)
94
+ if load_path is not None:
95
+ self.load_network(self.net_g_ema, load_path,
96
+ self.opt['path'].get('strict_load_g',
97
+ True), 'params_ema')
98
+ else:
99
+ self.model_ema(0) # copy net_g weight
100
+ self.net_g_ema.eval()
101
+
102
+ # define losses
103
+ if train_opt.get('pixel_opt'):
104
+ pixel_type = train_opt['pixel_opt'].pop('type')
105
+ cri_pix_cls = getattr(loss_module, pixel_type)
106
+ self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
107
+ self.device)
108
+ else:
109
+ raise ValueError('pixel loss are None.')
110
+
111
+ if train_opt.get('fft_loss_opt'):
112
+ fft_type = train_opt['fft_loss_opt'].pop('type')
113
+ cri_fft_cls = getattr(loss_module, fft_type)
114
+ self.cri_fft = cri_fft_cls(**train_opt['fft_loss_opt']).to(
115
+ self.device)
116
+
117
+ else:
118
+ self.cri_fft = None
119
+
120
+ # set up optimizers and schedulers
121
+ self.setup_optimizers()
122
+ self.setup_schedulers()
123
+
124
+ def setup_optimizers(self):
125
+ train_opt = self.opt['train']
126
+ optim_params = []
127
+
128
+ for k, v in self.net_g.named_parameters():
129
+ if v.requires_grad:
130
+ optim_params.append(v)
131
+ else:
132
+ logger = get_root_logger()
133
+ logger.warning(f'Params {k} will not be optimized.')
134
+
135
+ optim_type = train_opt['optim_g'].pop('type')
136
+ if optim_type == 'Adam':
137
+ self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g'])
138
+ elif optim_type == 'AdamW':
139
+ self.optimizer_g = torch.optim.AdamW(optim_params, **train_opt['optim_g'])
140
+ else:
141
+ raise NotImplementedError(
142
+ f'optimizer {optim_type} is not supperted yet.')
143
+ self.optimizers.append(self.optimizer_g)
144
+
145
+ def feed_train_data(self, data):
146
+ self.lq = data['lq'].to(self.device)
147
+ if 'gt' in data:
148
+ self.gt = data['gt'].to(self.device)
149
+
150
+ if self.mixing_flag:
151
+ self.gt, self.lq = self.mixing_augmentation(self.gt, self.lq)
152
+
153
+ def feed_data(self, data):
154
+ self.lq = data['lq'].to(self.device)
155
+ if 'gt' in data:
156
+ self.gt = data['gt'].to(self.device)
157
+
158
+ def optimize_parameters(self, current_iter):
159
+ self.optimizer_g.zero_grad()
160
+ preds = self.net_g(self.lq)
161
+ if not isinstance(preds, list):
162
+ preds = [preds]
163
+
164
+ self.output = preds[-1]
165
+
166
+ # loss_dict = OrderedDict()
167
+ # # pixel loss
168
+ # l_pix = 0.
169
+ # for pred in preds:
170
+ # l_pix += self.cri_pix(pred, self.gt)
171
+
172
+ # loss_dict['l_pix'] = l_pix
173
+
174
+ # l_pix.backward()
175
+ l_total = 0
176
+ loss_dict = OrderedDict()
177
+ # pixel loss
178
+ if self.cri_pix:
179
+ l_pix = 0.
180
+ for pred in preds:
181
+ l_pix += self.cri_pix(pred, self.gt)
182
+
183
+ # print('l pix ... ', l_pix)
184
+ l_total += l_pix
185
+ loss_dict['l_pix'] = l_pix
186
+
187
+ # fft loss
188
+ if self.cri_fft:
189
+ l_fft = self.cri_fft(preds[-1], self.gt)
190
+ l_total += l_fft
191
+ loss_dict['l_fft'] = l_fft
192
+
193
+ l_total = l_total + 0. * sum(p.sum() for p in self.net_g.parameters())
194
+
195
+ l_total = l_total
196
+
197
+ l_total.backward()
198
+
199
+
200
+ if self.opt['train']['use_grad_clip']:
201
+ torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01)
202
+ self.optimizer_g.step()
203
+
204
+ self.log_dict = self.reduce_loss_dict(loss_dict)
205
+
206
+ if self.ema_decay > 0:
207
+ self.model_ema(decay=self.ema_decay)
208
+
209
+ def pad_test(self, window_size):
210
+ scale = self.opt.get('scale', 1)
211
+ mod_pad_h, mod_pad_w = 0, 0
212
+ _, _, h, w = self.lq.size()
213
+ if h % window_size != 0:
214
+ mod_pad_h = window_size - h % window_size
215
+ if w % window_size != 0:
216
+ mod_pad_w = window_size - w % window_size
217
+ img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
218
+ self.nonpad_test(img)
219
+ _, _, h, w = self.output.size()
220
+ self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
221
+
222
+ def nonpad_test(self, img=None):
223
+ if img is None:
224
+ img = self.lq
225
+ if hasattr(self, 'net_g_ema'):
226
+ self.net_g_ema.eval()
227
+ with torch.no_grad():
228
+ pred = self.net_g_ema(img)
229
+ if isinstance(pred, list):
230
+ pred = pred[-1]
231
+ self.output = pred
232
+ else:
233
+ self.net_g.eval()
234
+ with torch.no_grad():
235
+ pred = self.net_g(img)
236
+ if isinstance(pred, list):
237
+ pred = pred[-1]
238
+ self.output = pred
239
+ self.net_g.train()
240
+
241
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image):
242
+ if os.environ['LOCAL_RANK'] == '0':
243
+ return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
244
+ else:
245
+ return 0.
246
+
247
+ def nondist_validation(self, dataloader, current_iter, tb_logger,
248
+ save_img, rgb2bgr, use_image):
249
+ dataset_name = dataloader.dataset.opt['name']
250
+ with_metrics = self.opt['val'].get('metrics') is not None
251
+ if with_metrics:
252
+ self.metric_results = {
253
+ metric: 0
254
+ for metric in self.opt['val']['metrics'].keys()
255
+ }
256
+ # pbar = tqdm(total=len(dataloader), unit='image')
257
+
258
+ window_size = self.opt['val'].get('window_size', 0)
259
+
260
+ if window_size:
261
+ test = partial(self.pad_test, window_size)
262
+ else:
263
+ test = self.nonpad_test
264
+
265
+ cnt = 0
266
+
267
+ for idx, val_data in enumerate(dataloader):
268
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
269
+
270
+ self.feed_data(val_data)
271
+ test()
272
+
273
+ visuals = self.get_current_visuals()
274
+ sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
275
+ if 'gt' in visuals:
276
+ gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
277
+ del self.gt
278
+
279
+ # tentative for out of GPU memory
280
+ del self.lq
281
+ del self.output
282
+ torch.cuda.empty_cache()
283
+
284
+ if save_img:
285
+
286
+ if self.opt['is_train']:
287
+
288
+ save_img_path = osp.join(self.opt['path']['visualization'],
289
+ img_name,
290
+ f'{img_name}_{current_iter}.png')
291
+
292
+ save_gt_img_path = osp.join(self.opt['path']['visualization'],
293
+ img_name,
294
+ f'{img_name}_{current_iter}_gt.png')
295
+ else:
296
+
297
+ save_img_path = osp.join(
298
+ self.opt['path']['visualization'], dataset_name,
299
+ f'{img_name}.png')
300
+ save_gt_img_path = osp.join(
301
+ self.opt['path']['visualization'], dataset_name,
302
+ f'{img_name}_gt.png')
303
+
304
+ imwrite(sr_img, save_img_path)
305
+ imwrite(gt_img, save_gt_img_path)
306
+
307
+ if with_metrics:
308
+ # calculate metrics
309
+ opt_metric = deepcopy(self.opt['val']['metrics'])
310
+ if use_image:
311
+ for name, opt_ in opt_metric.items():
312
+ metric_type = opt_.pop('type')
313
+ self.metric_results[name] += getattr(
314
+ metric_module, metric_type)(sr_img, gt_img, **opt_)
315
+ else:
316
+ for name, opt_ in opt_metric.items():
317
+ metric_type = opt_.pop('type')
318
+ self.metric_results[name] += getattr(
319
+ metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_)
320
+
321
+ cnt += 1
322
+
323
+ current_metric = 0.
324
+ if with_metrics:
325
+ for metric in self.metric_results.keys():
326
+ self.metric_results[metric] /= cnt
327
+ current_metric = self.metric_results[metric]
328
+
329
+ self._log_validation_metric_values(current_iter, dataset_name,
330
+ tb_logger)
331
+ return current_metric
332
+
333
+
334
+ def _log_validation_metric_values(self, current_iter, dataset_name,
335
+ tb_logger):
336
+ log_str = f'Validation {dataset_name},\t'
337
+ for metric, value in self.metric_results.items():
338
+ log_str += f'\t # {metric}: {value:.4f}'
339
+ logger = get_root_logger()
340
+ logger.info(log_str)
341
+ if tb_logger:
342
+ for metric, value in self.metric_results.items():
343
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
344
+
345
+ def get_current_visuals(self):
346
+ out_dict = OrderedDict()
347
+ out_dict['lq'] = self.lq.detach().cpu()
348
+ out_dict['result'] = self.output.detach().cpu()
349
+ if hasattr(self, 'gt'):
350
+ out_dict['gt'] = self.gt.detach().cpu()
351
+ return out_dict
352
+
353
+ def save(self, epoch, current_iter):
354
+ if self.ema_decay > 0:
355
+ self.save_network([self.net_g, self.net_g_ema],
356
+ 'net_g',
357
+ current_iter,
358
+ param_key=['params', 'params_ema'])
359
+ else:
360
+ self.save_network(self.net_g, 'net_g', current_iter)
361
+ self.save_training_state(epoch, current_iter)