| | import os |
| | import gym |
| | import tempfile |
| | import pytest |
| | import tensorflow as tf |
| | import numpy as np |
| |
|
| | from baselines.common.tests.envs.mnist_env import MnistEnv |
| | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv |
| | from baselines.run import get_learn_function |
| | from baselines.common.tf_util import make_session, get_session |
| |
|
| | from functools import partial |
| |
|
| |
|
| | learn_kwargs = { |
| | 'deepq': {}, |
| | 'a2c': {}, |
| | 'acktr': {}, |
| | 'acer': {}, |
| | 'ppo2': {'nminibatches': 1, 'nsteps': 10}, |
| | 'trpo_mpi': {}, |
| | } |
| |
|
| | network_kwargs = { |
| | 'mlp': {}, |
| | 'cnn': {'pad': 'SAME'}, |
| | 'lstm': {}, |
| | 'cnn_lnlstm': {'pad': 'SAME'} |
| | } |
| |
|
| |
|
| | @pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) |
| | @pytest.mark.parametrize("network_fn", network_kwargs.keys()) |
| | def test_serialization(learn_fn, network_fn): |
| | ''' |
| | Test if the trained model can be serialized |
| | ''' |
| |
|
| |
|
| | if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']: |
| | |
| | |
| | |
| | return |
| |
|
| | def make_env(): |
| | env = MnistEnv(episode_len=100) |
| | env.seed(10) |
| | return env |
| |
|
| | env = DummyVecEnv([make_env]) |
| | ob = env.reset().copy() |
| | learn = get_learn_function(learn_fn) |
| |
|
| | kwargs = {} |
| | kwargs.update(network_kwargs[network_fn]) |
| | kwargs.update(learn_kwargs[learn_fn]) |
| |
|
| |
|
| | learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs) |
| |
|
| | with tempfile.TemporaryDirectory() as td: |
| | model_path = os.path.join(td, 'serialization_test_model') |
| |
|
| | with tf.Graph().as_default(), make_session().as_default(): |
| | model = learn(total_timesteps=100) |
| | model.save(model_path) |
| | mean1, std1 = _get_action_stats(model, ob) |
| | variables_dict1 = _serialize_variables() |
| |
|
| | with tf.Graph().as_default(), make_session().as_default(): |
| | model = learn(total_timesteps=0, load_path=model_path) |
| | mean2, std2 = _get_action_stats(model, ob) |
| | variables_dict2 = _serialize_variables() |
| |
|
| | for k, v in variables_dict1.items(): |
| | np.testing.assert_allclose(v, variables_dict2[k], atol=0.01, |
| | err_msg='saved and loaded variable {} value mismatch'.format(k)) |
| |
|
| | np.testing.assert_allclose(mean1, mean2, atol=0.5) |
| | np.testing.assert_allclose(std1, std2, atol=0.5) |
| |
|
| |
|
| | @pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) |
| | @pytest.mark.parametrize("network_fn", ['mlp']) |
| | def test_coexistence(learn_fn, network_fn): |
| | ''' |
| | Test if more than one model can exist at a time |
| | ''' |
| |
|
| | if learn_fn == 'deepq': |
| | |
| | |
| | return |
| |
|
| | if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: |
| | |
| | |
| | |
| | return |
| |
|
| | env = DummyVecEnv([lambda: gym.make('CartPole-v0')]) |
| | learn = get_learn_function(learn_fn) |
| |
|
| | kwargs = {} |
| | kwargs.update(network_kwargs[network_fn]) |
| | kwargs.update(learn_kwargs[learn_fn]) |
| |
|
| | learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs) |
| | make_session(make_default=True, graph=tf.Graph()) |
| | model1 = learn(seed=1) |
| | make_session(make_default=True, graph=tf.Graph()) |
| | model2 = learn(seed=2) |
| |
|
| | model1.step(env.observation_space.sample()) |
| | model2.step(env.observation_space.sample()) |
| |
|
| |
|
| |
|
| | def _serialize_variables(): |
| | sess = get_session() |
| | variables = tf.compat.v1.trainable_variables() |
| | values = sess.run(variables) |
| | return {var.name: value for var, value in zip(variables, values)} |
| |
|
| |
|
| | def _get_action_stats(model, ob): |
| | ntrials = 1000 |
| | if model.initial_state is None or model.initial_state == []: |
| | actions = np.array([model.step(ob)[0] for _ in range(ntrials)]) |
| | else: |
| | actions = np.array([model.step(ob, S=model.initial_state, M=[False])[0] for _ in range(ntrials)]) |
| |
|
| | mean = np.mean(actions, axis=0) |
| | std = np.std(actions, axis=0) |
| |
|
| | return mean, std |
| |
|
| |
|