Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| from omegaconf import OmegaConf | |
| from trainers import trainers_dict | |
| def make_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', default='configs/_.yaml') | |
| parser.add_argument('--name', '-n', default=None) | |
| parser.add_argument('--tag', '-t', default=None) | |
| parser.add_argument('--resume', '-r', action='store_true') | |
| parser.add_argument('--force-replace', '-f', action='store_true') | |
| parser.add_argument('--comet', '-c', action='store_true', help='Enable Comet ML logging') | |
| parser.add_argument('--save-root', default='save') | |
| parser.add_argument('--eval-only', action='store_true') | |
| args = parser.parse_args() | |
| return args | |
| def parse_config(config): | |
| if config.get('__base__') is not None: | |
| filenames = config.pop('__base__') | |
| if isinstance(filenames, str): | |
| filenames = [filenames] | |
| base_config = OmegaConf.merge(*[ | |
| parse_config(OmegaConf.load(_)) | |
| for _ in filenames | |
| ]) | |
| config = OmegaConf.merge(base_config, config) | |
| return config | |
| def make_env(args): | |
| env = dict() | |
| if args.name is None: | |
| exp_name = os.path.splitext(os.path.basename(args.config))[0] | |
| else: | |
| exp_name = args.name | |
| if args.tag is not None: | |
| exp_name += '_' + args.tag | |
| env['exp_name'] = exp_name | |
| env['save_dir'] = os.path.join(args.save_root, exp_name) | |
| env['comet'] = args.comet | |
| env['resume'] = args.resume | |
| env['force_replace'] = args.force_replace | |
| return env | |
| if __name__ == '__main__': | |
| args = make_args() | |
| env = make_env(args) | |
| config = parse_config(OmegaConf.load(args.config)) | |
| trainer = trainers_dict[config.trainer](env, config) | |
| trainer.run(eval_only=args.eval_only) | |