dxcanh commited on
Commit
ce460c0
Β·
verified Β·
1 Parent(s): 0eeaa0a

Upload 14 files

Browse files
basicsr/data/__init__.py CHANGED
@@ -1,101 +1,101 @@
1
- import importlib
2
- import numpy as np
3
- import random
4
- import torch
5
- import torch.utils.data
6
- from copy import deepcopy
7
- from functools import partial
8
- from os import path as osp
9
-
10
- from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
- from basicsr.utils import get_root_logger, scandir
12
- from basicsr.utils.dist_util import get_dist_info
13
- from basicsr.utils.registry import DATASET_REGISTRY
14
-
15
- __all__ = ['build_dataset', 'build_dataloader']
16
-
17
- # automatically scan and import dataset modules for registry
18
- # scan all the files under the data folder with '_dataset' in file names
19
- data_folder = osp.dirname(osp.abspath(__file__))
20
- dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
- # import all the dataset modules
22
- _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
-
24
-
25
- def build_dataset(dataset_opt):
26
- """Build dataset from options.
27
-
28
- Args:
29
- dataset_opt (dict): Configuration for dataset. It must contain:
30
- name (str): Dataset name.
31
- type (str): Dataset type.
32
- """
33
- dataset_opt = deepcopy(dataset_opt)
34
- dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
- logger = get_root_logger()
36
- logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37
- return dataset
38
-
39
-
40
- def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
- """Build dataloader.
42
-
43
- Args:
44
- dataset (torch.utils.data.Dataset): Dataset.
45
- dataset_opt (dict): Dataset options. It contains the following keys:
46
- phase (str): 'train' or 'val'.
47
- num_worker_per_gpu (int): Number of workers for each GPU.
48
- batch_size_per_gpu (int): Training batch size for each GPU.
49
- num_gpu (int): Number of GPUs. Used only in the train phase.
50
- Default: 1.
51
- dist (bool): Whether in distributed training. Used only in the train
52
- phase. Default: False.
53
- sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
- seed (int | None): Seed. Default: None
55
- """
56
- phase = dataset_opt['phase']
57
- rank, _ = get_dist_info()
58
- if phase == 'train':
59
- if dist: # distributed training
60
- batch_size = dataset_opt['batch_size_per_gpu']
61
- num_workers = dataset_opt['num_worker_per_gpu']
62
- else: # non-distributed training
63
- multiplier = 1 if num_gpu == 0 else num_gpu
64
- batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
- num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
- dataloader_args = dict(
67
- dataset=dataset,
68
- batch_size=batch_size,
69
- shuffle=False,
70
- num_workers=num_workers,
71
- sampler=sampler,
72
- drop_last=True)
73
- if sampler is None:
74
- dataloader_args['shuffle'] = True
75
- dataloader_args['worker_init_fn'] = partial(
76
- worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
- elif phase in ['val', 'test']: # validation
78
- dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
- else:
80
- raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81
-
82
- dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
- dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
-
85
- prefetch_mode = dataset_opt.get('prefetch_mode')
86
- if prefetch_mode == 'cpu': # CPUPrefetcher
87
- num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
- logger = get_root_logger()
89
- logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
- return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
- else:
92
- # prefetch_mode=None: Normal dataloader
93
- # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
- return torch.utils.data.DataLoader(**dataloader_args)
95
-
96
-
97
- def worker_init_fn(worker_id, num_workers, rank, seed):
98
- # Set the worker seed to num_workers * rank + worker_id + seed
99
- worker_seed = num_workers * rank + worker_id + seed
100
- np.random.seed(worker_seed)
101
- random.seed(worker_seed)
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must contain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
+
85
+ prefetch_mode = dataset_opt.get('prefetch_mode')
86
+ if prefetch_mode == 'cpu': # CPUPrefetcher
87
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
+ logger = get_root_logger()
89
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
+ else:
92
+ # prefetch_mode=None: Normal dataloader
93
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
+ return torch.utils.data.DataLoader(**dataloader_args)
95
+
96
+
97
+ def worker_init_fn(worker_id, num_workers, rank, seed):
98
+ # Set the worker seed to num_workers * rank + worker_id + seed
99
+ worker_seed = num_workers * rank + worker_id + seed
100
+ np.random.seed(worker_seed)
101
+ random.seed(worker_seed)
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.data.transforms import mod_crop
8
+ from basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+ return_imgname(bool): Whether return image names. Default False.
20
+
21
+ Returns:
22
+ Tensor: size (t, c, h, w), RGB, [0, 1].
23
+ list[str]: Returned image name list.
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+
31
+ if require_mod_crop:
32
+ imgs = [mod_crop(img, scale) for img in imgs]
33
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
+ imgs = torch.stack(imgs, dim=0)
35
+
36
+ if return_imgname:
37
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
+ return imgs, imgnames
39
+ else:
40
+ return imgs
41
+
42
+
43
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
+ """Generate an index list for reading `num_frames` frames from a sequence
45
+ of images.
46
+
47
+ Args:
48
+ crt_idx (int): Current center index.
49
+ max_frame_num (int): Max number of the sequence of images (from 1).
50
+ num_frames (int): Reading num_frames frames.
51
+ padding (str): Padding mode, one of
52
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
+ Examples: current_idx = 0, num_frames = 5
54
+ The generated frame indices under different padding mode:
55
+ replicate: [0, 0, 0, 1, 2]
56
+ reflection: [2, 1, 0, 1, 2]
57
+ reflection_circle: [4, 3, 0, 1, 2]
58
+ circle: [3, 4, 0, 1, 2]
59
+
60
+ Returns:
61
+ list[int]: A list of indices.
62
+ """
63
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65
+
66
+ max_frame_num = max_frame_num - 1 # start from 0
67
+ num_pad = num_frames // 2
68
+
69
+ indices = []
70
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
+ if i < 0:
72
+ if padding == 'replicate':
73
+ pad_idx = 0
74
+ elif padding == 'reflection':
75
+ pad_idx = -i
76
+ elif padding == 'reflection_circle':
77
+ pad_idx = crt_idx + num_pad - i
78
+ else:
79
+ pad_idx = num_frames + i
80
+ elif i > max_frame_num:
81
+ if padding == 'replicate':
82
+ pad_idx = max_frame_num
83
+ elif padding == 'reflection':
84
+ pad_idx = max_frame_num * 2 - i
85
+ elif padding == 'reflection_circle':
86
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
+ else:
88
+ pad_idx = i - num_frames
89
+ else:
90
+ pad_idx = i
91
+ indices.append(pad_idx)
92
+ return indices
93
+
94
+
95
+ def paired_paths_from_lmdb(folders, keys):
96
+ """Generate paired paths from lmdb files.
97
+
98
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99
+
100
+ ::
101
+
102
+ lq.lmdb
103
+ β”œβ”€β”€ data.mdb
104
+ β”œβ”€β”€ lock.mdb
105
+ β”œβ”€β”€ meta_info.txt
106
+
107
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
108
+ https://lmdb.readthedocs.io/en/release/ for more details.
109
+
110
+ The meta_info.txt is a specified txt file to record the meta information
111
+ of our datasets. It will be automatically created when preparing
112
+ datasets by our provided dataset tools.
113
+ Each line in the txt file records
114
+ 1)image name (with extension),
115
+ 2)image shape,
116
+ 3)compression level, separated by a white space.
117
+ Example: `baboon.png (120,125,3) 1`
118
+
119
+ We use the image name without extension as the lmdb key.
120
+ Note that we use the same key for the corresponding lq and gt images.
121
+
122
+ Args:
123
+ folders (list[str]): A list of folder path. The order of list should
124
+ be [input_folder, gt_folder].
125
+ keys (list[str]): A list of keys identifying folders. The order should
126
+ be in consistent with folders, e.g., ['lq', 'gt'].
127
+ Note that this key is different from lmdb keys.
128
+
129
+ Returns:
130
+ list[str]: Returned path list.
131
+ """
132
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
133
+ f'But got {len(folders)}')
134
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
135
+ input_folder, gt_folder = folders
136
+ input_key, gt_key = keys
137
+
138
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
139
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
140
+ f'formats. But received {input_key}: {input_folder}; '
141
+ f'{gt_key}: {gt_folder}')
142
+ # ensure that the two meta_info files are the same
143
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
144
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
145
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
146
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
147
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
148
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
149
+ else:
150
+ paths = []
151
+ for lmdb_key in sorted(input_lmdb_keys):
152
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
153
+ return paths
154
+
155
+
156
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
157
+ """Generate paired paths from an meta information file.
158
+
159
+ Each line in the meta information file contains the image names and
160
+ image shape (usually for gt), separated by a white space.
161
+
162
+ Example of an meta information file:
163
+ ```
164
+ 0001_s001.png (480,480,3)
165
+ 0001_s002.png (480,480,3)
166
+ ```
167
+
168
+ Args:
169
+ folders (list[str]): A list of folder path. The order of list should
170
+ be [input_folder, gt_folder].
171
+ keys (list[str]): A list of keys identifying folders. The order should
172
+ be in consistent with folders, e.g., ['lq', 'gt'].
173
+ meta_info_file (str): Path to the meta information file.
174
+ filename_tmpl (str): Template for each filename. Note that the
175
+ template excludes the file extension. Usually the filename_tmpl is
176
+ for files in the input folder.
177
+
178
+ Returns:
179
+ list[str]: Returned path list.
180
+ """
181
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
182
+ f'But got {len(folders)}')
183
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
184
+ input_folder, gt_folder = folders
185
+ input_key, gt_key = keys
186
+
187
+ with open(meta_info_file, 'r') as fin:
188
+ gt_names = [line.strip().split(' ')[0] for line in fin]
189
+
190
+ paths = []
191
+ for gt_name in gt_names:
192
+ basename, ext = osp.splitext(osp.basename(gt_name))
193
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
194
+ input_path = osp.join(input_folder, input_name)
195
+ gt_path = osp.join(gt_folder, gt_name)
196
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
197
+ return paths
198
+
199
+
200
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
201
+ """Generate paired paths from folders.
202
+
203
+ Args:
204
+ folders (list[str]): A list of folder path. The order of list should
205
+ be [input_folder, gt_folder].
206
+ keys (list[str]): A list of keys identifying folders. The order should
207
+ be in consistent with folders, e.g., ['lq', 'gt'].
208
+ filename_tmpl (str): Template for each filename. Note that the
209
+ template excludes the file extension. Usually the filename_tmpl is
210
+ for files in the input folder.
211
+
212
+ Returns:
213
+ list[str]: Returned path list.
214
+ """
215
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
216
+ f'But got {len(folders)}')
217
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
218
+ input_folder, gt_folder = folders
219
+ input_key, gt_key = keys
220
+
221
+ input_paths = list(scandir(input_folder))
222
+ gt_paths = list(scandir(gt_folder))
223
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
224
+ f'{len(input_paths)}, {len(gt_paths)}.')
225
+ paths = []
226
+ for gt_path in gt_paths:
227
+ basename, ext = osp.splitext(osp.basename(gt_path))
228
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
229
+ input_path = osp.join(input_folder, input_name)
230
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
231
+ gt_path = osp.join(gt_folder, gt_path)
232
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
233
+ return paths
234
+
235
+
236
+ def paths_from_folder(folder):
237
+ """Generate paths from folder.
238
+
239
+ Args:
240
+ folder (str): Folder path.
241
+
242
+ Returns:
243
+ list[str]: Returned path list.
244
+ """
245
+
246
+ paths = list(scandir(folder))
247
+ paths = [osp.join(folder, path) for path in paths]
248
+ return paths
249
+
250
+
251
+ def paths_from_lmdb(folder):
252
+ """Generate paths from lmdb.
253
+
254
+ Args:
255
+ folder (str): Folder path.
256
+
257
+ Returns:
258
+ list[str]: Returned path list.
259
+ """
260
+ if not folder.endswith('.lmdb'):
261
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
262
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
263
+ paths = [line.split('.')[0] for line in fin]
264
+ return paths
265
+
266
+
267
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
268
+ """Generate Gaussian kernel used in `duf_downsample`.
269
+
270
+ Args:
271
+ kernel_size (int): Kernel size. Default: 13.
272
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
273
+
274
+ Returns:
275
+ np.array: The Gaussian kernel.
276
+ """
277
+ from scipy.ndimage import filters as filters
278
+ kernel = np.zeros((kernel_size, kernel_size))
279
+ # set element at the middle to one, a dirac delta
280
+ kernel[kernel_size // 2, kernel_size // 2] = 1
281
+ # gaussian-smooth the dirac, resulting in a gaussian filter
282
+ return filters.gaussian_filter(kernel, sigma)
283
+
284
+
285
+ def duf_downsample(x, kernel_size=13, scale=4):
286
+ """Downsamping with Gaussian kernel used in the DUF official code.
287
+
288
+ Args:
289
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
290
+ kernel_size (int): Kernel size. Default: 13.
291
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
292
+ Default: 4.
293
+
294
+ Returns:
295
+ Tensor: DUF downsampled frames.
296
+ """
297
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
298
+
299
+ squeeze_flag = False
300
+ if x.ndim == 4:
301
+ squeeze_flag = True
302
+ x = x.unsqueeze(0)
303
+ b, t, c, h, w = x.size()
304
+ x = x.view(-1, 1, h, w)
305
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
306
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
307
+
308
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
309
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
310
+ x = F.conv2d(x, gaussian_filter, stride=scale)
311
+ x = x[:, :, 2:-2, 2:-2]
312
+ x = x.view(b, t, c, x.size(2), x.size(3))
313
+ if squeeze_flag:
314
+ x = x.squeeze(0)
315
+ return x
basicsr/data/degradations.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from scipy import special
7
+ from scipy.stats import multivariate_normal
8
+ from torchvision.transforms.functional import rgb_to_grayscale
9
+
10
+ # -------------------------------------------------------------------- #
11
+ # --------------------------- blur kernels --------------------------- #
12
+ # -------------------------------------------------------------------- #
13
+
14
+
15
+ # --------------------------- util functions --------------------------- #
16
+ def sigma_matrix2(sig_x, sig_y, theta):
17
+ """Calculate the rotated sigma matrix (two dimensional matrix).
18
+
19
+ Args:
20
+ sig_x (float):
21
+ sig_y (float):
22
+ theta (float): Radian measurement.
23
+
24
+ Returns:
25
+ ndarray: Rotated sigma matrix.
26
+ """
27
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
28
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
29
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
30
+
31
+
32
+ def mesh_grid(kernel_size):
33
+ """Generate the mesh grid, centering at zero.
34
+
35
+ Args:
36
+ kernel_size (int):
37
+
38
+ Returns:
39
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
40
+ xx (ndarray): with the shape (kernel_size, kernel_size)
41
+ yy (ndarray): with the shape (kernel_size, kernel_size)
42
+ """
43
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
44
+ xx, yy = np.meshgrid(ax, ax)
45
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
46
+ 1))).reshape(kernel_size, kernel_size, 2)
47
+ return xy, xx, yy
48
+
49
+
50
+ def pdf2(sigma_matrix, grid):
51
+ """Calculate PDF of the bivariate Gaussian distribution.
52
+
53
+ Args:
54
+ sigma_matrix (ndarray): with the shape (2, 2)
55
+ grid (ndarray): generated by :func:`mesh_grid`,
56
+ with the shape (K, K, 2), K is the kernel size.
57
+
58
+ Returns:
59
+ kernel (ndarrray): un-normalized kernel.
60
+ """
61
+ inverse_sigma = np.linalg.inv(sigma_matrix)
62
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
63
+ return kernel
64
+
65
+
66
+ def cdf2(d_matrix, grid):
67
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
68
+ Used in skewed Gaussian distribution.
69
+
70
+ Args:
71
+ d_matrix (ndarrasy): skew matrix.
72
+ grid (ndarray): generated by :func:`mesh_grid`,
73
+ with the shape (K, K, 2), K is the kernel size.
74
+
75
+ Returns:
76
+ cdf (ndarray): skewed cdf.
77
+ """
78
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
79
+ grid = np.dot(grid, d_matrix)
80
+ cdf = rv.cdf(grid)
81
+ return cdf
82
+
83
+
84
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
85
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
86
+
87
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
88
+
89
+ Args:
90
+ kernel_size (int):
91
+ sig_x (float):
92
+ sig_y (float):
93
+ theta (float): Radian measurement.
94
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
95
+ with the shape (K, K, 2), K is the kernel size. Default: None
96
+ isotropic (bool):
97
+
98
+ Returns:
99
+ kernel (ndarray): normalized kernel.
100
+ """
101
+ if grid is None:
102
+ grid, _, _ = mesh_grid(kernel_size)
103
+ if isotropic:
104
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
105
+ else:
106
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
107
+ kernel = pdf2(sigma_matrix, grid)
108
+ kernel = kernel / np.sum(kernel)
109
+ return kernel
110
+
111
+
112
+ def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
113
+ """Generate a bivariate generalized Gaussian kernel.
114
+
115
+ ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
116
+
117
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
118
+
119
+ Args:
120
+ kernel_size (int):
121
+ sig_x (float):
122
+ sig_y (float):
123
+ theta (float): Radian measurement.
124
+ beta (float): shape parameter, beta = 1 is the normal distribution.
125
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
126
+ with the shape (K, K, 2), K is the kernel size. Default: None
127
+
128
+ Returns:
129
+ kernel (ndarray): normalized kernel.
130
+ """
131
+ if grid is None:
132
+ grid, _, _ = mesh_grid(kernel_size)
133
+ if isotropic:
134
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
135
+ else:
136
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
137
+ inverse_sigma = np.linalg.inv(sigma_matrix)
138
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
139
+ kernel = kernel / np.sum(kernel)
140
+ return kernel
141
+
142
+
143
+ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
144
+ """Generate a plateau-like anisotropic kernel.
145
+
146
+ 1 / (1+x^(beta))
147
+
148
+ Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
149
+
150
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
151
+
152
+ Args:
153
+ kernel_size (int):
154
+ sig_x (float):
155
+ sig_y (float):
156
+ theta (float): Radian measurement.
157
+ beta (float): shape parameter, beta = 1 is the normal distribution.
158
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
159
+ with the shape (K, K, 2), K is the kernel size. Default: None
160
+
161
+ Returns:
162
+ kernel (ndarray): normalized kernel.
163
+ """
164
+ if grid is None:
165
+ grid, _, _ = mesh_grid(kernel_size)
166
+ if isotropic:
167
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
168
+ else:
169
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
170
+ inverse_sigma = np.linalg.inv(sigma_matrix)
171
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
172
+ kernel = kernel / np.sum(kernel)
173
+ return kernel
174
+
175
+
176
+ def random_bivariate_Gaussian(kernel_size,
177
+ sigma_x_range,
178
+ sigma_y_range,
179
+ rotation_range,
180
+ noise_range=None,
181
+ isotropic=True):
182
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
183
+
184
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
185
+
186
+ Args:
187
+ kernel_size (int):
188
+ sigma_x_range (tuple): [0.6, 5]
189
+ sigma_y_range (tuple): [0.6, 5]
190
+ rotation range (tuple): [-math.pi, math.pi]
191
+ noise_range(tuple, optional): multiplicative kernel noise,
192
+ [0.75, 1.25]. Default: None
193
+
194
+ Returns:
195
+ kernel (ndarray):
196
+ """
197
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
198
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
199
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
200
+ if isotropic is False:
201
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
202
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
203
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
204
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
205
+ else:
206
+ sigma_y = sigma_x
207
+ rotation = 0
208
+
209
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
210
+
211
+ # add multiplicative noise
212
+ if noise_range is not None:
213
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
214
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
215
+ kernel = kernel * noise
216
+ kernel = kernel / np.sum(kernel)
217
+ return kernel
218
+
219
+
220
+ def random_bivariate_generalized_Gaussian(kernel_size,
221
+ sigma_x_range,
222
+ sigma_y_range,
223
+ rotation_range,
224
+ beta_range,
225
+ noise_range=None,
226
+ isotropic=True):
227
+ """Randomly generate bivariate generalized Gaussian kernels.
228
+
229
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
230
+
231
+ Args:
232
+ kernel_size (int):
233
+ sigma_x_range (tuple): [0.6, 5]
234
+ sigma_y_range (tuple): [0.6, 5]
235
+ rotation range (tuple): [-math.pi, math.pi]
236
+ beta_range (tuple): [0.5, 8]
237
+ noise_range(tuple, optional): multiplicative kernel noise,
238
+ [0.75, 1.25]. Default: None
239
+
240
+ Returns:
241
+ kernel (ndarray):
242
+ """
243
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
244
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
245
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
246
+ if isotropic is False:
247
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
248
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
249
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
250
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
251
+ else:
252
+ sigma_y = sigma_x
253
+ rotation = 0
254
+
255
+ # assume beta_range[0] < 1 < beta_range[1]
256
+ if np.random.uniform() < 0.5:
257
+ beta = np.random.uniform(beta_range[0], 1)
258
+ else:
259
+ beta = np.random.uniform(1, beta_range[1])
260
+
261
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
262
+
263
+ # add multiplicative noise
264
+ if noise_range is not None:
265
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
266
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
267
+ kernel = kernel * noise
268
+ kernel = kernel / np.sum(kernel)
269
+ return kernel
270
+
271
+
272
+ def random_bivariate_plateau(kernel_size,
273
+ sigma_x_range,
274
+ sigma_y_range,
275
+ rotation_range,
276
+ beta_range,
277
+ noise_range=None,
278
+ isotropic=True):
279
+ """Randomly generate bivariate plateau kernels.
280
+
281
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
282
+
283
+ Args:
284
+ kernel_size (int):
285
+ sigma_x_range (tuple): [0.6, 5]
286
+ sigma_y_range (tuple): [0.6, 5]
287
+ rotation range (tuple): [-math.pi/2, math.pi/2]
288
+ beta_range (tuple): [1, 4]
289
+ noise_range(tuple, optional): multiplicative kernel noise,
290
+ [0.75, 1.25]. Default: None
291
+
292
+ Returns:
293
+ kernel (ndarray):
294
+ """
295
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
296
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
297
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
298
+ if isotropic is False:
299
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
300
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
301
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
302
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
303
+ else:
304
+ sigma_y = sigma_x
305
+ rotation = 0
306
+
307
+ # TODO: this may be not proper
308
+ if np.random.uniform() < 0.5:
309
+ beta = np.random.uniform(beta_range[0], 1)
310
+ else:
311
+ beta = np.random.uniform(1, beta_range[1])
312
+
313
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
314
+ # add multiplicative noise
315
+ if noise_range is not None:
316
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
317
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
318
+ kernel = kernel * noise
319
+ kernel = kernel / np.sum(kernel)
320
+
321
+ return kernel
322
+
323
+
324
+ def random_mixed_kernels(kernel_list,
325
+ kernel_prob,
326
+ kernel_size=21,
327
+ sigma_x_range=(0.6, 5),
328
+ sigma_y_range=(0.6, 5),
329
+ rotation_range=(-math.pi, math.pi),
330
+ betag_range=(0.5, 8),
331
+ betap_range=(0.5, 8),
332
+ noise_range=None):
333
+ """Randomly generate mixed kernels.
334
+
335
+ Args:
336
+ kernel_list (tuple): a list name of kernel types,
337
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
338
+ 'plateau_aniso']
339
+ kernel_prob (tuple): corresponding kernel probability for each
340
+ kernel type
341
+ kernel_size (int):
342
+ sigma_x_range (tuple): [0.6, 5]
343
+ sigma_y_range (tuple): [0.6, 5]
344
+ rotation range (tuple): [-math.pi, math.pi]
345
+ beta_range (tuple): [0.5, 8]
346
+ noise_range(tuple, optional): multiplicative kernel noise,
347
+ [0.75, 1.25]. Default: None
348
+
349
+ Returns:
350
+ kernel (ndarray):
351
+ """
352
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
353
+ if kernel_type == 'iso':
354
+ kernel = random_bivariate_Gaussian(
355
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
356
+ elif kernel_type == 'aniso':
357
+ kernel = random_bivariate_Gaussian(
358
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
359
+ elif kernel_type == 'generalized_iso':
360
+ kernel = random_bivariate_generalized_Gaussian(
361
+ kernel_size,
362
+ sigma_x_range,
363
+ sigma_y_range,
364
+ rotation_range,
365
+ betag_range,
366
+ noise_range=noise_range,
367
+ isotropic=True)
368
+ elif kernel_type == 'generalized_aniso':
369
+ kernel = random_bivariate_generalized_Gaussian(
370
+ kernel_size,
371
+ sigma_x_range,
372
+ sigma_y_range,
373
+ rotation_range,
374
+ betag_range,
375
+ noise_range=noise_range,
376
+ isotropic=False)
377
+ elif kernel_type == 'plateau_iso':
378
+ kernel = random_bivariate_plateau(
379
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
380
+ elif kernel_type == 'plateau_aniso':
381
+ kernel = random_bivariate_plateau(
382
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
383
+ return kernel
384
+
385
+
386
+ np.seterr(divide='ignore', invalid='ignore')
387
+
388
+
389
+ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
390
+ """2D sinc filter
391
+
392
+ Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
393
+
394
+ Args:
395
+ cutoff (float): cutoff frequency in radians (pi is max)
396
+ kernel_size (int): horizontal and vertical size, must be odd.
397
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
398
+ """
399
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
400
+ kernel = np.fromfunction(
401
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
402
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
403
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
404
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
405
+ kernel = kernel / np.sum(kernel)
406
+ if pad_to > kernel_size:
407
+ pad_size = (pad_to - kernel_size) // 2
408
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
409
+ return kernel
410
+
411
+
412
+ # ------------------------------------------------------------- #
413
+ # --------------------------- noise --------------------------- #
414
+ # ------------------------------------------------------------- #
415
+
416
+ # ----------------------- Gaussian Noise ----------------------- #
417
+
418
+
419
+ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
420
+ """Generate Gaussian noise.
421
+
422
+ Args:
423
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
424
+ sigma (float): Noise scale (measured in range 255). Default: 10.
425
+
426
+ Returns:
427
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
428
+ float32.
429
+ """
430
+ if gray_noise:
431
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
432
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
433
+ else:
434
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
435
+ return noise
436
+
437
+
438
+ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
439
+ """Add Gaussian noise.
440
+
441
+ Args:
442
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
443
+ sigma (float): Noise scale (measured in range 255). Default: 10.
444
+
445
+ Returns:
446
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
447
+ float32.
448
+ """
449
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
450
+ out = img + noise
451
+ if clip and rounds:
452
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
453
+ elif clip:
454
+ out = np.clip(out, 0, 1)
455
+ elif rounds:
456
+ out = (out * 255.0).round() / 255.
457
+ return out
458
+
459
+
460
+ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
461
+ """Add Gaussian noise (PyTorch version).
462
+
463
+ Args:
464
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
465
+ scale (float | Tensor): Noise scale. Default: 1.0.
466
+
467
+ Returns:
468
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
469
+ float32.
470
+ """
471
+ b, _, h, w = img.size()
472
+ if not isinstance(sigma, (float, int)):
473
+ sigma = sigma.view(img.size(0), 1, 1, 1)
474
+ if isinstance(gray_noise, (float, int)):
475
+ cal_gray_noise = gray_noise > 0
476
+ else:
477
+ gray_noise = gray_noise.view(b, 1, 1, 1)
478
+ cal_gray_noise = torch.sum(gray_noise) > 0
479
+
480
+ if cal_gray_noise:
481
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
482
+ noise_gray = noise_gray.view(b, 1, h, w)
483
+
484
+ # always calculate color noise
485
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
486
+
487
+ if cal_gray_noise:
488
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
489
+ return noise
490
+
491
+
492
+ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
493
+ """Add Gaussian noise (PyTorch version).
494
+
495
+ Args:
496
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
497
+ scale (float | Tensor): Noise scale. Default: 1.0.
498
+
499
+ Returns:
500
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
501
+ float32.
502
+ """
503
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
504
+ out = img + noise
505
+ if clip and rounds:
506
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
507
+ elif clip:
508
+ out = torch.clamp(out, 0, 1)
509
+ elif rounds:
510
+ out = (out * 255.0).round() / 255.
511
+ return out
512
+
513
+
514
+ # ----------------------- Random Gaussian Noise ----------------------- #
515
+ def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
516
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
517
+ if np.random.uniform() < gray_prob:
518
+ gray_noise = True
519
+ else:
520
+ gray_noise = False
521
+ return generate_gaussian_noise(img, sigma, gray_noise)
522
+
523
+
524
+ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
525
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
526
+ out = img + noise
527
+ if clip and rounds:
528
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
529
+ elif clip:
530
+ out = np.clip(out, 0, 1)
531
+ elif rounds:
532
+ out = (out * 255.0).round() / 255.
533
+ return out
534
+
535
+
536
+ def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
537
+ sigma = torch.rand(
538
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
539
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
540
+ gray_noise = (gray_noise < gray_prob).float()
541
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
542
+
543
+
544
+ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
545
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
546
+ out = img + noise
547
+ if clip and rounds:
548
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
549
+ elif clip:
550
+ out = torch.clamp(out, 0, 1)
551
+ elif rounds:
552
+ out = (out * 255.0).round() / 255.
553
+ return out
554
+
555
+
556
+ # ----------------------- Poisson (Shot) Noise ----------------------- #
557
+
558
+
559
+ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
560
+ """Generate poisson noise.
561
+
562
+ Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
563
+
564
+ Args:
565
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
566
+ scale (float): Noise scale. Default: 1.0.
567
+ gray_noise (bool): Whether generate gray noise. Default: False.
568
+
569
+ Returns:
570
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
571
+ float32.
572
+ """
573
+ if gray_noise:
574
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
575
+ # round and clip image for counting vals correctly
576
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
577
+ vals = len(np.unique(img))
578
+ vals = 2**np.ceil(np.log2(vals))
579
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
580
+ noise = out - img
581
+ if gray_noise:
582
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
583
+ return noise * scale
584
+
585
+
586
+ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
587
+ """Add poisson noise.
588
+
589
+ Args:
590
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
591
+ scale (float): Noise scale. Default: 1.0.
592
+ gray_noise (bool): Whether generate gray noise. Default: False.
593
+
594
+ Returns:
595
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
596
+ float32.
597
+ """
598
+ noise = generate_poisson_noise(img, scale, gray_noise)
599
+ out = img + noise
600
+ if clip and rounds:
601
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
602
+ elif clip:
603
+ out = np.clip(out, 0, 1)
604
+ elif rounds:
605
+ out = (out * 255.0).round() / 255.
606
+ return out
607
+
608
+
609
+ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
610
+ """Generate a batch of poisson noise (PyTorch version)
611
+
612
+ Args:
613
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
614
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
615
+ Default: 1.0.
616
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
617
+ 0 for False, 1 for True. Default: 0.
618
+
619
+ Returns:
620
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
621
+ float32.
622
+ """
623
+ b, _, h, w = img.size()
624
+ if isinstance(gray_noise, (float, int)):
625
+ cal_gray_noise = gray_noise > 0
626
+ else:
627
+ gray_noise = gray_noise.view(b, 1, 1, 1)
628
+ cal_gray_noise = torch.sum(gray_noise) > 0
629
+ if cal_gray_noise:
630
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
631
+ # round and clip image for counting vals correctly
632
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
633
+ # use for-loop to get the unique values for each sample
634
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
635
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
636
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
637
+ out = torch.poisson(img_gray * vals) / vals
638
+ noise_gray = out - img_gray
639
+ noise_gray = noise_gray.expand(b, 3, h, w)
640
+
641
+ # always calculate color noise
642
+ # round and clip image for counting vals correctly
643
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
644
+ # use for-loop to get the unique values for each sample
645
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
646
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
647
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
648
+ out = torch.poisson(img * vals) / vals
649
+ noise = out - img
650
+ if cal_gray_noise:
651
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
652
+ if not isinstance(scale, (float, int)):
653
+ scale = scale.view(b, 1, 1, 1)
654
+ return noise * scale
655
+
656
+
657
+ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
658
+ """Add poisson noise to a batch of images (PyTorch version).
659
+
660
+ Args:
661
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
662
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
663
+ Default: 1.0.
664
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
665
+ 0 for False, 1 for True. Default: 0.
666
+
667
+ Returns:
668
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
669
+ float32.
670
+ """
671
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
672
+ out = img + noise
673
+ if clip and rounds:
674
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
675
+ elif clip:
676
+ out = torch.clamp(out, 0, 1)
677
+ elif rounds:
678
+ out = (out * 255.0).round() / 255.
679
+ return out
680
+
681
+
682
+ # ----------------------- Random Poisson (Shot) Noise ----------------------- #
683
+
684
+
685
+ def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
686
+ scale = np.random.uniform(scale_range[0], scale_range[1])
687
+ if np.random.uniform() < gray_prob:
688
+ gray_noise = True
689
+ else:
690
+ gray_noise = False
691
+ return generate_poisson_noise(img, scale, gray_noise)
692
+
693
+
694
+ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
695
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
696
+ out = img + noise
697
+ if clip and rounds:
698
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
699
+ elif clip:
700
+ out = np.clip(out, 0, 1)
701
+ elif rounds:
702
+ out = (out * 255.0).round() / 255.
703
+ return out
704
+
705
+
706
+ def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
707
+ scale = torch.rand(
708
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
709
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
710
+ gray_noise = (gray_noise < gray_prob).float()
711
+ return generate_poisson_noise_pt(img, scale, gray_noise)
712
+
713
+
714
+ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
715
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
716
+ out = img + noise
717
+ if clip and rounds:
718
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
719
+ elif clip:
720
+ out = torch.clamp(out, 0, 1)
721
+ elif rounds:
722
+ out = (out * 255.0).round() / 255.
723
+ return out
724
+
725
+
726
+ # ------------------------------------------------------------------------ #
727
+ # --------------------------- JPEG compression --------------------------- #
728
+ # ------------------------------------------------------------------------ #
729
+
730
+
731
+ def add_jpg_compression(img, quality=90):
732
+ """Add JPG compression artifacts.
733
+
734
+ Args:
735
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
736
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
737
+ best quality. Default: 90.
738
+
739
+ Returns:
740
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
741
+ float32.
742
+ """
743
+ img = np.clip(img, 0, 1)
744
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
745
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
746
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
747
+ return img
748
+
749
+
750
+ def random_add_jpg_compression(img, quality_range=(90, 100)):
751
+ """Randomly add JPG compression artifacts.
752
+
753
+ Args:
754
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
755
+ quality_range (tuple[float] | list[float]): JPG compression quality
756
+ range. 0 for lowest quality, 100 for best quality.
757
+ Default: (90, 100).
758
+
759
+ Returns:
760
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
761
+ float32.
762
+ """
763
+ quality = np.random.uniform(quality_range[0], quality_range[1])
764
+ return add_jpg_compression(img, quality)
basicsr/data/ffhq_dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ from os import path as osp
4
+ from torch.utils import data as data
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from basicsr.data.transforms import augment
8
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
+ from basicsr.utils.registry import DATASET_REGISTRY
10
+
11
+
12
+ @DATASET_REGISTRY.register()
13
+ class FFHQDataset(data.Dataset):
14
+ """FFHQ dataset for StyleGAN.
15
+
16
+ Args:
17
+ opt (dict): Config for train datasets. It contains the following keys:
18
+ dataroot_gt (str): Data root path for gt.
19
+ io_backend (dict): IO backend type and other kwarg.
20
+ mean (list | tuple): Image mean.
21
+ std (list | tuple): Image std.
22
+ use_hflip (bool): Whether to horizontally flip.
23
+
24
+ """
25
+
26
+ def __init__(self, opt):
27
+ super(FFHQDataset, self).__init__()
28
+ self.opt = opt
29
+ # file client (io backend)
30
+ self.file_client = None
31
+ self.io_backend_opt = opt['io_backend']
32
+
33
+ self.gt_folder = opt['dataroot_gt']
34
+ self.mean = opt['mean']
35
+ self.std = opt['std']
36
+
37
+ if self.io_backend_opt['type'] == 'lmdb':
38
+ self.io_backend_opt['db_paths'] = self.gt_folder
39
+ if not self.gt_folder.endswith('.lmdb'):
40
+ raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
41
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
42
+ self.paths = [line.split('.')[0] for line in fin]
43
+ else:
44
+ # FFHQ has 70000 images in total
45
+ self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
46
+
47
+ def __getitem__(self, index):
48
+ if self.file_client is None:
49
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50
+
51
+ # load gt image
52
+ gt_path = self.paths[index]
53
+ # avoid errors caused by high latency in reading files
54
+ retry = 3
55
+ while retry > 0:
56
+ try:
57
+ img_bytes = self.file_client.get(gt_path)
58
+ except Exception as e:
59
+ logger = get_root_logger()
60
+ logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
61
+ # change another file to read
62
+ index = random.randint(0, self.__len__())
63
+ gt_path = self.paths[index]
64
+ time.sleep(1) # sleep 1s for occasional server congestion
65
+ else:
66
+ break
67
+ finally:
68
+ retry -= 1
69
+ img_gt = imfrombytes(img_bytes, float32=True)
70
+
71
+ # random horizontal flip
72
+ img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
73
+ # BGR to RGB, HWC to CHW, numpy to tensor
74
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
75
+ # normalize
76
+ normalize(img_gt, self.mean, self.std, inplace=True)
77
+ return {'gt': img_gt, 'gt_path': gt_path}
78
+
79
+ def __len__(self):
80
+ return len(self.paths)
basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5
+ from basicsr.data.transforms import augment, paired_random_crop
6
+ from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
7
+ from basicsr.utils.registry import DATASET_REGISTRY
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class PairedImageDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15
+
16
+ There are three modes:
17
+
18
+ 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
19
+ 2. **meta_info_file**: Use meta information file to generate paths. \
20
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
21
+ 3. **folder**: Scan folders to generate paths. The rest.
22
+
23
+ Args:
24
+ opt (dict): Config for train datasets. It contains the following keys:
25
+ dataroot_gt (str): Data root path for gt.
26
+ dataroot_lq (str): Data root path for lq.
27
+ meta_info_file (str): Path for meta information file.
28
+ io_backend (dict): IO backend type and other kwarg.
29
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
30
+ Default: '{}'.
31
+ gt_size (int): Cropped patched size for gt patches.
32
+ use_hflip (bool): Use horizontal flips.
33
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
34
+ scale (bool): Scale, which will be added automatically.
35
+ phase (str): 'train' or 'val'.
36
+ """
37
+
38
+ def __init__(self, opt):
39
+ super(PairedImageDataset, self).__init__()
40
+ self.opt = opt
41
+ # file client (io backend)
42
+ self.file_client = None
43
+ self.io_backend_opt = opt['io_backend']
44
+ self.mean = opt['mean'] if 'mean' in opt else None
45
+ self.std = opt['std'] if 'std' in opt else None
46
+
47
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
48
+ if 'filename_tmpl' in opt:
49
+ self.filename_tmpl = opt['filename_tmpl']
50
+ else:
51
+ self.filename_tmpl = '{}'
52
+
53
+ if self.io_backend_opt['type'] == 'lmdb':
54
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
55
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
56
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
57
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
58
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
59
+ self.opt['meta_info_file'], self.filename_tmpl)
60
+ else:
61
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
62
+
63
+ def __getitem__(self, index):
64
+ if self.file_client is None:
65
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
66
+
67
+ scale = self.opt['scale']
68
+
69
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
70
+ # image range: [0, 1], float32.
71
+ gt_path = self.paths[index]['gt_path']
72
+ img_bytes = self.file_client.get(gt_path, 'gt')
73
+ img_gt = imfrombytes(img_bytes, float32=True)
74
+ lq_path = self.paths[index]['lq_path']
75
+ img_bytes = self.file_client.get(lq_path, 'lq')
76
+ img_lq = imfrombytes(img_bytes, float32=True)
77
+
78
+ # augmentation for training
79
+ if self.opt['phase'] == 'train':
80
+ gt_size = self.opt['gt_size']
81
+ # random crop
82
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
83
+ # flip, rotation
84
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
85
+
86
+ # color space transform
87
+ if 'color' in self.opt and self.opt['color'] == 'y':
88
+ img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
89
+ img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
90
+
91
+ # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
92
+ # TODO: It is better to update the datasets, rather than force to crop
93
+ if self.opt['phase'] != 'train':
94
+ img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
95
+
96
+ # BGR to RGB, HWC to CHW, numpy to tensor
97
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
98
+ # normalize
99
+ if self.mean is not None or self.std is not None:
100
+ normalize(img_lq, self.mean, self.std, inplace=True)
101
+ normalize(img_gt, self.mean, self.std, inplace=True)
102
+
103
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
104
+
105
+ def __len__(self):
106
+ return len(self.paths)
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
11
+
12
+ Args:
13
+ generator: Python generator.
14
+ num_prefetch_queue (int): Number of prefetch queue.
15
+ """
16
+
17
+ def __init__(self, generator, num_prefetch_queue):
18
+ threading.Thread.__init__(self)
19
+ self.queue = Queue.Queue(num_prefetch_queue)
20
+ self.generator = generator
21
+ self.daemon = True
22
+ self.start()
23
+
24
+ def run(self):
25
+ for item in self.generator:
26
+ self.queue.put(item)
27
+ self.queue.put(None)
28
+
29
+ def __next__(self):
30
+ next_item = self.queue.get()
31
+ if next_item is None:
32
+ raise StopIteration
33
+ return next_item
34
+
35
+ def __iter__(self):
36
+ return self
37
+
38
+
39
+ class PrefetchDataLoader(DataLoader):
40
+ """Prefetch version of dataloader.
41
+
42
+ Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
43
+
44
+ TODO:
45
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
46
+ ddp.
47
+
48
+ Args:
49
+ num_prefetch_queue (int): Number of prefetch queue.
50
+ kwargs (dict): Other arguments for dataloader.
51
+ """
52
+
53
+ def __init__(self, num_prefetch_queue, **kwargs):
54
+ self.num_prefetch_queue = num_prefetch_queue
55
+ super(PrefetchDataLoader, self).__init__(**kwargs)
56
+
57
+ def __iter__(self):
58
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
59
+
60
+
61
+ class CPUPrefetcher():
62
+ """CPU prefetcher.
63
+
64
+ Args:
65
+ loader: Dataloader.
66
+ """
67
+
68
+ def __init__(self, loader):
69
+ self.ori_loader = loader
70
+ self.loader = iter(loader)
71
+
72
+ def next(self):
73
+ try:
74
+ return next(self.loader)
75
+ except StopIteration:
76
+ return None
77
+
78
+ def reset(self):
79
+ self.loader = iter(self.ori_loader)
80
+
81
+
82
+ class CUDAPrefetcher():
83
+ """CUDA prefetcher.
84
+
85
+ Reference: https://github.com/NVIDIA/apex/issues/304#
86
+
87
+ It may consume more GPU memory.
88
+
89
+ Args:
90
+ loader: Dataloader.
91
+ opt (dict): Options.
92
+ """
93
+
94
+ def __init__(self, loader, opt):
95
+ self.ori_loader = loader
96
+ self.loader = iter(loader)
97
+ self.opt = opt
98
+ self.stream = torch.cuda.Stream()
99
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
100
+ self.preload()
101
+
102
+ def preload(self):
103
+ try:
104
+ self.batch = next(self.loader) # self.batch is a dict
105
+ except StopIteration:
106
+ self.batch = None
107
+ return None
108
+ # put tensors to gpu
109
+ with torch.cuda.stream(self.stream):
110
+ for k, v in self.batch.items():
111
+ if torch.is_tensor(v):
112
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
113
+
114
+ def next(self):
115
+ torch.cuda.current_stream().wait_stream(self.stream)
116
+ batch = self.batch
117
+ self.preload()
118
+ return batch
119
+
120
+ def reset(self):
121
+ self.loader = iter(self.ori_loader)
122
+ self.preload()
basicsr/data/realesrgan_dataset.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ import torch
9
+ from torch.utils import data as data
10
+
11
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
12
+ from basicsr.data.transforms import augment
13
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
14
+ from basicsr.utils.registry import DATASET_REGISTRY
15
+
16
+
17
+ @DATASET_REGISTRY.register(suffix='basicsr')
18
+ class RealESRGANDataset(data.Dataset):
19
+ """Dataset used for Real-ESRGAN model:
20
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
21
+
22
+ It loads gt (Ground-Truth) images, and augments them.
23
+ It also generates blur kernels and sinc kernels for generating low-quality images.
24
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
25
+
26
+ Args:
27
+ opt (dict): Config for train datasets. It contains the following keys:
28
+ dataroot_gt (str): Data root path for gt.
29
+ meta_info (str): Path for meta information file.
30
+ io_backend (dict): IO backend type and other kwarg.
31
+ use_hflip (bool): Use horizontal flips.
32
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
33
+ Please see more options in the codes.
34
+ """
35
+
36
+ def __init__(self, opt):
37
+ super(RealESRGANDataset, self).__init__()
38
+ self.opt = opt
39
+ self.file_client = None
40
+ self.io_backend_opt = opt['io_backend']
41
+ self.gt_folder = opt['dataroot_gt']
42
+
43
+ # file client (lmdb io backend)
44
+ if self.io_backend_opt['type'] == 'lmdb':
45
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
46
+ self.io_backend_opt['client_keys'] = ['gt']
47
+ if not self.gt_folder.endswith('.lmdb'):
48
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
49
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
50
+ self.paths = [line.split('.')[0] for line in fin]
51
+ else:
52
+ # disk backend with meta_info
53
+ # Each line in the meta_info describes the relative path to an image
54
+ with open(self.opt['meta_info']) as fin:
55
+ paths = [line.strip().split(' ')[0] for line in fin]
56
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
57
+
58
+ # blur settings for the first degradation
59
+ self.blur_kernel_size = opt['blur_kernel_size']
60
+ self.kernel_list = opt['kernel_list']
61
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
62
+ self.blur_sigma = opt['blur_sigma']
63
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
64
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
65
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
66
+
67
+ # blur settings for the second degradation
68
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
69
+ self.kernel_list2 = opt['kernel_list2']
70
+ self.kernel_prob2 = opt['kernel_prob2']
71
+ self.blur_sigma2 = opt['blur_sigma2']
72
+ self.betag_range2 = opt['betag_range2']
73
+ self.betap_range2 = opt['betap_range2']
74
+ self.sinc_prob2 = opt['sinc_prob2']
75
+
76
+ # a final sinc filter
77
+ self.final_sinc_prob = opt['final_sinc_prob']
78
+
79
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
80
+ # TODO: kernel range is now hard-coded, should be in the configure file
81
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
82
+ self.pulse_tensor[10, 10] = 1
83
+
84
+ def __getitem__(self, index):
85
+ if self.file_client is None:
86
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
87
+
88
+ # -------------------------------- Load gt images -------------------------------- #
89
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
90
+ gt_path = self.paths[index]
91
+ # avoid errors caused by high latency in reading files
92
+ retry = 3
93
+ while retry > 0:
94
+ try:
95
+ img_bytes = self.file_client.get(gt_path, 'gt')
96
+ except (IOError, OSError) as e:
97
+ logger = get_root_logger()
98
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
99
+ # change another file to read
100
+ index = random.randint(0, self.__len__())
101
+ gt_path = self.paths[index]
102
+ time.sleep(1) # sleep 1s for occasional server congestion
103
+ else:
104
+ break
105
+ finally:
106
+ retry -= 1
107
+ img_gt = imfrombytes(img_bytes, float32=True)
108
+
109
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
110
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
111
+
112
+ # crop or pad to 400
113
+ # TODO: 400 is hard-coded. You may change it accordingly
114
+ h, w = img_gt.shape[0:2]
115
+ crop_pad_size = 400
116
+ # pad
117
+ if h < crop_pad_size or w < crop_pad_size:
118
+ pad_h = max(0, crop_pad_size - h)
119
+ pad_w = max(0, crop_pad_size - w)
120
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
121
+ # crop
122
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
123
+ h, w = img_gt.shape[0:2]
124
+ # randomly choose top and left coordinates
125
+ top = random.randint(0, h - crop_pad_size)
126
+ left = random.randint(0, w - crop_pad_size)
127
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
128
+
129
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
130
+ kernel_size = random.choice(self.kernel_range)
131
+ if np.random.uniform() < self.opt['sinc_prob']:
132
+ # this sinc filter setting is for kernels ranging from [7, 21]
133
+ if kernel_size < 13:
134
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
135
+ else:
136
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
137
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
138
+ else:
139
+ kernel = random_mixed_kernels(
140
+ self.kernel_list,
141
+ self.kernel_prob,
142
+ kernel_size,
143
+ self.blur_sigma,
144
+ self.blur_sigma, [-math.pi, math.pi],
145
+ self.betag_range,
146
+ self.betap_range,
147
+ noise_range=None)
148
+ # pad kernel
149
+ pad_size = (21 - kernel_size) // 2
150
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
151
+
152
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
153
+ kernel_size = random.choice(self.kernel_range)
154
+ if np.random.uniform() < self.opt['sinc_prob2']:
155
+ if kernel_size < 13:
156
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
157
+ else:
158
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
159
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
160
+ else:
161
+ kernel2 = random_mixed_kernels(
162
+ self.kernel_list2,
163
+ self.kernel_prob2,
164
+ kernel_size,
165
+ self.blur_sigma2,
166
+ self.blur_sigma2, [-math.pi, math.pi],
167
+ self.betag_range2,
168
+ self.betap_range2,
169
+ noise_range=None)
170
+
171
+ # pad kernel
172
+ pad_size = (21 - kernel_size) // 2
173
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
174
+
175
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
176
+ if np.random.uniform() < self.opt['final_sinc_prob']:
177
+ kernel_size = random.choice(self.kernel_range)
178
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
179
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
180
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
181
+ else:
182
+ sinc_kernel = self.pulse_tensor
183
+
184
+ # BGR to RGB, HWC to CHW, numpy to tensor
185
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
186
+ kernel = torch.FloatTensor(kernel)
187
+ kernel2 = torch.FloatTensor(kernel2)
188
+
189
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
190
+ return return_d
191
+
192
+ def __len__(self):
193
+ return len(self.paths)
basicsr/data/realesrgan_paired_dataset.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils import data as data
3
+ from torchvision.transforms.functional import normalize
4
+
5
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
6
+ from basicsr.data.transforms import augment, paired_random_crop
7
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
8
+ from basicsr.utils.registry import DATASET_REGISTRY
9
+
10
+
11
+ @DATASET_REGISTRY.register(suffix='basicsr')
12
+ class RealESRGANPairedDataset(data.Dataset):
13
+ """Paired image dataset for image restoration.
14
+
15
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
16
+
17
+ There are three modes:
18
+
19
+ 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
20
+ 2. **meta_info_file**: Use meta information file to generate paths. \
21
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
22
+ 3. **folder**: Scan folders to generate paths. The rest.
23
+
24
+ Args:
25
+ opt (dict): Config for train datasets. It contains the following keys:
26
+ dataroot_gt (str): Data root path for gt.
27
+ dataroot_lq (str): Data root path for lq.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31
+ Default: '{}'.
32
+ gt_size (int): Cropped patched size for gt patches.
33
+ use_hflip (bool): Use horizontal flips.
34
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
35
+ scale (bool): Scale, which will be added automatically.
36
+ phase (str): 'train' or 'val'.
37
+ """
38
+
39
+ def __init__(self, opt):
40
+ super(RealESRGANPairedDataset, self).__init__()
41
+ self.opt = opt
42
+ self.file_client = None
43
+ self.io_backend_opt = opt['io_backend']
44
+ # mean and std for normalizing the input images
45
+ self.mean = opt['mean'] if 'mean' in opt else None
46
+ self.std = opt['std'] if 'std' in opt else None
47
+
48
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
49
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
50
+
51
+ # file client (lmdb io backend)
52
+ if self.io_backend_opt['type'] == 'lmdb':
53
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
54
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
55
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
56
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
57
+ # disk backend with meta_info
58
+ # Each line in the meta_info describes the relative path to an image
59
+ with open(self.opt['meta_info']) as fin:
60
+ paths = [line.strip() for line in fin]
61
+ self.paths = []
62
+ for path in paths:
63
+ gt_path, lq_path = path.split(', ')
64
+ gt_path = os.path.join(self.gt_folder, gt_path)
65
+ lq_path = os.path.join(self.lq_folder, lq_path)
66
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
67
+ else:
68
+ # disk backend
69
+ # it will scan the whole folder to get meta info
70
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
71
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
72
+
73
+ def __getitem__(self, index):
74
+ if self.file_client is None:
75
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
76
+
77
+ scale = self.opt['scale']
78
+
79
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
80
+ # image range: [0, 1], float32.
81
+ gt_path = self.paths[index]['gt_path']
82
+ img_bytes = self.file_client.get(gt_path, 'gt')
83
+ img_gt = imfrombytes(img_bytes, float32=True)
84
+ lq_path = self.paths[index]['lq_path']
85
+ img_bytes = self.file_client.get(lq_path, 'lq')
86
+ img_lq = imfrombytes(img_bytes, float32=True)
87
+
88
+ # augmentation for training
89
+ if self.opt['phase'] == 'train':
90
+ gt_size = self.opt['gt_size']
91
+ # random crop
92
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
93
+ # flip, rotation
94
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
95
+
96
+ # BGR to RGB, HWC to CHW, numpy to tensor
97
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
98
+ # normalize
99
+ if self.mean is not None or self.std is not None:
100
+ normalize(img_lq, self.mean, self.std, inplace=True)
101
+ normalize(img_gt, self.mean, self.std, inplace=True)
102
+
103
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
104
+
105
+ def __len__(self):
106
+ return len(self.paths)
basicsr/data/reds_dataset.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from pathlib import Path
5
+ from torch.utils import data as data
6
+
7
+ from basicsr.data.transforms import augment, paired_random_crop
8
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
+ from basicsr.utils.flow_util import dequantize_flow
10
+ from basicsr.utils.registry import DATASET_REGISTRY
11
+
12
+
13
+ @DATASET_REGISTRY.register()
14
+ class REDSDataset(data.Dataset):
15
+ """REDS dataset for training.
16
+
17
+ The keys are generated from a meta info txt file.
18
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
19
+
20
+ Each line contains:
21
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
22
+ a white space.
23
+ Examples:
24
+ 000 100 (720,1280,3)
25
+ 001 100 (720,1280,3)
26
+ ...
27
+
28
+ Key examples: "000/00000000"
29
+ GT (gt): Ground-Truth;
30
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
31
+
32
+ Args:
33
+ opt (dict): Config for train dataset. It contains the following keys:
34
+ dataroot_gt (str): Data root path for gt.
35
+ dataroot_lq (str): Data root path for lq.
36
+ dataroot_flow (str, optional): Data root path for flow.
37
+ meta_info_file (str): Path for meta information file.
38
+ val_partition (str): Validation partition types. 'REDS4' or 'official'.
39
+ io_backend (dict): IO backend type and other kwarg.
40
+ num_frame (int): Window size for input frames.
41
+ gt_size (int): Cropped patched size for gt patches.
42
+ interval_list (list): Interval list for temporal augmentation.
43
+ random_reverse (bool): Random reverse input frames.
44
+ use_hflip (bool): Use horizontal flips.
45
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
46
+ scale (bool): Scale, which will be added automatically.
47
+ """
48
+
49
+ def __init__(self, opt):
50
+ super(REDSDataset, self).__init__()
51
+ self.opt = opt
52
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
53
+ self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
54
+ assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
55
+ self.num_frame = opt['num_frame']
56
+ self.num_half_frames = opt['num_frame'] // 2
57
+
58
+ self.keys = []
59
+ with open(opt['meta_info_file'], 'r') as fin:
60
+ for line in fin:
61
+ folder, frame_num, _ = line.split(' ')
62
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
63
+
64
+ # remove the video clips used in validation
65
+ if opt['val_partition'] == 'REDS4':
66
+ val_partition = ['000', '011', '015', '020']
67
+ elif opt['val_partition'] == 'official':
68
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
69
+ else:
70
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
71
+ f"Supported ones are ['official', 'REDS4'].")
72
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
73
+
74
+ # file client (io backend)
75
+ self.file_client = None
76
+ self.io_backend_opt = opt['io_backend']
77
+ self.is_lmdb = False
78
+ if self.io_backend_opt['type'] == 'lmdb':
79
+ self.is_lmdb = True
80
+ if self.flow_root is not None:
81
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
82
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
83
+ else:
84
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
85
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
86
+
87
+ # temporal augmentation configs
88
+ self.interval_list = opt['interval_list']
89
+ self.random_reverse = opt['random_reverse']
90
+ interval_str = ','.join(str(x) for x in opt['interval_list'])
91
+ logger = get_root_logger()
92
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
93
+ f'random reverse is {self.random_reverse}.')
94
+
95
+ def __getitem__(self, index):
96
+ if self.file_client is None:
97
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
98
+
99
+ scale = self.opt['scale']
100
+ gt_size = self.opt['gt_size']
101
+ key = self.keys[index]
102
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
103
+ center_frame_idx = int(frame_name)
104
+
105
+ # determine the neighboring frames
106
+ interval = random.choice(self.interval_list)
107
+
108
+ # ensure not exceeding the borders
109
+ start_frame_idx = center_frame_idx - self.num_half_frames * interval
110
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
111
+ # each clip has 100 frames starting from 0 to 99
112
+ while (start_frame_idx < 0) or (end_frame_idx > 99):
113
+ center_frame_idx = random.randint(0, 99)
114
+ start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
115
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
116
+ frame_name = f'{center_frame_idx:08d}'
117
+ neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
118
+ # random reverse
119
+ if self.random_reverse and random.random() < 0.5:
120
+ neighbor_list.reverse()
121
+
122
+ assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
123
+
124
+ # get the GT frame (as the center frame)
125
+ if self.is_lmdb:
126
+ img_gt_path = f'{clip_name}/{frame_name}'
127
+ else:
128
+ img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
129
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
130
+ img_gt = imfrombytes(img_bytes, float32=True)
131
+
132
+ # get the neighboring LQ frames
133
+ img_lqs = []
134
+ for neighbor in neighbor_list:
135
+ if self.is_lmdb:
136
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
137
+ else:
138
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
139
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
140
+ img_lq = imfrombytes(img_bytes, float32=True)
141
+ img_lqs.append(img_lq)
142
+
143
+ # get flows
144
+ if self.flow_root is not None:
145
+ img_flows = []
146
+ # read previous flows
147
+ for i in range(self.num_half_frames, 0, -1):
148
+ if self.is_lmdb:
149
+ flow_path = f'{clip_name}/{frame_name}_p{i}'
150
+ else:
151
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
152
+ img_bytes = self.file_client.get(flow_path, 'flow')
153
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
154
+ dx, dy = np.split(cat_flow, 2, axis=0)
155
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
156
+ img_flows.append(flow)
157
+ # read next flows
158
+ for i in range(1, self.num_half_frames + 1):
159
+ if self.is_lmdb:
160
+ flow_path = f'{clip_name}/{frame_name}_n{i}'
161
+ else:
162
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
163
+ img_bytes = self.file_client.get(flow_path, 'flow')
164
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
165
+ dx, dy = np.split(cat_flow, 2, axis=0)
166
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
167
+ img_flows.append(flow)
168
+
169
+ # for random crop, here, img_flows and img_lqs have the same
170
+ # spatial size
171
+ img_lqs.extend(img_flows)
172
+
173
+ # randomly crop
174
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
175
+ if self.flow_root is not None:
176
+ img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
177
+
178
+ # augmentation - flip, rotate
179
+ img_lqs.append(img_gt)
180
+ if self.flow_root is not None:
181
+ img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
182
+ else:
183
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
184
+
185
+ img_results = img2tensor(img_results)
186
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
187
+ img_gt = img_results[-1]
188
+
189
+ if self.flow_root is not None:
190
+ img_flows = img2tensor(img_flows)
191
+ # add the zero center flow
192
+ img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
193
+ img_flows = torch.stack(img_flows, dim=0)
194
+
195
+ # img_lqs: (t, c, h, w)
196
+ # img_flows: (t, 2, h, w)
197
+ # img_gt: (c, h, w)
198
+ # key: str
199
+ if self.flow_root is not None:
200
+ return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
201
+ else:
202
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
203
+
204
+ def __len__(self):
205
+ return len(self.keys)
206
+
207
+
208
+ @DATASET_REGISTRY.register()
209
+ class REDSRecurrentDataset(data.Dataset):
210
+ """REDS dataset for training recurrent networks.
211
+
212
+ The keys are generated from a meta info txt file.
213
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
214
+
215
+ Each line contains:
216
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
217
+ a white space.
218
+ Examples:
219
+ 000 100 (720,1280,3)
220
+ 001 100 (720,1280,3)
221
+ ...
222
+
223
+ Key examples: "000/00000000"
224
+ GT (gt): Ground-Truth;
225
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
226
+
227
+ Args:
228
+ opt (dict): Config for train dataset. It contains the following keys:
229
+ dataroot_gt (str): Data root path for gt.
230
+ dataroot_lq (str): Data root path for lq.
231
+ dataroot_flow (str, optional): Data root path for flow.
232
+ meta_info_file (str): Path for meta information file.
233
+ val_partition (str): Validation partition types. 'REDS4' or 'official'.
234
+ io_backend (dict): IO backend type and other kwarg.
235
+ num_frame (int): Window size for input frames.
236
+ gt_size (int): Cropped patched size for gt patches.
237
+ interval_list (list): Interval list for temporal augmentation.
238
+ random_reverse (bool): Random reverse input frames.
239
+ use_hflip (bool): Use horizontal flips.
240
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
241
+ scale (bool): Scale, which will be added automatically.
242
+ """
243
+
244
+ def __init__(self, opt):
245
+ super(REDSRecurrentDataset, self).__init__()
246
+ self.opt = opt
247
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
248
+ self.num_frame = opt['num_frame']
249
+
250
+ self.keys = []
251
+ with open(opt['meta_info_file'], 'r') as fin:
252
+ for line in fin:
253
+ folder, frame_num, _ = line.split(' ')
254
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
255
+
256
+ # remove the video clips used in validation
257
+ if opt['val_partition'] == 'REDS4':
258
+ val_partition = ['000', '011', '015', '020']
259
+ elif opt['val_partition'] == 'official':
260
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
261
+ else:
262
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
263
+ f"Supported ones are ['official', 'REDS4'].")
264
+ if opt['test_mode']:
265
+ self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
266
+ else:
267
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
268
+
269
+ # file client (io backend)
270
+ self.file_client = None
271
+ self.io_backend_opt = opt['io_backend']
272
+ self.is_lmdb = False
273
+ if self.io_backend_opt['type'] == 'lmdb':
274
+ self.is_lmdb = True
275
+ if hasattr(self, 'flow_root') and self.flow_root is not None:
276
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
277
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
278
+ else:
279
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
280
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
281
+
282
+ # temporal augmentation configs
283
+ self.interval_list = opt.get('interval_list', [1])
284
+ self.random_reverse = opt.get('random_reverse', False)
285
+ interval_str = ','.join(str(x) for x in self.interval_list)
286
+ logger = get_root_logger()
287
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
288
+ f'random reverse is {self.random_reverse}.')
289
+
290
+ def __getitem__(self, index):
291
+ if self.file_client is None:
292
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
293
+
294
+ scale = self.opt['scale']
295
+ gt_size = self.opt['gt_size']
296
+ key = self.keys[index]
297
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
298
+
299
+ # determine the neighboring frames
300
+ interval = random.choice(self.interval_list)
301
+
302
+ # ensure not exceeding the borders
303
+ start_frame_idx = int(frame_name)
304
+ if start_frame_idx > 100 - self.num_frame * interval:
305
+ start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
306
+ end_frame_idx = start_frame_idx + self.num_frame * interval
307
+
308
+ neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
309
+
310
+ # random reverse
311
+ if self.random_reverse and random.random() < 0.5:
312
+ neighbor_list.reverse()
313
+
314
+ # get the neighboring LQ and GT frames
315
+ img_lqs = []
316
+ img_gts = []
317
+ for neighbor in neighbor_list:
318
+ if self.is_lmdb:
319
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
320
+ img_gt_path = f'{clip_name}/{neighbor:08d}'
321
+ else:
322
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
323
+ img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
324
+
325
+ # get LQ
326
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
327
+ img_lq = imfrombytes(img_bytes, float32=True)
328
+ img_lqs.append(img_lq)
329
+
330
+ # get GT
331
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
332
+ img_gt = imfrombytes(img_bytes, float32=True)
333
+ img_gts.append(img_gt)
334
+
335
+ # randomly crop
336
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
337
+
338
+ # augmentation - flip, rotate
339
+ img_lqs.extend(img_gts)
340
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
341
+
342
+ img_results = img2tensor(img_results)
343
+ img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
344
+ img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
345
+
346
+ # img_lqs: (t, c, h, w)
347
+ # img_gts: (t, c, h, w)
348
+ # key: str
349
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
350
+
351
+ def __len__(self):
352
+ return len(self.keys)
basicsr/data/single_image_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ from torch.utils import data as data
3
+ from torchvision.transforms.functional import normalize
4
+
5
+ from basicsr.data.data_util import paths_from_lmdb
6
+ from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
7
+ from basicsr.utils.registry import DATASET_REGISTRY
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class SingleImageDataset(data.Dataset):
12
+ """Read only lq images in the test phase.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
15
+
16
+ There are two modes:
17
+ 1. 'meta_info_file': Use meta information file to generate paths.
18
+ 2. 'folder': Scan folders to generate paths.
19
+
20
+ Args:
21
+ opt (dict): Config for train datasets. It contains the following keys:
22
+ dataroot_lq (str): Data root path for lq.
23
+ meta_info_file (str): Path for meta information file.
24
+ io_backend (dict): IO backend type and other kwarg.
25
+ """
26
+
27
+ def __init__(self, opt):
28
+ super(SingleImageDataset, self).__init__()
29
+ self.opt = opt
30
+ # file client (io backend)
31
+ self.file_client = None
32
+ self.io_backend_opt = opt['io_backend']
33
+ self.mean = opt['mean'] if 'mean' in opt else None
34
+ self.std = opt['std'] if 'std' in opt else None
35
+ self.lq_folder = opt['dataroot_lq']
36
+
37
+ if self.io_backend_opt['type'] == 'lmdb':
38
+ self.io_backend_opt['db_paths'] = [self.lq_folder]
39
+ self.io_backend_opt['client_keys'] = ['lq']
40
+ self.paths = paths_from_lmdb(self.lq_folder)
41
+ elif 'meta_info_file' in self.opt:
42
+ with open(self.opt['meta_info_file'], 'r') as fin:
43
+ self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
44
+ else:
45
+ self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
46
+
47
+ def __getitem__(self, index):
48
+ if self.file_client is None:
49
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50
+
51
+ # load lq image
52
+ lq_path = self.paths[index]
53
+ img_bytes = self.file_client.get(lq_path, 'lq')
54
+ img_lq = imfrombytes(img_bytes, float32=True)
55
+
56
+ # color space transform
57
+ if 'color' in self.opt and self.opt['color'] == 'y':
58
+ img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
59
+
60
+ # BGR to RGB, HWC to CHW, numpy to tensor
61
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
62
+ # normalize
63
+ if self.mean is not None or self.std is not None:
64
+ normalize(img_lq, self.mean, self.std, inplace=True)
65
+ return {'lq': img_lq, 'lq_path': lq_path}
66
+
67
+ def __len__(self):
68
+ return len(self.paths)
basicsr/data/transforms.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import torch
4
+
5
+
6
+ def mod_crop(img, scale):
7
+ """Mod crop images, used during testing.
8
+
9
+ Args:
10
+ img (ndarray): Input image.
11
+ scale (int): Scale factor.
12
+
13
+ Returns:
14
+ ndarray: Result image.
15
+ """
16
+ img = img.copy()
17
+ if img.ndim in (2, 3):
18
+ h, w = img.shape[0], img.shape[1]
19
+ h_remainder, w_remainder = h % scale, w % scale
20
+ img = img[:h - h_remainder, :w - w_remainder, ...]
21
+ else:
22
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
23
+ return img
24
+
25
+
26
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
27
+ """Paired random crop. Support Numpy array and Tensor inputs.
28
+
29
+ It crops lists of lq and gt images with corresponding locations.
30
+
31
+ Args:
32
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
33
+ should have the same shape. If the input is an ndarray, it will
34
+ be transformed to a list containing itself.
35
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
36
+ should have the same shape. If the input is an ndarray, it will
37
+ be transformed to a list containing itself.
38
+ gt_patch_size (int): GT patch size.
39
+ scale (int): Scale factor.
40
+ gt_path (str): Path to ground-truth. Default: None.
41
+
42
+ Returns:
43
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
44
+ only have one element, just return ndarray.
45
+ """
46
+
47
+ if not isinstance(img_gts, list):
48
+ img_gts = [img_gts]
49
+ if not isinstance(img_lqs, list):
50
+ img_lqs = [img_lqs]
51
+
52
+ # determine input type: Numpy array or Tensor
53
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
54
+
55
+ if input_type == 'Tensor':
56
+ h_lq, w_lq = img_lqs[0].size()[-2:]
57
+ h_gt, w_gt = img_gts[0].size()[-2:]
58
+ else:
59
+ h_lq, w_lq = img_lqs[0].shape[0:2]
60
+ h_gt, w_gt = img_gts[0].shape[0:2]
61
+ lq_patch_size = gt_patch_size // scale
62
+
63
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
64
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
65
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
66
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
67
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
68
+ f'({lq_patch_size}, {lq_patch_size}). '
69
+ f'Please remove {gt_path}.')
70
+
71
+ # randomly choose top and left coordinates for lq patch
72
+ top = random.randint(0, h_lq - lq_patch_size)
73
+ left = random.randint(0, w_lq - lq_patch_size)
74
+
75
+ # crop lq patch
76
+ if input_type == 'Tensor':
77
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
78
+ else:
79
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
80
+
81
+ # crop corresponding gt patch
82
+ top_gt, left_gt = int(top * scale), int(left * scale)
83
+ if input_type == 'Tensor':
84
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
85
+ else:
86
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
87
+ if len(img_gts) == 1:
88
+ img_gts = img_gts[0]
89
+ if len(img_lqs) == 1:
90
+ img_lqs = img_lqs[0]
91
+ return img_gts, img_lqs
92
+
93
+
94
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
95
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
96
+
97
+ We use vertical flip and transpose for rotation implementation.
98
+ All the images in the list use the same augmentation.
99
+
100
+ Args:
101
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
102
+ is an ndarray, it will be transformed to a list.
103
+ hflip (bool): Horizontal flip. Default: True.
104
+ rotation (bool): Ratotation. Default: True.
105
+ flows (list[ndarray]: Flows to be augmented. If the input is an
106
+ ndarray, it will be transformed to a list.
107
+ Dimension is (h, w, 2). Default: None.
108
+ return_status (bool): Return the status of flip and rotation.
109
+ Default: False.
110
+
111
+ Returns:
112
+ list[ndarray] | ndarray: Augmented images and flows. If returned
113
+ results only have one element, just return ndarray.
114
+
115
+ """
116
+ hflip = hflip and random.random() < 0.5
117
+ vflip = rotation and random.random() < 0.5
118
+ rot90 = rotation and random.random() < 0.5
119
+
120
+ def _augment(img):
121
+ if hflip: # horizontal
122
+ cv2.flip(img, 1, img)
123
+ if vflip: # vertical
124
+ cv2.flip(img, 0, img)
125
+ if rot90:
126
+ img = img.transpose(1, 0, 2)
127
+ return img
128
+
129
+ def _augment_flow(flow):
130
+ if hflip: # horizontal
131
+ cv2.flip(flow, 1, flow)
132
+ flow[:, :, 0] *= -1
133
+ if vflip: # vertical
134
+ cv2.flip(flow, 0, flow)
135
+ flow[:, :, 1] *= -1
136
+ if rot90:
137
+ flow = flow.transpose(1, 0, 2)
138
+ flow = flow[:, :, [1, 0]]
139
+ return flow
140
+
141
+ if not isinstance(imgs, list):
142
+ imgs = [imgs]
143
+ imgs = [_augment(img) for img in imgs]
144
+ if len(imgs) == 1:
145
+ imgs = imgs[0]
146
+
147
+ if flows is not None:
148
+ if not isinstance(flows, list):
149
+ flows = [flows]
150
+ flows = [_augment_flow(flow) for flow in flows]
151
+ if len(flows) == 1:
152
+ flows = flows[0]
153
+ return imgs, flows
154
+ else:
155
+ if return_status:
156
+ return imgs, (hflip, vflip, rot90)
157
+ else:
158
+ return imgs
159
+
160
+
161
+ def img_rotate(img, angle, center=None, scale=1.0):
162
+ """Rotate image.
163
+
164
+ Args:
165
+ img (ndarray): Image to be rotated.
166
+ angle (float): Rotation angle in degrees. Positive values mean
167
+ counter-clockwise rotation.
168
+ center (tuple[int]): Rotation center. If the center is None,
169
+ initialize it as the center of the image. Default: None.
170
+ scale (float): Isotropic scale factor. Default: 1.0.
171
+ """
172
+ (h, w) = img.shape[:2]
173
+
174
+ if center is None:
175
+ center = (w // 2, h // 2)
176
+
177
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
178
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
179
+ return rotated_img
basicsr/data/video_test_dataset.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import torch
3
+ from os import path as osp
4
+ from torch.utils import data as data
5
+
6
+ from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
7
+ from basicsr.utils import get_root_logger, scandir
8
+ from basicsr.utils.registry import DATASET_REGISTRY
9
+
10
+
11
+ @DATASET_REGISTRY.register()
12
+ class VideoTestDataset(data.Dataset):
13
+ """Video test dataset.
14
+
15
+ Supported datasets: Vid4, REDS4, REDSofficial.
16
+ More generally, it supports testing dataset with following structures:
17
+
18
+ ::
19
+
20
+ dataroot
21
+ β”œβ”€β”€ subfolder1
22
+ β”œβ”€β”€ frame000
23
+ β”œβ”€β”€ frame001
24
+ β”œβ”€β”€ ...
25
+ β”œβ”€β”€ subfolder2
26
+ β”œβ”€β”€ frame000
27
+ β”œβ”€β”€ frame001
28
+ β”œβ”€β”€ ...
29
+ β”œβ”€β”€ ...
30
+
31
+ For testing datasets, there is no need to prepare LMDB files.
32
+
33
+ Args:
34
+ opt (dict): Config for train dataset. It contains the following keys:
35
+ dataroot_gt (str): Data root path for gt.
36
+ dataroot_lq (str): Data root path for lq.
37
+ io_backend (dict): IO backend type and other kwarg.
38
+ cache_data (bool): Whether to cache testing datasets.
39
+ name (str): Dataset name.
40
+ meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
41
+ in the dataroot will be used.
42
+ num_frame (int): Window size for input frames.
43
+ padding (str): Padding mode.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ super(VideoTestDataset, self).__init__()
48
+ self.opt = opt
49
+ self.cache_data = opt['cache_data']
50
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
51
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
52
+ # file client (io backend)
53
+ self.file_client = None
54
+ self.io_backend_opt = opt['io_backend']
55
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
56
+
57
+ logger = get_root_logger()
58
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
59
+ self.imgs_lq, self.imgs_gt = {}, {}
60
+ if 'meta_info_file' in opt:
61
+ with open(opt['meta_info_file'], 'r') as fin:
62
+ subfolders = [line.split(' ')[0] for line in fin]
63
+ subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
64
+ subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
65
+ else:
66
+ subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
67
+ subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
68
+
69
+ if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
70
+ for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
71
+ # get frame list for lq and gt
72
+ subfolder_name = osp.basename(subfolder_lq)
73
+ img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
74
+ img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
75
+
76
+ max_idx = len(img_paths_lq)
77
+ assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
78
+ f' and gt folders ({len(img_paths_gt)})')
79
+
80
+ self.data_info['lq_path'].extend(img_paths_lq)
81
+ self.data_info['gt_path'].extend(img_paths_gt)
82
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
83
+ for i in range(max_idx):
84
+ self.data_info['idx'].append(f'{i}/{max_idx}')
85
+ border_l = [0] * max_idx
86
+ for i in range(self.opt['num_frame'] // 2):
87
+ border_l[i] = 1
88
+ border_l[max_idx - i - 1] = 1
89
+ self.data_info['border'].extend(border_l)
90
+
91
+ # cache data or save the frame list
92
+ if self.cache_data:
93
+ logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
94
+ self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
95
+ self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
96
+ else:
97
+ self.imgs_lq[subfolder_name] = img_paths_lq
98
+ self.imgs_gt[subfolder_name] = img_paths_gt
99
+ else:
100
+ raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
101
+
102
+ def __getitem__(self, index):
103
+ folder = self.data_info['folder'][index]
104
+ idx, max_idx = self.data_info['idx'][index].split('/')
105
+ idx, max_idx = int(idx), int(max_idx)
106
+ border = self.data_info['border'][index]
107
+ lq_path = self.data_info['lq_path'][index]
108
+
109
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
110
+
111
+ if self.cache_data:
112
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
113
+ img_gt = self.imgs_gt[folder][idx]
114
+ else:
115
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
116
+ imgs_lq = read_img_seq(img_paths_lq)
117
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]])
118
+ img_gt.squeeze_(0)
119
+
120
+ return {
121
+ 'lq': imgs_lq, # (t, c, h, w)
122
+ 'gt': img_gt, # (c, h, w)
123
+ 'folder': folder, # folder name
124
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
125
+ 'border': border, # 1 for border, 0 for non-border
126
+ 'lq_path': lq_path # center frame
127
+ }
128
+
129
+ def __len__(self):
130
+ return len(self.data_info['gt_path'])
131
+
132
+
133
+ @DATASET_REGISTRY.register()
134
+ class VideoTestVimeo90KDataset(data.Dataset):
135
+ """Video test dataset for Vimeo90k-Test dataset.
136
+
137
+ It only keeps the center frame for testing.
138
+ For testing datasets, there is no need to prepare LMDB files.
139
+
140
+ Args:
141
+ opt (dict): Config for train dataset. It contains the following keys:
142
+ dataroot_gt (str): Data root path for gt.
143
+ dataroot_lq (str): Data root path for lq.
144
+ io_backend (dict): IO backend type and other kwarg.
145
+ cache_data (bool): Whether to cache testing datasets.
146
+ name (str): Dataset name.
147
+ meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
148
+ in the dataroot will be used.
149
+ num_frame (int): Window size for input frames.
150
+ padding (str): Padding mode.
151
+ """
152
+
153
+ def __init__(self, opt):
154
+ super(VideoTestVimeo90KDataset, self).__init__()
155
+ self.opt = opt
156
+ self.cache_data = opt['cache_data']
157
+ if self.cache_data:
158
+ raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
159
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
160
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
161
+ neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
162
+
163
+ # file client (io backend)
164
+ self.file_client = None
165
+ self.io_backend_opt = opt['io_backend']
166
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
167
+
168
+ logger = get_root_logger()
169
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
170
+ with open(opt['meta_info_file'], 'r') as fin:
171
+ subfolders = [line.split(' ')[0] for line in fin]
172
+ for idx, subfolder in enumerate(subfolders):
173
+ gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
174
+ self.data_info['gt_path'].append(gt_path)
175
+ lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
176
+ self.data_info['lq_path'].append(lq_paths)
177
+ self.data_info['folder'].append('vimeo90k')
178
+ self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
179
+ self.data_info['border'].append(0)
180
+
181
+ def __getitem__(self, index):
182
+ lq_path = self.data_info['lq_path'][index]
183
+ gt_path = self.data_info['gt_path'][index]
184
+ imgs_lq = read_img_seq(lq_path)
185
+ img_gt = read_img_seq([gt_path])
186
+ img_gt.squeeze_(0)
187
+
188
+ return {
189
+ 'lq': imgs_lq, # (t, c, h, w)
190
+ 'gt': img_gt, # (c, h, w)
191
+ 'folder': self.data_info['folder'][index], # folder name
192
+ 'idx': self.data_info['idx'][index], # e.g., 0/843
193
+ 'border': self.data_info['border'][index], # 0 for non-border
194
+ 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
195
+ }
196
+
197
+ def __len__(self):
198
+ return len(self.data_info['gt_path'])
199
+
200
+
201
+ @DATASET_REGISTRY.register()
202
+ class VideoTestDUFDataset(VideoTestDataset):
203
+ """ Video test dataset for DUF dataset.
204
+
205
+ Args:
206
+ opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
207
+ It has the following extra keys:
208
+ use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
209
+ scale (bool): Scale, which will be added automatically.
210
+ """
211
+
212
+ def __getitem__(self, index):
213
+ folder = self.data_info['folder'][index]
214
+ idx, max_idx = self.data_info['idx'][index].split('/')
215
+ idx, max_idx = int(idx), int(max_idx)
216
+ border = self.data_info['border'][index]
217
+ lq_path = self.data_info['lq_path'][index]
218
+
219
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
220
+
221
+ if self.cache_data:
222
+ if self.opt['use_duf_downsampling']:
223
+ # read imgs_gt to generate low-resolution frames
224
+ imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
225
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
226
+ else:
227
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
228
+ img_gt = self.imgs_gt[folder][idx]
229
+ else:
230
+ if self.opt['use_duf_downsampling']:
231
+ img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
232
+ # read imgs_gt to generate low-resolution frames
233
+ imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
234
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
235
+ else:
236
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
237
+ imgs_lq = read_img_seq(img_paths_lq)
238
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
239
+ img_gt.squeeze_(0)
240
+
241
+ return {
242
+ 'lq': imgs_lq, # (t, c, h, w)
243
+ 'gt': img_gt, # (c, h, w)
244
+ 'folder': folder, # folder name
245
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
246
+ 'border': border, # 1 for border, 0 for non-border
247
+ 'lq_path': lq_path # center frame
248
+ }
249
+
250
+
251
+ @DATASET_REGISTRY.register()
252
+ class VideoRecurrentTestDataset(VideoTestDataset):
253
+ """Video test dataset for recurrent architectures, which takes LR video
254
+ frames as input and output corresponding HR video frames.
255
+
256
+ Args:
257
+ opt (dict): Same as VideoTestDataset. Unused opt:
258
+ padding (str): Padding mode.
259
+
260
+ """
261
+
262
+ def __init__(self, opt):
263
+ super(VideoRecurrentTestDataset, self).__init__(opt)
264
+ # Find unique folder strings
265
+ self.folders = sorted(list(set(self.data_info['folder'])))
266
+
267
+ def __getitem__(self, index):
268
+ folder = self.folders[index]
269
+
270
+ if self.cache_data:
271
+ imgs_lq = self.imgs_lq[folder]
272
+ imgs_gt = self.imgs_gt[folder]
273
+ else:
274
+ raise NotImplementedError('Without cache_data is not implemented.')
275
+
276
+ return {
277
+ 'lq': imgs_lq,
278
+ 'gt': imgs_gt,
279
+ 'folder': folder,
280
+ }
281
+
282
+ def __len__(self):
283
+ return len(self.folders)
basicsr/data/vimeo90k_dataset.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from pathlib import Path
4
+ from torch.utils import data as data
5
+
6
+ from basicsr.data.transforms import augment, paired_random_crop
7
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8
+ from basicsr.utils.registry import DATASET_REGISTRY
9
+
10
+
11
+ @DATASET_REGISTRY.register()
12
+ class Vimeo90KDataset(data.Dataset):
13
+ """Vimeo90K dataset for training.
14
+
15
+ The keys are generated from a meta info txt file.
16
+ basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
17
+
18
+ Each line contains the following items, separated by a white space.
19
+
20
+ 1. clip name;
21
+ 2. frame number;
22
+ 3. image shape
23
+
24
+ Examples:
25
+
26
+ ::
27
+
28
+ 00001/0001 7 (256,448,3)
29
+ 00001/0002 7 (256,448,3)
30
+
31
+ - Key examples: "00001/0001"
32
+ - GT (gt): Ground-Truth;
33
+ - LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
34
+
35
+ The neighboring frame list for different num_frame:
36
+
37
+ ::
38
+
39
+ num_frame | frame list
40
+ 1 | 4
41
+ 3 | 3,4,5
42
+ 5 | 2,3,4,5,6
43
+ 7 | 1,2,3,4,5,6,7
44
+
45
+ Args:
46
+ opt (dict): Config for train dataset. It contains the following keys:
47
+ dataroot_gt (str): Data root path for gt.
48
+ dataroot_lq (str): Data root path for lq.
49
+ meta_info_file (str): Path for meta information file.
50
+ io_backend (dict): IO backend type and other kwarg.
51
+ num_frame (int): Window size for input frames.
52
+ gt_size (int): Cropped patched size for gt patches.
53
+ random_reverse (bool): Random reverse input frames.
54
+ use_hflip (bool): Use horizontal flips.
55
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
56
+ scale (bool): Scale, which will be added automatically.
57
+ """
58
+
59
+ def __init__(self, opt):
60
+ super(Vimeo90KDataset, self).__init__()
61
+ self.opt = opt
62
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
63
+
64
+ with open(opt['meta_info_file'], 'r') as fin:
65
+ self.keys = [line.split(' ')[0] for line in fin]
66
+
67
+ # file client (io backend)
68
+ self.file_client = None
69
+ self.io_backend_opt = opt['io_backend']
70
+ self.is_lmdb = False
71
+ if self.io_backend_opt['type'] == 'lmdb':
72
+ self.is_lmdb = True
73
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
74
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
75
+
76
+ # indices of input images
77
+ self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
78
+
79
+ # temporal augmentation configs
80
+ self.random_reverse = opt['random_reverse']
81
+ logger = get_root_logger()
82
+ logger.info(f'Random reverse is {self.random_reverse}.')
83
+
84
+ def __getitem__(self, index):
85
+ if self.file_client is None:
86
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
87
+
88
+ # random reverse
89
+ if self.random_reverse and random.random() < 0.5:
90
+ self.neighbor_list.reverse()
91
+
92
+ scale = self.opt['scale']
93
+ gt_size = self.opt['gt_size']
94
+ key = self.keys[index]
95
+ clip, seq = key.split('/') # key example: 00001/0001
96
+
97
+ # get the GT frame (im4.png)
98
+ if self.is_lmdb:
99
+ img_gt_path = f'{key}/im4'
100
+ else:
101
+ img_gt_path = self.gt_root / clip / seq / 'im4.png'
102
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
103
+ img_gt = imfrombytes(img_bytes, float32=True)
104
+
105
+ # get the neighboring LQ frames
106
+ img_lqs = []
107
+ for neighbor in self.neighbor_list:
108
+ if self.is_lmdb:
109
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
110
+ else:
111
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
112
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
113
+ img_lq = imfrombytes(img_bytes, float32=True)
114
+ img_lqs.append(img_lq)
115
+
116
+ # randomly crop
117
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
118
+
119
+ # augmentation - flip, rotate
120
+ img_lqs.append(img_gt)
121
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
122
+
123
+ img_results = img2tensor(img_results)
124
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
125
+ img_gt = img_results[-1]
126
+
127
+ # img_lqs: (t, c, h, w)
128
+ # img_gt: (c, h, w)
129
+ # key: str
130
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
131
+
132
+ def __len__(self):
133
+ return len(self.keys)
134
+
135
+
136
+ @DATASET_REGISTRY.register()
137
+ class Vimeo90KRecurrentDataset(Vimeo90KDataset):
138
+
139
+ def __init__(self, opt):
140
+ super(Vimeo90KRecurrentDataset, self).__init__(opt)
141
+
142
+ self.flip_sequence = opt['flip_sequence']
143
+ self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
144
+
145
+ def __getitem__(self, index):
146
+ if self.file_client is None:
147
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
+
149
+ # random reverse
150
+ if self.random_reverse and random.random() < 0.5:
151
+ self.neighbor_list.reverse()
152
+
153
+ scale = self.opt['scale']
154
+ gt_size = self.opt['gt_size']
155
+ key = self.keys[index]
156
+ clip, seq = key.split('/') # key example: 00001/0001
157
+
158
+ # get the neighboring LQ and GT frames
159
+ img_lqs = []
160
+ img_gts = []
161
+ for neighbor in self.neighbor_list:
162
+ if self.is_lmdb:
163
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
164
+ img_gt_path = f'{clip}/{seq}/im{neighbor}'
165
+ else:
166
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
167
+ img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
168
+ # LQ
169
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
170
+ img_lq = imfrombytes(img_bytes, float32=True)
171
+ # GT
172
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
173
+ img_gt = imfrombytes(img_bytes, float32=True)
174
+
175
+ img_lqs.append(img_lq)
176
+ img_gts.append(img_gt)
177
+
178
+ # randomly crop
179
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
180
+
181
+ # augmentation - flip, rotate
182
+ img_lqs.extend(img_gts)
183
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
184
+
185
+ img_results = img2tensor(img_results)
186
+ img_lqs = torch.stack(img_results[:7], dim=0)
187
+ img_gts = torch.stack(img_results[7:], dim=0)
188
+
189
+ if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
190
+ img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
191
+ img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
192
+
193
+ # img_lqs: (t, c, h, w)
194
+ # img_gt: (c, h, w)
195
+ # key: str
196
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
197
+
198
+ def __len__(self):
199
+ return len(self.keys)