Spaces:
Sleeping
Sleeping
| """ | |
| Derived from https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py | |
| Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, SupportsFloat | |
| import cv2 | |
| import numpy as np | |
| import gymnasium as gym | |
| from gymnasium.core import WrapperActType, WrapperObsType | |
| from gymnasium.spaces import Box | |
| class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): | |
| def __init__( | |
| self, | |
| env: gym.Env, | |
| noop_max: int, | |
| frame_skip: int, | |
| screen_size: int, | |
| ): | |
| gym.utils.RecordConstructorArgs.__init__( | |
| self, | |
| noop_max=noop_max, | |
| frame_skip=frame_skip, | |
| screen_size=screen_size, | |
| ) | |
| gym.Wrapper.__init__(self, env) | |
| assert frame_skip > 0 | |
| assert screen_size > 0 | |
| assert noop_max >= 0 | |
| if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1: | |
| raise ValueError( | |
| "Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper" | |
| ) | |
| self.noop_max = noop_max | |
| assert env.unwrapped.get_action_meanings()[0] == "NOOP" | |
| self.frame_skip = frame_skip | |
| self.screen_size = screen_size | |
| # buffer of most recent two observations for max pooling | |
| assert isinstance(env.observation_space, Box) | |
| self.obs_buffer = [ | |
| np.empty(env.observation_space.shape, dtype=np.uint8), | |
| np.empty(env.observation_space.shape, dtype=np.uint8), | |
| ] | |
| self.lives = 0 | |
| self.game_over = False | |
| _low, _high, _obs_dtype = (0, 255, np.uint8) | |
| _shape = (screen_size, screen_size, 3) | |
| self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype) | |
| def ale(self): | |
| """Make ale as a class property to avoid serialization error.""" | |
| return self.env.unwrapped.ale | |
| def step(self, action: WrapperActType) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: | |
| total_reward, terminated, truncated, info = 0.0, False, False, {} | |
| life_loss = False | |
| for t in range(self.frame_skip): | |
| _, reward, terminated, truncated, info = self.env.step(action) | |
| total_reward += reward | |
| self.game_over = terminated | |
| if self.ale.lives() < self.lives: | |
| life_loss = True | |
| self.lives = self.ale.lives() | |
| if terminated or truncated: | |
| break | |
| if t == self.frame_skip - 2: | |
| self.ale.getScreenRGB(self.obs_buffer[1]) | |
| elif t == self.frame_skip - 1: | |
| self.ale.getScreenRGB(self.obs_buffer[0]) | |
| info["life_loss"] = life_loss | |
| obs, original_obs = self._get_obs() | |
| info["original_obs"] = original_obs | |
| return obs, total_reward, terminated, truncated, info | |
| def reset( | |
| self, *, seed: int | None = None, options: dict[str, Any] | None = None | |
| ) -> tuple[WrapperObsType, dict[str, Any]]: | |
| """Resets the environment using preprocessing.""" | |
| # NoopReset | |
| _, reset_info = self.env.reset(seed=seed, options=options) | |
| reset_info["life_loss"] = False | |
| noops = self.env.unwrapped.np_random.integers(1, self.noop_max + 1) if self.noop_max > 0 else 0 | |
| for _ in range(noops): | |
| _, _, terminated, truncated, step_info = self.env.step(0) | |
| reset_info.update(step_info) | |
| if terminated or truncated: | |
| _, reset_info = self.env.reset(seed=seed, options=options) | |
| self.lives = self.ale.lives() | |
| self.ale.getScreenRGB(self.obs_buffer[0]) | |
| self.obs_buffer[1].fill(0) | |
| obs, original_obs = self._get_obs() | |
| reset_info["original_obs"] = original_obs | |
| return obs, reset_info | |
| def _get_obs(self): | |
| if self.frame_skip > 1: # more efficient in-place pooling | |
| np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) | |
| original_obs = self.obs_buffer[0] | |
| obs = cv2.resize( | |
| original_obs, | |
| (self.screen_size, self.screen_size), | |
| interpolation=cv2.INTER_AREA, | |
| ) | |
| obs = np.asarray(obs, dtype=np.uint8) | |
| return obs, original_obs | |