| import argparse | |
| import train_ode | |
| import train_resnet | |
| import train_cnf | |
| def main(args): | |
| if args.model == 'odenet': | |
| train_ode.train_and_evaluate(args.lr, args.n_epoch, args.batch_size, args.tol) | |
| elif args.model == 'resnet': | |
| train_resnet.train_and_evaluate(args.lr, args.n_epoch, args.batch_size) | |
| elif args.model == 'cnf': | |
| train_cnf.train(0.001, 1000, 512, 2, 32, 64, 0., 10., args.viz, args.sample_dataset) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='main.py') | |
| parser.add_argument("--model", type=str, choices=['odenet', 'resnet', 'cnf'], default="odenet", | |
| help="Type of model") | |
| parser.add_argument("--tol", type=float, default=1e-1, | |
| help="Error tolerance for ODE solver. This only works with odenet") | |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") | |
| parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch") | |
| parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch") | |
| parser.add_argument("--sample_dataset", type=str, choices=['circles', 'moons'], default="circles", | |
| help="Sample dataset") | |
| parser.add_argument("--viz", action='store_true') | |
| args = parser.parse_args() | |
| main(args) | |