Spaces:
Build error
Build error
| import numpy as np | |
| import os, sys, time | |
| import torch | |
| import random | |
| import string | |
| import yaml | |
| import utils.util as util | |
| import time | |
| from utils.util import EasyDict as edict | |
| # torch.backends.cudnn.enabled = False | |
| # torch.backends.cudnn.benchmark = False | |
| # torch.backends.cudnn.deterministic = True | |
| def parse_arguments(args): | |
| # parse from command line (syntax: --key1.key2.key3=value) | |
| opt_cmd = {} | |
| for arg in args: | |
| assert(arg.startswith("--")) | |
| if "=" not in arg[2:]: # --key means key=True, --key! means key=False | |
| key_str, value = (arg[2:-1], "false") if arg[-1]=="!" else (arg[2:], "true") | |
| else: | |
| key_str, value = arg[2:].split("=") | |
| keys_sub = key_str.split(".") | |
| opt_sub = opt_cmd | |
| for k in keys_sub[:-1]: | |
| if k not in opt_sub: opt_sub[k] = {} | |
| opt_sub = opt_sub[k] | |
| # if opt_cmd['key1']['key2']['key3'] already exist for key1.key2.key3, print key3 as error msg | |
| assert keys_sub[-1] not in opt_sub, keys_sub[-1] | |
| opt_sub[keys_sub[-1]] = yaml.safe_load(value) | |
| opt_cmd = edict(opt_cmd) | |
| return opt_cmd | |
| def set(opt_cmd={}, verbose=True, safe_check=True): | |
| print("setting configurations...") | |
| fname = opt_cmd.yaml # load from yaml file | |
| opt_base = load_options(fname) | |
| # override with command line arguments | |
| opt = override_options(opt_base, opt_cmd, key_stack=[], safe_check=safe_check) | |
| process_options(opt) | |
| if verbose: | |
| def print_options(opt, level=0): | |
| for key, value in sorted(opt.items()): | |
| if isinstance(value, (dict, edict)): | |
| print(" "*level+"* "+key+":") | |
| print_options(value, level+1) | |
| else: | |
| print(" "*level+"* "+key+":", value) | |
| print_options(opt) | |
| return opt | |
| def load_options(fname): | |
| with open(fname) as file: | |
| opt = edict(yaml.safe_load(file)) | |
| if "_parent_" in opt: | |
| # load parent yaml file(s) as base options | |
| parent_fnames = opt.pop("_parent_") | |
| if type(parent_fnames) is str: | |
| parent_fnames = [parent_fnames] | |
| for parent_fname in parent_fnames: | |
| opt_parent = load_options(parent_fname) | |
| opt_parent = override_options(opt_parent, opt, key_stack=[]) | |
| opt = opt_parent | |
| print("loading {}...".format(fname)) | |
| return opt | |
| def override_options(opt, opt_over, key_stack=None, safe_check=False): | |
| for key, value in opt_over.items(): | |
| if isinstance(value, dict): | |
| # parse child options (until leaf nodes are reached) | |
| opt[key] = override_options(opt.get(key, dict()), value, key_stack=key_stack+[key], safe_check=safe_check) | |
| else: | |
| # ensure command line argument to override is also in yaml file | |
| if safe_check and key not in opt: | |
| add_new = None | |
| while add_new not in ["y", "n"]: | |
| key_str = ".".join(key_stack+[key]) | |
| add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) | |
| if add_new=="n": | |
| print("safe exiting...") | |
| exit() | |
| opt[key] = value | |
| return opt | |
| def process_options(opt): | |
| # set seed | |
| if opt.seed is not None: | |
| random.seed(opt.seed) | |
| np.random.seed(opt.seed) | |
| torch.manual_seed(opt.seed) | |
| torch.cuda.manual_seed_all(opt.seed) | |
| else: | |
| # create random string as run ID | |
| randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) | |
| opt.name += "_{}".format(randkey) | |
| # other default options | |
| opt.output_path = "{0}/{1}/{2}".format(opt.output_root, opt.group, opt.name) | |
| os.makedirs(opt.output_path, exist_ok=True) | |
| opt.H, opt.W = opt.image_size | |
| if opt.freq.eval is None: | |
| opt.freq.eval = max(opt.max_epoch // 20, 1) | |
| if 'loss_weight' in opt: | |
| opt.get_depth = False | |
| opt.get_normal = False | |
| def save_options_file(opt): | |
| opt_fname = "{}/options.yaml".format(opt.output_path) | |
| if os.path.isfile(opt_fname): | |
| with open(opt_fname) as file: | |
| opt_old = yaml.safe_load(file) | |
| if opt!=opt_old: | |
| # prompt if options are not identical | |
| opt_new_fname = "{}/options_temp.yaml".format(opt.output_path) | |
| with open(opt_new_fname, "w") as file: | |
| yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4) | |
| print("existing options file found (different from current one)...") | |
| os.system("diff {} {}".format(opt_fname, opt_new_fname)) | |
| os.system("rm {}".format(opt_new_fname)) | |
| if not opt.debug: | |
| print("please cancel within 10 seconds if you do not want to override...") | |
| time.sleep(10) | |
| else: print("existing options file found (identical)") | |
| else: print("(creating new options file...)") | |
| with open(opt_fname, "w") as file: | |
| yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4) | |