PIWM / src /envs /atari_preprocessing.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
"""
Derived from https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py
Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018.
"""
from __future__ import annotations
from typing import Any, SupportsFloat
import cv2
import numpy as np
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.spaces import Box
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(
self,
env: gym.Env,
noop_max: int,
frame_skip: int,
screen_size: int,
):
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
)
gym.Wrapper.__init__(self, env)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (0, 255, np.uint8)
_shape = (screen_size, screen_size, 3)
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
@property
def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale
def step(self, action: WrapperActType) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
total_reward, terminated, truncated, info = 0.0, False, False, {}
life_loss = False
for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
self.game_over = terminated
if self.ale.lives() < self.lives:
life_loss = True
self.lives = self.ale.lives()
if terminated or truncated:
break
if t == self.frame_skip - 2:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
self.ale.getScreenRGB(self.obs_buffer[0])
info["life_loss"] = life_loss
obs, original_obs = self._get_obs()
info["original_obs"] = original_obs
return obs, total_reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(seed=seed, options=options)
reset_info["life_loss"] = False
noops = self.env.unwrapped.np_random.integers(1, self.noop_max + 1) if self.noop_max > 0 else 0
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(seed=seed, options=options)
self.lives = self.ale.lives()
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
obs, original_obs = self._get_obs()
reset_info["original_obs"] = original_obs
return obs, reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
original_obs = self.obs_buffer[0]
obs = cv2.resize(
original_obs,
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
obs = np.asarray(obs, dtype=np.uint8)
return obs, original_obs