| | import pytest |
| | try: |
| | import mujoco_py |
| | _mujoco_present = True |
| | except BaseException: |
| | mujoco_py = None |
| | _mujoco_present = False |
| |
|
| |
|
| | @pytest.mark.skipif( |
| | not _mujoco_present, |
| | reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library' |
| | ) |
| | def test_lstm_example(): |
| | import tensorflow as tf |
| | from baselines.common import policies, models, cmd_util |
| | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv |
| |
|
| | |
| | venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)]) |
| |
|
| | with tf.compat.v1.Session() as sess: |
| | |
| | policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1) |
| |
|
| | |
| | sess.run(tf.compat.v1.global_variables_initializer()) |
| |
|
| | |
| | ob = venv.reset() |
| | state = policy.initial_state |
| | done = [False] |
| | step_counter = 0 |
| |
|
| | |
| | while True: |
| | action, _, state, _ = policy.step(ob, S=state, M=done) |
| | ob, reward, done, _ = venv.step(action) |
| | step_counter += 1 |
| | if done: |
| | break |
| |
|
| |
|
| | assert step_counter > 5 |
| |
|
| |
|
| |
|
| |
|
| |
|