| | import yaml |
| | import time |
| | import os |
| | from collections import OrderedDict |
| | from os import path as osp |
| | from basicsr.utils.misc import get_time_str |
| | import argparse |
| | import random |
| | import torch |
| | from collections import OrderedDict |
| |
|
| | from basicsr.utils import set_random_seed |
| | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only |
| |
|
| |
|
| | def ordered_yaml(): |
| | """Support OrderedDict for yaml. |
| | |
| | Returns: |
| | yaml Loader and Dumper. |
| | """ |
| | try: |
| | from yaml import CDumper as Dumper |
| | from yaml import CLoader as Loader |
| | except ImportError: |
| | from yaml import Dumper, Loader |
| |
|
| | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG |
| |
|
| | def dict_representer(dumper, data): |
| | return dumper.represent_dict(data.items()) |
| |
|
| | def dict_constructor(loader, node): |
| | return OrderedDict(loader.construct_pairs(node)) |
| |
|
| | Dumper.add_representer(OrderedDict, dict_representer) |
| | Loader.add_constructor(_mapping_tag, dict_constructor) |
| | return Loader, Dumper |
| |
|
| |
|
| | def yaml_load(f): |
| | """Load yaml file or string. |
| | |
| | Args: |
| | f (str): File path or a python string. |
| | |
| | Returns: |
| | dict: Loaded dict. |
| | """ |
| | if os.path.isfile(f): |
| | with open(f, 'r') as f: |
| | return yaml.load(f, Loader=ordered_yaml()[0]) |
| | else: |
| | return yaml.load(f, Loader=ordered_yaml()[0]) |
| |
|
| |
|
| | def dict2str(opt, indent_level=1): |
| | """dict to string for printing options. |
| | |
| | Args: |
| | opt (dict): Option dict. |
| | indent_level (int): Indent level. Default: 1. |
| | |
| | Return: |
| | (str): Option string for printing. |
| | """ |
| | msg = '\n' |
| | for k, v in opt.items(): |
| | if isinstance(v, dict): |
| | msg += ' ' * (indent_level * 2) + k + ':[' |
| | msg += dict2str(v, indent_level + 1) |
| | msg += ' ' * (indent_level * 2) + ']\n' |
| | else: |
| | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' |
| | return msg |
| |
|
| |
|
| | def _postprocess_yml_value(value): |
| | |
| | if value == '~' or value.lower() == 'none': |
| | return None |
| | |
| | if value.lower() == 'true': |
| | return True |
| | elif value.lower() == 'false': |
| | return False |
| | |
| | if value.startswith('!!float'): |
| | return float(value.replace('!!float', '')) |
| | |
| | if value.isdigit(): |
| | return int(value) |
| | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: |
| | return float(value) |
| | |
| | if value.startswith('['): |
| | return eval(value) |
| | |
| | return value |
| |
|
| |
|
| | def parse_options(root_path, is_train=True): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('-opt', type=str, required=True, |
| | help='Path to option YAML file.') |
| | parser.add_argument( |
| | '--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') |
| | parser.add_argument('--auto_resume', action='store_true') |
| | parser.add_argument('--debug', action='store_true') |
| | parser.add_argument('--local_rank', type=int, default=0) |
| | parser.add_argument( |
| | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') |
| | args = parser.parse_args() |
| |
|
| | |
| | opt = yaml_load(args.opt) |
| |
|
| | |
| | if args.launcher == 'none': |
| | opt['dist'] = False |
| | print('Disable distributed.', flush=True) |
| | else: |
| | opt['dist'] = True |
| | if args.launcher == 'slurm' and 'dist_params' in opt: |
| | init_dist(args.launcher, **opt['dist_params']) |
| | else: |
| | init_dist(args.launcher) |
| | opt['rank'], opt['world_size'] = get_dist_info() |
| |
|
| | |
| | seed = opt.get('manual_seed') |
| | if seed is None: |
| | seed = random.randint(1, 10000) |
| | opt['manual_seed'] = seed |
| | set_random_seed(seed + opt['rank']) |
| |
|
| | |
| | if args.force_yml is not None: |
| | for entry in args.force_yml: |
| | |
| | keys, value = entry.split('=') |
| | keys, value = keys.strip(), value.strip() |
| | value = _postprocess_yml_value(value) |
| | eval_str = 'opt' |
| | for key in keys.split(':'): |
| | eval_str += f'["{key}"]' |
| | eval_str += '=value' |
| | |
| | exec(eval_str) |
| |
|
| | opt['auto_resume'] = args.auto_resume |
| | opt['is_train'] = is_train |
| |
|
| | |
| | if args.debug and not opt['name'].startswith('debug'): |
| | opt['name'] = 'debug_' + opt['name'] |
| |
|
| | if opt['num_gpu'] == 'auto': |
| | opt['num_gpu'] = torch.cuda.device_count() |
| |
|
| | |
| | for phase, dataset in opt['datasets'].items(): |
| | |
| | phase = phase.split('_')[0] |
| | dataset['phase'] = phase |
| | if 'scale' in opt: |
| | dataset['scale'] = opt['scale'] |
| | if dataset.get('dataroot_gt') is not None: |
| | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) |
| | if dataset.get('dataroot_lq') is not None: |
| | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) |
| |
|
| | |
| | for key, val in opt['path'].items(): |
| | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): |
| | opt['path'][key] = osp.expanduser(val) |
| |
|
| | if is_train: |
| | experiments_root = opt['path'].get('experiments_root') |
| | if experiments_root is None: |
| | experiments_root = osp.join(root_path, 'experiments') |
| | experiments_root = osp.join(experiments_root, opt['name']) |
| |
|
| | opt['path']['experiments_root'] = experiments_root |
| | opt['path']['models'] = osp.join(experiments_root, 'models') |
| | opt['path']['training_states'] = osp.join( |
| | experiments_root, 'training_states') |
| | opt['path']['log'] = experiments_root |
| | opt['path']['visualization'] = osp.join( |
| | experiments_root, 'visualization') |
| |
|
| | |
| | if 'debug' in opt['name']: |
| | if 'val' in opt: |
| | opt['val']['val_freq'] = 8 |
| | opt['logger']['print_freq'] = 1 |
| | opt['logger']['save_checkpoint_freq'] = 8 |
| | else: |
| | results_root = opt['path'].get('results_root') |
| | if results_root is None: |
| | results_root = osp.join(root_path, 'results') |
| | results_root = osp.join(results_root, opt['name']) |
| |
|
| | opt['path']['results_root'] = results_root |
| | opt['path']['log'] = results_root |
| | opt['path']['visualization'] = osp.join(results_root, 'visualization') |
| |
|
| | return opt, args |
| |
|
| |
|
| | def parse(opt_path, root_path, is_train=True): |
| | """Parse option file. |
| | |
| | Args: |
| | opt_path (str): Option file path. |
| | is_train (str): Indicate whether in training or not. Default: True. |
| | |
| | Returns: |
| | (dict): Options. |
| | """ |
| | with open(opt_path, mode='r') as f: |
| | Loader, _ = ordered_yaml() |
| | opt = yaml.load(f, Loader=Loader) |
| |
|
| | opt['is_train'] = is_train |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | for phase, dataset in opt['datasets'].items(): |
| | |
| | phase = phase.split('_')[0] |
| | dataset['phase'] = phase |
| | if 'scale' in opt: |
| | dataset['scale'] = opt['scale'] |
| | if dataset.get('dataroot_gt') is not None: |
| | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) |
| | if dataset.get('dataroot_lq') is not None: |
| | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) |
| |
|
| | |
| | for key, val in opt['path'].items(): |
| | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): |
| | opt['path'][key] = osp.expanduser(val) |
| |
|
| | if is_train: |
| | experiments_root = osp.join(root_path, 'experiments', opt['name']) |
| | opt['path']['experiments_root'] = experiments_root |
| | opt['path']['models'] = osp.join(experiments_root, 'models') |
| | opt['path']['training_states'] = osp.join( |
| | experiments_root, 'training_states') |
| | opt['path']['log'] = experiments_root |
| | opt['path']['visualization'] = osp.join( |
| | experiments_root, 'visualization') |
| |
|
| | |
| | if 'debug' in opt['name']: |
| | if 'val' in opt: |
| | opt['val']['val_freq'] = 8 |
| | opt['logger']['print_freq'] = 1 |
| | opt['logger']['save_checkpoint_freq'] = 8 |
| |
|
| | else: |
| | results_root = osp.join(root_path, 'results', opt['name']) |
| | opt['path']['results_root'] = results_root |
| | opt['path']['log'] = results_root |
| | opt['path']['visualization'] = osp.join(results_root, 'visualization') |
| |
|
| | return opt |
| |
|
| |
|
| | def dict2str(opt, indent_level=1): |
| | """dict to string for printing options. |
| | |
| | Args: |
| | opt (dict): Option dict. |
| | indent_level (int): Indent level. Default: 1. |
| | |
| | Return: |
| | (str): Option string for printing. |
| | """ |
| | msg = '\n' |
| | for k, v in opt.items(): |
| | if isinstance(v, dict): |
| | msg += ' ' * (indent_level * 2) + k + ':[' |
| | msg += dict2str(v, indent_level + 1) |
| | msg += ' ' * (indent_level * 2) + ']\n' |
| | else: |
| | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' |
| | return msg |
| |
|