| import sys | |
| import torch | |
| import yaml | |
| def load_yaml_config(path): | |
| with open(path) as f: | |
| config = yaml.full_load(f) | |
| return config | |
| def save_config_to_yaml(config, path): | |
| assert path.endswith('.yaml') | |
| with open(path, 'w') as f: | |
| f.write(yaml.dump(config)) | |
| f.close() | |
| def write_args(args, path): | |
| args_dict = dict((name, getattr(args, name)) for name in dir(args) | |
| if not name.startswith('_')) | |
| with open(path, 'a') as args_file: | |
| args_file.write('==> torch version: {}\n'.format(torch.__version__)) | |
| args_file.write( | |
| '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) | |
| args_file.write('==> Cmd:\n') | |
| args_file.write(str(sys.argv)) | |
| args_file.write('\n==> args:\n') | |
| for k, v in sorted(args_dict.items()): | |
| args_file.write(' %s: %s\n' % (str(k), str(v))) | |
| args_file.close() | |