Spaces:
Sleeping
Sleeping
| from functools import partial as bind | |
| import embodied | |
| import numpy as np | |
| class TestDriver: | |
| def test_episode_length(self): | |
| agent = self._make_agent() | |
| driver = embodied.Driver([self._make_env]) | |
| driver.reset(agent.init_policy) | |
| seq = [] | |
| driver.on_step(lambda tran, _: seq.append(tran)) | |
| driver(agent.policy, episodes=1) | |
| assert len(seq) == 11 | |
| def test_first_step(self): | |
| agent = self._make_agent() | |
| driver = embodied.Driver([self._make_env]) | |
| driver.reset(agent.init_policy) | |
| seq = [] | |
| driver.on_step(lambda tran, _: seq.append(tran)) | |
| driver(agent.policy, episodes=2) | |
| for index in [0, 11]: | |
| assert seq[index]['is_first'].item() is True | |
| assert seq[index]['is_last'].item() is False | |
| for index in [1, 10, 12]: | |
| assert seq[index]['is_first'].item() is False | |
| def test_last_step(self): | |
| agent = self._make_agent() | |
| driver = embodied.Driver([self._make_env]) | |
| driver.reset(agent.init_policy) | |
| seq = [] | |
| driver.on_step(lambda tran, _: seq.append(tran)) | |
| driver(agent.policy, episodes=2) | |
| for index in [10, 21]: | |
| assert seq[index]['is_last'].item() is True | |
| assert seq[index]['is_first'].item() is False | |
| for index in [0, 1, 9, 11, 20]: | |
| assert seq[index]['is_last'].item() is False | |
| def test_env_reset(self): | |
| agent = self._make_agent() | |
| driver = embodied.Driver([bind(self._make_env, length=5)]) | |
| driver.reset(agent.init_policy) | |
| seq = [] | |
| driver.on_step(lambda tran, _: seq.append(tran)) | |
| action = {'act_disc': np.ones(1, int), 'act_cont': np.zeros((1, 6), float)} | |
| policy = lambda carry, obs: (carry, action, {}) | |
| driver(policy, episodes=2) | |
| assert len(seq) == 12 | |
| seq = {k: np.array([seq[i][k] for i in range(len(seq))]) for k in seq[0]} | |
| assert (seq['is_first'] == [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]).all() | |
| assert (seq['is_last'] == [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]).all() | |
| assert (seq['reset'] == [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]).all() | |
| assert (seq['act_disc'] == [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]).all() | |
| def test_agent_inputs(self): | |
| agent = self._make_agent() | |
| driver = embodied.Driver([self._make_env]) | |
| driver.reset(agent.init_policy) | |
| inputs = [] | |
| states = [] | |
| def policy(carry, obs, mode='train'): | |
| inputs.append(obs) | |
| states.append(carry) | |
| _, act, _ = agent.policy(carry, obs, mode) | |
| return 'carry', act, {} | |
| seq = [] | |
| driver.on_step(lambda tran, _: seq.append(tran)) | |
| driver(policy, episodes=2) | |
| assert len(seq) == 22 | |
| assert states == ([()] + ['carry'] * 21) | |
| for index in [0, 11]: | |
| assert inputs[index]['is_first'].item() is True | |
| for index in [1, 10, 12, 21]: | |
| assert inputs[index]['is_first'].item() is False | |
| for index in [10, 21]: | |
| assert inputs[index]['is_last'].item() is True | |
| for index in [0, 1, 9, 11, 20]: | |
| assert inputs[index]['is_last'].item() is False | |
| def test_unexpected_reset(self): | |
| class UnexpectedReset(embodied.Wrapper): | |
| """Send is_first without preceeding is_last.""" | |
| def __init__(self, env, when): | |
| super().__init__(env) | |
| self._when = when | |
| self._step = 0 | |
| def step(self, action): | |
| if self._step == self._when: | |
| action = action.copy() | |
| action['reset'] = np.ones_like(action['reset']) | |
| self._step += 1 | |
| return self.env.step(action) | |
| env = self._make_env(length=4) | |
| env = UnexpectedReset(env, when=3) | |
| agent = self._make_agent() | |
| driver = embodied.Driver([lambda: env]) | |
| driver.reset(agent.init_policy) | |
| steps = [] | |
| driver.on_step(lambda tran, _: steps.append(tran)) | |
| driver(agent.policy, episodes=1) | |
| assert len(steps) == 8 | |
| steps = {k: np.array([x[k] for x in steps]) for k in steps[0]} | |
| assert (steps['reset'] == [0, 0, 0, 0, 0, 0, 0, 1]).all() | |
| assert (steps['is_first'] == [1, 0, 0, 1, 0, 0, 0, 0]).all() | |
| assert (steps['is_last'] == [0, 0, 0, 0, 0, 0, 0, 1]).all() | |
| def _make_env(self, length=10): | |
| from embodied.envs import dummy | |
| return dummy.Dummy('disc', length=length) | |
| def _make_agent(self): | |
| env = self._make_env() | |
| agent = embodied.RandomAgent(env.obs_space, env.act_space) | |
| env.close() | |
| return agent | |