File size: 1,172 Bytes
5960497 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import gym
import tensorflow as tf
import numpy as np
from functools import partial
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.tf_util import make_session
from baselines.ppo2.ppo2 import learn
from baselines.ppo2.microbatched_model import MicrobatchedModel
def test_microbatches():
def env_fn():
env = gym.make('CartPole-v0')
env.seed(0)
return env
learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0)
env_ref = DummyVecEnv([env_fn])
sess_ref = make_session(make_default=True, graph=tf.Graph())
learn_fn(env=env_ref)
vars_ref = {v.name: sess_ref.run(v) for v in tf.compat.v1.trainable_variables()}
env_test = DummyVecEnv([env_fn])
sess_test = make_session(make_default=True, graph=tf.Graph())
learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2))
# learn_fn(env=env_test)
vars_test = {v.name: sess_test.run(v) for v in tf.compat.v1.trainable_variables()}
for v in vars_ref:
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=3e-3)
if __name__ == '__main__':
test_microbatches()
|