| | import pytest |
| | import gym |
| |
|
| | from baselines.run import get_learn_function |
| | from baselines.common.tests.util import reward_per_episode_test |
| | from baselines.common.tests import mark_slow |
| |
|
| | pytest.importorskip('mujoco_py') |
| |
|
| | common_kwargs = dict( |
| | network='mlp', |
| | seed=0, |
| | ) |
| |
|
| | learn_kwargs = { |
| | 'her': dict(total_timesteps=2000) |
| | } |
| |
|
| | @mark_slow |
| | @pytest.mark.parametrize("alg", learn_kwargs.keys()) |
| | def test_fetchreach(alg): |
| | ''' |
| | Test if the algorithm (with an mlp policy) |
| | can learn the FetchReach task |
| | ''' |
| |
|
| | kwargs = common_kwargs.copy() |
| | kwargs.update(learn_kwargs[alg]) |
| |
|
| | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) |
| | def env_fn(): |
| |
|
| | env = gym.make('FetchReach-v1') |
| | env.seed(0) |
| | return env |
| |
|
| | reward_per_episode_test(env_fn, learn_fn, -15) |
| |
|
| | if __name__ == '__main__': |
| | test_fetchreach('her') |
| |
|