primepake
add training flowvae
4f877a2
raw
history blame
1.79 kB
import argparse
import os
from omegaconf import OmegaConf
from trainers import trainers_dict
def make_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='configs/_.yaml')
parser.add_argument('--name', '-n', default=None)
parser.add_argument('--tag', '-t', default=None)
parser.add_argument('--resume', '-r', action='store_true')
parser.add_argument('--force-replace', '-f', action='store_true')
parser.add_argument('--comet', '-c', action='store_true', help='Enable Comet ML logging')
parser.add_argument('--save-root', default='save')
parser.add_argument('--eval-only', action='store_true')
args = parser.parse_args()
return args
def parse_config(config):
if config.get('__base__') is not None:
filenames = config.pop('__base__')
if isinstance(filenames, str):
filenames = [filenames]
base_config = OmegaConf.merge(*[
parse_config(OmegaConf.load(_))
for _ in filenames
])
config = OmegaConf.merge(base_config, config)
return config
def make_env(args):
env = dict()
if args.name is None:
exp_name = os.path.splitext(os.path.basename(args.config))[0]
else:
exp_name = args.name
if args.tag is not None:
exp_name += '_' + args.tag
env['exp_name'] = exp_name
env['save_dir'] = os.path.join(args.save_root, exp_name)
env['comet'] = args.comet
env['resume'] = args.resume
env['force_replace'] = args.force_replace
return env
if __name__ == '__main__':
args = make_args()
env = make_env(args)
config = parse_config(OmegaConf.load(args.config))
trainer = trainers_dict[config.trainer](env, config)
trainer.run(eval_only=args.eval_only)