| import sys |
|
|
| import torch |
| import yaml |
|
|
|
|
| def load_yaml_config(path): |
| with open(path) as f: |
| config = yaml.full_load(f) |
| return config |
|
|
|
|
| def save_config_to_yaml(config, path): |
| assert path.endswith(".yaml") |
| with open(path, "w") as f: |
| f.write(yaml.dump(config)) |
| f.close() |
|
|
|
|
| def write_args(args, path): |
| args_dict = dict( |
| (name, getattr(args, name)) for name in dir(args) if not name.startswith("_") |
| ) |
| with open(path, "a") as args_file: |
| args_file.write("==> torch version: {}\n".format(torch.__version__)) |
| args_file.write( |
| "==> cudnn version: {}\n".format(torch.backends.cudnn.version()) |
| ) |
| args_file.write("==> Cmd:\n") |
| args_file.write(str(sys.argv)) |
| args_file.write("\n==> args:\n") |
| for k, v in sorted(args_dict.items()): |
| args_file.write(" %s: %s\n" % (str(k), str(v))) |
| args_file.close() |
|
|