File size: 860 Bytes
5960497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pytest
import gym

from baselines.run import get_learn_function
from baselines.common.tests.util import reward_per_episode_test
from baselines.common.tests import mark_slow

pytest.importorskip('mujoco_py')

common_kwargs = dict(
    network='mlp',
    seed=0,
)

learn_kwargs = {
    'her': dict(total_timesteps=2000)
}

@mark_slow
@pytest.mark.parametrize("alg", learn_kwargs.keys())
def test_fetchreach(alg):
    '''
    Test if the algorithm (with an mlp policy)
    can learn the FetchReach task
    '''

    kwargs = common_kwargs.copy()
    kwargs.update(learn_kwargs[alg])

    learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
    def env_fn():

        env = gym.make('FetchReach-v1')
        env.seed(0)
        return env

    reward_per_episode_test(env_fn, learn_fn, -15)

if __name__ == '__main__':
    test_fetchreach('her')