Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import multiprocessing | |
| import random | |
| # argument parsing | |
| parser = argparse.ArgumentParser(description="Run parameter tests for MC agent") | |
| parser.add_argument( | |
| "--env", | |
| type=str, | |
| default="FrozenLake-v1", | |
| help="environment to run", | |
| ) | |
| parser.add_argument( | |
| "--num_tests", | |
| type=int, | |
| default=10, | |
| help="number of tests to run for each parameter combination", | |
| ) | |
| parser.add_argument( | |
| "--wandb_project", | |
| type=str, | |
| default=None, | |
| help="wandb project name to log to", | |
| ) | |
| args = parser.parse_args() | |
| env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project | |
| agent = "MCAgent" | |
| vals_update_type = [ | |
| # "on_policy", | |
| "off_policy", | |
| ] # Note: Every visit takes too long due to these environment's reward structure | |
| # vals_gamma = [1.0, 0.98, 0.96, 0.94] | |
| vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5] | |
| vals_gamma = [1.0] | |
| # vals_epsilon = [0.5] | |
| vals_size = [8, 16, 32, 64] | |
| if env == "CliffWalking-v0": | |
| n_train_episodes = 2500 | |
| # max_steps = 200 | |
| elif env == "FrozenLake-v1": | |
| n_train_episodes = 25000 | |
| # max_steps = 200 | |
| elif env == "Taxi-v3": | |
| n_train_episodes = 10000 | |
| # max_steps = 500 | |
| else: | |
| raise ValueError(f"Unsupported environment: {env}") | |
| def run_test(args): | |
| command = f"python3 run.py --train --agent {agent} --env {env}" | |
| # command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}" | |
| command += f" --n_train_episodes {n_train_episodes}" | |
| for k, v in args.items(): | |
| command += f" --{k} {v}" | |
| if wandb_project is not None: | |
| command += f" --wandb_project {wandb_project}" | |
| command += " --no_save" | |
| os.system(command) | |
| with multiprocessing.Pool(8) as p: | |
| tests = [] | |
| for update_type in vals_update_type: | |
| for gamma in vals_gamma: | |
| for eps in vals_epsilon: | |
| if env == "FrozenLake-v1": | |
| for size in vals_size: | |
| tests.extend( | |
| { | |
| "gamma": gamma, | |
| "epsilon": eps, | |
| "update_type": update_type, | |
| "size": size, | |
| "run_name_suffix": i, | |
| } | |
| for i in range(num_tests) | |
| ) | |
| else: | |
| tests.extend( | |
| { | |
| "gamma": gamma, | |
| "epsilon": eps, | |
| "update_type": update_type, | |
| "run_name_suffix": i, | |
| } | |
| for i in range(num_tests) | |
| ) | |
| random.shuffle(tests) | |
| p.map(run_test, tests) | |