Spaces:
Sleeping
Sleeping
File size: 7,290 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
This file implements a wrapper for visualizing important sites in a given environment.
By default, this visualizes all sites possible for the environment. Visualization options
for a given environment can be found by calling `get_visualization_settings()`, and can
be set individually by calling `set_visualization_setting(setting, visible)`.
"""
import xml.etree.ElementTree as ET
from copy import deepcopy
import numpy as np
from robosuite.utils.mjcf_utils import new_body, new_geom, new_site
from robosuite.wrappers import Wrapper
DEFAULT_INDICATOR_SITE_CONFIG = {
"type": "sphere",
"size": [0.03],
"rgba": [1, 0, 0, 0.5],
}
class VisualizationWrapper(Wrapper):
def __init__(self, env, indicator_configs=None):
"""
Initializes the data collection wrapper. Note that this automatically conducts a (hard) reset initially to make
sure indicators are properly added to the sim model.
Args:
env (MujocoEnv): The environment to visualize
indicator_configs (None or str or dict or list): Configurations to use for indicator objects.
If None, no indicator objects will be used
If a string, this should be `'default'`, which corresponds to single default spherical indicator
If a dict, should specify a single indicator object config
If a list, should specify specific indicator object configs to use for multiple indicators (which in
turn can either be `'default'` or a dict)
As each indicator object is essentially a site element, each dict should map site attribute keywords to
values. Note that, at the very minimum, the `'name'` attribute MUST be specified for each indicator. See
http://www.mujoco.org/book/XMLreference.html#site for specific site attributes that can be specified.
"""
super().__init__(env)
# Make sure that the environment is NOT using segmentation sensors, since we cannot use segmentation masks
# with visualization sites simultaneously
assert all(
seg is None for seg in env.camera_segmentations
), "Cannot use camera segmentations with visualization wrapper!"
# Standardize indicator configs
self.indicator_configs = None
if indicator_configs is not None:
self.indicator_configs = []
if type(indicator_configs) in {str, dict}:
indicator_configs = [indicator_configs]
for i, indicator_config in enumerate(indicator_configs):
if indicator_config == "default":
indicator_config = deepcopy(DEFAULT_INDICATOR_SITE_CONFIG)
indicator_config["name"] = f"indicator{i}"
# Make sure name attribute is specified
assert "name" in indicator_config, "Name must be specified for all indicator object configurations!"
# Add this configuration to the internal array
self.indicator_configs.append(indicator_config)
# Create internal dict to store visualization settings (set to True by default)
self._vis_settings = {vis: True for vis in self.env._visualizations}
# Add the post-processor to make sure indicator objects get added to model before it's actually loaded in sim
self.env.set_xml_processor(processor=self._add_indicators_to_model)
# Conduct a (hard) reset to make sure visualization changes propagate
reset_mode = self.env.hard_reset
self.env.hard_reset = True
self.reset()
self.env.hard_reset = reset_mode
def get_indicator_names(self):
"""
Gets all indicator object names for this environment.
Returns:
list: Indicator names for this environment.
"""
return (
[ind_config["name"] for ind_config in self.indicator_configs] if self.indicator_configs is not None else []
)
def set_indicator_pos(self, indicator, pos):
"""
Sets the specified @indicator to the desired position @pos
Args:
indicator (str): Name of the indicator to set
pos (3-array): (x, y, z) Cartesian world coordinates to set the specified indicator to
"""
# Make sure indicator is valid
indicator_names = set(self.get_indicator_names())
assert indicator in indicator_names, "Invalid indicator name specified. Valid options are {}, got {}".format(
indicator_names, indicator
)
# Set the specified indicator
self.env.sim.model.body_pos[self.env.sim.model.body_name2id(indicator + "_body")] = np.array(pos)
def get_visualization_settings(self):
"""
Gets all settings for visualizing this environment
Returns:
list: Visualization keywords for this environment.
"""
return self._vis_settings.keys()
def set_visualization_setting(self, setting, visible):
"""
Sets the specified @setting to have visibility = @visible.
Args:
setting (str): Visualization keyword to set
visible (bool): True if setting should be visualized.
"""
assert (
setting in self._vis_settings
), "Invalid visualization setting specified. Valid options are {}, got {}".format(
self._vis_settings.keys(), setting
)
self._vis_settings[setting] = visible
def reset(self):
"""
Extends vanilla reset() function call to accommodate visualization
Returns:
OrderedDict: Environment observation space after reset occurs
"""
ret = super().reset()
# Update any visualization
self.env.visualize(vis_settings=self._vis_settings)
return ret
def step(self, action):
"""
Extends vanilla step() function call to accommodate visualization
Args:
action (np.array): Action to take in environment
Returns:
4-tuple:
- (OrderedDict) observations from the environment
- (float) reward from the environment
- (bool) whether the current episode is completed or not
- (dict) misc information
"""
ret = super().step(action)
# Update any visualization
self.env.visualize(vis_settings=self._vis_settings)
return ret
def _add_indicators_to_model(self, xml):
"""
Adds indicators to the mujoco simulation model
Args:
xml (string): MJCF model in xml format, for the current simulation to be loaded
"""
if self.indicator_configs is not None:
root = ET.fromstring(xml)
worldbody = root.find("worldbody")
for indicator_config in self.indicator_configs:
config = deepcopy(indicator_config)
indicator_body = new_body(name=config["name"] + "_body", pos=config.pop("pos", (0, 0, 0)))
indicator_body.append(new_site(**config))
worldbody.append(indicator_body)
xml = ET.tostring(root, encoding="utf8").decode("utf8")
return xml
|