from __future__ import annotations from typing import Any, Dict, Optional, Tuple import ale_py import gymnasium from gymnasium.vector import AsyncVectorEnv import numpy as np import torch from torch import Tensor from .atari_preprocessing import AtariPreprocessing def make_atari_env( id: str, num_envs: int, device: torch.device, done_on_life_loss: bool, size: int, max_episode_steps: Optional[int], ) -> TorchEnv: def env_fn(): env = gymnasium.make( id, full_action_space=False, frameskip=1, render_mode="rgb_array", max_episode_steps=max_episode_steps, ) env = AtariPreprocessing( env=env, noop_max=30, frame_skip=4, screen_size=size, ) return env env = AsyncVectorEnv([env_fn for _ in range(num_envs)]) # The AsyncVectorEnv resets the env on termination, which means that it will # reset the environment if we use the default AtariPreprocessing of gymnasium with # terminate_on_life_loss=True (which means that we will only see the first life). # Hence a separate wrapper for life_loss, coming after the AsyncVectorEnv. if done_on_life_loss: env = DoneOnLifeLoss(env) env = TorchEnv(env, device) return env class DoneOnLifeLoss(gymnasium.Wrapper): def __init__(self, env: AsyncVectorEnv) -> None: super().__init__(env) def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]: obs, rew, end, trunc, info = self.env.step(actions) life_loss = info["life_loss"] if life_loss.any(): end[life_loss] = True info["final_observation"] = obs return obs, rew, end, trunc, info class TorchEnv(gymnasium.Wrapper): def __init__(self, env: gymnasium.Env, device: torch.device) -> None: super().__init__(env) self.device = device self.num_envs = env.observation_space.shape[0] self.num_actions = env.unwrapped.single_action_space.n b, h, w, c = env.observation_space.shape self.observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(b, c, h, w)) def reset(self, *args, **kwargs) -> Tuple[Tensor, Dict[str, Any]]: obs, info = self.env.reset(*args, **kwargs) return self._to_tensor(obs), info def step(self, actions: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]: obs, rew, end, trunc, info = self.env.step(actions.cpu().numpy()) dead = np.logical_or(end, trunc) if dead.any(): info["final_observation"] = self._to_tensor(np.stack(info["final_observation"][dead])) obs, rew, end, trunc = (self._to_tensor(x) for x in (obs, rew, end, trunc)) return obs, rew, end, trunc, info def _to_tensor(self, x: Tensor) -> Tensor: if x.ndim == 4: return torch.tensor(x, device=self.device).div(255).mul(2).sub(1).permute(0, 3, 1, 2).contiguous() elif x.dtype is np.dtype("bool"): return torch.tensor(x, dtype=torch.uint8, device=self.device) else: return torch.tensor(x, dtype=torch.float32, device=self.device)