xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
This file implements a wrapper for visualizing important sites in a given environment.
By default, this visualizes all sites possible for the environment. Visualization options
for a given environment can be found by calling `get_visualization_settings()`, and can
be set individually by calling `set_visualization_setting(setting, visible)`.
"""
import xml.etree.ElementTree as ET
from copy import deepcopy
import numpy as np
from robosuite.utils.mjcf_utils import new_body, new_geom, new_site
from robosuite.wrappers import Wrapper
DEFAULT_INDICATOR_SITE_CONFIG = {
"type": "sphere",
"size": [0.03],
"rgba": [1, 0, 0, 0.5],
}
class VisualizationWrapper(Wrapper):
def __init__(self, env, indicator_configs=None):
"""
Initializes the data collection wrapper. Note that this automatically conducts a (hard) reset initially to make
sure indicators are properly added to the sim model.
Args:
env (MujocoEnv): The environment to visualize
indicator_configs (None or str or dict or list): Configurations to use for indicator objects.
If None, no indicator objects will be used
If a string, this should be `'default'`, which corresponds to single default spherical indicator
If a dict, should specify a single indicator object config
If a list, should specify specific indicator object configs to use for multiple indicators (which in
turn can either be `'default'` or a dict)
As each indicator object is essentially a site element, each dict should map site attribute keywords to
values. Note that, at the very minimum, the `'name'` attribute MUST be specified for each indicator. See
http://www.mujoco.org/book/XMLreference.html#site for specific site attributes that can be specified.
"""
super().__init__(env)
# Make sure that the environment is NOT using segmentation sensors, since we cannot use segmentation masks
# with visualization sites simultaneously
assert all(
seg is None for seg in env.camera_segmentations
), "Cannot use camera segmentations with visualization wrapper!"
# Standardize indicator configs
self.indicator_configs = None
if indicator_configs is not None:
self.indicator_configs = []
if type(indicator_configs) in {str, dict}:
indicator_configs = [indicator_configs]
for i, indicator_config in enumerate(indicator_configs):
if indicator_config == "default":
indicator_config = deepcopy(DEFAULT_INDICATOR_SITE_CONFIG)
indicator_config["name"] = f"indicator{i}"
# Make sure name attribute is specified
assert "name" in indicator_config, "Name must be specified for all indicator object configurations!"
# Add this configuration to the internal array
self.indicator_configs.append(indicator_config)
# Create internal dict to store visualization settings (set to True by default)
self._vis_settings = {vis: True for vis in self.env._visualizations}
# Add the post-processor to make sure indicator objects get added to model before it's actually loaded in sim
self.env.set_xml_processor(processor=self._add_indicators_to_model)
# Conduct a (hard) reset to make sure visualization changes propagate
reset_mode = self.env.hard_reset
self.env.hard_reset = True
self.reset()
self.env.hard_reset = reset_mode
def get_indicator_names(self):
"""
Gets all indicator object names for this environment.
Returns:
list: Indicator names for this environment.
"""
return (
[ind_config["name"] for ind_config in self.indicator_configs] if self.indicator_configs is not None else []
)
def set_indicator_pos(self, indicator, pos):
"""
Sets the specified @indicator to the desired position @pos
Args:
indicator (str): Name of the indicator to set
pos (3-array): (x, y, z) Cartesian world coordinates to set the specified indicator to
"""
# Make sure indicator is valid
indicator_names = set(self.get_indicator_names())
assert indicator in indicator_names, "Invalid indicator name specified. Valid options are {}, got {}".format(
indicator_names, indicator
)
# Set the specified indicator
self.env.sim.model.body_pos[self.env.sim.model.body_name2id(indicator + "_body")] = np.array(pos)
def get_visualization_settings(self):
"""
Gets all settings for visualizing this environment
Returns:
list: Visualization keywords for this environment.
"""
return self._vis_settings.keys()
def set_visualization_setting(self, setting, visible):
"""
Sets the specified @setting to have visibility = @visible.
Args:
setting (str): Visualization keyword to set
visible (bool): True if setting should be visualized.
"""
assert (
setting in self._vis_settings
), "Invalid visualization setting specified. Valid options are {}, got {}".format(
self._vis_settings.keys(), setting
)
self._vis_settings[setting] = visible
def reset(self):
"""
Extends vanilla reset() function call to accommodate visualization
Returns:
OrderedDict: Environment observation space after reset occurs
"""
ret = super().reset()
# Update any visualization
self.env.visualize(vis_settings=self._vis_settings)
return ret
def step(self, action):
"""
Extends vanilla step() function call to accommodate visualization
Args:
action (np.array): Action to take in environment
Returns:
4-tuple:
- (OrderedDict) observations from the environment
- (float) reward from the environment
- (bool) whether the current episode is completed or not
- (dict) misc information
"""
ret = super().step(action)
# Update any visualization
self.env.visualize(vis_settings=self._vis_settings)
return ret
def _add_indicators_to_model(self, xml):
"""
Adds indicators to the mujoco simulation model
Args:
xml (string): MJCF model in xml format, for the current simulation to be loaded
"""
if self.indicator_configs is not None:
root = ET.fromstring(xml)
worldbody = root.find("worldbody")
for indicator_config in self.indicator_configs:
config = deepcopy(indicator_config)
indicator_body = new_body(name=config["name"] + "_body", pos=config.pop("pos", (0, 0, 0)))
indicator_body.append(new_site(**config))
worldbody.append(indicator_body)
xml = ET.tostring(root, encoding="utf8").decode("utf8")
return xml