from collections import deque import gym import numpy as np class DelayedEnv(gym.Wrapper): def __init__(self, env, seed, obs_delayed_steps, act_delayed_steps): super(DelayedEnv, self).__init__(env) assert obs_delayed_steps + act_delayed_steps > 0 self.env.action_space.seed(seed) self.observation_space = self.env.observation_space self.action_space = self.env.action_space self._max_episode_steps = self.env._max_episode_steps self.obs_buffer = deque(maxlen=obs_delayed_steps) self.reward_buffer = deque(maxlen=obs_delayed_steps) self.done_buffer = deque(maxlen=obs_delayed_steps) self.action_buffer = deque(maxlen=act_delayed_steps) self.obs_delayed_steps = obs_delayed_steps self.act_delayed_steps = act_delayed_steps def reset(self): for _ in range(self.act_delayed_steps): self.action_buffer.append(np.zeros_like(self.env.action_space.sample())) init_state, _ = self.env.reset() for _ in range(self.obs_delayed_steps): self.obs_buffer.append(init_state) self.reward_buffer.append(0) self.done_buffer.append(False) return init_state def step(self, action): if self.act_delayed_steps > 0: delayed_action = self.action_buffer.popleft() self.action_buffer.append(action) else: delayed_action = action current_obs, current_reward, current_terminated, current_truncated, _ = self.env.step(delayed_action) current_done = current_terminated or current_truncated if self.obs_delayed_steps > 0: delayed_obs = self.obs_buffer.popleft() delayed_reward = self.reward_buffer.popleft() delayed_done = self.done_buffer.popleft() self.obs_buffer.append(current_obs) self.reward_buffer.append(current_reward) self.done_buffer.append(current_done) else: delayed_obs = current_obs delayed_reward = current_reward delayed_done = current_done return delayed_obs, delayed_reward, delayed_done, {'current_obs': current_obs, 'current_reward': current_reward, 'current_done': current_done}