Spaces:
Sleeping
Sleeping
| # Adapted from openai baselines: https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py | |
| from datetime import datetime | |
| from typing import Optional | |
| import cv2 | |
| import gymnasium | |
| import gym | |
| import numpy as np | |
| from ding.envs import NoopResetWrapper, MaxAndSkipWrapper, EpisodicLifeWrapper, FireResetWrapper, WarpFrameWrapper, \ | |
| ScaledFloatFrameWrapper, \ | |
| ClipRewardWrapper, FrameStackWrapper | |
| from ding.utils.compression_helper import jpeg_data_compressor | |
| from easydict import EasyDict | |
| from gymnasium.wrappers import RecordVideo | |
| # only for reference now | |
| def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): | |
| """Configure environment for DeepMind-style Atari. The observation is | |
| channel-first: (c, h, w) instead of (h, w, c). | |
| :param str env_id: the atari environment id. | |
| :param bool episode_life: wrap the episode life wrapper. | |
| :param bool clip_rewards: wrap the reward clipping wrapper. | |
| :param int frame_stack: wrap the frame stacking wrapper. | |
| :param bool scale: wrap the scaling observation wrapper. | |
| :param bool warp_frame: wrap the grayscale + resize observation wrapper. | |
| :return: the wrapped atari environment. | |
| """ | |
| assert 'NoFrameskip' in env_id | |
| env = gym.make(env_id) | |
| env = NoopResetWrapper(env, noop_max=30) | |
| env = MaxAndSkipWrapper(env, skip=4) | |
| if episode_life: | |
| env = EpisodicLifeWrapper(env) | |
| if 'FIRE' in env.unwrapped.get_action_meanings(): | |
| env = FireResetWrapper(env) | |
| if warp_frame: | |
| env = WarpFrameWrapper(env) | |
| if scale: | |
| env = ScaledFloatFrameWrapper(env) | |
| if clip_rewards: | |
| env = ClipRewardWrapper(env) | |
| if frame_stack: | |
| env = FrameStackWrapper(env, frame_stack) | |
| return env | |
| # only for reference now | |
| def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): | |
| """Configure environment for DeepMind-style Atari. The observation is | |
| channel-first: (c, h, w) instead of (h, w, c). | |
| :param str env_id: the atari environment id. | |
| :param bool episode_life: wrap the episode life wrapper. | |
| :param bool clip_rewards: wrap the reward clipping wrapper. | |
| :param int frame_stack: wrap the frame stacking wrapper. | |
| :param bool scale: wrap the scaling observation wrapper. | |
| :param bool warp_frame: wrap the grayscale + resize observation wrapper. | |
| :return: the wrapped atari environment. | |
| """ | |
| assert 'MontezumaRevenge' in env_id | |
| env = gym.make(env_id) | |
| env = NoopResetWrapper(env, noop_max=30) | |
| env = MaxAndSkipWrapper(env, skip=4) | |
| if episode_life: | |
| env = EpisodicLifeWrapper(env) | |
| if 'FIRE' in env.unwrapped.get_action_meanings(): | |
| env = FireResetWrapper(env) | |
| if warp_frame: | |
| env = WarpFrameWrapper(env) | |
| if scale: | |
| env = ScaledFloatFrameWrapper(env) | |
| if clip_rewards: | |
| env = ClipRewardWrapper(env) | |
| if frame_stack: | |
| env = FrameStackWrapper(env, frame_stack) | |
| return env | |
| def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> gym.Env: | |
| """ | |
| Overview: | |
| Configure environment for MuZero-style Atari. The observation is | |
| channel-first: (c, h, w) instead of (h, w, c). | |
| Arguments: | |
| - config (:obj:`Dict`): Dict containing configuration parameters for the environment. | |
| - episode_life (:obj:`bool`): If True, the agent starts with a set number of lives and loses them during the game. | |
| - clip_rewards (:obj:`bool`): If True, the rewards are clipped to a certain range. | |
| Return: | |
| - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. | |
| """ | |
| if config.render_mode_human: | |
| env = gymnasium.make(config.env_name, render_mode='human') | |
| else: | |
| env = gymnasium.make(config.env_name, render_mode='rgb_array') | |
| assert 'NoFrameskip' in env.spec.id | |
| if config.save_replay: | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| video_name = f'{env.spec.id}-video-{timestamp}' | |
| env = RecordVideo( | |
| env, | |
| video_folder=config.replay_path, | |
| episode_trigger=lambda episode_id: True, | |
| name_prefix=video_name | |
| ) | |
| env = GymnasiumToGymWrapper(env) | |
| env = NoopResetWrapper(env, noop_max=30) | |
| env = MaxAndSkipWrapper(env, skip=config.frame_skip) | |
| if episode_life: | |
| env = EpisodicLifeWrapper(env) | |
| env = TimeLimit(env, max_episode_steps=config.max_episode_steps) | |
| if config.warp_frame: | |
| # we must set WarpFrame before ScaledFloatFrameWrapper | |
| env = WarpFrame(env, width=config.obs_shape[1], height=config.obs_shape[2], grayscale=config.gray_scale) | |
| if config.scale: | |
| env = ScaledFloatFrameWrapper(env) | |
| if clip_rewards: | |
| env = ClipRewardWrapper(env) | |
| env = JpegWrapper(env, transform2string=config.transform2string) | |
| if config.game_wrapper: | |
| env = GameWrapper(env) | |
| return env | |
| class TimeLimit(gym.Wrapper): | |
| """ | |
| Overview: | |
| A wrapper that limits the maximum number of steps in an episode. | |
| """ | |
| def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): | |
| """ | |
| Arguments: | |
| - env (:obj:`gym.Env`): The environment to wrap. | |
| - max_episode_steps (:obj:`Optional[int]`): Maximum number of steps per episode. If None, no limit is applied. | |
| """ | |
| super(TimeLimit, self).__init__(env) | |
| self._max_episode_steps = max_episode_steps | |
| self._elapsed_steps = 0 | |
| def step(self, ac): | |
| observation, reward, done, info = self.env.step(ac) | |
| self._elapsed_steps += 1 | |
| if self._elapsed_steps >= self._max_episode_steps: | |
| done = True | |
| info['TimeLimit.truncated'] = True | |
| return observation, reward, done, info | |
| def reset(self, **kwargs): | |
| self._elapsed_steps = 0 | |
| return self.env.reset(**kwargs) | |
| class WarpFrame(gym.ObservationWrapper): | |
| """ | |
| Overview: | |
| A wrapper that warps frames to 84x84 as done in the Nature paper and later work. | |
| """ | |
| def __init__(self, env: gym.Env, width: int = 84, height: int = 84, grayscale: bool = True, | |
| dict_space_key: Optional[str] = None): | |
| """ | |
| Arguments: | |
| - env (:obj:`gym.Env`): The environment to wrap. | |
| - width (:obj:`int`): The width to which the frames are resized. | |
| - height (:obj:`int`): The height to which the frames are resized. | |
| - grayscale (:obj:`bool`): If True, convert frames to grayscale. | |
| - dict_space_key (:obj:`Optional[str]`): If specified, indicates which observation should be warped. | |
| """ | |
| super().__init__(env) | |
| self._width = width | |
| self._height = height | |
| self._grayscale = grayscale | |
| self._key = dict_space_key | |
| if self._grayscale: | |
| num_colors = 1 | |
| else: | |
| num_colors = 3 | |
| new_space = gym.spaces.Box( | |
| low=0, | |
| high=255, | |
| shape=(self._height, self._width, num_colors), | |
| dtype=np.uint8, | |
| ) | |
| if self._key is None: | |
| original_space = self.observation_space | |
| self.observation_space = new_space | |
| else: | |
| original_space = self.observation_space.spaces[self._key] | |
| self.observation_space.spaces[self._key] = new_space | |
| assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 | |
| def observation(self, obs): | |
| if self._key is None: | |
| frame = obs | |
| else: | |
| frame = obs[self._key] | |
| if self._grayscale: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) | |
| frame = cv2.resize(frame, (self._width, self._height), interpolation=cv2.INTER_AREA) | |
| if self._grayscale: | |
| frame = np.expand_dims(frame, -1) | |
| if self._key is None: | |
| obs = frame | |
| else: | |
| obs = obs.copy() | |
| obs[self._key] = frame | |
| return obs | |
| class JpegWrapper(gym.Wrapper): | |
| """ | |
| Overview: | |
| A wrapper that converts the observation into a string to save memory. | |
| """ | |
| def __init__(self, env: gym.Env, transform2string: bool = True): | |
| """ | |
| Arguments: | |
| - env (:obj:`gym.Env`): The environment to wrap. | |
| - transform2string (:obj:`bool`): If True, transform the observations to string. | |
| """ | |
| super().__init__(env) | |
| self.transform2string = transform2string | |
| def step(self, action): | |
| observation, reward, done, info = self.env.step(action) | |
| if self.transform2string: | |
| observation = jpeg_data_compressor(observation) | |
| return observation, reward, done, info | |
| def reset(self, **kwargs): | |
| observation = self.env.reset(**kwargs) | |
| if self.transform2string: | |
| observation = jpeg_data_compressor(observation) | |
| return observation | |
| class GameWrapper(gym.Wrapper): | |
| """ | |
| Overview: | |
| A wrapper to adapt the environment to the game interface. | |
| """ | |
| def __init__(self, env: gym.Env): | |
| """ | |
| Arguments: | |
| - env (:obj:`gym.Env`): The environment to wrap. | |
| """ | |
| super().__init__(env) | |
| def legal_actions(self): | |
| return [_ for _ in range(self.env.action_space.n)] | |
| class GymnasiumToGymWrapper(gym.Wrapper): | |
| """ | |
| Overview: | |
| A wrapper class that adapts a Gymnasium environment to the Gym interface. | |
| Interface: | |
| ``__init__``, ``reset``, ``seed`` | |
| Properties: | |
| - _seed (:obj:`int` or None): The seed value for the environment. | |
| """ | |
| def __init__(self, env): | |
| """ | |
| Overview: | |
| Initializes the GymnasiumToGymWrapper. | |
| Arguments: | |
| - env (:obj:`gymnasium.Env`): The Gymnasium environment to be wrapped. | |
| """ | |
| assert isinstance(env, gymnasium.Env), type(env) | |
| super().__init__(env) | |
| self._seed = None | |
| def seed(self, seed): | |
| """ | |
| Overview: | |
| Sets the seed value for the environment. | |
| Arguments: | |
| - seed (:obj:`int`): The seed value to use for random number generation. | |
| """ | |
| self._seed = seed | |
| def reset(self): | |
| """ | |
| Overview: | |
| Resets the environment and returns the initial observation. | |
| Returns: | |
| - observation (:obj:`Any`): The initial observation of the environment. | |
| """ | |
| if self._seed is not None: | |
| obs, _ = self.env.reset(seed=self._seed) | |
| return obs | |
| else: | |
| obs, _ = self.env.reset() | |
| return obs |