XiangZ commited on
Commit
f3c1de0
·
verified ·
1 Parent(s): b579ad2

Delete utils

Browse files
utils/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- from .file_client import FileClient
2
- from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
3
- from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
4
- from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
5
-
6
- __all__ = [
7
- # file_client.py
8
- 'FileClient',
9
- # img_util.py
10
- 'img2tensor',
11
- 'tensor2img',
12
- 'imfrombytes',
13
- 'imwrite',
14
- 'crop_border',
15
- # logger.py
16
- 'MessageLogger',
17
- 'AvgTimer',
18
- 'init_tb_logger',
19
- 'init_wandb_logger',
20
- 'get_root_logger',
21
- 'get_env_info',
22
- # misc.py
23
- 'set_random_seed',
24
- 'get_time_str',
25
- 'mkdir_and_rename',
26
- 'make_exp_dirs',
27
- 'scandir',
28
- 'check_resume',
29
- 'sizeof_fmt',
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (854 Bytes)
 
utils/__pycache__/dist_util.cpython-38.pyc DELETED
Binary file (2.6 kB)
 
utils/__pycache__/file_client.cpython-38.pyc DELETED
Binary file (6.5 kB)
 
utils/__pycache__/img_util.cpython-38.pyc DELETED
Binary file (6.12 kB)
 
utils/__pycache__/logger.cpython-38.pyc DELETED
Binary file (6.94 kB)
 
utils/__pycache__/matlab_functions.cpython-38.pyc DELETED
Binary file (10.6 kB)
 
utils/__pycache__/misc.cpython-38.pyc DELETED
Binary file (4.37 kB)
 
utils/__pycache__/options.cpython-38.pyc DELETED
Binary file (5.11 kB)
 
utils/__pycache__/registry.cpython-38.pyc DELETED
Binary file (2.61 kB)
 
utils/dist_util.py DELETED
@@ -1,82 +0,0 @@
1
- # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
- import functools
3
- import os
4
- import subprocess
5
- import torch
6
- import torch.distributed as dist
7
- import torch.multiprocessing as mp
8
-
9
-
10
- def init_dist(launcher, backend='nccl', **kwargs):
11
- if mp.get_start_method(allow_none=True) is None:
12
- mp.set_start_method('spawn')
13
- if launcher == 'pytorch':
14
- _init_dist_pytorch(backend, **kwargs)
15
- elif launcher == 'slurm':
16
- _init_dist_slurm(backend, **kwargs)
17
- else:
18
- raise ValueError(f'Invalid launcher type: {launcher}')
19
-
20
-
21
- def _init_dist_pytorch(backend, **kwargs):
22
- rank = int(os.environ['RANK'])
23
- num_gpus = torch.cuda.device_count()
24
- torch.cuda.set_device(rank % num_gpus)
25
- dist.init_process_group(backend=backend, **kwargs)
26
-
27
-
28
- def _init_dist_slurm(backend, port=None):
29
- """Initialize slurm distributed training environment.
30
-
31
- If argument ``port`` is not specified, then the master port will be system
32
- environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33
- environment variable, then a default port ``29500`` will be used.
34
-
35
- Args:
36
- backend (str): Backend of torch.distributed.
37
- port (int, optional): Master port. Defaults to None.
38
- """
39
- proc_id = int(os.environ['SLURM_PROCID'])
40
- ntasks = int(os.environ['SLURM_NTASKS'])
41
- node_list = os.environ['SLURM_NODELIST']
42
- num_gpus = torch.cuda.device_count()
43
- torch.cuda.set_device(proc_id % num_gpus)
44
- addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
45
- # specify master port
46
- if port is not None:
47
- os.environ['MASTER_PORT'] = str(port)
48
- elif 'MASTER_PORT' in os.environ:
49
- pass # use MASTER_PORT in the environment variable
50
- else:
51
- # 29500 is torch.distributed default port
52
- os.environ['MASTER_PORT'] = '29500'
53
- os.environ['MASTER_ADDR'] = addr
54
- os.environ['WORLD_SIZE'] = str(ntasks)
55
- os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
56
- os.environ['RANK'] = str(proc_id)
57
- dist.init_process_group(backend=backend)
58
-
59
-
60
- def get_dist_info():
61
- if dist.is_available():
62
- initialized = dist.is_initialized()
63
- else:
64
- initialized = False
65
- if initialized:
66
- rank = dist.get_rank()
67
- world_size = dist.get_world_size()
68
- else:
69
- rank = 0
70
- world_size = 1
71
- return rank, world_size
72
-
73
-
74
- def master_only(func):
75
-
76
- @functools.wraps(func)
77
- def wrapper(*args, **kwargs):
78
- rank, _ = get_dist_info()
79
- if rank == 0:
80
- return func(*args, **kwargs)
81
-
82
- return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/file_client.py DELETED
@@ -1,167 +0,0 @@
1
- # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2
- from abc import ABCMeta, abstractmethod
3
-
4
-
5
- class BaseStorageBackend(metaclass=ABCMeta):
6
- """Abstract class of storage backends.
7
-
8
- All backends need to implement two apis: ``get()`` and ``get_text()``.
9
- ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10
- as texts.
11
- """
12
-
13
- @abstractmethod
14
- def get(self, filepath):
15
- pass
16
-
17
- @abstractmethod
18
- def get_text(self, filepath):
19
- pass
20
-
21
-
22
- class MemcachedBackend(BaseStorageBackend):
23
- """Memcached storage backend.
24
-
25
- Attributes:
26
- server_list_cfg (str): Config file for memcached server list.
27
- client_cfg (str): Config file for memcached client.
28
- sys_path (str | None): Additional path to be appended to `sys.path`.
29
- Default: None.
30
- """
31
-
32
- def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33
- if sys_path is not None:
34
- import sys
35
- sys.path.append(sys_path)
36
- try:
37
- import mc
38
- except ImportError:
39
- raise ImportError('Please install memcached to enable MemcachedBackend.')
40
-
41
- self.server_list_cfg = server_list_cfg
42
- self.client_cfg = client_cfg
43
- self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
44
- # mc.pyvector servers as a point which points to a memory cache
45
- self._mc_buffer = mc.pyvector()
46
-
47
- def get(self, filepath):
48
- filepath = str(filepath)
49
- import mc
50
- self._client.Get(filepath, self._mc_buffer)
51
- value_buf = mc.ConvertBuffer(self._mc_buffer)
52
- return value_buf
53
-
54
- def get_text(self, filepath):
55
- raise NotImplementedError
56
-
57
-
58
- class HardDiskBackend(BaseStorageBackend):
59
- """Raw hard disks storage backend."""
60
-
61
- def get(self, filepath):
62
- filepath = str(filepath)
63
- with open(filepath, 'rb') as f:
64
- value_buf = f.read()
65
- return value_buf
66
-
67
- def get_text(self, filepath):
68
- filepath = str(filepath)
69
- with open(filepath, 'r') as f:
70
- value_buf = f.read()
71
- return value_buf
72
-
73
-
74
- class LmdbBackend(BaseStorageBackend):
75
- """Lmdb storage backend.
76
-
77
- Args:
78
- db_paths (str | list[str]): Lmdb database paths.
79
- client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
80
- readonly (bool, optional): Lmdb environment parameter. If True,
81
- disallow any write operations. Default: True.
82
- lock (bool, optional): Lmdb environment parameter. If False, when
83
- concurrent access occurs, do not lock the database. Default: False.
84
- readahead (bool, optional): Lmdb environment parameter. If False,
85
- disable the OS filesystem readahead mechanism, which may improve
86
- random read performance when a database is larger than RAM.
87
- Default: False.
88
-
89
- Attributes:
90
- db_paths (list): Lmdb database path.
91
- _client (list): A list of several lmdb envs.
92
- """
93
-
94
- def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
95
- try:
96
- import lmdb
97
- except ImportError:
98
- raise ImportError('Please install lmdb to enable LmdbBackend.')
99
-
100
- if isinstance(client_keys, str):
101
- client_keys = [client_keys]
102
-
103
- if isinstance(db_paths, list):
104
- self.db_paths = [str(v) for v in db_paths]
105
- elif isinstance(db_paths, str):
106
- self.db_paths = [str(db_paths)]
107
- assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
108
- f'but received {len(client_keys)} and {len(self.db_paths)}.')
109
-
110
- self._client = {}
111
- for client, path in zip(client_keys, self.db_paths):
112
- self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
113
-
114
- def get(self, filepath, client_key):
115
- """Get values according to the filepath from one lmdb named client_key.
116
-
117
- Args:
118
- filepath (str | obj:`Path`): Here, filepath is the lmdb key.
119
- client_key (str): Used for distinguishing different lmdb envs.
120
- """
121
- filepath = str(filepath)
122
- assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
123
- client = self._client[client_key]
124
- with client.begin(write=False) as txn:
125
- value_buf = txn.get(filepath.encode('ascii'))
126
- return value_buf
127
-
128
- def get_text(self, filepath):
129
- raise NotImplementedError
130
-
131
-
132
- class FileClient(object):
133
- """A general file client to access files in different backend.
134
-
135
- The client loads a file or text in a specified backend from its path
136
- and return it as a binary file. it can also register other backend
137
- accessor with a given name and backend class.
138
-
139
- Attributes:
140
- backend (str): The storage backend type. Options are "disk",
141
- "memcached" and "lmdb".
142
- client (:obj:`BaseStorageBackend`): The backend object.
143
- """
144
-
145
- _backends = {
146
- 'disk': HardDiskBackend,
147
- 'memcached': MemcachedBackend,
148
- 'lmdb': LmdbBackend,
149
- }
150
-
151
- def __init__(self, backend='disk', **kwargs):
152
- if backend not in self._backends:
153
- raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
154
- f' are {list(self._backends.keys())}')
155
- self.backend = backend
156
- self.client = self._backends[backend](**kwargs)
157
-
158
- def get(self, filepath, client_key='default'):
159
- # client_key is used only for lmdb, where different fileclients have
160
- # different lmdb environments.
161
- if self.backend == 'lmdb':
162
- return self.client.get(filepath, client_key)
163
- else:
164
- return self.client.get(filepath)
165
-
166
- def get_text(self, filepath):
167
- return self.client.get_text(filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/img_util.py DELETED
@@ -1,172 +0,0 @@
1
- import cv2
2
- import math
3
- import numpy as np
4
- import os
5
- import torch
6
- from torchvision.utils import make_grid
7
-
8
-
9
- def img2tensor(imgs, bgr2rgb=True, float32=True):
10
- """Numpy array to tensor.
11
-
12
- Args:
13
- imgs (list[ndarray] | ndarray): Input images.
14
- bgr2rgb (bool): Whether to change bgr to rgb.
15
- float32 (bool): Whether to change to float32.
16
-
17
- Returns:
18
- list[tensor] | tensor: Tensor images. If returned results only have
19
- one element, just return tensor.
20
- """
21
-
22
- def _totensor(img, bgr2rgb, float32):
23
- if img.shape[2] == 3 and bgr2rgb:
24
- if img.dtype == 'float64':
25
- img = img.astype('float32')
26
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
- img = torch.from_numpy(img.transpose(2, 0, 1))
28
- if float32:
29
- img = img.float()
30
- return img
31
-
32
- if isinstance(imgs, list):
33
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
34
- else:
35
- return _totensor(imgs, bgr2rgb, float32)
36
-
37
-
38
- def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39
- """Convert torch Tensors into image numpy arrays.
40
-
41
- After clamping to [min, max], values will be normalized to [0, 1].
42
-
43
- Args:
44
- tensor (Tensor or list[Tensor]): Accept shapes:
45
- 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46
- 2) 3D Tensor of shape (3/1 x H x W);
47
- 3) 2D Tensor of shape (H x W).
48
- Tensor channel should be in RGB order.
49
- rgb2bgr (bool): Whether to change rgb to bgr.
50
- out_type (numpy type): output types. If ``np.uint8``, transform outputs
51
- to uint8 type with range [0, 255]; otherwise, float type with
52
- range [0, 1]. Default: ``np.uint8``.
53
- min_max (tuple[int]): min and max values for clamp.
54
-
55
- Returns:
56
- (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
- shape (H x W). The channel order is BGR.
58
- """
59
- if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
- raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
-
62
- if torch.is_tensor(tensor):
63
- tensor = [tensor]
64
- result = []
65
- for _tensor in tensor:
66
- _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
- _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
-
69
- n_dim = _tensor.dim()
70
- if n_dim == 4:
71
- img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
- img_np = img_np.transpose(1, 2, 0)
73
- if rgb2bgr:
74
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
- elif n_dim == 3:
76
- img_np = _tensor.numpy()
77
- img_np = img_np.transpose(1, 2, 0)
78
- if img_np.shape[2] == 1: # gray image
79
- img_np = np.squeeze(img_np, axis=2)
80
- else:
81
- if rgb2bgr:
82
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
- elif n_dim == 2:
84
- img_np = _tensor.numpy()
85
- else:
86
- raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
- if out_type == np.uint8:
88
- # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
- img_np = (img_np * 255.0).round()
90
- img_np = img_np.astype(out_type)
91
- result.append(img_np)
92
- if len(result) == 1:
93
- result = result[0]
94
- return result
95
-
96
-
97
- def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98
- """This implementation is slightly faster than tensor2img.
99
- It now only supports torch tensor with shape (1, c, h, w).
100
-
101
- Args:
102
- tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103
- rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104
- min_max (tuple[int]): min and max values for clamp.
105
- """
106
- output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107
- output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108
- output = output.type(torch.uint8).cpu().numpy()
109
- if rgb2bgr:
110
- output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111
- return output
112
-
113
-
114
- def imfrombytes(content, flag='color', float32=False):
115
- """Read an image from bytes.
116
-
117
- Args:
118
- content (bytes): Image bytes got from files or other streams.
119
- flag (str): Flags specifying the color type of a loaded image,
120
- candidates are `color`, `grayscale` and `unchanged`.
121
- float32 (bool): Whether to change to float32., If True, will also norm
122
- to [0, 1]. Default: False.
123
-
124
- Returns:
125
- ndarray: Loaded image array.
126
- """
127
- img_np = np.frombuffer(content, np.uint8)
128
- imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129
- img = cv2.imdecode(img_np, imread_flags[flag])
130
- if float32:
131
- img = img.astype(np.float32) / 255.
132
- return img
133
-
134
-
135
- def imwrite(img, file_path, params=None, auto_mkdir=True):
136
- """Write image to file.
137
-
138
- Args:
139
- img (ndarray): Image array to be written.
140
- file_path (str): Image file path.
141
- params (None or list): Same as opencv's :func:`imwrite` interface.
142
- auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143
- whether to create it automatically.
144
-
145
- Returns:
146
- bool: Successful or not.
147
- """
148
- if auto_mkdir:
149
- dir_name = os.path.abspath(os.path.dirname(file_path))
150
- os.makedirs(dir_name, exist_ok=True)
151
- ok = cv2.imwrite(file_path, img, params)
152
- if not ok:
153
- raise IOError('Failed in writing images.')
154
-
155
-
156
- def crop_border(imgs, crop_border):
157
- """Crop borders of images.
158
-
159
- Args:
160
- imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161
- crop_border (int): Crop border for each end of height and weight.
162
-
163
- Returns:
164
- list[ndarray]: Cropped images.
165
- """
166
- if crop_border == 0:
167
- return imgs
168
- else:
169
- if isinstance(imgs, list):
170
- return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171
- else:
172
- return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/logger.py DELETED
@@ -1,213 +0,0 @@
1
- import datetime
2
- import logging
3
- import time
4
-
5
- from .dist_util import get_dist_info, master_only
6
-
7
- initialized_logger = {}
8
-
9
-
10
- class AvgTimer():
11
-
12
- def __init__(self, window=200):
13
- self.window = window # average window
14
- self.current_time = 0
15
- self.total_time = 0
16
- self.count = 0
17
- self.avg_time = 0
18
- self.start()
19
-
20
- def start(self):
21
- self.start_time = self.tic = time.time()
22
-
23
- def record(self):
24
- self.count += 1
25
- self.toc = time.time()
26
- self.current_time = self.toc - self.tic
27
- self.total_time += self.current_time
28
- # calculate average time
29
- self.avg_time = self.total_time / self.count
30
-
31
- # reset
32
- if self.count > self.window:
33
- self.count = 0
34
- self.total_time = 0
35
-
36
- self.tic = time.time()
37
-
38
- def get_current_time(self):
39
- return self.current_time
40
-
41
- def get_avg_time(self):
42
- return self.avg_time
43
-
44
-
45
- class MessageLogger():
46
- """Message logger for printing.
47
-
48
- Args:
49
- opt (dict): Config. It contains the following keys:
50
- name (str): Exp name.
51
- logger (dict): Contains 'print_freq' (str) for logger interval.
52
- train (dict): Contains 'total_iter' (int) for total iters.
53
- use_tb_logger (bool): Use tensorboard logger.
54
- start_iter (int): Start iter. Default: 1.
55
- tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
56
- """
57
-
58
- def __init__(self, opt, start_iter=1, tb_logger=None):
59
- self.exp_name = opt['name']
60
- self.interval = opt['logger']['print_freq']
61
- self.start_iter = start_iter
62
- self.max_iters = opt['train']['total_iter']
63
- self.use_tb_logger = opt['logger']['use_tb_logger']
64
- self.tb_logger = tb_logger
65
- self.start_time = time.time()
66
- self.logger = get_root_logger()
67
-
68
- def reset_start_time(self):
69
- self.start_time = time.time()
70
-
71
- @master_only
72
- def __call__(self, log_vars):
73
- """Format logging message.
74
-
75
- Args:
76
- log_vars (dict): It contains the following keys:
77
- epoch (int): Epoch number.
78
- iter (int): Current iter.
79
- lrs (list): List for learning rates.
80
-
81
- time (float): Iter time.
82
- data_time (float): Data time for each iter.
83
- """
84
- # epoch, iter, learning rates
85
- epoch = log_vars.pop('epoch')
86
- current_iter = log_vars.pop('iter')
87
- lrs = log_vars.pop('lrs')
88
-
89
- message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
90
- for v in lrs:
91
- message += f'{v:.3e},'
92
- message += ')] '
93
-
94
- # time and estimated time
95
- if 'time' in log_vars.keys():
96
- iter_time = log_vars.pop('time')
97
- data_time = log_vars.pop('data_time')
98
-
99
- total_time = time.time() - self.start_time
100
- time_sec_avg = total_time / (current_iter - self.start_iter + 1)
101
- eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
102
- eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
103
- message += f'[eta: {eta_str}, '
104
- message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
105
-
106
- # other items, especially losses
107
- for k, v in log_vars.items():
108
- message += f'{k}: {v:.4e} '
109
- # tensorboard logger
110
- if self.use_tb_logger and 'debug' not in self.exp_name:
111
- if k.startswith('l_'):
112
- self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
113
- else:
114
- self.tb_logger.add_scalar(k, v, current_iter)
115
- self.logger.info(message)
116
-
117
-
118
- @master_only
119
- def init_tb_logger(log_dir):
120
- from torch.utils.tensorboard import SummaryWriter
121
- tb_logger = SummaryWriter(log_dir=log_dir)
122
- return tb_logger
123
-
124
-
125
- @master_only
126
- def init_wandb_logger(opt):
127
- """We now only use wandb to sync tensorboard log."""
128
- import wandb
129
- logger = get_root_logger()
130
-
131
- project = opt['logger']['wandb']['project']
132
- resume_id = opt['logger']['wandb'].get('resume_id')
133
- if resume_id:
134
- wandb_id = resume_id
135
- resume = 'allow'
136
- logger.warning(f'Resume wandb logger with id={wandb_id}.')
137
- else:
138
- wandb_id = wandb.util.generate_id()
139
- resume = 'never'
140
-
141
- wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
142
-
143
- logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
144
-
145
-
146
- def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
147
- """Get the root logger.
148
-
149
- The logger will be initialized if it has not been initialized. By default a
150
- StreamHandler will be added. If `log_file` is specified, a FileHandler will
151
- also be added.
152
-
153
- Args:
154
- logger_name (str): root logger name. Default: 'basicsr'.
155
- log_file (str | None): The log filename. If specified, a FileHandler
156
- will be added to the root logger.
157
- log_level (int): The root logger level. Note that only the process of
158
- rank 0 is affected, while other processes will set the level to
159
- "Error" and be silent most of the time.
160
-
161
- Returns:
162
- logging.Logger: The root logger.
163
- """
164
- logger = logging.getLogger(logger_name)
165
- # if the logger has been initialized, just return it
166
- if logger_name in initialized_logger:
167
- return logger
168
-
169
- format_str = '%(asctime)s %(levelname)s: %(message)s'
170
- stream_handler = logging.StreamHandler()
171
- stream_handler.setFormatter(logging.Formatter(format_str))
172
- logger.addHandler(stream_handler)
173
- logger.propagate = False
174
- rank, _ = get_dist_info()
175
- if rank != 0:
176
- logger.setLevel('ERROR')
177
- elif log_file is not None:
178
- logger.setLevel(log_level)
179
- # add file handler
180
- file_handler = logging.FileHandler(log_file, 'w')
181
- file_handler.setFormatter(logging.Formatter(format_str))
182
- file_handler.setLevel(log_level)
183
- logger.addHandler(file_handler)
184
- initialized_logger[logger_name] = True
185
- return logger
186
-
187
-
188
- def get_env_info():
189
- """Get environment information.
190
-
191
- Currently, only log the software version.
192
- """
193
- import torch
194
- import torchvision
195
-
196
- from basicsr.version import __version__
197
- msg = r"""
198
- ____ _ _____ ____
199
- / __ ) ____ _ _____ (_)_____/ ___/ / __ \
200
- / __ |/ __ `// ___// // ___/\__ \ / /_/ /
201
- / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
202
- /_____/ \__,_//____//_/ \___//____//_/ |_|
203
- ______ __ __ __ __
204
- / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
205
- / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
206
- / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
207
- \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
208
- """
209
- msg += ('\nVersion Information: '
210
- f'\n\tBasicSR: {__version__}'
211
- f'\n\tPyTorch: {torch.__version__}'
212
- f'\n\tTorchVision: {torchvision.__version__}')
213
- return msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/matlab_functions.py DELETED
@@ -1,359 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
-
5
-
6
- def cubic(x):
7
- """cubic function used for calculate_weights_indices."""
8
- absx = torch.abs(x)
9
- absx2 = absx**2
10
- absx3 = absx**3
11
- return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12
- (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
13
- (absx <= 2)).type_as(absx))
14
-
15
-
16
- def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
17
- """Calculate weights and indices, used for imresize function.
18
-
19
- Args:
20
- in_length (int): Input length.
21
- out_length (int): Output length.
22
- scale (float): Scale factor.
23
- kernel_width (int): Kernel width.
24
- antialisaing (bool): Whether to apply anti-aliasing when downsampling.
25
- """
26
-
27
- if (scale < 1) and antialiasing:
28
- # Use a modified kernel (larger kernel width) to simultaneously
29
- # interpolate and antialias
30
- kernel_width = kernel_width / scale
31
-
32
- # Output-space coordinates
33
- x = torch.linspace(1, out_length, out_length)
34
-
35
- # Input-space coordinates. Calculate the inverse mapping such that 0.5
36
- # in output space maps to 0.5 in input space, and 0.5 + scale in output
37
- # space maps to 1.5 in input space.
38
- u = x / scale + 0.5 * (1 - 1 / scale)
39
-
40
- # What is the left-most pixel that can be involved in the computation?
41
- left = torch.floor(u - kernel_width / 2)
42
-
43
- # What is the maximum number of pixels that can be involved in the
44
- # computation? Note: it's OK to use an extra pixel here; if the
45
- # corresponding weights are all zero, it will be eliminated at the end
46
- # of this function.
47
- p = math.ceil(kernel_width) + 2
48
-
49
- # The indices of the input pixels involved in computing the k-th output
50
- # pixel are in row k of the indices matrix.
51
- indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
52
- out_length, p)
53
-
54
- # The weights used to compute the k-th output pixel are in row k of the
55
- # weights matrix.
56
- distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
57
-
58
- # apply cubic kernel
59
- if (scale < 1) and antialiasing:
60
- weights = scale * cubic(distance_to_center * scale)
61
- else:
62
- weights = cubic(distance_to_center)
63
-
64
- # Normalize the weights matrix so that each row sums to 1.
65
- weights_sum = torch.sum(weights, 1).view(out_length, 1)
66
- weights = weights / weights_sum.expand(out_length, p)
67
-
68
- # If a column in weights is all zero, get rid of it. only consider the
69
- # first and last column.
70
- weights_zero_tmp = torch.sum((weights == 0), 0)
71
- if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
72
- indices = indices.narrow(1, 1, p - 2)
73
- weights = weights.narrow(1, 1, p - 2)
74
- if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
75
- indices = indices.narrow(1, 0, p - 2)
76
- weights = weights.narrow(1, 0, p - 2)
77
- weights = weights.contiguous()
78
- indices = indices.contiguous()
79
- sym_len_s = -indices.min() + 1
80
- sym_len_e = indices.max() - in_length
81
- indices = indices + sym_len_s - 1
82
- return weights, indices, int(sym_len_s), int(sym_len_e)
83
-
84
-
85
- @torch.no_grad()
86
- def imresize(img, scale, antialiasing=True):
87
- """imresize function same as MATLAB.
88
-
89
- It now only supports bicubic.
90
- The same scale applies for both height and width.
91
-
92
- Args:
93
- img (Tensor | Numpy array):
94
- Tensor: Input image with shape (c, h, w), [0, 1] range.
95
- Numpy: Input image with shape (h, w, c), [0, 1] range.
96
- scale (float): Scale factor. The same scale applies for both height
97
- and width.
98
- antialisaing (bool): Whether to apply anti-aliasing when downsampling.
99
- Default: True.
100
-
101
- Returns:
102
- Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
103
- """
104
- squeeze_flag = False
105
- if type(img).__module__ == np.__name__: # numpy type
106
- numpy_type = True
107
- if img.ndim == 2:
108
- img = img[:, :, None]
109
- squeeze_flag = True
110
- img = torch.from_numpy(img.transpose(2, 0, 1)).float()
111
- else:
112
- numpy_type = False
113
- if img.ndim == 2:
114
- img = img.unsqueeze(0)
115
- squeeze_flag = True
116
-
117
- in_c, in_h, in_w = img.size()
118
- out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
119
- kernel_width = 4
120
- kernel = 'cubic'
121
-
122
- # get weights and indices
123
- weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
124
- antialiasing)
125
- weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
126
- antialiasing)
127
- # process H dimension
128
- # symmetric copying
129
- img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
130
- img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
131
-
132
- sym_patch = img[:, :sym_len_hs, :]
133
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
135
- img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
136
-
137
- sym_patch = img[:, -sym_len_he:, :]
138
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
139
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
140
- img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
141
-
142
- out_1 = torch.FloatTensor(in_c, out_h, in_w)
143
- kernel_width = weights_h.size(1)
144
- for i in range(out_h):
145
- idx = int(indices_h[i][0])
146
- for j in range(in_c):
147
- out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
148
-
149
- # process W dimension
150
- # symmetric copying
151
- out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
152
- out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
153
-
154
- sym_patch = out_1[:, :, :sym_len_ws]
155
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
156
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
157
- out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
158
-
159
- sym_patch = out_1[:, :, -sym_len_we:]
160
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
161
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
162
- out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
163
-
164
- out_2 = torch.FloatTensor(in_c, out_h, out_w)
165
- kernel_width = weights_w.size(1)
166
- for i in range(out_w):
167
- idx = int(indices_w[i][0])
168
- for j in range(in_c):
169
- out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
170
-
171
- if squeeze_flag:
172
- out_2 = out_2.squeeze(0)
173
- if numpy_type:
174
- out_2 = out_2.numpy()
175
- if not squeeze_flag:
176
- out_2 = out_2.transpose(1, 2, 0)
177
-
178
- return out_2
179
-
180
-
181
- def rgb2ycbcr(img, y_only=False):
182
- """Convert a RGB image to YCbCr image.
183
-
184
- This function produces the same results as Matlab's `rgb2ycbcr` function.
185
- It implements the ITU-R BT.601 conversion for standard-definition
186
- television. See more details in
187
- https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
188
-
189
- It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
190
- In OpenCV, it implements a JPEG conversion. See more details in
191
- https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
192
-
193
- Args:
194
- img (ndarray): The input image. It accepts:
195
- 1. np.uint8 type with range [0, 255];
196
- 2. np.float32 type with range [0, 1].
197
- y_only (bool): Whether to only return Y channel. Default: False.
198
-
199
- Returns:
200
- ndarray: The converted YCbCr image. The output image has the same type
201
- and range as input image.
202
- """
203
- img_type = img.dtype
204
- img = _convert_input_type_range(img)
205
- if y_only:
206
- out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
207
- else:
208
- out_img = np.matmul(
209
- img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
210
- out_img = _convert_output_type_range(out_img, img_type)
211
- return out_img
212
-
213
-
214
- def bgr2ycbcr(img, y_only=False):
215
- """Convert a BGR image to YCbCr image.
216
-
217
- The bgr version of rgb2ycbcr.
218
- It implements the ITU-R BT.601 conversion for standard-definition
219
- television. See more details in
220
- https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
221
-
222
- It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
223
- In OpenCV, it implements a JPEG conversion. See more details in
224
- https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
225
-
226
- Args:
227
- img (ndarray): The input image. It accepts:
228
- 1. np.uint8 type with range [0, 255];
229
- 2. np.float32 type with range [0, 1].
230
- y_only (bool): Whether to only return Y channel. Default: False.
231
-
232
- Returns:
233
- ndarray: The converted YCbCr image. The output image has the same type
234
- and range as input image.
235
- """
236
- img_type = img.dtype
237
- img = _convert_input_type_range(img)
238
- if y_only:
239
- out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
240
- else:
241
- out_img = np.matmul(
242
- img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
243
- out_img = _convert_output_type_range(out_img, img_type)
244
- return out_img
245
-
246
-
247
- def ycbcr2rgb(img):
248
- """Convert a YCbCr image to RGB image.
249
-
250
- This function produces the same results as Matlab's ycbcr2rgb function.
251
- It implements the ITU-R BT.601 conversion for standard-definition
252
- television. See more details in
253
- https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
254
-
255
- It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
256
- In OpenCV, it implements a JPEG conversion. See more details in
257
- https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
258
-
259
- Args:
260
- img (ndarray): The input image. It accepts:
261
- 1. np.uint8 type with range [0, 255];
262
- 2. np.float32 type with range [0, 1].
263
-
264
- Returns:
265
- ndarray: The converted RGB image. The output image has the same type
266
- and range as input image.
267
- """
268
- img_type = img.dtype
269
- img = _convert_input_type_range(img) * 255
270
- out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
271
- [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
272
- out_img = _convert_output_type_range(out_img, img_type)
273
- return out_img
274
-
275
-
276
- def ycbcr2bgr(img):
277
- """Convert a YCbCr image to BGR image.
278
-
279
- The bgr version of ycbcr2rgb.
280
- It implements the ITU-R BT.601 conversion for standard-definition
281
- television. See more details in
282
- https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
283
-
284
- It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
285
- In OpenCV, it implements a JPEG conversion. See more details in
286
- https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
287
-
288
- Args:
289
- img (ndarray): The input image. It accepts:
290
- 1. np.uint8 type with range [0, 255];
291
- 2. np.float32 type with range [0, 1].
292
-
293
- Returns:
294
- ndarray: The converted BGR image. The output image has the same type
295
- and range as input image.
296
- """
297
- img_type = img.dtype
298
- img = _convert_input_type_range(img) * 255
299
- out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
300
- [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
301
- out_img = _convert_output_type_range(out_img, img_type)
302
- return out_img
303
-
304
-
305
- def _convert_input_type_range(img):
306
- """Convert the type and range of the input image.
307
-
308
- It converts the input image to np.float32 type and range of [0, 1].
309
- It is mainly used for pre-processing the input image in colorspace
310
- conversion functions such as rgb2ycbcr and ycbcr2rgb.
311
-
312
- Args:
313
- img (ndarray): The input image. It accepts:
314
- 1. np.uint8 type with range [0, 255];
315
- 2. np.float32 type with range [0, 1].
316
-
317
- Returns:
318
- (ndarray): The converted image with type of np.float32 and range of
319
- [0, 1].
320
- """
321
- img_type = img.dtype
322
- img = img.astype(np.float32)
323
- if img_type == np.float32:
324
- pass
325
- elif img_type == np.uint8:
326
- img /= 255.
327
- else:
328
- raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
329
- return img
330
-
331
-
332
- def _convert_output_type_range(img, dst_type):
333
- """Convert the type and range of the image according to dst_type.
334
-
335
- It converts the image to desired type and range. If `dst_type` is np.uint8,
336
- images will be converted to np.uint8 type with range [0, 255]. If
337
- `dst_type` is np.float32, it converts the image to np.float32 type with
338
- range [0, 1].
339
- It is mainly used for post-processing images in colorspace conversion
340
- functions such as rgb2ycbcr and ycbcr2rgb.
341
-
342
- Args:
343
- img (ndarray): The image to be converted with np.float32 type and
344
- range [0, 255].
345
- dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
346
- converts the image to np.uint8 type with range [0, 255]. If
347
- dst_type is np.float32, it converts the image to np.float32 type
348
- with range [0, 1].
349
-
350
- Returns:
351
- (ndarray): The converted image with desired type and range.
352
- """
353
- if dst_type not in (np.uint8, np.float32):
354
- raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
355
- if dst_type == np.uint8:
356
- img = img.round()
357
- else:
358
- img /= 255.
359
- return img.astype(dst_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/misc.py DELETED
@@ -1,141 +0,0 @@
1
- import numpy as np
2
- import os
3
- import random
4
- import time
5
- import torch
6
- from os import path as osp
7
-
8
- from .dist_util import master_only
9
-
10
-
11
- def set_random_seed(seed):
12
- """Set random seeds."""
13
- random.seed(seed)
14
- np.random.seed(seed)
15
- torch.manual_seed(seed)
16
- torch.cuda.manual_seed(seed)
17
- torch.cuda.manual_seed_all(seed)
18
-
19
-
20
- def get_time_str():
21
- return time.strftime('%Y%m%d_%H%M%S', time.localtime())
22
-
23
-
24
- def mkdir_and_rename(path):
25
- """mkdirs. If path exists, rename it with timestamp and create a new one.
26
-
27
- Args:
28
- path (str): Folder path.
29
- """
30
- if osp.exists(path):
31
- new_name = path + '_archived_' + get_time_str()
32
- print(f'Path already exists. Rename it to {new_name}', flush=True)
33
- os.rename(path, new_name)
34
- os.makedirs(path, exist_ok=True)
35
-
36
-
37
- @master_only
38
- def make_exp_dirs(opt):
39
- """Make dirs for experiments."""
40
- path_opt = opt['path'].copy()
41
- if opt['is_train']:
42
- mkdir_and_rename(path_opt.pop('experiments_root'))
43
- else:
44
- mkdir_and_rename(path_opt.pop('results_root'))
45
- for key, path in path_opt.items():
46
- if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
47
- continue
48
- else:
49
- os.makedirs(path, exist_ok=True)
50
-
51
-
52
- def scandir(dir_path, suffix=None, recursive=False, full_path=False):
53
- """Scan a directory to find the interested files.
54
-
55
- Args:
56
- dir_path (str): Path of the directory.
57
- suffix (str | tuple(str), optional): File suffix that we are
58
- interested in. Default: None.
59
- recursive (bool, optional): If set to True, recursively scan the
60
- directory. Default: False.
61
- full_path (bool, optional): If set to True, include the dir_path.
62
- Default: False.
63
-
64
- Returns:
65
- A generator for all the interested files with relative paths.
66
- """
67
-
68
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
69
- raise TypeError('"suffix" must be a string or tuple of strings')
70
-
71
- root = dir_path
72
-
73
- def _scandir(dir_path, suffix, recursive):
74
- for entry in os.scandir(dir_path):
75
- if not entry.name.startswith('.') and entry.is_file():
76
- if full_path:
77
- return_path = entry.path
78
- else:
79
- return_path = osp.relpath(entry.path, root)
80
-
81
- if suffix is None:
82
- yield return_path
83
- elif return_path.endswith(suffix):
84
- yield return_path
85
- else:
86
- if recursive:
87
- yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
88
- else:
89
- continue
90
-
91
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
92
-
93
-
94
- def check_resume(opt, resume_iter):
95
- """Check resume states and pretrain_network paths.
96
-
97
- Args:
98
- opt (dict): Options.
99
- resume_iter (int): Resume iteration.
100
- """
101
- if opt['path']['resume_state']:
102
- # get all the networks
103
- networks = [key for key in opt.keys() if key.startswith('network_')]
104
- flag_pretrain = False
105
- for network in networks:
106
- if opt['path'].get(f'pretrain_{network}') is not None:
107
- flag_pretrain = True
108
- if flag_pretrain:
109
- print('pretrain_network path will be ignored during resuming.')
110
- # set pretrained model paths
111
- for network in networks:
112
- name = f'pretrain_{network}'
113
- basename = network.replace('network_', '')
114
- if opt['path'].get('ignore_resume_networks') is None or (network
115
- not in opt['path']['ignore_resume_networks']):
116
- opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
117
- print(f"Set {name} to {opt['path'][name]}")
118
-
119
- # change param_key to params in resume
120
- param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
121
- for param_key in param_keys:
122
- if opt['path'][param_key] == 'params_ema':
123
- opt['path'][param_key] = 'params'
124
- print(f'Set {param_key} to params')
125
-
126
-
127
- def sizeof_fmt(size, suffix='B'):
128
- """Get human readable file size.
129
-
130
- Args:
131
- size (int): File size.
132
- suffix (str): Suffix. Default: 'B'.
133
-
134
- Return:
135
- str: Formatted file siz.
136
- """
137
- for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
138
- if abs(size) < 1024.0:
139
- return f'{size:3.1f} {unit}{suffix}'
140
- size /= 1024.0
141
- return f'{size:3.1f} Y{suffix}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/options.py DELETED
@@ -1,194 +0,0 @@
1
- import argparse
2
- import random
3
- import torch
4
- import yaml
5
- from collections import OrderedDict
6
- from os import path as osp
7
-
8
- from basicsr.utils import set_random_seed
9
- from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
10
-
11
-
12
- def ordered_yaml():
13
- """Support OrderedDict for yaml.
14
-
15
- Returns:
16
- yaml Loader and Dumper.
17
- """
18
- try:
19
- from yaml import CDumper as Dumper
20
- from yaml import CLoader as Loader
21
- except ImportError:
22
- from yaml import Dumper, Loader
23
-
24
- _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25
-
26
- def dict_representer(dumper, data):
27
- return dumper.represent_dict(data.items())
28
-
29
- def dict_constructor(loader, node):
30
- return OrderedDict(loader.construct_pairs(node))
31
-
32
- Dumper.add_representer(OrderedDict, dict_representer)
33
- Loader.add_constructor(_mapping_tag, dict_constructor)
34
- return Loader, Dumper
35
-
36
-
37
- def dict2str(opt, indent_level=1):
38
- """dict to string for printing options.
39
-
40
- Args:
41
- opt (dict): Option dict.
42
- indent_level (int): Indent level. Default: 1.
43
-
44
- Return:
45
- (str): Option string for printing.
46
- """
47
- msg = '\n'
48
- for k, v in opt.items():
49
- if isinstance(v, dict):
50
- msg += ' ' * (indent_level * 2) + k + ':['
51
- msg += dict2str(v, indent_level + 1)
52
- msg += ' ' * (indent_level * 2) + ']\n'
53
- else:
54
- msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
55
- return msg
56
-
57
-
58
- def _postprocess_yml_value(value):
59
- # None
60
- if value == '~' or value.lower() == 'none':
61
- return None
62
- # bool
63
- if value.lower() == 'true':
64
- return True
65
- elif value.lower() == 'false':
66
- return False
67
- # !!float number
68
- if value.startswith('!!float'):
69
- return float(value.replace('!!float', ''))
70
- # number
71
- if value.isdigit():
72
- return int(value)
73
- elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
74
- return float(value)
75
- # list
76
- if value.startswith('['):
77
- return eval(value)
78
- # str
79
- return value
80
-
81
-
82
- def parse_options(root_path, is_train=True):
83
- parser = argparse.ArgumentParser()
84
- parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
85
- parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
86
- parser.add_argument('--auto_resume', action='store_true')
87
- parser.add_argument('--debug', action='store_true')
88
- parser.add_argument('--local_rank', type=int, default=0)
89
- parser.add_argument(
90
- '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
91
- args = parser.parse_args()
92
-
93
- # parse yml to dict
94
- with open(args.opt, mode='r') as f:
95
- opt = yaml.load(f, Loader=ordered_yaml()[0])
96
-
97
- # distributed settings
98
- if args.launcher == 'none':
99
- opt['dist'] = False
100
- print('Disable distributed.', flush=True)
101
- else:
102
- opt['dist'] = True
103
- if args.launcher == 'slurm' and 'dist_params' in opt:
104
- init_dist(args.launcher, **opt['dist_params'])
105
- else:
106
- init_dist(args.launcher)
107
- opt['rank'], opt['world_size'] = get_dist_info()
108
-
109
- # random seed
110
- seed = opt.get('manual_seed')
111
- if seed is None:
112
- seed = random.randint(1, 10000)
113
- opt['manual_seed'] = seed
114
- set_random_seed(seed + opt['rank'])
115
-
116
- # force to update yml options
117
- if args.force_yml is not None:
118
- for entry in args.force_yml:
119
- # now do not support creating new keys
120
- keys, value = entry.split('=')
121
- keys, value = keys.strip(), value.strip()
122
- value = _postprocess_yml_value(value)
123
- eval_str = 'opt'
124
- for key in keys.split(':'):
125
- eval_str += f'["{key}"]'
126
- eval_str += '=value'
127
- # using exec function
128
- exec(eval_str)
129
-
130
- opt['auto_resume'] = args.auto_resume
131
- opt['is_train'] = is_train
132
-
133
- # debug setting
134
- if args.debug and not opt['name'].startswith('debug'):
135
- opt['name'] = 'debug_' + opt['name']
136
-
137
- if opt['num_gpu'] == 'auto':
138
- opt['num_gpu'] = torch.cuda.device_count()
139
-
140
- # datasets
141
- for phase, dataset in opt['datasets'].items():
142
- # for multiple datasets, e.g., val_1, val_2; test_1, test_2
143
- phase = phase.split('_')[0]
144
- dataset['phase'] = phase
145
- if 'scale' in opt:
146
- dataset['scale'] = opt['scale']
147
- if dataset.get('dataroot_gt') is not None:
148
- dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
149
- if dataset.get('dataroot_lq') is not None:
150
- dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
151
-
152
- # paths
153
- for key, val in opt['path'].items():
154
- if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
155
- opt['path'][key] = osp.expanduser(val)
156
-
157
- if is_train:
158
- experiments_root = osp.join(root_path, 'experiments', opt['name'])
159
- opt['path']['experiments_root'] = experiments_root
160
- opt['path']['models'] = osp.join(experiments_root, 'models')
161
- opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
162
- opt['path']['log'] = experiments_root
163
- opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
164
-
165
- # change some options for debug mode
166
- if 'debug' in opt['name']:
167
- if 'val' in opt:
168
- opt['val']['val_freq'] = 8
169
- opt['logger']['print_freq'] = 1
170
- opt['logger']['save_checkpoint_freq'] = 8
171
- else: # test
172
- results_root = osp.join(root_path, 'results', opt['name'])
173
- opt['path']['results_root'] = results_root
174
- opt['path']['log'] = results_root
175
- opt['path']['visualization'] = osp.join(results_root, 'visualization')
176
-
177
- return opt, args
178
-
179
-
180
- @master_only
181
- def copy_opt_file(opt_file, experiments_root):
182
- # copy the yml file to the experiment root
183
- import sys
184
- import time
185
- from shutil import copyfile
186
- cmd = ' '.join(sys.argv)
187
- filename = osp.join(experiments_root, osp.basename(opt_file))
188
- copyfile(opt_file, filename)
189
-
190
- with open(filename, 'r+') as f:
191
- lines = f.readlines()
192
- lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
193
- f.seek(0)
194
- f.writelines(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/registry.py DELETED
@@ -1,82 +0,0 @@
1
- # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2
-
3
-
4
- class Registry():
5
- """
6
- The registry that provides name -> object mapping, to support third-party
7
- users' custom modules.
8
-
9
- To create a registry (e.g. a backbone registry):
10
-
11
- .. code-block:: python
12
-
13
- BACKBONE_REGISTRY = Registry('BACKBONE')
14
-
15
- To register an object:
16
-
17
- .. code-block:: python
18
-
19
- @BACKBONE_REGISTRY.register()
20
- class MyBackbone():
21
- ...
22
-
23
- Or:
24
-
25
- .. code-block:: python
26
-
27
- BACKBONE_REGISTRY.register(MyBackbone)
28
- """
29
-
30
- def __init__(self, name):
31
- """
32
- Args:
33
- name (str): the name of this registry
34
- """
35
- self._name = name
36
- self._obj_map = {}
37
-
38
- def _do_register(self, name, obj):
39
- assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
40
- f"in '{self._name}' registry!")
41
- self._obj_map[name] = obj
42
-
43
- def register(self, obj=None):
44
- """
45
- Register the given object under the the name `obj.__name__`.
46
- Can be used as either a decorator or not.
47
- See docstring of this class for usage.
48
- """
49
- if obj is None:
50
- # used as a decorator
51
- def deco(func_or_class):
52
- name = func_or_class.__name__
53
- self._do_register(name, func_or_class)
54
- return func_or_class
55
-
56
- return deco
57
-
58
- # used as a function call
59
- name = obj.__name__
60
- self._do_register(name, obj)
61
-
62
- def get(self, name):
63
- ret = self._obj_map.get(name)
64
- if ret is None:
65
- raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
66
- return ret
67
-
68
- def __contains__(self, name):
69
- return name in self._obj_map
70
-
71
- def __iter__(self):
72
- return iter(self._obj_map.items())
73
-
74
- def keys(self):
75
- return self._obj_map.keys()
76
-
77
-
78
- DATASET_REGISTRY = Registry('dataset')
79
- ARCH_REGISTRY = Registry('arch')
80
- MODEL_REGISTRY = Registry('model')
81
- LOSS_REGISTRY = Registry('loss')
82
- METRIC_REGISTRY = Registry('metric')