| """Wrapper to train/test models.""" | |
| import os | |
| import pytz | |
| from datetime import datetime | |
| from utils.config import Config | |
| def update_config(cfg, exp_name='', job_name=''): | |
| """ | |
| Update some configs. | |
| Args: | |
| cfg: <Config> from submit_config.config | |
| """ | |
| tz_NY = pytz.timezone('America/New_York') | |
| if 'lemon' in cfg.out_root: | |
| cfg.out_dir = os.path.join(cfg.root_dir_lemon, cfg.out_dir) | |
| else: | |
| cfg.out_dir = os.path.join(cfg.root_dir_yogurt_out, cfg.out_dir) | |
| cfg.vis_itr = int(cfg.vis_itr) | |
| if cfg.eval_only: | |
| cfg.out_dir = os.path.join(cfg.out_dir, 'Test', exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M")) | |
| else: | |
| cfg.out_dir = os.path.join(cfg.out_dir, exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M")) | |
| return cfg | |
| def merge_and_update_from_dict(cfg, dct): | |
| """ | |
| (Compatible for submitit's Dict as attribute trick) | |
| Merge dict as dict() to config as CfgNode(). | |
| Args: | |
| cfg: dict | |
| dct: dict | |
| """ | |
| if dct is not None: | |
| for key, value in dct.items(): | |
| if isinstance(value, dict): | |
| if key in cfg.keys(): | |
| sub_cfgnode = cfg[key] | |
| else: | |
| sub_cfgnode = dict() | |
| cfg.__setattr__(key, sub_cfgnode) | |
| sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value) | |
| else: | |
| cfg[key] = value | |
| return cfg | |
| def load_config(default_cfg_file, add_cfg_files = [], cfg_dir = ''): | |
| cfg = Config(default_cfg_file) | |
| for cfg_file in add_cfg_files: | |
| if os.path.isabs(cfg_file): | |
| add_cfg = Config(cfg_file) | |
| else: | |
| assert os.path.isabs(cfg_dir) | |
| if not cfg_file.endswith('.yaml'): | |
| cfg_file += '.yaml' | |
| add_cfg = Config(os.path.join(cfg_dir, cfg_file)) | |
| cfg = merge_and_update_from_dict(cfg, add_cfg) | |
| if "exp_name" in cfg: | |
| return update_config(cfg, exp_name=cfg["exp_name"], job_name = cfg["job_name"]) | |
| else: | |
| return cfg | |