codewraith / data /source_files /clean /3ec8fd0b4a93.py
slenk's picture
Upload folder using huggingface_hub
eeef81e verified
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from rlstructures import logging
from rlstructures.env_wrappers import GymEnv, GymEnvInf
from rlstructures.tools import weight_init
import torch.nn as nn
import copy
import torch
import time
import numpy as np
import torch.nn.functional as F
from tutorial.tutorial_recurrent_policy.agent import RecurrentAgent
from tutorial.tutorial_recurrent_policy.a2c import A2C
import gym
from gym.wrappers import TimeLimit
# We write the 'create_env' and 'create_agent' function in the main file to allow these functions to be used with pickle when creating the batcher processes
def create_gym_env(env_name):
return gym.make(env_name)
def create_env(n_envs, env_name=None, max_episode_steps=None, seed=None):
envs = []
for k in range(n_envs):
e = create_gym_env(env_name)
e = TimeLimit(e, max_episode_steps=max_episode_steps)
envs.append(e)
return GymEnv(envs, seed)
def create_train_env(n_envs, env_name=None, max_episode_steps=None, seed=None):
envs = []
for k in range(n_envs):
e = create_gym_env(env_name)
e = TimeLimit(e, max_episode_steps=max_episode_steps)
envs.append(e)
return GymEnvInf(envs, seed)
def create_agent(model, n_actions=1):
return RecurrentAgent(model=model, n_actions=n_actions)
class Experiment(A2C):
def __init__(self, config, create_env, create_train_env, create_agent):
super().__init__(config, create_env, create_train_env, create_agent)
if __name__ == "__main__":
# We use spawn mode such that most of the environment will run in multiple processes
import torch.multiprocessing as mp
mp.set_start_method("spawn")
config = {
"env_name": "CartPole-v0",
"a2c_timesteps": 3,
"n_envs": 4,
"max_episode_steps": 100,
"env_seed": 42,
"n_threads": 4,
"n_evaluation_threads": 2,
"n_evaluation_episodes": 256,
"time_limit": 3600,
"lr": 0.001,
"discount_factor": 0.95,
"critic_coef": 1.0,
"entropy_coef": 0.01,
"a2c_coef": 1.0,
"logdir": "./results",
}
exp = Experiment(config, create_env, create_train_env, create_agent)
exp.run()