Spaces:
Running
Running
| import os | |
| from collections import OrderedDict | |
| from datetime import datetime | |
| import json | |
| import re | |
| import glob | |
| ''' | |
| # -------------------------------------------- | |
| # Kai Zhang (github: https://github.com/cszn) | |
| # 03/Mar/2019 | |
| # -------------------------------------------- | |
| # https://github.com/xinntao/BasicSR | |
| # -------------------------------------------- | |
| ''' | |
| def get_timestamp(): | |
| return datetime.now().strftime('_%y%m%d_%H%M%S') | |
| def parse(opt_path, is_train=True): | |
| # ---------------------------------------- | |
| # remove comments starting with '//' | |
| # ---------------------------------------- | |
| json_str = '' | |
| with open(opt_path, 'r') as f: | |
| for line in f: | |
| line = line.split('//')[0] + '\n' | |
| json_str += line | |
| # ---------------------------------------- | |
| # initialize opt | |
| # ---------------------------------------- | |
| opt = json.loads(json_str, object_pairs_hook=OrderedDict) | |
| opt['opt_path'] = opt_path | |
| opt['is_train'] = is_train | |
| # ---------------------------------------- | |
| # set default | |
| # ---------------------------------------- | |
| if 'merge_bn' not in opt: | |
| opt['merge_bn'] = False | |
| opt['merge_bn_startpoint'] = -1 | |
| if 'scale' not in opt: | |
| opt['scale'] = 1 | |
| # ---------------------------------------- | |
| # datasets | |
| # ---------------------------------------- | |
| for phase, dataset in opt['datasets'].items(): | |
| phase = phase.split('_')[0] | |
| dataset['phase'] = phase | |
| dataset['scale'] = opt['scale'] # broadcast | |
| dataset['n_channels'] = opt['n_channels'] # broadcast | |
| if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None: | |
| dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H']) | |
| if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None: | |
| dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L']) | |
| # ---------------------------------------- | |
| # path | |
| # ---------------------------------------- | |
| for key, path in opt['path'].items(): | |
| if path and key in opt['path']: | |
| opt['path'][key] = os.path.expanduser(path) | |
| path_task = os.path.join(opt['path']['root'], opt['task']) | |
| opt['path']['task'] = path_task | |
| opt['path']['log'] = path_task | |
| opt['path']['options'] = os.path.join(path_task, 'options') | |
| if is_train: | |
| opt['path']['models'] = os.path.join(path_task, 'models') | |
| opt['path']['images'] = os.path.join(path_task, 'images') | |
| else: # test | |
| opt['path']['images'] = os.path.join(path_task, 'test_images') | |
| # ---------------------------------------- | |
| # network | |
| # ---------------------------------------- | |
| opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 | |
| # ---------------------------------------- | |
| # GPU devices | |
| # ---------------------------------------- | |
| gpu_list = ','.join(str(x) for x in opt['gpu_ids']) | |
| os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list | |
| print('export CUDA_VISIBLE_DEVICES=' + gpu_list) | |
| # ---------------------------------------- | |
| # default setting for distributeddataparallel | |
| # ---------------------------------------- | |
| if 'find_unused_parameters' not in opt: | |
| opt['find_unused_parameters'] = True | |
| if 'use_static_graph' not in opt: | |
| opt['use_static_graph'] = False | |
| if 'dist' not in opt: | |
| opt['dist'] = False | |
| opt['num_gpu'] = len(opt['gpu_ids']) | |
| print('number of GPUs is: ' + str(opt['num_gpu'])) | |
| # ---------------------------------------- | |
| # default setting for perceptual loss | |
| # ---------------------------------------- | |
| if 'F_feature_layer' not in opt['train']: | |
| opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] | |
| if 'F_weights' not in opt['train']: | |
| opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] | |
| if 'F_lossfn_type' not in opt['train']: | |
| opt['train']['F_lossfn_type'] = 'l1' | |
| if 'F_use_input_norm' not in opt['train']: | |
| opt['train']['F_use_input_norm'] = True | |
| if 'F_use_range_norm' not in opt['train']: | |
| opt['train']['F_use_range_norm'] = False | |
| # ---------------------------------------- | |
| # default setting for optimizer | |
| # ---------------------------------------- | |
| if 'G_optimizer_type' not in opt['train']: | |
| opt['train']['G_optimizer_type'] = "adam" | |
| if 'G_optimizer_betas' not in opt['train']: | |
| opt['train']['G_optimizer_betas'] = [0.9,0.999] | |
| if 'G_scheduler_restart_weights' not in opt['train']: | |
| opt['train']['G_scheduler_restart_weights'] = 1 | |
| if 'G_optimizer_wd' not in opt['train']: | |
| opt['train']['G_optimizer_wd'] = 0 | |
| if 'G_optimizer_reuse' not in opt['train']: | |
| opt['train']['G_optimizer_reuse'] = False | |
| if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: | |
| opt['train']['D_optimizer_reuse'] = False | |
| # ---------------------------------------- | |
| # default setting of strict for model loading | |
| # ---------------------------------------- | |
| if 'G_param_strict' not in opt['train']: | |
| opt['train']['G_param_strict'] = True | |
| if 'netD' in opt and 'D_param_strict' not in opt['path']: | |
| opt['train']['D_param_strict'] = True | |
| if 'E_param_strict' not in opt['path']: | |
| opt['train']['E_param_strict'] = True | |
| # ---------------------------------------- | |
| # Exponential Moving Average | |
| # ---------------------------------------- | |
| if 'E_decay' not in opt['train']: | |
| opt['train']['E_decay'] = 0 | |
| # ---------------------------------------- | |
| # default setting for discriminator | |
| # ---------------------------------------- | |
| if 'netD' in opt: | |
| if 'net_type' not in opt['netD']: | |
| opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet | |
| if 'in_nc' not in opt['netD']: | |
| opt['netD']['in_nc'] = 3 | |
| if 'base_nc' not in opt['netD']: | |
| opt['netD']['base_nc'] = 64 | |
| if 'n_layers' not in opt['netD']: | |
| opt['netD']['n_layers'] = 3 | |
| if 'norm_type' not in opt['netD']: | |
| opt['netD']['norm_type'] = 'spectral' | |
| return opt | |
| def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): | |
| """ | |
| Args: | |
| save_dir: model folder | |
| net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' | |
| pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path | |
| Return: | |
| init_iter: iteration number | |
| init_path: model path | |
| """ | |
| file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) | |
| if file_list: | |
| iter_exist = [] | |
| for file_ in file_list: | |
| iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) | |
| iter_exist.append(int(iter_current[0])) | |
| init_iter = max(iter_exist) | |
| init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) | |
| else: | |
| init_iter = 0 | |
| init_path = pretrained_path | |
| return init_iter, init_path | |
| ''' | |
| # -------------------------------------------- | |
| # convert the opt into json file | |
| # -------------------------------------------- | |
| ''' | |
| def save(opt): | |
| opt_path = opt['opt_path'] | |
| opt_path_copy = opt['path']['options'] | |
| dirname, filename_ext = os.path.split(opt_path) | |
| filename, ext = os.path.splitext(filename_ext) | |
| dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) | |
| with open(dump_path, 'w') as dump_file: | |
| json.dump(opt, dump_file, indent=2) | |
| ''' | |
| # -------------------------------------------- | |
| # dict to string for logger | |
| # -------------------------------------------- | |
| ''' | |
| def dict2str(opt, indent_l=1): | |
| msg = '' | |
| for k, v in opt.items(): | |
| if isinstance(v, dict): | |
| msg += ' ' * (indent_l * 2) + k + ':[\n' | |
| msg += dict2str(v, indent_l + 1) | |
| msg += ' ' * (indent_l * 2) + ']\n' | |
| else: | |
| msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' | |
| return msg | |
| ''' | |
| # -------------------------------------------- | |
| # convert OrderedDict to NoneDict, | |
| # return None for missing key | |
| # -------------------------------------------- | |
| ''' | |
| def dict_to_nonedict(opt): | |
| if isinstance(opt, dict): | |
| new_opt = dict() | |
| for key, sub_opt in opt.items(): | |
| new_opt[key] = dict_to_nonedict(sub_opt) | |
| return NoneDict(**new_opt) | |
| elif isinstance(opt, list): | |
| return [dict_to_nonedict(sub_opt) for sub_opt in opt] | |
| else: | |
| return opt | |
| class NoneDict(dict): | |
| def __missing__(self, key): | |
| return None | |