Spaces:
Sleeping
Sleeping
| import gym | |
| import numpy as np | |
| from easydict import EasyDict | |
| from ding.envs import BaseEnvTimestep | |
| from ding.utils import ENV_WRAPPER_REGISTRY | |
| class LightZeroEnvWrapper(gym.Wrapper): | |
| """ | |
| Overview: | |
| Package the classic_control, box2d environment into the format required by LightZero. | |
| Wrap obs as a dict, containing keys: obs, action_mask and to_play. | |
| Interface: | |
| ``__init__``, ``reset``, ``step`` | |
| Properties: | |
| - env (:obj:`gym.Env`): the environment to wrap. | |
| """ | |
| def __init__(self, env: gym.Env, cfg: EasyDict) -> None: | |
| """ | |
| Overview: | |
| Initialize ``self.`` See ``help(type(self))`` for accurate signature; \ | |
| setup the properties according to running mean and std. | |
| Arguments: | |
| - env (:obj:`gym.Env`): the environment to wrap. | |
| """ | |
| super().__init__(env) | |
| assert 'is_train' in cfg, '`is_train` flag must set in the config of env' | |
| self.is_train = cfg.is_train | |
| self.cfg = cfg | |
| self.env_name = cfg.env_name | |
| self.continuous = cfg.continuous | |
| def reset(self, **kwargs): | |
| """ | |
| Overview: | |
| Resets the state of the environment and reset properties. | |
| Arguments: | |
| - kwargs (:obj:`Dict`): Reset with this key argumets | |
| Returns: | |
| - observation (:obj:`Any`): New observation after reset | |
| """ | |
| # The core original env reset. | |
| obs = self.env.reset(**kwargs) | |
| self._eval_episode_return = 0. | |
| self._raw_observation_space = self.env.observation_space | |
| if self.cfg.continuous: | |
| action_mask = None | |
| else: | |
| action_mask = np.ones(self.env.action_space.n, 'int8') | |
| if self.cfg.continuous: | |
| self._observation_space = gym.spaces.Dict( | |
| { | |
| 'observation': self._raw_observation_space, | |
| 'action_mask': gym.spaces.Box(low=np.inf, high=np.inf, | |
| shape=(1, )), # TODO: gym.spaces.Constant(None) | |
| 'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), # TODO: gym.spaces.Constant(-1) | |
| } | |
| ) | |
| else: | |
| self._observation_space = gym.spaces.Dict( | |
| { | |
| 'observation': self._raw_observation_space, | |
| 'action_mask': gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.n)]) | |
| if isinstance(self.env.action_space, gym.spaces.Discrete) else | |
| gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.shape[0])]), # {0,1} | |
| 'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), # TODO: gym.spaces.Constant(-1) | |
| } | |
| ) | |
| lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} | |
| return lightzero_obs_dict | |
| def step(self, action): | |
| """ | |
| Overview: | |
| Step the environment with the given action. Repeat action, sum reward, \ | |
| and update ``data_count``, and also update the ``self.rms`` property \ | |
| once after integrating with the input ``action``. | |
| Arguments: | |
| - action (:obj:`Any`): the given action to step with. | |
| Returns: | |
| - ``self.observation(observation)`` : normalized observation after the \ | |
| input action and updated ``self.rms`` | |
| - reward (:obj:`Any`) : amount of reward returned after previous action | |
| - done (:obj:`Bool`) : whether the episode has ended, in which case further \ | |
| step() calls will return undefined results | |
| - info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \ | |
| for debugging, and sometimes learning) | |
| """ | |
| # The core original env step. | |
| obs, rew, done, info = self.env.step(action) | |
| if self.cfg.continuous: | |
| action_mask = None | |
| else: | |
| action_mask = np.ones(self.env.action_space.n, 'int8') | |
| lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} | |
| self._eval_episode_return += rew | |
| if done: | |
| info['eval_episode_return'] = self._eval_episode_return | |
| return BaseEnvTimestep(lightzero_obs_dict, rew, done, info) | |
| def __repr__(self) -> str: | |
| return "LightZero Env." |