File size: 7,290 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
This file implements a wrapper for visualizing important sites in a given environment.

By default, this visualizes all sites possible for the environment. Visualization options
for a given environment can be found by calling `get_visualization_settings()`, and can
be set individually by calling `set_visualization_setting(setting, visible)`.
"""
import xml.etree.ElementTree as ET
from copy import deepcopy

import numpy as np

from robosuite.utils.mjcf_utils import new_body, new_geom, new_site
from robosuite.wrappers import Wrapper

DEFAULT_INDICATOR_SITE_CONFIG = {
    "type": "sphere",
    "size": [0.03],
    "rgba": [1, 0, 0, 0.5],
}


class VisualizationWrapper(Wrapper):
    def __init__(self, env, indicator_configs=None):
        """
        Initializes the data collection wrapper. Note that this automatically conducts a (hard) reset initially to make
        sure indicators are properly added to the sim model.

        Args:
            env (MujocoEnv): The environment to visualize

            indicator_configs (None or str or dict or list): Configurations to use for indicator objects.

                If None, no indicator objects will be used

                If a string, this should be `'default'`, which corresponds to single default spherical indicator

                If a dict, should specify a single indicator object config

                If a list, should specify specific indicator object configs to use for multiple indicators (which in
                turn can either be `'default'` or a dict)

                As each indicator object is essentially a site element, each dict should map site attribute keywords to
                values. Note that, at the very minimum, the `'name'` attribute MUST be specified for each indicator. See
                http://www.mujoco.org/book/XMLreference.html#site for specific site attributes that can be specified.
        """
        super().__init__(env)

        # Make sure that the environment is NOT using segmentation sensors, since we cannot use segmentation masks
        # with visualization sites simultaneously
        assert all(
            seg is None for seg in env.camera_segmentations
        ), "Cannot use camera segmentations with visualization wrapper!"

        # Standardize indicator configs
        self.indicator_configs = None
        if indicator_configs is not None:
            self.indicator_configs = []
            if type(indicator_configs) in {str, dict}:
                indicator_configs = [indicator_configs]
            for i, indicator_config in enumerate(indicator_configs):
                if indicator_config == "default":
                    indicator_config = deepcopy(DEFAULT_INDICATOR_SITE_CONFIG)
                    indicator_config["name"] = f"indicator{i}"
                # Make sure name attribute is specified
                assert "name" in indicator_config, "Name must be specified for all indicator object configurations!"
                # Add this configuration to the internal array
                self.indicator_configs.append(indicator_config)

        # Create internal dict to store visualization settings (set to True by default)
        self._vis_settings = {vis: True for vis in self.env._visualizations}

        # Add the post-processor to make sure indicator objects get added to model before it's actually loaded in sim
        self.env.set_xml_processor(processor=self._add_indicators_to_model)

        # Conduct a (hard) reset to make sure visualization changes propagate
        reset_mode = self.env.hard_reset
        self.env.hard_reset = True
        self.reset()
        self.env.hard_reset = reset_mode

    def get_indicator_names(self):
        """
        Gets all indicator object names for this environment.

        Returns:
            list: Indicator names for this environment.
        """
        return (
            [ind_config["name"] for ind_config in self.indicator_configs] if self.indicator_configs is not None else []
        )

    def set_indicator_pos(self, indicator, pos):
        """
        Sets the specified @indicator to the desired position @pos

        Args:
            indicator (str): Name of the indicator to set
            pos (3-array): (x, y, z) Cartesian world coordinates to set the specified indicator to
        """
        # Make sure indicator is valid
        indicator_names = set(self.get_indicator_names())
        assert indicator in indicator_names, "Invalid indicator name specified. Valid options are {}, got {}".format(
            indicator_names, indicator
        )
        # Set the specified indicator
        self.env.sim.model.body_pos[self.env.sim.model.body_name2id(indicator + "_body")] = np.array(pos)

    def get_visualization_settings(self):
        """
        Gets all settings for visualizing this environment

        Returns:
            list: Visualization keywords for this environment.
        """
        return self._vis_settings.keys()

    def set_visualization_setting(self, setting, visible):
        """
        Sets the specified @setting to have visibility = @visible.

        Args:
            setting (str): Visualization keyword to set
            visible (bool): True if setting should be visualized.
        """
        assert (
            setting in self._vis_settings
        ), "Invalid visualization setting specified. Valid options are {}, got {}".format(
            self._vis_settings.keys(), setting
        )
        self._vis_settings[setting] = visible

    def reset(self):
        """
        Extends vanilla reset() function call to accommodate visualization

        Returns:
            OrderedDict: Environment observation space after reset occurs
        """
        ret = super().reset()
        # Update any visualization
        self.env.visualize(vis_settings=self._vis_settings)
        return ret

    def step(self, action):
        """
        Extends vanilla step() function call to accommodate visualization

        Args:
            action (np.array): Action to take in environment

        Returns:
            4-tuple:

                - (OrderedDict) observations from the environment
                - (float) reward from the environment
                - (bool) whether the current episode is completed or not
                - (dict) misc information
        """
        ret = super().step(action)

        # Update any visualization
        self.env.visualize(vis_settings=self._vis_settings)

        return ret

    def _add_indicators_to_model(self, xml):
        """
        Adds indicators to the mujoco simulation model

        Args:
            xml (string): MJCF model in xml format, for the current simulation to be loaded
        """
        if self.indicator_configs is not None:
            root = ET.fromstring(xml)
            worldbody = root.find("worldbody")

            for indicator_config in self.indicator_configs:
                config = deepcopy(indicator_config)
                indicator_body = new_body(name=config["name"] + "_body", pos=config.pop("pos", (0, 0, 0)))
                indicator_body.append(new_site(**config))
                worldbody.append(indicator_body)

            xml = ET.tostring(root, encoding="utf8").decode("utf8")

        return xml