File size: 1,371 Bytes
5960497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pytest
try:
    import mujoco_py
    _mujoco_present = True
except BaseException:
    mujoco_py = None
    _mujoco_present = False


@pytest.mark.skipif(
    not _mujoco_present,
    reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library'
)
def test_lstm_example():
    import tensorflow as tf
    from baselines.common import policies, models, cmd_util
    from baselines.common.vec_env.dummy_vec_env import DummyVecEnv

    # create vectorized environment
    venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)])

    with tf.compat.v1.Session() as sess:
        # build policy based on lstm network with 128 units
        policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1)

        # initialize tensorflow variables
        sess.run(tf.compat.v1.global_variables_initializer())

        # prepare environment variables
        ob = venv.reset()
        state = policy.initial_state
        done = [False]
        step_counter = 0

        # run a single episode until the end (i.e. until done)
        while True:
            action, _, state, _ = policy.step(ob, S=state, M=done)
            ob, reward, done, _ = venv.step(action)
            step_counter += 1
            if done:
                break


        assert step_counter > 5