leonepson's picture
Upload 254 files
5960497 verified
"""
Tests for asynchronous vectorized environments.
"""
import gym
import numpy as np
import pytest
from .dummy_vec_env import DummyVecEnv
from .shmem_vec_env import ShmemVecEnv
from .subproc_vec_env import SubprocVecEnv
from baselines.common.tests.test_with_mpi import with_mpi
def assert_venvs_equal(venv1, venv2, num_steps):
"""
Compare two environments over num_steps steps and make sure
that the observations produced by each are the same when given
the same actions.
"""
assert venv1.num_envs == venv2.num_envs
assert venv1.observation_space.shape == venv2.observation_space.shape
assert venv1.observation_space.dtype == venv2.observation_space.dtype
assert venv1.action_space.shape == venv2.action_space.shape
assert venv1.action_space.dtype == venv2.action_space.dtype
try:
obs1, obs2 = venv1.reset(), venv2.reset()
assert np.array(obs1).shape == np.array(obs2).shape
assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape
assert np.allclose(obs1, obs2)
venv1.action_space.seed(1337)
for _ in range(num_steps):
actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)])
for venv in [venv1, venv2]:
venv.step_async(actions)
outs1 = venv1.step_wait()
outs2 = venv2.step_wait()
for out1, out2 in zip(outs1[:3], outs2[:3]):
assert np.array(out1).shape == np.array(out2).shape
assert np.allclose(out1, out2)
assert list(outs1[3]) == list(outs2[3])
finally:
venv1.close()
venv2.close()
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
def test_vec_env(klass, dtype): # pylint: disable=R0914
"""
Test that a vectorized environment is equivalent to
DummyVecEnv, since DummyVecEnv is less likely to be
error prone.
"""
num_envs = 3
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = DummyVecEnv(fns)
env2 = klass(fns)
assert_venvs_equal(env1, env2, num_steps=num_steps)
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
@pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
def test_sync_sampling(dtype, num_envs_in_series):
"""
Test that a SubprocVecEnv running with envs in series
outputs the same as DummyVecEnv.
"""
num_envs = 12
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = DummyVecEnv(fns)
env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
assert_venvs_equal(env1, env2, num_steps=num_steps)
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
@pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
def test_sync_sampling_sanity(dtype, num_envs_in_series):
"""
Test that a SubprocVecEnv running with envs in series
outputs the same as SubprocVecEnv without running in series.
"""
num_envs = 12
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = SubprocVecEnv(fns)
env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
assert_venvs_equal(env1, env2, num_steps=num_steps)
class SimpleEnv(gym.Env):
"""
An environment with a pre-determined observation space
and RNG seed.
"""
def __init__(self, seed, shape, dtype):
np.random.seed(seed)
self._dtype = dtype
self._start_obs = np.array(np.random.randint(0, 0x100, size=shape),
dtype=dtype)
self._max_steps = seed + 1
self._cur_obs = None
self._cur_step = 0
# this is 0xFF instead of 0x100 because the Box space includes
# the high end, while randint does not
self.action_space = gym.spaces.Box(low=0, high=0xFF, shape=shape, dtype=dtype)
self.observation_space = self.action_space
def step(self, action):
self._cur_obs += np.array(action, dtype=self._dtype)
self._cur_step += 1
done = self._cur_step >= self._max_steps
reward = self._cur_step / self._max_steps
return self._cur_obs, reward, done, {'foo': 'bar' + str(reward)}
def reset(self):
self._cur_obs = self._start_obs
self._cur_step = 0
return self._cur_obs
def render(self, mode=None):
raise NotImplementedError
@with_mpi()
def test_mpi_with_subprocvecenv():
shape = (2,3,4)
nenv = 1
venv = SubprocVecEnv([lambda: SimpleEnv(0, shape, 'float32')] * nenv)
ob = venv.reset()
venv.close()
assert ob.shape == (nenv,) + shape