| import pytest |
|
|
| |
| from baselines.common.tests.envs.mnist_env import MnistEnv |
| from baselines.common.tests.util import simple_test |
| from baselines.run import get_learn_function |
| from baselines.common.tests import mark_slow |
|
|
| |
| |
| common_kwargs = { |
| 'seed': 0, |
| 'network':'cnn', |
| 'gamma':0.9, |
| 'pad':'SAME' |
| } |
|
|
| learn_args = { |
| 'a2c': dict(total_timesteps=50000), |
| 'acer': dict(total_timesteps=20000), |
| 'deepq': dict(total_timesteps=5000), |
| 'acktr': dict(total_timesteps=30000), |
| 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0), |
| 'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001) |
| } |
|
|
|
|
| |
| |
| @pytest.mark.skip |
| @mark_slow |
| @pytest.mark.parametrize("alg", learn_args.keys()) |
| def test_mnist(alg): |
| ''' |
| Test if the algorithm can learn to classify MNIST digits. |
| Uses CNN policy. |
| ''' |
|
|
| learn_kwargs = learn_args[alg] |
| learn_kwargs.update(common_kwargs) |
|
|
| learn = get_learn_function(alg) |
| learn_fn = lambda e: learn(env=e, **learn_kwargs) |
| env_fn = lambda: MnistEnv(episode_len=100) |
|
|
| simple_test(env_fn, learn_fn, 0.6) |
|
|
| if __name__ == '__main__': |
| test_mnist('acer') |
|
|