Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Optional | |
| import numpy as np | |
| import gym | |
| from gym.spaces import Box | |
| from robomimic.envs.env_robosuite import EnvRobosuite | |
| class RobomimicLowdimWrapper(gym.Env): | |
| def __init__( | |
| self, | |
| env: EnvRobosuite, | |
| obs_keys: List[str] = ["object", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"], | |
| init_state: Optional[np.ndarray] = None, | |
| render_hw=(256, 256), | |
| render_camera_name="agentview", | |
| ): | |
| self.env = env | |
| self.obs_keys = obs_keys | |
| self.init_state = init_state | |
| self.render_hw = render_hw | |
| self.render_camera_name = render_camera_name | |
| self.seed_state_map = dict() | |
| self._seed = None | |
| # import IPython; IPython.embed() | |
| # setup spaces | |
| low = np.full(env.action_dimension, fill_value=-1) | |
| high = np.full(env.action_dimension, fill_value=1) | |
| self.action_space = Box( | |
| low=low, | |
| high=high, | |
| ) | |
| obs_example = self.get_observation() | |
| low = np.full_like(obs_example, fill_value=-1) | |
| high = np.full_like(obs_example, fill_value=1) | |
| self.observation_space = Box( | |
| low=low, | |
| high=high, | |
| ) | |
| def get_observation(self): | |
| raw_obs = self.env.get_observation() | |
| obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) | |
| return obs | |
| def seed(self, seed=None): | |
| np.random.seed(seed=seed) | |
| self._seed = seed | |
| def reset(self): | |
| if self.init_state is not None: | |
| # always reset to the same state | |
| # to be compatible with gym | |
| self.env.reset_to({"states": self.init_state}) | |
| elif self._seed is not None: | |
| # reset to a specific seed | |
| seed = self._seed | |
| if seed in self.seed_state_map: | |
| # env.reset is expensive, use cache | |
| self.env.reset_to({"states": self.seed_state_map[seed]}) | |
| else: | |
| # robosuite's initializes all use numpy global random state | |
| np.random.seed(seed=seed) | |
| self.env.reset() | |
| state = self.env.get_state()["states"] | |
| self.seed_state_map[seed] = state | |
| self._seed = None | |
| else: | |
| # random reset | |
| self.env.reset() | |
| # return obs | |
| obs = self.get_observation() | |
| return obs | |
| def step(self, action): | |
| raw_obs, reward, done, info = self.env.step(action) | |
| obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) | |
| return obs, reward, done, info | |
| def render(self, mode="rgb_array"): | |
| h, w = self.render_hw | |
| return self.env.render(mode=mode, height=h, width=w, camera_name=self.render_camera_name) | |
| def close(self): | |
| self.env.env.close() |