Spaces:
Sleeping
Sleeping
File size: 4,437 Bytes
c64c726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
Derived from https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py
Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018.
"""
from __future__ import annotations
from typing import Any, SupportsFloat
import cv2
import numpy as np
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.spaces import Box
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(
self,
env: gym.Env,
noop_max: int,
frame_skip: int,
screen_size: int,
):
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
)
gym.Wrapper.__init__(self, env)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (0, 255, np.uint8)
_shape = (screen_size, screen_size, 3)
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
@property
def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale
def step(self, action: WrapperActType) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
total_reward, terminated, truncated, info = 0.0, False, False, {}
life_loss = False
for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
self.game_over = terminated
if self.ale.lives() < self.lives:
life_loss = True
self.lives = self.ale.lives()
if terminated or truncated:
break
if t == self.frame_skip - 2:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
self.ale.getScreenRGB(self.obs_buffer[0])
info["life_loss"] = life_loss
obs, original_obs = self._get_obs()
info["original_obs"] = original_obs
return obs, total_reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(seed=seed, options=options)
reset_info["life_loss"] = False
noops = self.env.unwrapped.np_random.integers(1, self.noop_max + 1) if self.noop_max > 0 else 0
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(seed=seed, options=options)
self.lives = self.ale.lives()
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
obs, original_obs = self._get_obs()
reset_info["original_obs"] = original_obs
return obs, reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
original_obs = self.obs_buffer[0]
obs = cv2.resize(
original_obs,
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
obs = np.asarray(obs, dtype=np.uint8)
return obs, original_obs
|