File size: 1,793 Bytes
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
067b9b6
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
067b9b6
37163a6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import os

from omegaconf import OmegaConf

from trainers import trainers_dict


def make_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='configs/_.yaml')
    parser.add_argument('--name', '-n', default=None)
    parser.add_argument('--tag', '-t', default=None)
    parser.add_argument('--resume', '-r', action='store_true')
    parser.add_argument('--force-replace', '-f', action='store_true')
    parser.add_argument('--comet', '-c', action='store_true', help='Enable Comet ML logging')
    parser.add_argument('--save-root', default='save')
    parser.add_argument('--eval-only', action='store_true')
    args = parser.parse_args()
    return args


def parse_config(config):
    if config.get('__base__') is not None:
        filenames = config.pop('__base__')
        if isinstance(filenames, str):
            filenames = [filenames]
        base_config = OmegaConf.merge(*[
            parse_config(OmegaConf.load(_))
            for _ in filenames
        ])
        config = OmegaConf.merge(base_config, config)
    return config


def make_env(args):
    env = dict()
    
    if args.name is None:
        exp_name = os.path.splitext(os.path.basename(args.config))[0]
    else:
        exp_name = args.name
    if args.tag is not None:
        exp_name += '_' + args.tag
    env['exp_name'] = exp_name
    
    env['save_dir'] = os.path.join(args.save_root, exp_name)
    env['comet'] = args.comet
    env['resume'] = args.resume
    env['force_replace'] = args.force_replace
    return env


if __name__ == '__main__':
    args = make_args()
    env = make_env(args)
    config = parse_config(OmegaConf.load(args.config))
    trainer = trainers_dict[config.trainer](env, config)
    trainer.run(eval_only=args.eval_only)