Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import torch | |
| from solver_2 import Solver | |
| from data_loader import get_loader | |
| from hparams_autopst import hparams, hparams_debug_string | |
| def str2bool(v): | |
| return v.lower() in ('true') | |
| def main(config): | |
| # Data loader | |
| data_loader = get_loader(hparams) | |
| # Solver for training | |
| solver = Solver(data_loader, config, hparams) | |
| solver.train() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| # Training configuration. | |
| parser.add_argument('--num_iters', type=int, default=1000000) | |
| # Miscellaneous. | |
| parser.add_argument('--device_id', type=int, default=0) | |
| # Step size. | |
| parser.add_argument('--log_step', type=int, default=10) | |
| config = parser.parse_args() | |
| print(config) | |
| print(hparams_debug_string()) | |
| main(config) |