semantic_rl / baselines /ppo2 /test_microbatches.py
leonepson's picture
Upload 254 files
5960497 verified
import gym
import tensorflow as tf
import numpy as np
from functools import partial
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.tf_util import make_session
from baselines.ppo2.ppo2 import learn
from baselines.ppo2.microbatched_model import MicrobatchedModel
def test_microbatches():
def env_fn():
env = gym.make('CartPole-v0')
env.seed(0)
return env
learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0)
env_ref = DummyVecEnv([env_fn])
sess_ref = make_session(make_default=True, graph=tf.Graph())
learn_fn(env=env_ref)
vars_ref = {v.name: sess_ref.run(v) for v in tf.compat.v1.trainable_variables()}
env_test = DummyVecEnv([env_fn])
sess_test = make_session(make_default=True, graph=tf.Graph())
learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2))
# learn_fn(env=env_test)
vars_test = {v.name: sess_test.run(v) for v in tf.compat.v1.trainable_variables()}
for v in vars_ref:
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=3e-3)
if __name__ == '__main__':
test_microbatches()