Delete utils
Browse files- utils/__init__.py +0 -30
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/dist_util.cpython-38.pyc +0 -0
- utils/__pycache__/file_client.cpython-38.pyc +0 -0
- utils/__pycache__/img_util.cpython-38.pyc +0 -0
- utils/__pycache__/logger.cpython-38.pyc +0 -0
- utils/__pycache__/matlab_functions.cpython-38.pyc +0 -0
- utils/__pycache__/misc.cpython-38.pyc +0 -0
- utils/__pycache__/options.cpython-38.pyc +0 -0
- utils/__pycache__/registry.cpython-38.pyc +0 -0
- utils/dist_util.py +0 -82
- utils/file_client.py +0 -167
- utils/img_util.py +0 -172
- utils/logger.py +0 -213
- utils/matlab_functions.py +0 -359
- utils/misc.py +0 -141
- utils/options.py +0 -194
- utils/registry.py +0 -82
utils/__init__.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
from .file_client import FileClient
|
| 2 |
-
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
| 3 |
-
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
| 4 |
-
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
| 5 |
-
|
| 6 |
-
__all__ = [
|
| 7 |
-
# file_client.py
|
| 8 |
-
'FileClient',
|
| 9 |
-
# img_util.py
|
| 10 |
-
'img2tensor',
|
| 11 |
-
'tensor2img',
|
| 12 |
-
'imfrombytes',
|
| 13 |
-
'imwrite',
|
| 14 |
-
'crop_border',
|
| 15 |
-
# logger.py
|
| 16 |
-
'MessageLogger',
|
| 17 |
-
'AvgTimer',
|
| 18 |
-
'init_tb_logger',
|
| 19 |
-
'init_wandb_logger',
|
| 20 |
-
'get_root_logger',
|
| 21 |
-
'get_env_info',
|
| 22 |
-
# misc.py
|
| 23 |
-
'set_random_seed',
|
| 24 |
-
'get_time_str',
|
| 25 |
-
'mkdir_and_rename',
|
| 26 |
-
'make_exp_dirs',
|
| 27 |
-
'scandir',
|
| 28 |
-
'check_resume',
|
| 29 |
-
'sizeof_fmt',
|
| 30 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__pycache__/__init__.cpython-38.pyc
DELETED
|
Binary file (854 Bytes)
|
|
|
utils/__pycache__/dist_util.cpython-38.pyc
DELETED
|
Binary file (2.6 kB)
|
|
|
utils/__pycache__/file_client.cpython-38.pyc
DELETED
|
Binary file (6.5 kB)
|
|
|
utils/__pycache__/img_util.cpython-38.pyc
DELETED
|
Binary file (6.12 kB)
|
|
|
utils/__pycache__/logger.cpython-38.pyc
DELETED
|
Binary file (6.94 kB)
|
|
|
utils/__pycache__/matlab_functions.cpython-38.pyc
DELETED
|
Binary file (10.6 kB)
|
|
|
utils/__pycache__/misc.cpython-38.pyc
DELETED
|
Binary file (4.37 kB)
|
|
|
utils/__pycache__/options.cpython-38.pyc
DELETED
|
Binary file (5.11 kB)
|
|
|
utils/__pycache__/registry.cpython-38.pyc
DELETED
|
Binary file (2.61 kB)
|
|
|
utils/dist_util.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
| 2 |
-
import functools
|
| 3 |
-
import os
|
| 4 |
-
import subprocess
|
| 5 |
-
import torch
|
| 6 |
-
import torch.distributed as dist
|
| 7 |
-
import torch.multiprocessing as mp
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def init_dist(launcher, backend='nccl', **kwargs):
|
| 11 |
-
if mp.get_start_method(allow_none=True) is None:
|
| 12 |
-
mp.set_start_method('spawn')
|
| 13 |
-
if launcher == 'pytorch':
|
| 14 |
-
_init_dist_pytorch(backend, **kwargs)
|
| 15 |
-
elif launcher == 'slurm':
|
| 16 |
-
_init_dist_slurm(backend, **kwargs)
|
| 17 |
-
else:
|
| 18 |
-
raise ValueError(f'Invalid launcher type: {launcher}')
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def _init_dist_pytorch(backend, **kwargs):
|
| 22 |
-
rank = int(os.environ['RANK'])
|
| 23 |
-
num_gpus = torch.cuda.device_count()
|
| 24 |
-
torch.cuda.set_device(rank % num_gpus)
|
| 25 |
-
dist.init_process_group(backend=backend, **kwargs)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _init_dist_slurm(backend, port=None):
|
| 29 |
-
"""Initialize slurm distributed training environment.
|
| 30 |
-
|
| 31 |
-
If argument ``port`` is not specified, then the master port will be system
|
| 32 |
-
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
| 33 |
-
environment variable, then a default port ``29500`` will be used.
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
backend (str): Backend of torch.distributed.
|
| 37 |
-
port (int, optional): Master port. Defaults to None.
|
| 38 |
-
"""
|
| 39 |
-
proc_id = int(os.environ['SLURM_PROCID'])
|
| 40 |
-
ntasks = int(os.environ['SLURM_NTASKS'])
|
| 41 |
-
node_list = os.environ['SLURM_NODELIST']
|
| 42 |
-
num_gpus = torch.cuda.device_count()
|
| 43 |
-
torch.cuda.set_device(proc_id % num_gpus)
|
| 44 |
-
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
| 45 |
-
# specify master port
|
| 46 |
-
if port is not None:
|
| 47 |
-
os.environ['MASTER_PORT'] = str(port)
|
| 48 |
-
elif 'MASTER_PORT' in os.environ:
|
| 49 |
-
pass # use MASTER_PORT in the environment variable
|
| 50 |
-
else:
|
| 51 |
-
# 29500 is torch.distributed default port
|
| 52 |
-
os.environ['MASTER_PORT'] = '29500'
|
| 53 |
-
os.environ['MASTER_ADDR'] = addr
|
| 54 |
-
os.environ['WORLD_SIZE'] = str(ntasks)
|
| 55 |
-
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
| 56 |
-
os.environ['RANK'] = str(proc_id)
|
| 57 |
-
dist.init_process_group(backend=backend)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def get_dist_info():
|
| 61 |
-
if dist.is_available():
|
| 62 |
-
initialized = dist.is_initialized()
|
| 63 |
-
else:
|
| 64 |
-
initialized = False
|
| 65 |
-
if initialized:
|
| 66 |
-
rank = dist.get_rank()
|
| 67 |
-
world_size = dist.get_world_size()
|
| 68 |
-
else:
|
| 69 |
-
rank = 0
|
| 70 |
-
world_size = 1
|
| 71 |
-
return rank, world_size
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def master_only(func):
|
| 75 |
-
|
| 76 |
-
@functools.wraps(func)
|
| 77 |
-
def wrapper(*args, **kwargs):
|
| 78 |
-
rank, _ = get_dist_info()
|
| 79 |
-
if rank == 0:
|
| 80 |
-
return func(*args, **kwargs)
|
| 81 |
-
|
| 82 |
-
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/file_client.py
DELETED
|
@@ -1,167 +0,0 @@
|
|
| 1 |
-
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
| 2 |
-
from abc import ABCMeta, abstractmethod
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class BaseStorageBackend(metaclass=ABCMeta):
|
| 6 |
-
"""Abstract class of storage backends.
|
| 7 |
-
|
| 8 |
-
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
| 9 |
-
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
| 10 |
-
as texts.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
@abstractmethod
|
| 14 |
-
def get(self, filepath):
|
| 15 |
-
pass
|
| 16 |
-
|
| 17 |
-
@abstractmethod
|
| 18 |
-
def get_text(self, filepath):
|
| 19 |
-
pass
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class MemcachedBackend(BaseStorageBackend):
|
| 23 |
-
"""Memcached storage backend.
|
| 24 |
-
|
| 25 |
-
Attributes:
|
| 26 |
-
server_list_cfg (str): Config file for memcached server list.
|
| 27 |
-
client_cfg (str): Config file for memcached client.
|
| 28 |
-
sys_path (str | None): Additional path to be appended to `sys.path`.
|
| 29 |
-
Default: None.
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
| 33 |
-
if sys_path is not None:
|
| 34 |
-
import sys
|
| 35 |
-
sys.path.append(sys_path)
|
| 36 |
-
try:
|
| 37 |
-
import mc
|
| 38 |
-
except ImportError:
|
| 39 |
-
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
| 40 |
-
|
| 41 |
-
self.server_list_cfg = server_list_cfg
|
| 42 |
-
self.client_cfg = client_cfg
|
| 43 |
-
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
| 44 |
-
# mc.pyvector servers as a point which points to a memory cache
|
| 45 |
-
self._mc_buffer = mc.pyvector()
|
| 46 |
-
|
| 47 |
-
def get(self, filepath):
|
| 48 |
-
filepath = str(filepath)
|
| 49 |
-
import mc
|
| 50 |
-
self._client.Get(filepath, self._mc_buffer)
|
| 51 |
-
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
| 52 |
-
return value_buf
|
| 53 |
-
|
| 54 |
-
def get_text(self, filepath):
|
| 55 |
-
raise NotImplementedError
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class HardDiskBackend(BaseStorageBackend):
|
| 59 |
-
"""Raw hard disks storage backend."""
|
| 60 |
-
|
| 61 |
-
def get(self, filepath):
|
| 62 |
-
filepath = str(filepath)
|
| 63 |
-
with open(filepath, 'rb') as f:
|
| 64 |
-
value_buf = f.read()
|
| 65 |
-
return value_buf
|
| 66 |
-
|
| 67 |
-
def get_text(self, filepath):
|
| 68 |
-
filepath = str(filepath)
|
| 69 |
-
with open(filepath, 'r') as f:
|
| 70 |
-
value_buf = f.read()
|
| 71 |
-
return value_buf
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
class LmdbBackend(BaseStorageBackend):
|
| 75 |
-
"""Lmdb storage backend.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
db_paths (str | list[str]): Lmdb database paths.
|
| 79 |
-
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
| 80 |
-
readonly (bool, optional): Lmdb environment parameter. If True,
|
| 81 |
-
disallow any write operations. Default: True.
|
| 82 |
-
lock (bool, optional): Lmdb environment parameter. If False, when
|
| 83 |
-
concurrent access occurs, do not lock the database. Default: False.
|
| 84 |
-
readahead (bool, optional): Lmdb environment parameter. If False,
|
| 85 |
-
disable the OS filesystem readahead mechanism, which may improve
|
| 86 |
-
random read performance when a database is larger than RAM.
|
| 87 |
-
Default: False.
|
| 88 |
-
|
| 89 |
-
Attributes:
|
| 90 |
-
db_paths (list): Lmdb database path.
|
| 91 |
-
_client (list): A list of several lmdb envs.
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
| 95 |
-
try:
|
| 96 |
-
import lmdb
|
| 97 |
-
except ImportError:
|
| 98 |
-
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
| 99 |
-
|
| 100 |
-
if isinstance(client_keys, str):
|
| 101 |
-
client_keys = [client_keys]
|
| 102 |
-
|
| 103 |
-
if isinstance(db_paths, list):
|
| 104 |
-
self.db_paths = [str(v) for v in db_paths]
|
| 105 |
-
elif isinstance(db_paths, str):
|
| 106 |
-
self.db_paths = [str(db_paths)]
|
| 107 |
-
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
| 108 |
-
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
| 109 |
-
|
| 110 |
-
self._client = {}
|
| 111 |
-
for client, path in zip(client_keys, self.db_paths):
|
| 112 |
-
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
| 113 |
-
|
| 114 |
-
def get(self, filepath, client_key):
|
| 115 |
-
"""Get values according to the filepath from one lmdb named client_key.
|
| 116 |
-
|
| 117 |
-
Args:
|
| 118 |
-
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
| 119 |
-
client_key (str): Used for distinguishing different lmdb envs.
|
| 120 |
-
"""
|
| 121 |
-
filepath = str(filepath)
|
| 122 |
-
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
|
| 123 |
-
client = self._client[client_key]
|
| 124 |
-
with client.begin(write=False) as txn:
|
| 125 |
-
value_buf = txn.get(filepath.encode('ascii'))
|
| 126 |
-
return value_buf
|
| 127 |
-
|
| 128 |
-
def get_text(self, filepath):
|
| 129 |
-
raise NotImplementedError
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class FileClient(object):
|
| 133 |
-
"""A general file client to access files in different backend.
|
| 134 |
-
|
| 135 |
-
The client loads a file or text in a specified backend from its path
|
| 136 |
-
and return it as a binary file. it can also register other backend
|
| 137 |
-
accessor with a given name and backend class.
|
| 138 |
-
|
| 139 |
-
Attributes:
|
| 140 |
-
backend (str): The storage backend type. Options are "disk",
|
| 141 |
-
"memcached" and "lmdb".
|
| 142 |
-
client (:obj:`BaseStorageBackend`): The backend object.
|
| 143 |
-
"""
|
| 144 |
-
|
| 145 |
-
_backends = {
|
| 146 |
-
'disk': HardDiskBackend,
|
| 147 |
-
'memcached': MemcachedBackend,
|
| 148 |
-
'lmdb': LmdbBackend,
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
def __init__(self, backend='disk', **kwargs):
|
| 152 |
-
if backend not in self._backends:
|
| 153 |
-
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
| 154 |
-
f' are {list(self._backends.keys())}')
|
| 155 |
-
self.backend = backend
|
| 156 |
-
self.client = self._backends[backend](**kwargs)
|
| 157 |
-
|
| 158 |
-
def get(self, filepath, client_key='default'):
|
| 159 |
-
# client_key is used only for lmdb, where different fileclients have
|
| 160 |
-
# different lmdb environments.
|
| 161 |
-
if self.backend == 'lmdb':
|
| 162 |
-
return self.client.get(filepath, client_key)
|
| 163 |
-
else:
|
| 164 |
-
return self.client.get(filepath)
|
| 165 |
-
|
| 166 |
-
def get_text(self, filepath):
|
| 167 |
-
return self.client.get_text(filepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/img_util.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
import cv2
|
| 2 |
-
import math
|
| 3 |
-
import numpy as np
|
| 4 |
-
import os
|
| 5 |
-
import torch
|
| 6 |
-
from torchvision.utils import make_grid
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 10 |
-
"""Numpy array to tensor.
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
imgs (list[ndarray] | ndarray): Input images.
|
| 14 |
-
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 15 |
-
float32 (bool): Whether to change to float32.
|
| 16 |
-
|
| 17 |
-
Returns:
|
| 18 |
-
list[tensor] | tensor: Tensor images. If returned results only have
|
| 19 |
-
one element, just return tensor.
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
def _totensor(img, bgr2rgb, float32):
|
| 23 |
-
if img.shape[2] == 3 and bgr2rgb:
|
| 24 |
-
if img.dtype == 'float64':
|
| 25 |
-
img = img.astype('float32')
|
| 26 |
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 27 |
-
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 28 |
-
if float32:
|
| 29 |
-
img = img.float()
|
| 30 |
-
return img
|
| 31 |
-
|
| 32 |
-
if isinstance(imgs, list):
|
| 33 |
-
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 34 |
-
else:
|
| 35 |
-
return _totensor(imgs, bgr2rgb, float32)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
| 39 |
-
"""Convert torch Tensors into image numpy arrays.
|
| 40 |
-
|
| 41 |
-
After clamping to [min, max], values will be normalized to [0, 1].
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
tensor (Tensor or list[Tensor]): Accept shapes:
|
| 45 |
-
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
| 46 |
-
2) 3D Tensor of shape (3/1 x H x W);
|
| 47 |
-
3) 2D Tensor of shape (H x W).
|
| 48 |
-
Tensor channel should be in RGB order.
|
| 49 |
-
rgb2bgr (bool): Whether to change rgb to bgr.
|
| 50 |
-
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
| 51 |
-
to uint8 type with range [0, 255]; otherwise, float type with
|
| 52 |
-
range [0, 1]. Default: ``np.uint8``.
|
| 53 |
-
min_max (tuple[int]): min and max values for clamp.
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
| 57 |
-
shape (H x W). The channel order is BGR.
|
| 58 |
-
"""
|
| 59 |
-
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
| 60 |
-
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
| 61 |
-
|
| 62 |
-
if torch.is_tensor(tensor):
|
| 63 |
-
tensor = [tensor]
|
| 64 |
-
result = []
|
| 65 |
-
for _tensor in tensor:
|
| 66 |
-
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
| 67 |
-
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
| 68 |
-
|
| 69 |
-
n_dim = _tensor.dim()
|
| 70 |
-
if n_dim == 4:
|
| 71 |
-
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
| 72 |
-
img_np = img_np.transpose(1, 2, 0)
|
| 73 |
-
if rgb2bgr:
|
| 74 |
-
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 75 |
-
elif n_dim == 3:
|
| 76 |
-
img_np = _tensor.numpy()
|
| 77 |
-
img_np = img_np.transpose(1, 2, 0)
|
| 78 |
-
if img_np.shape[2] == 1: # gray image
|
| 79 |
-
img_np = np.squeeze(img_np, axis=2)
|
| 80 |
-
else:
|
| 81 |
-
if rgb2bgr:
|
| 82 |
-
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 83 |
-
elif n_dim == 2:
|
| 84 |
-
img_np = _tensor.numpy()
|
| 85 |
-
else:
|
| 86 |
-
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
| 87 |
-
if out_type == np.uint8:
|
| 88 |
-
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
| 89 |
-
img_np = (img_np * 255.0).round()
|
| 90 |
-
img_np = img_np.astype(out_type)
|
| 91 |
-
result.append(img_np)
|
| 92 |
-
if len(result) == 1:
|
| 93 |
-
result = result[0]
|
| 94 |
-
return result
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
| 98 |
-
"""This implementation is slightly faster than tensor2img.
|
| 99 |
-
It now only supports torch tensor with shape (1, c, h, w).
|
| 100 |
-
|
| 101 |
-
Args:
|
| 102 |
-
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
| 103 |
-
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
| 104 |
-
min_max (tuple[int]): min and max values for clamp.
|
| 105 |
-
"""
|
| 106 |
-
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
| 107 |
-
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
| 108 |
-
output = output.type(torch.uint8).cpu().numpy()
|
| 109 |
-
if rgb2bgr:
|
| 110 |
-
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
| 111 |
-
return output
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def imfrombytes(content, flag='color', float32=False):
|
| 115 |
-
"""Read an image from bytes.
|
| 116 |
-
|
| 117 |
-
Args:
|
| 118 |
-
content (bytes): Image bytes got from files or other streams.
|
| 119 |
-
flag (str): Flags specifying the color type of a loaded image,
|
| 120 |
-
candidates are `color`, `grayscale` and `unchanged`.
|
| 121 |
-
float32 (bool): Whether to change to float32., If True, will also norm
|
| 122 |
-
to [0, 1]. Default: False.
|
| 123 |
-
|
| 124 |
-
Returns:
|
| 125 |
-
ndarray: Loaded image array.
|
| 126 |
-
"""
|
| 127 |
-
img_np = np.frombuffer(content, np.uint8)
|
| 128 |
-
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
| 129 |
-
img = cv2.imdecode(img_np, imread_flags[flag])
|
| 130 |
-
if float32:
|
| 131 |
-
img = img.astype(np.float32) / 255.
|
| 132 |
-
return img
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
| 136 |
-
"""Write image to file.
|
| 137 |
-
|
| 138 |
-
Args:
|
| 139 |
-
img (ndarray): Image array to be written.
|
| 140 |
-
file_path (str): Image file path.
|
| 141 |
-
params (None or list): Same as opencv's :func:`imwrite` interface.
|
| 142 |
-
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
| 143 |
-
whether to create it automatically.
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
bool: Successful or not.
|
| 147 |
-
"""
|
| 148 |
-
if auto_mkdir:
|
| 149 |
-
dir_name = os.path.abspath(os.path.dirname(file_path))
|
| 150 |
-
os.makedirs(dir_name, exist_ok=True)
|
| 151 |
-
ok = cv2.imwrite(file_path, img, params)
|
| 152 |
-
if not ok:
|
| 153 |
-
raise IOError('Failed in writing images.')
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def crop_border(imgs, crop_border):
|
| 157 |
-
"""Crop borders of images.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
| 161 |
-
crop_border (int): Crop border for each end of height and weight.
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
list[ndarray]: Cropped images.
|
| 165 |
-
"""
|
| 166 |
-
if crop_border == 0:
|
| 167 |
-
return imgs
|
| 168 |
-
else:
|
| 169 |
-
if isinstance(imgs, list):
|
| 170 |
-
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
| 171 |
-
else:
|
| 172 |
-
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/logger.py
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
import datetime
|
| 2 |
-
import logging
|
| 3 |
-
import time
|
| 4 |
-
|
| 5 |
-
from .dist_util import get_dist_info, master_only
|
| 6 |
-
|
| 7 |
-
initialized_logger = {}
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class AvgTimer():
|
| 11 |
-
|
| 12 |
-
def __init__(self, window=200):
|
| 13 |
-
self.window = window # average window
|
| 14 |
-
self.current_time = 0
|
| 15 |
-
self.total_time = 0
|
| 16 |
-
self.count = 0
|
| 17 |
-
self.avg_time = 0
|
| 18 |
-
self.start()
|
| 19 |
-
|
| 20 |
-
def start(self):
|
| 21 |
-
self.start_time = self.tic = time.time()
|
| 22 |
-
|
| 23 |
-
def record(self):
|
| 24 |
-
self.count += 1
|
| 25 |
-
self.toc = time.time()
|
| 26 |
-
self.current_time = self.toc - self.tic
|
| 27 |
-
self.total_time += self.current_time
|
| 28 |
-
# calculate average time
|
| 29 |
-
self.avg_time = self.total_time / self.count
|
| 30 |
-
|
| 31 |
-
# reset
|
| 32 |
-
if self.count > self.window:
|
| 33 |
-
self.count = 0
|
| 34 |
-
self.total_time = 0
|
| 35 |
-
|
| 36 |
-
self.tic = time.time()
|
| 37 |
-
|
| 38 |
-
def get_current_time(self):
|
| 39 |
-
return self.current_time
|
| 40 |
-
|
| 41 |
-
def get_avg_time(self):
|
| 42 |
-
return self.avg_time
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class MessageLogger():
|
| 46 |
-
"""Message logger for printing.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
opt (dict): Config. It contains the following keys:
|
| 50 |
-
name (str): Exp name.
|
| 51 |
-
logger (dict): Contains 'print_freq' (str) for logger interval.
|
| 52 |
-
train (dict): Contains 'total_iter' (int) for total iters.
|
| 53 |
-
use_tb_logger (bool): Use tensorboard logger.
|
| 54 |
-
start_iter (int): Start iter. Default: 1.
|
| 55 |
-
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
def __init__(self, opt, start_iter=1, tb_logger=None):
|
| 59 |
-
self.exp_name = opt['name']
|
| 60 |
-
self.interval = opt['logger']['print_freq']
|
| 61 |
-
self.start_iter = start_iter
|
| 62 |
-
self.max_iters = opt['train']['total_iter']
|
| 63 |
-
self.use_tb_logger = opt['logger']['use_tb_logger']
|
| 64 |
-
self.tb_logger = tb_logger
|
| 65 |
-
self.start_time = time.time()
|
| 66 |
-
self.logger = get_root_logger()
|
| 67 |
-
|
| 68 |
-
def reset_start_time(self):
|
| 69 |
-
self.start_time = time.time()
|
| 70 |
-
|
| 71 |
-
@master_only
|
| 72 |
-
def __call__(self, log_vars):
|
| 73 |
-
"""Format logging message.
|
| 74 |
-
|
| 75 |
-
Args:
|
| 76 |
-
log_vars (dict): It contains the following keys:
|
| 77 |
-
epoch (int): Epoch number.
|
| 78 |
-
iter (int): Current iter.
|
| 79 |
-
lrs (list): List for learning rates.
|
| 80 |
-
|
| 81 |
-
time (float): Iter time.
|
| 82 |
-
data_time (float): Data time for each iter.
|
| 83 |
-
"""
|
| 84 |
-
# epoch, iter, learning rates
|
| 85 |
-
epoch = log_vars.pop('epoch')
|
| 86 |
-
current_iter = log_vars.pop('iter')
|
| 87 |
-
lrs = log_vars.pop('lrs')
|
| 88 |
-
|
| 89 |
-
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
|
| 90 |
-
for v in lrs:
|
| 91 |
-
message += f'{v:.3e},'
|
| 92 |
-
message += ')] '
|
| 93 |
-
|
| 94 |
-
# time and estimated time
|
| 95 |
-
if 'time' in log_vars.keys():
|
| 96 |
-
iter_time = log_vars.pop('time')
|
| 97 |
-
data_time = log_vars.pop('data_time')
|
| 98 |
-
|
| 99 |
-
total_time = time.time() - self.start_time
|
| 100 |
-
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
| 101 |
-
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
| 102 |
-
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
| 103 |
-
message += f'[eta: {eta_str}, '
|
| 104 |
-
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
| 105 |
-
|
| 106 |
-
# other items, especially losses
|
| 107 |
-
for k, v in log_vars.items():
|
| 108 |
-
message += f'{k}: {v:.4e} '
|
| 109 |
-
# tensorboard logger
|
| 110 |
-
if self.use_tb_logger and 'debug' not in self.exp_name:
|
| 111 |
-
if k.startswith('l_'):
|
| 112 |
-
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
| 113 |
-
else:
|
| 114 |
-
self.tb_logger.add_scalar(k, v, current_iter)
|
| 115 |
-
self.logger.info(message)
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
@master_only
|
| 119 |
-
def init_tb_logger(log_dir):
|
| 120 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 121 |
-
tb_logger = SummaryWriter(log_dir=log_dir)
|
| 122 |
-
return tb_logger
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@master_only
|
| 126 |
-
def init_wandb_logger(opt):
|
| 127 |
-
"""We now only use wandb to sync tensorboard log."""
|
| 128 |
-
import wandb
|
| 129 |
-
logger = get_root_logger()
|
| 130 |
-
|
| 131 |
-
project = opt['logger']['wandb']['project']
|
| 132 |
-
resume_id = opt['logger']['wandb'].get('resume_id')
|
| 133 |
-
if resume_id:
|
| 134 |
-
wandb_id = resume_id
|
| 135 |
-
resume = 'allow'
|
| 136 |
-
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
| 137 |
-
else:
|
| 138 |
-
wandb_id = wandb.util.generate_id()
|
| 139 |
-
resume = 'never'
|
| 140 |
-
|
| 141 |
-
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
| 142 |
-
|
| 143 |
-
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
| 147 |
-
"""Get the root logger.
|
| 148 |
-
|
| 149 |
-
The logger will be initialized if it has not been initialized. By default a
|
| 150 |
-
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| 151 |
-
also be added.
|
| 152 |
-
|
| 153 |
-
Args:
|
| 154 |
-
logger_name (str): root logger name. Default: 'basicsr'.
|
| 155 |
-
log_file (str | None): The log filename. If specified, a FileHandler
|
| 156 |
-
will be added to the root logger.
|
| 157 |
-
log_level (int): The root logger level. Note that only the process of
|
| 158 |
-
rank 0 is affected, while other processes will set the level to
|
| 159 |
-
"Error" and be silent most of the time.
|
| 160 |
-
|
| 161 |
-
Returns:
|
| 162 |
-
logging.Logger: The root logger.
|
| 163 |
-
"""
|
| 164 |
-
logger = logging.getLogger(logger_name)
|
| 165 |
-
# if the logger has been initialized, just return it
|
| 166 |
-
if logger_name in initialized_logger:
|
| 167 |
-
return logger
|
| 168 |
-
|
| 169 |
-
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
| 170 |
-
stream_handler = logging.StreamHandler()
|
| 171 |
-
stream_handler.setFormatter(logging.Formatter(format_str))
|
| 172 |
-
logger.addHandler(stream_handler)
|
| 173 |
-
logger.propagate = False
|
| 174 |
-
rank, _ = get_dist_info()
|
| 175 |
-
if rank != 0:
|
| 176 |
-
logger.setLevel('ERROR')
|
| 177 |
-
elif log_file is not None:
|
| 178 |
-
logger.setLevel(log_level)
|
| 179 |
-
# add file handler
|
| 180 |
-
file_handler = logging.FileHandler(log_file, 'w')
|
| 181 |
-
file_handler.setFormatter(logging.Formatter(format_str))
|
| 182 |
-
file_handler.setLevel(log_level)
|
| 183 |
-
logger.addHandler(file_handler)
|
| 184 |
-
initialized_logger[logger_name] = True
|
| 185 |
-
return logger
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def get_env_info():
|
| 189 |
-
"""Get environment information.
|
| 190 |
-
|
| 191 |
-
Currently, only log the software version.
|
| 192 |
-
"""
|
| 193 |
-
import torch
|
| 194 |
-
import torchvision
|
| 195 |
-
|
| 196 |
-
from basicsr.version import __version__
|
| 197 |
-
msg = r"""
|
| 198 |
-
____ _ _____ ____
|
| 199 |
-
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
| 200 |
-
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
| 201 |
-
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
| 202 |
-
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
| 203 |
-
______ __ __ __ __
|
| 204 |
-
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
| 205 |
-
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
| 206 |
-
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
| 207 |
-
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
| 208 |
-
"""
|
| 209 |
-
msg += ('\nVersion Information: '
|
| 210 |
-
f'\n\tBasicSR: {__version__}'
|
| 211 |
-
f'\n\tPyTorch: {torch.__version__}'
|
| 212 |
-
f'\n\tTorchVision: {torchvision.__version__}')
|
| 213 |
-
return msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/matlab_functions.py
DELETED
|
@@ -1,359 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import numpy as np
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def cubic(x):
|
| 7 |
-
"""cubic function used for calculate_weights_indices."""
|
| 8 |
-
absx = torch.abs(x)
|
| 9 |
-
absx2 = absx**2
|
| 10 |
-
absx3 = absx**3
|
| 11 |
-
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
| 12 |
-
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
| 13 |
-
(absx <= 2)).type_as(absx))
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
| 17 |
-
"""Calculate weights and indices, used for imresize function.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
in_length (int): Input length.
|
| 21 |
-
out_length (int): Output length.
|
| 22 |
-
scale (float): Scale factor.
|
| 23 |
-
kernel_width (int): Kernel width.
|
| 24 |
-
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
if (scale < 1) and antialiasing:
|
| 28 |
-
# Use a modified kernel (larger kernel width) to simultaneously
|
| 29 |
-
# interpolate and antialias
|
| 30 |
-
kernel_width = kernel_width / scale
|
| 31 |
-
|
| 32 |
-
# Output-space coordinates
|
| 33 |
-
x = torch.linspace(1, out_length, out_length)
|
| 34 |
-
|
| 35 |
-
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
| 36 |
-
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
| 37 |
-
# space maps to 1.5 in input space.
|
| 38 |
-
u = x / scale + 0.5 * (1 - 1 / scale)
|
| 39 |
-
|
| 40 |
-
# What is the left-most pixel that can be involved in the computation?
|
| 41 |
-
left = torch.floor(u - kernel_width / 2)
|
| 42 |
-
|
| 43 |
-
# What is the maximum number of pixels that can be involved in the
|
| 44 |
-
# computation? Note: it's OK to use an extra pixel here; if the
|
| 45 |
-
# corresponding weights are all zero, it will be eliminated at the end
|
| 46 |
-
# of this function.
|
| 47 |
-
p = math.ceil(kernel_width) + 2
|
| 48 |
-
|
| 49 |
-
# The indices of the input pixels involved in computing the k-th output
|
| 50 |
-
# pixel are in row k of the indices matrix.
|
| 51 |
-
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
| 52 |
-
out_length, p)
|
| 53 |
-
|
| 54 |
-
# The weights used to compute the k-th output pixel are in row k of the
|
| 55 |
-
# weights matrix.
|
| 56 |
-
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
| 57 |
-
|
| 58 |
-
# apply cubic kernel
|
| 59 |
-
if (scale < 1) and antialiasing:
|
| 60 |
-
weights = scale * cubic(distance_to_center * scale)
|
| 61 |
-
else:
|
| 62 |
-
weights = cubic(distance_to_center)
|
| 63 |
-
|
| 64 |
-
# Normalize the weights matrix so that each row sums to 1.
|
| 65 |
-
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
| 66 |
-
weights = weights / weights_sum.expand(out_length, p)
|
| 67 |
-
|
| 68 |
-
# If a column in weights is all zero, get rid of it. only consider the
|
| 69 |
-
# first and last column.
|
| 70 |
-
weights_zero_tmp = torch.sum((weights == 0), 0)
|
| 71 |
-
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
| 72 |
-
indices = indices.narrow(1, 1, p - 2)
|
| 73 |
-
weights = weights.narrow(1, 1, p - 2)
|
| 74 |
-
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
| 75 |
-
indices = indices.narrow(1, 0, p - 2)
|
| 76 |
-
weights = weights.narrow(1, 0, p - 2)
|
| 77 |
-
weights = weights.contiguous()
|
| 78 |
-
indices = indices.contiguous()
|
| 79 |
-
sym_len_s = -indices.min() + 1
|
| 80 |
-
sym_len_e = indices.max() - in_length
|
| 81 |
-
indices = indices + sym_len_s - 1
|
| 82 |
-
return weights, indices, int(sym_len_s), int(sym_len_e)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
@torch.no_grad()
|
| 86 |
-
def imresize(img, scale, antialiasing=True):
|
| 87 |
-
"""imresize function same as MATLAB.
|
| 88 |
-
|
| 89 |
-
It now only supports bicubic.
|
| 90 |
-
The same scale applies for both height and width.
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
img (Tensor | Numpy array):
|
| 94 |
-
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
| 95 |
-
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
| 96 |
-
scale (float): Scale factor. The same scale applies for both height
|
| 97 |
-
and width.
|
| 98 |
-
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 99 |
-
Default: True.
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
| 103 |
-
"""
|
| 104 |
-
squeeze_flag = False
|
| 105 |
-
if type(img).__module__ == np.__name__: # numpy type
|
| 106 |
-
numpy_type = True
|
| 107 |
-
if img.ndim == 2:
|
| 108 |
-
img = img[:, :, None]
|
| 109 |
-
squeeze_flag = True
|
| 110 |
-
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
| 111 |
-
else:
|
| 112 |
-
numpy_type = False
|
| 113 |
-
if img.ndim == 2:
|
| 114 |
-
img = img.unsqueeze(0)
|
| 115 |
-
squeeze_flag = True
|
| 116 |
-
|
| 117 |
-
in_c, in_h, in_w = img.size()
|
| 118 |
-
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
| 119 |
-
kernel_width = 4
|
| 120 |
-
kernel = 'cubic'
|
| 121 |
-
|
| 122 |
-
# get weights and indices
|
| 123 |
-
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
| 124 |
-
antialiasing)
|
| 125 |
-
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
| 126 |
-
antialiasing)
|
| 127 |
-
# process H dimension
|
| 128 |
-
# symmetric copying
|
| 129 |
-
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
| 130 |
-
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
| 131 |
-
|
| 132 |
-
sym_patch = img[:, :sym_len_hs, :]
|
| 133 |
-
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 134 |
-
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 135 |
-
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
| 136 |
-
|
| 137 |
-
sym_patch = img[:, -sym_len_he:, :]
|
| 138 |
-
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 139 |
-
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 140 |
-
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
| 141 |
-
|
| 142 |
-
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
| 143 |
-
kernel_width = weights_h.size(1)
|
| 144 |
-
for i in range(out_h):
|
| 145 |
-
idx = int(indices_h[i][0])
|
| 146 |
-
for j in range(in_c):
|
| 147 |
-
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
| 148 |
-
|
| 149 |
-
# process W dimension
|
| 150 |
-
# symmetric copying
|
| 151 |
-
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
| 152 |
-
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
| 153 |
-
|
| 154 |
-
sym_patch = out_1[:, :, :sym_len_ws]
|
| 155 |
-
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 156 |
-
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 157 |
-
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
| 158 |
-
|
| 159 |
-
sym_patch = out_1[:, :, -sym_len_we:]
|
| 160 |
-
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 161 |
-
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 162 |
-
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
| 163 |
-
|
| 164 |
-
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
| 165 |
-
kernel_width = weights_w.size(1)
|
| 166 |
-
for i in range(out_w):
|
| 167 |
-
idx = int(indices_w[i][0])
|
| 168 |
-
for j in range(in_c):
|
| 169 |
-
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
| 170 |
-
|
| 171 |
-
if squeeze_flag:
|
| 172 |
-
out_2 = out_2.squeeze(0)
|
| 173 |
-
if numpy_type:
|
| 174 |
-
out_2 = out_2.numpy()
|
| 175 |
-
if not squeeze_flag:
|
| 176 |
-
out_2 = out_2.transpose(1, 2, 0)
|
| 177 |
-
|
| 178 |
-
return out_2
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
def rgb2ycbcr(img, y_only=False):
|
| 182 |
-
"""Convert a RGB image to YCbCr image.
|
| 183 |
-
|
| 184 |
-
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
| 185 |
-
It implements the ITU-R BT.601 conversion for standard-definition
|
| 186 |
-
television. See more details in
|
| 187 |
-
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 188 |
-
|
| 189 |
-
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
| 190 |
-
In OpenCV, it implements a JPEG conversion. See more details in
|
| 191 |
-
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
img (ndarray): The input image. It accepts:
|
| 195 |
-
1. np.uint8 type with range [0, 255];
|
| 196 |
-
2. np.float32 type with range [0, 1].
|
| 197 |
-
y_only (bool): Whether to only return Y channel. Default: False.
|
| 198 |
-
|
| 199 |
-
Returns:
|
| 200 |
-
ndarray: The converted YCbCr image. The output image has the same type
|
| 201 |
-
and range as input image.
|
| 202 |
-
"""
|
| 203 |
-
img_type = img.dtype
|
| 204 |
-
img = _convert_input_type_range(img)
|
| 205 |
-
if y_only:
|
| 206 |
-
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
| 207 |
-
else:
|
| 208 |
-
out_img = np.matmul(
|
| 209 |
-
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
| 210 |
-
out_img = _convert_output_type_range(out_img, img_type)
|
| 211 |
-
return out_img
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
def bgr2ycbcr(img, y_only=False):
|
| 215 |
-
"""Convert a BGR image to YCbCr image.
|
| 216 |
-
|
| 217 |
-
The bgr version of rgb2ycbcr.
|
| 218 |
-
It implements the ITU-R BT.601 conversion for standard-definition
|
| 219 |
-
television. See more details in
|
| 220 |
-
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 221 |
-
|
| 222 |
-
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
| 223 |
-
In OpenCV, it implements a JPEG conversion. See more details in
|
| 224 |
-
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
img (ndarray): The input image. It accepts:
|
| 228 |
-
1. np.uint8 type with range [0, 255];
|
| 229 |
-
2. np.float32 type with range [0, 1].
|
| 230 |
-
y_only (bool): Whether to only return Y channel. Default: False.
|
| 231 |
-
|
| 232 |
-
Returns:
|
| 233 |
-
ndarray: The converted YCbCr image. The output image has the same type
|
| 234 |
-
and range as input image.
|
| 235 |
-
"""
|
| 236 |
-
img_type = img.dtype
|
| 237 |
-
img = _convert_input_type_range(img)
|
| 238 |
-
if y_only:
|
| 239 |
-
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
| 240 |
-
else:
|
| 241 |
-
out_img = np.matmul(
|
| 242 |
-
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
| 243 |
-
out_img = _convert_output_type_range(out_img, img_type)
|
| 244 |
-
return out_img
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def ycbcr2rgb(img):
|
| 248 |
-
"""Convert a YCbCr image to RGB image.
|
| 249 |
-
|
| 250 |
-
This function produces the same results as Matlab's ycbcr2rgb function.
|
| 251 |
-
It implements the ITU-R BT.601 conversion for standard-definition
|
| 252 |
-
television. See more details in
|
| 253 |
-
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 254 |
-
|
| 255 |
-
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
| 256 |
-
In OpenCV, it implements a JPEG conversion. See more details in
|
| 257 |
-
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 258 |
-
|
| 259 |
-
Args:
|
| 260 |
-
img (ndarray): The input image. It accepts:
|
| 261 |
-
1. np.uint8 type with range [0, 255];
|
| 262 |
-
2. np.float32 type with range [0, 1].
|
| 263 |
-
|
| 264 |
-
Returns:
|
| 265 |
-
ndarray: The converted RGB image. The output image has the same type
|
| 266 |
-
and range as input image.
|
| 267 |
-
"""
|
| 268 |
-
img_type = img.dtype
|
| 269 |
-
img = _convert_input_type_range(img) * 255
|
| 270 |
-
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
| 271 |
-
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
| 272 |
-
out_img = _convert_output_type_range(out_img, img_type)
|
| 273 |
-
return out_img
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
def ycbcr2bgr(img):
|
| 277 |
-
"""Convert a YCbCr image to BGR image.
|
| 278 |
-
|
| 279 |
-
The bgr version of ycbcr2rgb.
|
| 280 |
-
It implements the ITU-R BT.601 conversion for standard-definition
|
| 281 |
-
television. See more details in
|
| 282 |
-
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 283 |
-
|
| 284 |
-
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
| 285 |
-
In OpenCV, it implements a JPEG conversion. See more details in
|
| 286 |
-
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 287 |
-
|
| 288 |
-
Args:
|
| 289 |
-
img (ndarray): The input image. It accepts:
|
| 290 |
-
1. np.uint8 type with range [0, 255];
|
| 291 |
-
2. np.float32 type with range [0, 1].
|
| 292 |
-
|
| 293 |
-
Returns:
|
| 294 |
-
ndarray: The converted BGR image. The output image has the same type
|
| 295 |
-
and range as input image.
|
| 296 |
-
"""
|
| 297 |
-
img_type = img.dtype
|
| 298 |
-
img = _convert_input_type_range(img) * 255
|
| 299 |
-
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
| 300 |
-
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
| 301 |
-
out_img = _convert_output_type_range(out_img, img_type)
|
| 302 |
-
return out_img
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
def _convert_input_type_range(img):
|
| 306 |
-
"""Convert the type and range of the input image.
|
| 307 |
-
|
| 308 |
-
It converts the input image to np.float32 type and range of [0, 1].
|
| 309 |
-
It is mainly used for pre-processing the input image in colorspace
|
| 310 |
-
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
| 311 |
-
|
| 312 |
-
Args:
|
| 313 |
-
img (ndarray): The input image. It accepts:
|
| 314 |
-
1. np.uint8 type with range [0, 255];
|
| 315 |
-
2. np.float32 type with range [0, 1].
|
| 316 |
-
|
| 317 |
-
Returns:
|
| 318 |
-
(ndarray): The converted image with type of np.float32 and range of
|
| 319 |
-
[0, 1].
|
| 320 |
-
"""
|
| 321 |
-
img_type = img.dtype
|
| 322 |
-
img = img.astype(np.float32)
|
| 323 |
-
if img_type == np.float32:
|
| 324 |
-
pass
|
| 325 |
-
elif img_type == np.uint8:
|
| 326 |
-
img /= 255.
|
| 327 |
-
else:
|
| 328 |
-
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
| 329 |
-
return img
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
def _convert_output_type_range(img, dst_type):
|
| 333 |
-
"""Convert the type and range of the image according to dst_type.
|
| 334 |
-
|
| 335 |
-
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
| 336 |
-
images will be converted to np.uint8 type with range [0, 255]. If
|
| 337 |
-
`dst_type` is np.float32, it converts the image to np.float32 type with
|
| 338 |
-
range [0, 1].
|
| 339 |
-
It is mainly used for post-processing images in colorspace conversion
|
| 340 |
-
functions such as rgb2ycbcr and ycbcr2rgb.
|
| 341 |
-
|
| 342 |
-
Args:
|
| 343 |
-
img (ndarray): The image to be converted with np.float32 type and
|
| 344 |
-
range [0, 255].
|
| 345 |
-
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
| 346 |
-
converts the image to np.uint8 type with range [0, 255]. If
|
| 347 |
-
dst_type is np.float32, it converts the image to np.float32 type
|
| 348 |
-
with range [0, 1].
|
| 349 |
-
|
| 350 |
-
Returns:
|
| 351 |
-
(ndarray): The converted image with desired type and range.
|
| 352 |
-
"""
|
| 353 |
-
if dst_type not in (np.uint8, np.float32):
|
| 354 |
-
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
| 355 |
-
if dst_type == np.uint8:
|
| 356 |
-
img = img.round()
|
| 357 |
-
else:
|
| 358 |
-
img /= 255.
|
| 359 |
-
return img.astype(dst_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/misc.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import os
|
| 3 |
-
import random
|
| 4 |
-
import time
|
| 5 |
-
import torch
|
| 6 |
-
from os import path as osp
|
| 7 |
-
|
| 8 |
-
from .dist_util import master_only
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def set_random_seed(seed):
|
| 12 |
-
"""Set random seeds."""
|
| 13 |
-
random.seed(seed)
|
| 14 |
-
np.random.seed(seed)
|
| 15 |
-
torch.manual_seed(seed)
|
| 16 |
-
torch.cuda.manual_seed(seed)
|
| 17 |
-
torch.cuda.manual_seed_all(seed)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_time_str():
|
| 21 |
-
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def mkdir_and_rename(path):
|
| 25 |
-
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
path (str): Folder path.
|
| 29 |
-
"""
|
| 30 |
-
if osp.exists(path):
|
| 31 |
-
new_name = path + '_archived_' + get_time_str()
|
| 32 |
-
print(f'Path already exists. Rename it to {new_name}', flush=True)
|
| 33 |
-
os.rename(path, new_name)
|
| 34 |
-
os.makedirs(path, exist_ok=True)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
@master_only
|
| 38 |
-
def make_exp_dirs(opt):
|
| 39 |
-
"""Make dirs for experiments."""
|
| 40 |
-
path_opt = opt['path'].copy()
|
| 41 |
-
if opt['is_train']:
|
| 42 |
-
mkdir_and_rename(path_opt.pop('experiments_root'))
|
| 43 |
-
else:
|
| 44 |
-
mkdir_and_rename(path_opt.pop('results_root'))
|
| 45 |
-
for key, path in path_opt.items():
|
| 46 |
-
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
|
| 47 |
-
continue
|
| 48 |
-
else:
|
| 49 |
-
os.makedirs(path, exist_ok=True)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
| 53 |
-
"""Scan a directory to find the interested files.
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
dir_path (str): Path of the directory.
|
| 57 |
-
suffix (str | tuple(str), optional): File suffix that we are
|
| 58 |
-
interested in. Default: None.
|
| 59 |
-
recursive (bool, optional): If set to True, recursively scan the
|
| 60 |
-
directory. Default: False.
|
| 61 |
-
full_path (bool, optional): If set to True, include the dir_path.
|
| 62 |
-
Default: False.
|
| 63 |
-
|
| 64 |
-
Returns:
|
| 65 |
-
A generator for all the interested files with relative paths.
|
| 66 |
-
"""
|
| 67 |
-
|
| 68 |
-
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 69 |
-
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 70 |
-
|
| 71 |
-
root = dir_path
|
| 72 |
-
|
| 73 |
-
def _scandir(dir_path, suffix, recursive):
|
| 74 |
-
for entry in os.scandir(dir_path):
|
| 75 |
-
if not entry.name.startswith('.') and entry.is_file():
|
| 76 |
-
if full_path:
|
| 77 |
-
return_path = entry.path
|
| 78 |
-
else:
|
| 79 |
-
return_path = osp.relpath(entry.path, root)
|
| 80 |
-
|
| 81 |
-
if suffix is None:
|
| 82 |
-
yield return_path
|
| 83 |
-
elif return_path.endswith(suffix):
|
| 84 |
-
yield return_path
|
| 85 |
-
else:
|
| 86 |
-
if recursive:
|
| 87 |
-
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
| 88 |
-
else:
|
| 89 |
-
continue
|
| 90 |
-
|
| 91 |
-
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def check_resume(opt, resume_iter):
|
| 95 |
-
"""Check resume states and pretrain_network paths.
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
opt (dict): Options.
|
| 99 |
-
resume_iter (int): Resume iteration.
|
| 100 |
-
"""
|
| 101 |
-
if opt['path']['resume_state']:
|
| 102 |
-
# get all the networks
|
| 103 |
-
networks = [key for key in opt.keys() if key.startswith('network_')]
|
| 104 |
-
flag_pretrain = False
|
| 105 |
-
for network in networks:
|
| 106 |
-
if opt['path'].get(f'pretrain_{network}') is not None:
|
| 107 |
-
flag_pretrain = True
|
| 108 |
-
if flag_pretrain:
|
| 109 |
-
print('pretrain_network path will be ignored during resuming.')
|
| 110 |
-
# set pretrained model paths
|
| 111 |
-
for network in networks:
|
| 112 |
-
name = f'pretrain_{network}'
|
| 113 |
-
basename = network.replace('network_', '')
|
| 114 |
-
if opt['path'].get('ignore_resume_networks') is None or (network
|
| 115 |
-
not in opt['path']['ignore_resume_networks']):
|
| 116 |
-
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
| 117 |
-
print(f"Set {name} to {opt['path'][name]}")
|
| 118 |
-
|
| 119 |
-
# change param_key to params in resume
|
| 120 |
-
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
|
| 121 |
-
for param_key in param_keys:
|
| 122 |
-
if opt['path'][param_key] == 'params_ema':
|
| 123 |
-
opt['path'][param_key] = 'params'
|
| 124 |
-
print(f'Set {param_key} to params')
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def sizeof_fmt(size, suffix='B'):
|
| 128 |
-
"""Get human readable file size.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
size (int): File size.
|
| 132 |
-
suffix (str): Suffix. Default: 'B'.
|
| 133 |
-
|
| 134 |
-
Return:
|
| 135 |
-
str: Formatted file siz.
|
| 136 |
-
"""
|
| 137 |
-
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
| 138 |
-
if abs(size) < 1024.0:
|
| 139 |
-
return f'{size:3.1f} {unit}{suffix}'
|
| 140 |
-
size /= 1024.0
|
| 141 |
-
return f'{size:3.1f} Y{suffix}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/options.py
DELETED
|
@@ -1,194 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import random
|
| 3 |
-
import torch
|
| 4 |
-
import yaml
|
| 5 |
-
from collections import OrderedDict
|
| 6 |
-
from os import path as osp
|
| 7 |
-
|
| 8 |
-
from basicsr.utils import set_random_seed
|
| 9 |
-
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def ordered_yaml():
|
| 13 |
-
"""Support OrderedDict for yaml.
|
| 14 |
-
|
| 15 |
-
Returns:
|
| 16 |
-
yaml Loader and Dumper.
|
| 17 |
-
"""
|
| 18 |
-
try:
|
| 19 |
-
from yaml import CDumper as Dumper
|
| 20 |
-
from yaml import CLoader as Loader
|
| 21 |
-
except ImportError:
|
| 22 |
-
from yaml import Dumper, Loader
|
| 23 |
-
|
| 24 |
-
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
| 25 |
-
|
| 26 |
-
def dict_representer(dumper, data):
|
| 27 |
-
return dumper.represent_dict(data.items())
|
| 28 |
-
|
| 29 |
-
def dict_constructor(loader, node):
|
| 30 |
-
return OrderedDict(loader.construct_pairs(node))
|
| 31 |
-
|
| 32 |
-
Dumper.add_representer(OrderedDict, dict_representer)
|
| 33 |
-
Loader.add_constructor(_mapping_tag, dict_constructor)
|
| 34 |
-
return Loader, Dumper
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def dict2str(opt, indent_level=1):
|
| 38 |
-
"""dict to string for printing options.
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
opt (dict): Option dict.
|
| 42 |
-
indent_level (int): Indent level. Default: 1.
|
| 43 |
-
|
| 44 |
-
Return:
|
| 45 |
-
(str): Option string for printing.
|
| 46 |
-
"""
|
| 47 |
-
msg = '\n'
|
| 48 |
-
for k, v in opt.items():
|
| 49 |
-
if isinstance(v, dict):
|
| 50 |
-
msg += ' ' * (indent_level * 2) + k + ':['
|
| 51 |
-
msg += dict2str(v, indent_level + 1)
|
| 52 |
-
msg += ' ' * (indent_level * 2) + ']\n'
|
| 53 |
-
else:
|
| 54 |
-
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
| 55 |
-
return msg
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def _postprocess_yml_value(value):
|
| 59 |
-
# None
|
| 60 |
-
if value == '~' or value.lower() == 'none':
|
| 61 |
-
return None
|
| 62 |
-
# bool
|
| 63 |
-
if value.lower() == 'true':
|
| 64 |
-
return True
|
| 65 |
-
elif value.lower() == 'false':
|
| 66 |
-
return False
|
| 67 |
-
# !!float number
|
| 68 |
-
if value.startswith('!!float'):
|
| 69 |
-
return float(value.replace('!!float', ''))
|
| 70 |
-
# number
|
| 71 |
-
if value.isdigit():
|
| 72 |
-
return int(value)
|
| 73 |
-
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
|
| 74 |
-
return float(value)
|
| 75 |
-
# list
|
| 76 |
-
if value.startswith('['):
|
| 77 |
-
return eval(value)
|
| 78 |
-
# str
|
| 79 |
-
return value
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def parse_options(root_path, is_train=True):
|
| 83 |
-
parser = argparse.ArgumentParser()
|
| 84 |
-
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
|
| 85 |
-
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
|
| 86 |
-
parser.add_argument('--auto_resume', action='store_true')
|
| 87 |
-
parser.add_argument('--debug', action='store_true')
|
| 88 |
-
parser.add_argument('--local_rank', type=int, default=0)
|
| 89 |
-
parser.add_argument(
|
| 90 |
-
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
|
| 91 |
-
args = parser.parse_args()
|
| 92 |
-
|
| 93 |
-
# parse yml to dict
|
| 94 |
-
with open(args.opt, mode='r') as f:
|
| 95 |
-
opt = yaml.load(f, Loader=ordered_yaml()[0])
|
| 96 |
-
|
| 97 |
-
# distributed settings
|
| 98 |
-
if args.launcher == 'none':
|
| 99 |
-
opt['dist'] = False
|
| 100 |
-
print('Disable distributed.', flush=True)
|
| 101 |
-
else:
|
| 102 |
-
opt['dist'] = True
|
| 103 |
-
if args.launcher == 'slurm' and 'dist_params' in opt:
|
| 104 |
-
init_dist(args.launcher, **opt['dist_params'])
|
| 105 |
-
else:
|
| 106 |
-
init_dist(args.launcher)
|
| 107 |
-
opt['rank'], opt['world_size'] = get_dist_info()
|
| 108 |
-
|
| 109 |
-
# random seed
|
| 110 |
-
seed = opt.get('manual_seed')
|
| 111 |
-
if seed is None:
|
| 112 |
-
seed = random.randint(1, 10000)
|
| 113 |
-
opt['manual_seed'] = seed
|
| 114 |
-
set_random_seed(seed + opt['rank'])
|
| 115 |
-
|
| 116 |
-
# force to update yml options
|
| 117 |
-
if args.force_yml is not None:
|
| 118 |
-
for entry in args.force_yml:
|
| 119 |
-
# now do not support creating new keys
|
| 120 |
-
keys, value = entry.split('=')
|
| 121 |
-
keys, value = keys.strip(), value.strip()
|
| 122 |
-
value = _postprocess_yml_value(value)
|
| 123 |
-
eval_str = 'opt'
|
| 124 |
-
for key in keys.split(':'):
|
| 125 |
-
eval_str += f'["{key}"]'
|
| 126 |
-
eval_str += '=value'
|
| 127 |
-
# using exec function
|
| 128 |
-
exec(eval_str)
|
| 129 |
-
|
| 130 |
-
opt['auto_resume'] = args.auto_resume
|
| 131 |
-
opt['is_train'] = is_train
|
| 132 |
-
|
| 133 |
-
# debug setting
|
| 134 |
-
if args.debug and not opt['name'].startswith('debug'):
|
| 135 |
-
opt['name'] = 'debug_' + opt['name']
|
| 136 |
-
|
| 137 |
-
if opt['num_gpu'] == 'auto':
|
| 138 |
-
opt['num_gpu'] = torch.cuda.device_count()
|
| 139 |
-
|
| 140 |
-
# datasets
|
| 141 |
-
for phase, dataset in opt['datasets'].items():
|
| 142 |
-
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
|
| 143 |
-
phase = phase.split('_')[0]
|
| 144 |
-
dataset['phase'] = phase
|
| 145 |
-
if 'scale' in opt:
|
| 146 |
-
dataset['scale'] = opt['scale']
|
| 147 |
-
if dataset.get('dataroot_gt') is not None:
|
| 148 |
-
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
| 149 |
-
if dataset.get('dataroot_lq') is not None:
|
| 150 |
-
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
| 151 |
-
|
| 152 |
-
# paths
|
| 153 |
-
for key, val in opt['path'].items():
|
| 154 |
-
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
|
| 155 |
-
opt['path'][key] = osp.expanduser(val)
|
| 156 |
-
|
| 157 |
-
if is_train:
|
| 158 |
-
experiments_root = osp.join(root_path, 'experiments', opt['name'])
|
| 159 |
-
opt['path']['experiments_root'] = experiments_root
|
| 160 |
-
opt['path']['models'] = osp.join(experiments_root, 'models')
|
| 161 |
-
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
|
| 162 |
-
opt['path']['log'] = experiments_root
|
| 163 |
-
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
|
| 164 |
-
|
| 165 |
-
# change some options for debug mode
|
| 166 |
-
if 'debug' in opt['name']:
|
| 167 |
-
if 'val' in opt:
|
| 168 |
-
opt['val']['val_freq'] = 8
|
| 169 |
-
opt['logger']['print_freq'] = 1
|
| 170 |
-
opt['logger']['save_checkpoint_freq'] = 8
|
| 171 |
-
else: # test
|
| 172 |
-
results_root = osp.join(root_path, 'results', opt['name'])
|
| 173 |
-
opt['path']['results_root'] = results_root
|
| 174 |
-
opt['path']['log'] = results_root
|
| 175 |
-
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
| 176 |
-
|
| 177 |
-
return opt, args
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
@master_only
|
| 181 |
-
def copy_opt_file(opt_file, experiments_root):
|
| 182 |
-
# copy the yml file to the experiment root
|
| 183 |
-
import sys
|
| 184 |
-
import time
|
| 185 |
-
from shutil import copyfile
|
| 186 |
-
cmd = ' '.join(sys.argv)
|
| 187 |
-
filename = osp.join(experiments_root, osp.basename(opt_file))
|
| 188 |
-
copyfile(opt_file, filename)
|
| 189 |
-
|
| 190 |
-
with open(filename, 'r+') as f:
|
| 191 |
-
lines = f.readlines()
|
| 192 |
-
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
|
| 193 |
-
f.seek(0)
|
| 194 |
-
f.writelines(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/registry.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class Registry():
|
| 5 |
-
"""
|
| 6 |
-
The registry that provides name -> object mapping, to support third-party
|
| 7 |
-
users' custom modules.
|
| 8 |
-
|
| 9 |
-
To create a registry (e.g. a backbone registry):
|
| 10 |
-
|
| 11 |
-
.. code-block:: python
|
| 12 |
-
|
| 13 |
-
BACKBONE_REGISTRY = Registry('BACKBONE')
|
| 14 |
-
|
| 15 |
-
To register an object:
|
| 16 |
-
|
| 17 |
-
.. code-block:: python
|
| 18 |
-
|
| 19 |
-
@BACKBONE_REGISTRY.register()
|
| 20 |
-
class MyBackbone():
|
| 21 |
-
...
|
| 22 |
-
|
| 23 |
-
Or:
|
| 24 |
-
|
| 25 |
-
.. code-block:: python
|
| 26 |
-
|
| 27 |
-
BACKBONE_REGISTRY.register(MyBackbone)
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
def __init__(self, name):
|
| 31 |
-
"""
|
| 32 |
-
Args:
|
| 33 |
-
name (str): the name of this registry
|
| 34 |
-
"""
|
| 35 |
-
self._name = name
|
| 36 |
-
self._obj_map = {}
|
| 37 |
-
|
| 38 |
-
def _do_register(self, name, obj):
|
| 39 |
-
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
| 40 |
-
f"in '{self._name}' registry!")
|
| 41 |
-
self._obj_map[name] = obj
|
| 42 |
-
|
| 43 |
-
def register(self, obj=None):
|
| 44 |
-
"""
|
| 45 |
-
Register the given object under the the name `obj.__name__`.
|
| 46 |
-
Can be used as either a decorator or not.
|
| 47 |
-
See docstring of this class for usage.
|
| 48 |
-
"""
|
| 49 |
-
if obj is None:
|
| 50 |
-
# used as a decorator
|
| 51 |
-
def deco(func_or_class):
|
| 52 |
-
name = func_or_class.__name__
|
| 53 |
-
self._do_register(name, func_or_class)
|
| 54 |
-
return func_or_class
|
| 55 |
-
|
| 56 |
-
return deco
|
| 57 |
-
|
| 58 |
-
# used as a function call
|
| 59 |
-
name = obj.__name__
|
| 60 |
-
self._do_register(name, obj)
|
| 61 |
-
|
| 62 |
-
def get(self, name):
|
| 63 |
-
ret = self._obj_map.get(name)
|
| 64 |
-
if ret is None:
|
| 65 |
-
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
| 66 |
-
return ret
|
| 67 |
-
|
| 68 |
-
def __contains__(self, name):
|
| 69 |
-
return name in self._obj_map
|
| 70 |
-
|
| 71 |
-
def __iter__(self):
|
| 72 |
-
return iter(self._obj_map.items())
|
| 73 |
-
|
| 74 |
-
def keys(self):
|
| 75 |
-
return self._obj_map.keys()
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
DATASET_REGISTRY = Registry('dataset')
|
| 79 |
-
ARCH_REGISTRY = Registry('arch')
|
| 80 |
-
MODEL_REGISTRY = Registry('model')
|
| 81 |
-
LOSS_REGISTRY = Registry('loss')
|
| 82 |
-
METRIC_REGISTRY = Registry('metric')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|