File size: 2,096 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Wrapper to train/test models."""

import os
import pytz
from datetime import datetime

from utils.config import Config

def update_config(cfg, exp_name='', job_name=''):
    """
    Update some configs.
    Args:
        cfg: <Config> from submit_config.config
    """
    tz_NY = pytz.timezone('America/New_York')

    if 'lemon' in cfg.out_root:
        cfg.out_dir = os.path.join(cfg.root_dir_lemon, cfg.out_dir) 
    else:
        cfg.out_dir = os.path.join(cfg.root_dir_yogurt_out, cfg.out_dir)

    cfg.vis_itr = int(cfg.vis_itr)


    if cfg.eval_only:
        cfg.out_dir = os.path.join(cfg.out_dir, 'Test', exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M"))
    else:
        cfg.out_dir = os.path.join(cfg.out_dir, exp_name, job_name, datetime.now(tz_NY).strftime("%m%d-%H%M"))
    return cfg


def merge_and_update_from_dict(cfg, dct):
    """
    (Compatible for submitit's Dict as attribute trick)
    Merge dict as dict() to config as CfgNode().
    Args:
        cfg: dict
        dct: dict
    """
    if dct is not None:
        for key, value in dct.items():
            if isinstance(value, dict):
                if key in cfg.keys():
                    sub_cfgnode = cfg[key]
                else:
                    sub_cfgnode = dict()
                    cfg.__setattr__(key, sub_cfgnode) 
                sub_cfgnode = merge_and_update_from_dict(sub_cfgnode, value)
            else:
                cfg[key] = value
    return cfg


def load_config(default_cfg_file, add_cfg_files = [], cfg_dir = ''):
    cfg = Config(default_cfg_file) 
    for cfg_file in add_cfg_files: 
        if os.path.isabs(cfg_file):
            add_cfg = Config(cfg_file)
        else:
            assert os.path.isabs(cfg_dir)
            if not cfg_file.endswith('.yaml'):
                cfg_file += '.yaml'
            add_cfg = Config(os.path.join(cfg_dir, cfg_file))
        cfg = merge_and_update_from_dict(cfg, add_cfg)
    if "exp_name" in cfg:
        return update_config(cfg, exp_name=cfg["exp_name"], job_name = cfg["job_name"])
    else:
        return cfg