| | import numpy as np |
| | from baselines.common.runners import AbstractEnvRunner |
| | from baselines.common.vec_env.vec_frame_stack import VecFrameStack |
| | from gym import spaces |
| |
|
| |
|
| | class Runner(AbstractEnvRunner): |
| |
|
| | def __init__(self, env, model, nsteps): |
| | super().__init__(env=env, model=model, nsteps=nsteps) |
| | assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!' |
| | assert isinstance(env, VecFrameStack) |
| |
|
| | self.nact = env.action_space.n |
| | nenv = self.nenv |
| | self.nbatch = nenv * nsteps |
| | self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape |
| |
|
| | self.obs = env.reset() |
| | self.obs_dtype = env.observation_space.dtype |
| | self.ac_dtype = env.action_space.dtype |
| | self.nstack = self.env.nstack |
| | self.nc = self.batch_ob_shape[-1] // self.nstack |
| |
|
| |
|
| | def run(self): |
| | |
| | enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1) |
| | mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], [] |
| | for _ in range(self.nsteps): |
| | actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones) |
| | mb_obs.append(np.copy(self.obs)) |
| | mb_actions.append(actions) |
| | mb_mus.append(mus) |
| | mb_dones.append(self.dones) |
| | obs, rewards, dones, _ = self.env.step(actions) |
| | |
| | self.states = states |
| | self.dones = dones |
| | self.obs = obs |
| | mb_rewards.append(rewards) |
| | enc_obs.append(obs[..., -self.nc:]) |
| | mb_obs.append(np.copy(self.obs)) |
| | mb_dones.append(self.dones) |
| |
|
| | enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0) |
| | mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0) |
| | mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0) |
| | mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0) |
| | mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0) |
| |
|
| | mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0) |
| |
|
| | mb_masks = mb_dones |
| | mb_dones = mb_dones[:, 1:] |
| |
|
| | |
| | |
| |
|
| | return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks |
| |
|
| |
|