"""Wrapper to train/test models.""" import os import pytz from datetime import datetime from utils.config import Config def update_config(cfg, exp_name='', job_name=''): """ Update some configs. Args: cfg: from submit_config.config """ tz_NY = pytz.timezone('America/New_York') if 'lemon' in cfg.out_root: cfg.out_dir = os.path.join(cfg.root_dir_lemon, cfg.out_dir) else: cfg.out_dir = os.path.join(cfg.root_dir_yogurt_out, cfg.out_dir) cfg.vis_itr = int(cfg.vis_itr) if cfg.eval_only: cfg.out_dir = os.path.join(cfg.out_dir, 'Test', exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M")) else: cfg.out_dir = os.path.join(cfg.out_dir, exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M")) return cfg def merge_and_update_from_dict(cfg, dct): """ (Compatible for submitit's Dict as attribute trick) Merge dict as dict() to config as CfgNode(). Args: cfg: dict dct: dict """ if dct is not None: for key, value in dct.items(): if isinstance(value, dict): if key in cfg.keys(): sub_cfgnode = cfg[key] else: sub_cfgnode = dict() cfg.__setattr__(key, sub_cfgnode) sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value) else: cfg[key] = value return cfg def load_config(default_cfg_file, add_cfg_files = [], cfg_dir = ''): cfg = Config(default_cfg_file) for cfg_file in add_cfg_files: if os.path.isabs(cfg_file): add_cfg = Config(cfg_file) else: assert os.path.isabs(cfg_dir) if not cfg_file.endswith('.yaml'): cfg_file += '.yaml' add_cfg = Config(os.path.join(cfg_dir, cfg_file)) cfg = merge_and_update_from_dict(cfg, add_cfg) if "exp_name" in cfg: return update_config(cfg, exp_name=cfg["exp_name"], job_name = cfg["job_name"]) else: return cfg