Spaces:
Sleeping
Sleeping
| # | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| from rlstructures import logging | |
| from rlstructures.env_wrappers import GymEnv, GymEnvInf | |
| from rlstructures.tools import weight_init | |
| import torch.nn as nn | |
| import copy | |
| import torch | |
| import time | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from tutorial.tutorial_recurrent_policy.agent import RecurrentAgent | |
| from tutorial.tutorial_recurrent_policy.a2c import A2C | |
| import gym | |
| from gym.wrappers import TimeLimit | |
| # We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes | |
| def create_gym_env(env_name): | |
| return gym.make(env_name) | |
| def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None): | |
| envs = [] | |
| for k in range(n_envs): | |
| e = create_gym_env(env_name) | |
| e = TimeLimit(e, max_episode_steps=max_episode_steps) | |
| envs.append(e) | |
| return GymEnv(envs, seed) | |
| def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None): | |
| envs = [] | |
| for k in range(n_envs): | |
| e = create_gym_env(env_name) | |
| e = TimeLimit(e, max_episode_steps=max_episode_steps) | |
| envs.append(e) | |
| return GymEnvInf(envs, seed) | |
| def create_agent(model, n_actions=1): | |
| return RecurrentAgent(model=model, n_actions=n_actions) | |
| class Experiment(A2C): | |
| def __init__(self, config, create_env, create_train_env, create_agent): | |
| super().__init__(config, create_env, create_train_env, create_agent) | |
| if __name__ == "__main__": | |
| # We use spawn mode such that most of the environment will run in multiple processes | |
| import torch.multiprocessing as mp | |
| mp.set_start_method("spawn") | |
| config = { | |
| "env_name": "CartPole-v0", | |
| "a2c_timesteps": 3, | |
| "n_envs": 4, | |
| "max_episode_steps": 100, | |
| "env_seed": 42, | |
| "n_threads": 4, | |
| "n_evaluation_threads": 2, | |
| "n_evaluation_episodes": 256, | |
| "time_limit": 3600, | |
| "lr": 0.001, | |
| "discount_factor": 0.95, | |
| "critic_coef": 1.0, | |
| "entropy_coef": 0.01, | |
| "a2c_coef": 1.0, | |
| "logdir": "./results", | |
| } | |
| exp = Experiment(config, create_env, create_train_env, create_agent) | |
| exp.run() | |