| | import os |
| | import re |
| | import random |
| | import time |
| | import torch |
| | import torch.nn as nn |
| | import logging |
| | import numpy as np |
| | from os import path as osp |
| |
|
| | def constant_init(module, val, bias=0): |
| | if hasattr(module, 'weight') and module.weight is not None: |
| | nn.init.constant_(module.weight, val) |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | nn.init.constant_(module.bias, bias) |
| |
|
| | initialized_logger = {} |
| | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): |
| | """Get the root logger. |
| | The logger will be initialized if it has not been initialized. By default a |
| | StreamHandler will be added. If `log_file` is specified, a FileHandler will |
| | also be added. |
| | Args: |
| | logger_name (str): root logger name. Default: 'basicsr'. |
| | log_file (str | None): The log filename. If specified, a FileHandler |
| | will be added to the root logger. |
| | log_level (int): The root logger level. Note that only the process of |
| | rank 0 is affected, while other processes will set the level to |
| | "Error" and be silent most of the time. |
| | Returns: |
| | logging.Logger: The root logger. |
| | """ |
| | logger = logging.getLogger(logger_name) |
| | |
| | if logger_name in initialized_logger: |
| | return logger |
| |
|
| | format_str = '%(asctime)s %(levelname)s: %(message)s' |
| | stream_handler = logging.StreamHandler() |
| | stream_handler.setFormatter(logging.Formatter(format_str)) |
| | logger.addHandler(stream_handler) |
| | logger.propagate = False |
| |
|
| | if log_file is not None: |
| | logger.setLevel(log_level) |
| | |
| | |
| | file_handler = logging.FileHandler(log_file, 'a') |
| | file_handler.setFormatter(logging.Formatter(format_str)) |
| | file_handler.setLevel(log_level) |
| | logger.addHandler(file_handler) |
| | initialized_logger[logger_name] = True |
| | return logger |
| |
|
| |
|
| | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ |
| | torch.__version__)[0][:3])] >= [1, 12, 0] |
| |
|
| | def gpu_is_available(): |
| | if IS_HIGH_VERSION: |
| | if torch.backends.mps.is_available(): |
| | return True |
| | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False |
| |
|
| | def get_device(gpu_id=None): |
| | if gpu_id is None: |
| | gpu_str = '' |
| | elif isinstance(gpu_id, int): |
| | gpu_str = f':{gpu_id}' |
| | else: |
| | raise TypeError('Input should be int value.') |
| |
|
| | if IS_HIGH_VERSION: |
| | if torch.backends.mps.is_available(): |
| | return torch.device('mps'+gpu_str) |
| | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') |
| |
|
| |
|
| | def set_random_seed(seed): |
| | """Set random seeds.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | def get_time_str(): |
| | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
| |
|
| |
|
| | def scandir(dir_path, suffix=None, recursive=False, full_path=False): |
| | """Scan a directory to find the interested files. |
| | |
| | Args: |
| | dir_path (str): Path of the directory. |
| | suffix (str | tuple(str), optional): File suffix that we are |
| | interested in. Default: None. |
| | recursive (bool, optional): If set to True, recursively scan the |
| | directory. Default: False. |
| | full_path (bool, optional): If set to True, include the dir_path. |
| | Default: False. |
| | |
| | Returns: |
| | A generator for all the interested files with relative pathes. |
| | """ |
| |
|
| | if (suffix is not None) and not isinstance(suffix, (str, tuple)): |
| | raise TypeError('"suffix" must be a string or tuple of strings') |
| |
|
| | root = dir_path |
| |
|
| | def _scandir(dir_path, suffix, recursive): |
| | for entry in os.scandir(dir_path): |
| | if not entry.name.startswith('.') and entry.is_file(): |
| | if full_path: |
| | return_path = entry.path |
| | else: |
| | return_path = osp.relpath(entry.path, root) |
| |
|
| | if suffix is None: |
| | yield return_path |
| | elif return_path.endswith(suffix): |
| | yield return_path |
| | else: |
| | if recursive: |
| | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) |
| | else: |
| | continue |
| |
|
| | return _scandir(dir_path, suffix=suffix, recursive=recursive) |