| | import os
|
| | import re
|
| | import random
|
| | import time
|
| | import torch
|
| | import torch.nn as nn
|
| | import logging
|
| | import numpy as np
|
| | from os import path as osp
|
| |
|
| | def constant_init(module, val, bias=0):
|
| | if hasattr(module, 'weight') and module.weight is not None:
|
| | nn.init.constant_(module.weight, val)
|
| | if hasattr(module, 'bias') and module.bias is not None:
|
| | nn.init.constant_(module.bias, bias)
|
| |
|
| | initialized_logger = {}
|
| | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
| | """Get the root logger.
|
| | The logger will be initialized if it has not been initialized. By default a
|
| | StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| | also be added.
|
| | Args:
|
| | logger_name (str): root logger name. Default: 'basicsr'.
|
| | log_file (str | None): The log filename. If specified, a FileHandler
|
| | will be added to the root logger.
|
| | log_level (int): The root logger level. Note that only the process of
|
| | rank 0 is affected, while other processes will set the level to
|
| | "Error" and be silent most of the time.
|
| | Returns:
|
| | logging.Logger: The root logger.
|
| | """
|
| | logger = logging.getLogger(logger_name)
|
| |
|
| | if logger_name in initialized_logger:
|
| | return logger
|
| |
|
| | format_str = '%(asctime)s %(levelname)s: %(message)s'
|
| | stream_handler = logging.StreamHandler()
|
| | stream_handler.setFormatter(logging.Formatter(format_str))
|
| | logger.addHandler(stream_handler)
|
| | logger.propagate = False
|
| |
|
| | if log_file is not None:
|
| | logger.setLevel(log_level)
|
| |
|
| |
|
| | file_handler = logging.FileHandler(log_file, 'a')
|
| | file_handler.setFormatter(logging.Formatter(format_str))
|
| | file_handler.setLevel(log_level)
|
| | logger.addHandler(file_handler)
|
| | initialized_logger[logger_name] = True
|
| | return logger
|
| |
|
| | match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
|
| | if match:
|
| | version_tuple = match.groups()
|
| | IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
|
| | else:
|
| | logger = get_root_logger()
|
| | logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
|
| | IS_HIGH_VERSION = False
|
| |
|
| | def gpu_is_available():
|
| | if IS_HIGH_VERSION:
|
| | if torch.backends.mps.is_available():
|
| | return True
|
| | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
| |
|
| | def get_device(gpu_id=None):
|
| | if gpu_id is None:
|
| | gpu_str = ''
|
| | elif isinstance(gpu_id, int):
|
| | gpu_str = f':{gpu_id}'
|
| | else:
|
| | raise TypeError('Input should be int value.')
|
| |
|
| | if IS_HIGH_VERSION:
|
| | if torch.backends.mps.is_available():
|
| | return torch.device('mps'+gpu_str)
|
| | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
| |
|
| |
|
| | def set_random_seed(seed):
|
| | """Set random seeds."""
|
| | random.seed(seed)
|
| | np.random.seed(seed)
|
| | torch.manual_seed(seed)
|
| | torch.cuda.manual_seed(seed)
|
| | torch.cuda.manual_seed_all(seed)
|
| |
|
| |
|
| | def get_time_str():
|
| | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
| |
|
| |
|
| | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
| | """Scan a directory to find the interested files.
|
| |
|
| | Args:
|
| | dir_path (str): Path of the directory.
|
| | suffix (str | tuple(str), optional): File suffix that we are
|
| | interested in. Default: None.
|
| | recursive (bool, optional): If set to True, recursively scan the
|
| | directory. Default: False.
|
| | full_path (bool, optional): If set to True, include the dir_path.
|
| | Default: False.
|
| |
|
| | Returns:
|
| | A generator for all the interested files with relative pathes.
|
| | """
|
| |
|
| | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| | raise TypeError('"suffix" must be a string or tuple of strings')
|
| |
|
| | root = dir_path
|
| |
|
| | def _scandir(dir_path, suffix, recursive):
|
| | for entry in os.scandir(dir_path):
|
| | if not entry.name.startswith('.') and entry.is_file():
|
| | if full_path:
|
| | return_path = entry.path
|
| | else:
|
| | return_path = osp.relpath(entry.path, root)
|
| |
|
| | if suffix is None:
|
| | yield return_path
|
| | elif return_path.endswith(suffix):
|
| | yield return_path
|
| | else:
|
| | if recursive:
|
| | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
| | else:
|
| | continue
|
| |
|
| | return _scandir(dir_path, suffix=suffix, recursive=recursive) |