| | import os |
| | import re |
| | import random |
| | import time |
| | import torch |
| | import numpy as np |
| | from os import path as osp |
| |
|
| | from .dist_util import master_only |
| | from .logger import get_root_logger |
| |
|
| | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$", |
| | torch.__version__)[0][:3])] >= [1, 12, 0] |
| |
|
| |
|
| | def gpu_is_available(): |
| | if IS_HIGH_VERSION: |
| | if torch.backends.mps.is_available(): |
| | return True |
| | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False |
| |
|
| |
|
| | def get_device(gpu_id=None): |
| | if gpu_id is None: |
| | gpu_str = '' |
| | elif isinstance(gpu_id, int): |
| | gpu_str = f':{gpu_id}' |
| | else: |
| | raise TypeError('Input should be int value.') |
| |
|
| | if IS_HIGH_VERSION: |
| | if torch.backends.mps.is_available(): |
| | return torch.device('mps'+gpu_str) |
| | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') |
| |
|
| |
|
| | def set_random_seed(seed): |
| | """Set random seeds.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | def get_time_str(): |
| | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
| |
|
| |
|
| | def mkdir_and_rename(path): |
| | """mkdirs. If path exists, rename it with timestamp and create a new one. |
| | |
| | Args: |
| | path (str): Folder path. |
| | """ |
| | if osp.exists(path): |
| | new_name = path + '_archived_' + get_time_str() |
| | print(f'Path already exists. Rename it to {new_name}', flush=True) |
| | os.rename(path, new_name) |
| | os.makedirs(path, exist_ok=True) |
| |
|
| |
|
| | @master_only |
| | def make_exp_dirs(opt): |
| | """Make dirs for experiments.""" |
| | path_opt = opt['path'].copy() |
| | if opt['is_train']: |
| | mkdir_and_rename(path_opt.pop('experiments_root')) |
| | else: |
| | mkdir_and_rename(path_opt.pop('results_root')) |
| | for key, path in path_opt.items(): |
| | if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key) and ('param_key' not in key): |
| | os.makedirs(path, exist_ok=True) |
| |
|
| |
|
| | def scandir(dir_path, suffix=None, recursive=False, full_path=False): |
| | """Scan a directory to find the interested files. |
| | |
| | Args: |
| | dir_path (str): Path of the directory. |
| | suffix (str | tuple(str), optional): File suffix that we are |
| | interested in. Default: None. |
| | recursive (bool, optional): If set to True, recursively scan the |
| | directory. Default: False. |
| | full_path (bool, optional): If set to True, include the dir_path. |
| | Default: False. |
| | |
| | Returns: |
| | A generator for all the interested files with relative pathes. |
| | """ |
| |
|
| | if (suffix is not None) and not isinstance(suffix, (str, tuple)): |
| | raise TypeError('"suffix" must be a string or tuple of strings') |
| |
|
| | root = dir_path |
| |
|
| | def _scandir(dir_path, suffix, recursive): |
| | for entry in os.scandir(dir_path): |
| | if not entry.name.startswith('.') and entry.is_file(): |
| | if full_path: |
| | return_path = entry.path |
| | else: |
| | return_path = osp.relpath(entry.path, root) |
| |
|
| | if suffix is None: |
| | yield return_path |
| | elif return_path.endswith(suffix): |
| | yield return_path |
| | else: |
| | if recursive: |
| | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) |
| | else: |
| | continue |
| |
|
| | return _scandir(dir_path, suffix=suffix, recursive=recursive) |
| |
|
| |
|
| | def check_resume(opt, resume_iter): |
| | """Check resume states and pretrain_network paths. |
| | |
| | Args: |
| | opt (dict): Options. |
| | resume_iter (int): Resume iteration. |
| | """ |
| | logger = get_root_logger() |
| | if opt['path']['resume_state']: |
| | |
| | networks = [key for key in opt.keys() if key.startswith('network_')] |
| | flag_pretrain = False |
| | for network in networks: |
| | if opt['path'].get(f'pretrain_{network}') is not None: |
| | flag_pretrain = True |
| | if flag_pretrain: |
| | logger.warning( |
| | 'pretrain_network path will be ignored during resuming.') |
| | |
| | for network in networks: |
| | name = f'pretrain_{network}' |
| | basename = network.replace('network_', '') |
| | if opt['path'].get('ignore_resume_networks') is None or (basename |
| | not in opt['path']['ignore_resume_networks']): |
| | opt['path'][name] = osp.join( |
| | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') |
| | logger.info(f"Set {name} to {opt['path'][name]}") |
| |
|
| |
|
| | def sizeof_fmt(size, suffix='B'): |
| | """Get human readable file size. |
| | |
| | Args: |
| | size (int): File size. |
| | suffix (str): Suffix. Default: 'B'. |
| | |
| | Return: |
| | str: Formated file siz. |
| | """ |
| | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: |
| | if abs(size) < 1024.0: |
| | return f'{size:3.1f} {unit}{suffix}' |
| | size /= 1024.0 |
| | return f'{size:3.1f} Y{suffix}' |
| |
|