import gym import tensorflow as tf import numpy as np from functools import partial from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common.tf_util import make_session from baselines.ppo2.ppo2 import learn from baselines.ppo2.microbatched_model import MicrobatchedModel def test_microbatches(): def env_fn(): env = gym.make('CartPole-v0') env.seed(0) return env learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0) env_ref = DummyVecEnv([env_fn]) sess_ref = make_session(make_default=True, graph=tf.Graph()) learn_fn(env=env_ref) vars_ref = {v.name: sess_ref.run(v) for v in tf.compat.v1.trainable_variables()} env_test = DummyVecEnv([env_fn]) sess_test = make_session(make_default=True, graph=tf.Graph()) learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2)) # learn_fn(env=env_test) vars_test = {v.name: sess_test.run(v) for v in tf.compat.v1.trainable_variables()} for v in vars_ref: np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=3e-3) if __name__ == '__main__': test_microbatches()