import pytest from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv from baselines.run import get_learn_function from baselines.common.tests.util import simple_test from baselines.common.tests import mark_slow common_kwargs = dict( total_timesteps=30000, network='mlp', gamma=0.9, seed=0, ) learn_kwargs = { 'a2c' : {}, 'acktr': {}, 'deepq': {}, 'ddpg': dict(layer_norm=True), 'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0), 'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01) } algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi'] algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi'] @mark_slow @pytest.mark.parametrize("alg", algos_disc) def test_discrete_identity(alg): ''' Test if the algorithm (with an mlp policy) can learn an identity transformation (i.e. return observation as an action) ''' kwargs = learn_kwargs[alg] kwargs.update(common_kwargs) learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100) simple_test(env_fn, learn_fn, 0.9) @mark_slow @pytest.mark.parametrize("alg", algos_multidisc) def test_multidiscrete_identity(alg): ''' Test if the algorithm (with an mlp policy) can learn an identity transformation (i.e. return observation as an action) ''' kwargs = learn_kwargs[alg] kwargs.update(common_kwargs) learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100) simple_test(env_fn, learn_fn, 0.9) @mark_slow @pytest.mark.parametrize("alg", algos_cont) def test_continuous_identity(alg): ''' Test if the algorithm (with an mlp policy) can learn an identity transformation (i.e. return observation as an action) to a required precision ''' kwargs = learn_kwargs[alg] kwargs.update(common_kwargs) learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) env_fn = lambda: BoxIdentityEnv((1,), episode_len=100) simple_test(env_fn, learn_fn, -0.1) if __name__ == '__main__': test_multidiscrete_identity('acktr')