Spaces:
Sleeping
Sleeping
| # utility functions for manipulating MJCF XML models | |
| import os | |
| import xml.etree.ElementTree as ET | |
| from collections.abc import Iterable | |
| from copy import deepcopy | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import robosuite | |
| RED = [1, 0, 0, 1] | |
| GREEN = [0, 1, 0, 1] | |
| BLUE = [0, 0, 1, 1] | |
| CYAN = [0, 1, 1, 1] | |
| ROBOT_COLLISION_COLOR = [0, 0.5, 0, 1] | |
| MOUNT_COLLISION_COLOR = [0.5, 0.5, 0, 1] | |
| GRIPPER_COLLISION_COLOR = [0, 0, 0.5, 1] | |
| OBJECT_COLLISION_COLOR = [0.5, 0, 0, 1] | |
| ENVIRONMENT_COLLISION_COLOR = [0.5, 0.5, 0, 1] | |
| SENSOR_TYPES = { | |
| "touch", | |
| "accelerometer", | |
| "velocimeter", | |
| "gyro", | |
| "force", | |
| "torque", | |
| "magnetometer", | |
| "rangefinder", | |
| "jointpos", | |
| "jointvel", | |
| "tendonpos", | |
| "tendonvel", | |
| "actuatorpos", | |
| "actuatorvel", | |
| "actuatorfrc", | |
| "ballangvel", | |
| "jointlimitpos", | |
| "jointlimitvel", | |
| "jointlimitfrc", | |
| "tendonlimitpos", | |
| "tendonlimitvel", | |
| "tendonlimitfrc", | |
| "framepos", | |
| "framequat", | |
| "framexaxis", | |
| "frameyaxis", | |
| "framezaxis", | |
| "framelinvel", | |
| "frameangvel", | |
| "framelinacc", | |
| "frameangacc", | |
| "subtreecom", | |
| "subtreelinvel", | |
| "subtreeangmom", | |
| "user", | |
| } | |
| MUJOCO_NAMED_ATTRIBUTES = { | |
| "class", | |
| "childclass", | |
| "name", | |
| "objname", | |
| "material", | |
| "texture", | |
| "joint", | |
| "joint1", | |
| "joint2", | |
| "jointinparent", | |
| "geom", | |
| "geom1", | |
| "geom2", | |
| "mesh", | |
| "fixed", | |
| "actuator", | |
| "objname", | |
| "tendon", | |
| "tendon1", | |
| "tendon2", | |
| "slidesite", | |
| "cranksite", | |
| "body", | |
| "body1", | |
| "body2", | |
| "hfield", | |
| "target", | |
| "prefix", | |
| "site", | |
| } | |
| IMAGE_CONVENTION_MAPPING = { | |
| "opengl": 1, | |
| "opencv": -1, | |
| } | |
| TEXTURE_FILES = { | |
| "WoodRed": "red-wood.png", | |
| "WoodGreen": "green-wood.png", | |
| "WoodBlue": "blue-wood.png", | |
| "WoodLight": "light-wood.png", | |
| "WoodDark": "dark-wood.png", | |
| "WoodTiles": "wood-tiles.png", | |
| "WoodPanels": "wood-varnished-panels.png", | |
| "WoodgrainGray": "gray-woodgrain.png", | |
| "PlasterCream": "cream-plaster.png", | |
| "PlasterPink": "pink-plaster.png", | |
| "PlasterYellow": "yellow-plaster.png", | |
| "PlasterGray": "gray-plaster.png", | |
| "PlasterWhite": "white-plaster.png", | |
| "BricksWhite": "white-bricks.png", | |
| "Metal": "metal.png", | |
| "SteelBrushed": "steel-brushed.png", | |
| "SteelScratched": "steel-scratched.png", | |
| "Brass": "brass-ambra.png", | |
| "Bread": "bread.png", | |
| "Can": "can.png", | |
| "Ceramic": "ceramic.png", | |
| "Cereal": "cereal.png", | |
| "Clay": "clay.png", | |
| "Dirt": "dirt.png", | |
| "Glass": "glass.png", | |
| "FeltGray": "gray-felt.png", | |
| "Lemon": "lemon.png", | |
| } | |
| TEXTURES = { | |
| texture_name: os.path.join("textures", texture_file) for (texture_name, texture_file) in TEXTURE_FILES.items() | |
| } | |
| ALL_TEXTURES = TEXTURES.keys() | |
| class CustomMaterial(object): | |
| """ | |
| Simple class to instantiate the necessary parameters to define an appropriate texture / material combo | |
| Instantiates a nested dict holding necessary components for procedurally generating a texture / material combo | |
| Please see http://www.mujoco.org/book/XMLreference.html#asset for specific details on | |
| attributes expected for Mujoco texture / material tags, respectively | |
| Note that the values in @tex_attrib and @mat_attrib can be in string or array / numerical form. | |
| Args: | |
| texture (None or str or 4-array): Name of texture file to be imported. If a string, should be part of | |
| ALL_TEXTURES. If texture is a 4-array, then this argument will be interpreted as an rgba tuple value and | |
| a template png will be procedurally generated during object instantiation, with any additional | |
| texture / material attributes specified. If None, no file will be linked and no rgba value will be set | |
| Note, if specified, the RGBA values are expected to be floats between 0 and 1 | |
| tex_name (str): Name to reference the imported texture | |
| mat_name (str): Name to reference the imported material | |
| tex_attrib (dict): Any other optional mujoco texture specifications. | |
| mat_attrib (dict): Any other optional mujoco material specifications. | |
| shared (bool): If True, this material should not have any naming prefixes added to all names | |
| Raises: | |
| AssertionError: [Invalid texture] | |
| """ | |
| def __init__( | |
| self, | |
| texture, | |
| tex_name, | |
| mat_name, | |
| tex_attrib=None, | |
| mat_attrib=None, | |
| shared=False, | |
| ): | |
| # Check if the desired texture is an rgba value | |
| if type(texture) is str: | |
| default = False | |
| # Verify that requested texture is valid | |
| assert texture in ALL_TEXTURES, "Error: Requested invalid texture. Got {}. Valid options are:\n{}".format( | |
| texture, ALL_TEXTURES | |
| ) | |
| else: | |
| default = True | |
| # If specified, this is an rgba value and a default texture is desired; make sure length of rgba array is 4 | |
| if texture is not None: | |
| assert len(texture) == 4, ( | |
| "Error: Requested default texture. Got array of length {}." | |
| "Expected rgba array of length 4.".format(len(texture)) | |
| ) | |
| # Setup the texture and material attributes | |
| self.tex_attrib = {} if tex_attrib is None else tex_attrib.copy() | |
| self.mat_attrib = {} if mat_attrib is None else mat_attrib.copy() | |
| # Add in name values | |
| self.name = mat_name | |
| self.shared = shared | |
| self.tex_attrib["name"] = tex_name | |
| self.mat_attrib["name"] = mat_name | |
| self.mat_attrib["texture"] = tex_name | |
| # Loop through all attributes and convert all non-string values into strings | |
| for attrib in (self.tex_attrib, self.mat_attrib): | |
| for k, v in attrib.items(): | |
| if type(v) is not str: | |
| if isinstance(v, Iterable): | |
| attrib[k] = array_to_string(v) | |
| else: | |
| attrib[k] = str(v) | |
| # Handle default and non-default cases separately for linking texture patch file locations | |
| if not default: | |
| # Add in the filepath to texture patch | |
| self.tex_attrib["file"] = xml_path_completion(TEXTURES[texture]) | |
| else: | |
| if texture is not None: | |
| # Create a texture patch | |
| tex = Image.new("RGBA", (100, 100), tuple((np.array(texture) * 255).astype("int"))) | |
| # Create temp directory if it does not exist | |
| save_dir = "/tmp/robosuite_temp_tex" | |
| Path(save_dir).mkdir(parents=True, exist_ok=True) | |
| # Save this texture patch to the temp directory on disk (MacOS / Linux) | |
| fpath = save_dir + "/{}.png".format(tex_name) | |
| tex.save(fpath, "PNG") | |
| # Link this texture file to the default texture dict | |
| self.tex_attrib["file"] = fpath | |
| def xml_path_completion(xml_path): | |
| """ | |
| Takes in a local xml path and returns a full path. | |
| if @xml_path is absolute, do nothing | |
| if @xml_path is not absolute, load xml that is shipped by the package | |
| Args: | |
| xml_path (str): local xml path | |
| Returns: | |
| str: Full (absolute) xml path | |
| """ | |
| if xml_path.startswith("/"): | |
| full_path = xml_path | |
| else: | |
| full_path = os.path.join(robosuite.models.assets_root, xml_path) | |
| return full_path | |
| def array_to_string(array): | |
| """ | |
| Converts a numeric array into the string format in mujoco. | |
| Examples: | |
| [0, 1, 2] => "0 1 2" | |
| Args: | |
| array (n-array): Array to convert to a string | |
| Returns: | |
| str: String equivalent of @array | |
| """ | |
| return " ".join(["{}".format(x) for x in array]) | |
| def string_to_array(string): | |
| """ | |
| Converts a array string in mujoco xml to np.array. | |
| Examples: | |
| "0 1 2" => [0, 1, 2] | |
| Args: | |
| string (str): String to convert to an array | |
| Returns: | |
| np.array: Numerical array equivalent of @string | |
| """ | |
| return np.array([float(x) for x in string.strip().split(" ")]) | |
| def convert_to_string(inp): | |
| """ | |
| Converts any type of {bool, int, float, list, tuple, array, string, np.str_} into an mujoco-xml compatible string. | |
| Note that an input string / np.str_ results in a no-op action. | |
| Args: | |
| inp: Input to convert to string | |
| Returns: | |
| str: String equivalent of @inp | |
| """ | |
| if type(inp) in {list, tuple, np.ndarray}: | |
| return array_to_string(inp) | |
| elif type(inp) in {int, float, bool}: | |
| return str(inp).lower() | |
| elif type(inp) in {str, np.str_}: | |
| return inp | |
| else: | |
| raise ValueError("Unsupported type received: got {}".format(type(inp))) | |
| def set_alpha(node, alpha=0.1): | |
| """ | |
| Sets all a(lpha) field of the rgba attribute to be @alpha | |
| for @node and all subnodes | |
| used for managing display | |
| Args: | |
| node (ET.Element): Specific node element within XML tree | |
| alpha (float): Value to set alpha value of rgba tuple | |
| """ | |
| for child_node in node.findall(".//*[@rgba]"): | |
| rgba_orig = string_to_array(child_node.get("rgba")) | |
| child_node.set("rgba", array_to_string(list(rgba_orig[0:3]) + [alpha])) | |
| def new_element(tag, name, **kwargs): | |
| """ | |
| Creates a new @tag element with attributes specified by @**kwargs. | |
| Args: | |
| tag (str): Type of element to create | |
| name (None or str): Name for this element. Should only be None for elements that do not have an explicit | |
| name attribute (e.g.: inertial elements) | |
| **kwargs: Specified attributes for the new joint | |
| Returns: | |
| ET.Element: new specified xml element | |
| """ | |
| # Name will be set if it's not None | |
| if name is not None: | |
| kwargs["name"] = name | |
| # Loop through all attributes and pop any that are None, otherwise convert them to strings | |
| for k, v in kwargs.copy().items(): | |
| if v is None: | |
| kwargs.pop(k) | |
| else: | |
| kwargs[k] = convert_to_string(v) | |
| element = ET.Element(tag, attrib=kwargs) | |
| return element | |
| def new_joint(name, **kwargs): | |
| """ | |
| Creates a joint tag with attributes specified by @**kwargs. | |
| Args: | |
| name (str): Name for this joint | |
| **kwargs: Specified attributes for the new joint | |
| Returns: | |
| ET.Element: new joint xml element | |
| """ | |
| return new_element(tag="joint", name=name, **kwargs) | |
| def new_actuator(name, joint, act_type="actuator", **kwargs): | |
| """ | |
| Creates an actuator tag with attributes specified by @**kwargs. | |
| Args: | |
| name (str): Name for this actuator | |
| joint (str): type of actuator transmission. | |
| see all types here: http://mujoco.org/book/modeling.html#actuator | |
| act_type (str): actuator type. Defaults to "actuator" | |
| **kwargs: Any additional specified attributes for the new joint | |
| Returns: | |
| ET.Element: new actuator xml element | |
| """ | |
| element = new_element(tag=act_type, name=name, **kwargs) | |
| element.set("joint", joint) | |
| return element | |
| def new_site(name, rgba=RED, pos=(0, 0, 0), size=(0.005,), **kwargs): | |
| """ | |
| Creates a site element with attributes specified by @**kwargs. | |
| NOTE: With the exception of @name, @pos, and @size, if any arg is set to | |
| None, the value will automatically be popped before passing the values | |
| to create the appropriate XML | |
| Args: | |
| name (str): Name for this site | |
| rgba (4-array): (r,g,b,a) color and transparency. Defaults to solid red. | |
| pos (3-array): (x,y,z) 3d position of the site. | |
| size (n-array of float): site size (sites are spherical by default). | |
| **kwargs: Any additional specified attributes for the new site | |
| Returns: | |
| ET.Element: new site xml element | |
| """ | |
| kwargs["pos"] = pos | |
| kwargs["size"] = size | |
| kwargs["rgba"] = rgba if rgba is not None else None | |
| return new_element(tag="site", name=name, **kwargs) | |
| def new_geom(name, type, size, pos=(0, 0, 0), group=0, **kwargs): | |
| """ | |
| Creates a geom element with attributes specified by @**kwargs. | |
| NOTE: With the exception of @geom_type, @size, and @pos, if any arg is set to | |
| None, the value will automatically be popped before passing the values | |
| to create the appropriate XML | |
| Args: | |
| name (str): Name for this geom | |
| type (str): type of the geom. | |
| see all types here: http://mujoco.org/book/modeling.html#geom | |
| size (n-array of float): geom size parameters. | |
| pos (3-array): (x,y,z) 3d position of the site. | |
| group (int): the integrer group that the geom belongs to. useful for | |
| separating visual and physical elements. | |
| **kwargs: Any additional specified attributes for the new geom | |
| Returns: | |
| ET.Element: new geom xml element | |
| """ | |
| kwargs["type"] = type | |
| kwargs["size"] = size | |
| kwargs["pos"] = pos | |
| kwargs["group"] = group if group is not None else None | |
| return new_element(tag="geom", name=name, **kwargs) | |
| def new_body(name, pos=(0, 0, 0), **kwargs): | |
| """ | |
| Creates a body element with attributes specified by @**kwargs. | |
| Args: | |
| name (str): Name for this body | |
| pos (3-array): (x,y,z) 3d position of the body frame. | |
| **kwargs: Any additional specified attributes for the new body | |
| Returns: | |
| ET.Element: new body xml element | |
| """ | |
| kwargs["pos"] = pos | |
| return new_element(tag="body", name=name, **kwargs) | |
| def new_inertial(pos=(0, 0, 0), mass=None, **kwargs): | |
| """ | |
| Creates a inertial element with attributes specified by @**kwargs. | |
| Args: | |
| pos (3-array): (x,y,z) 3d position of the inertial frame. | |
| mass (float): The mass of inertial | |
| **kwargs: Any additional specified attributes for the new inertial element | |
| Returns: | |
| ET.Element: new inertial xml element | |
| """ | |
| kwargs["mass"] = mass if mass is not None else None | |
| kwargs["pos"] = pos | |
| return new_element(tag="inertial", name=None, **kwargs) | |
| def get_size(size, size_max, size_min, default_max, default_min): | |
| """ | |
| Helper method for providing a size, or a range to randomize from | |
| Args: | |
| size (n-array): Array of numbers that explicitly define the size | |
| size_max (n-array): Array of numbers that define the custom max size from which to randomly sample | |
| size_min (n-array): Array of numbers that define the custom min size from which to randomly sample | |
| default_max (n-array): Array of numbers that define the default max size from which to randomly sample | |
| default_min (n-array): Array of numbers that define the default min size from which to randomly sample | |
| Returns: | |
| np.array: size generated | |
| Raises: | |
| ValueError: [Inconsistent array sizes] | |
| """ | |
| if len(default_max) != len(default_min): | |
| raise ValueError( | |
| "default_max = {} and default_min = {}".format(str(default_max), str(default_min)) | |
| + " have different lengths" | |
| ) | |
| if size is not None: | |
| if (size_max is not None) or (size_min is not None): | |
| raise ValueError("size = {} overrides size_max = {}, size_min = {}".format(size, size_max, size_min)) | |
| else: | |
| if size_max is None: | |
| size_max = default_max | |
| if size_min is None: | |
| size_min = default_min | |
| size = np.array([np.random.uniform(size_min[i], size_max[i]) for i in range(len(default_max))]) | |
| return np.array(size) | |
| def add_to_dict(dic, fill_in_defaults=True, default_value=None, **kwargs): | |
| """ | |
| Helper function to add key-values to dictionary @dic where each entry is its own array (list). | |
| Args: | |
| dic (dict): Dictionary to which new key / value pairs will be added. If the key already exists, | |
| will append the value to that key entry | |
| fill_in_defaults (bool): If True, will automatically add @default_value to all dictionary entries that are | |
| not explicitly specified in @kwargs | |
| default_value (any): Default value to fill (None by default) | |
| Returns: | |
| dict: Modified dictionary | |
| """ | |
| # Get keys and length of array for a given entry in dic | |
| keys = set(dic.keys()) | |
| n = len(list(keys)[0]) if keys else 0 | |
| for k, v in kwargs.items(): | |
| if k in dic: | |
| dic[k].append(v) | |
| keys.remove(k) | |
| else: | |
| dic[k] = [default_value] * n + [v] if fill_in_defaults else [v] | |
| # If filling in defaults, fill in remaining default values | |
| if fill_in_defaults: | |
| for k in keys: | |
| dic[k].append(default_value) | |
| return dic | |
| def add_prefix( | |
| root, | |
| prefix, | |
| tags="default", | |
| attribs="default", | |
| exclude=None, | |
| ): | |
| """ | |
| Find all element(s) matching the requested @tag, and appends @prefix to all @attributes if they exist. | |
| Args: | |
| root (ET.Element): Root of the xml element tree to start recursively searching through. | |
| prefix (str): Prefix to add to all specified attributes | |
| tags (str or list of str or set): Tag(s) to search for in this ElementTree. "Default" corresponds to all tags | |
| attribs (str or list of str or set): Element attribute(s) to append prefix to. "Default" corresponds | |
| to all attributes that reference names | |
| exclude (None or function): Filtering function that should take in an ET.Element or a string (attribute) and | |
| return True if we should exclude the given element / attribute from having any prefixes added | |
| """ | |
| # Standardize tags and attributes to be a set | |
| if tags != "default": | |
| tags = {tags} if type(tags) is str else set(tags) | |
| if attribs == "default": | |
| attribs = MUJOCO_NAMED_ATTRIBUTES | |
| attribs = {attribs} if type(attribs) is str else set(attribs) | |
| # Check the current element for matching conditions | |
| if (tags == "default" or root.tag in tags) and (exclude is None or not exclude(root)): | |
| for attrib in attribs: | |
| v = root.get(attrib, None) | |
| # Only add prefix if the attribute exist, the current attribute doesn't already begin with prefix, | |
| # and the @exclude filter is either None or returns False | |
| if v is not None and not v.startswith(prefix) and (exclude is None or not exclude(v)): | |
| root.set(attrib, prefix + v) | |
| # Continue recursively searching through the element tree | |
| for r in root: | |
| add_prefix(root=r, prefix=prefix, tags=tags, attribs=attribs, exclude=exclude) | |
| def add_material(root, naming_prefix="", custom_material=None): | |
| """ | |
| Iterates through all element(s) in @root recursively and adds a material / texture to all visual geoms that don't | |
| already have a material specified. | |
| Args: | |
| root (ET.Element): Root of the xml element tree to start recursively searching through. | |
| naming_prefix (str): Adds this prefix to all material and texture names | |
| custom_material (None or CustomMaterial): If specified, will add this material to all visual geoms. | |
| Else, will add a default "no-change" material. | |
| Returns: | |
| 4-tuple: (ET.Element, ET.Element, CustomMaterial, bool) (tex_element, mat_element, material, used) | |
| corresponding to the added material and whether the material was actually used or not. | |
| """ | |
| # Initialize used as False | |
| used = False | |
| # First, make sure material is specified | |
| if custom_material is None: | |
| custom_material = CustomMaterial( | |
| texture=None, | |
| tex_name="default_tex", | |
| mat_name="default_mat", | |
| tex_attrib={ | |
| "type": "cube", | |
| "builtin": "flat", | |
| "width": 100, | |
| "height": 100, | |
| "rgb1": np.ones(3), | |
| "rgb2": np.ones(3), | |
| }, | |
| ) | |
| # Else, check to make sure the custom material begins with the specified prefix and that it's unique | |
| if not custom_material.name.startswith(naming_prefix) and not custom_material.shared: | |
| custom_material.name = naming_prefix + custom_material.name | |
| custom_material.tex_attrib["name"] = naming_prefix + custom_material.tex_attrib["name"] | |
| custom_material.mat_attrib["name"] = naming_prefix + custom_material.mat_attrib["name"] | |
| custom_material.mat_attrib["texture"] = naming_prefix + custom_material.mat_attrib["texture"] | |
| # Check the current element for matching conditions | |
| if root.tag == "geom" and root.get("group", None) == "1" and root.get("material", None) is None: | |
| # Add a new material attribute to this geom | |
| root.set("material", custom_material.name) | |
| # Set used to True | |
| used = True | |
| # Continue recursively searching through the element tree | |
| for r in root: | |
| _, _, _, _used = add_material(root=r, naming_prefix=naming_prefix, custom_material=custom_material) | |
| # Update used | |
| used = used or _used | |
| # Lastly, return the new texture and material elements | |
| tex_element = new_element(tag="texture", **custom_material.tex_attrib) | |
| mat_element = new_element(tag="material", **custom_material.mat_attrib) | |
| return tex_element, mat_element, custom_material, used | |
| def recolor_collision_geoms(root, rgba, exclude=None): | |
| """ | |
| Iteratively searches through all elements starting with @root to find all geoms belonging to group 0 and set | |
| the corresponding rgba value to the specified @rgba argument. Note: also removes any material values for these | |
| elements. | |
| Args: | |
| root (ET.Element): Root of the xml element tree to start recursively searching through | |
| rgba (4-array): (R, G, B, A) values to assign to all geoms with this group. | |
| exclude (None or function): Filtering function that should take in an ET.Element and | |
| return True if we should exclude the given element / attribute from having its collision geom impacted. | |
| """ | |
| # Check this body | |
| if root.tag == "geom" and root.get("group") in {None, "0"} and (exclude is None or not exclude(root)): | |
| root.set("rgba", array_to_string(rgba)) | |
| root.attrib.pop("material", None) | |
| # Iterate through all children elements | |
| for r in root: | |
| recolor_collision_geoms(root=r, rgba=rgba, exclude=exclude) | |
| def _element_filter(element, parent): | |
| """ | |
| Default element filter to be used in sort_elements. This will filter for the following groups: | |
| :`'root_body'`: Top-level body element | |
| :`'bodies'`: Any body elements | |
| :`'joints'`: Any joint elements | |
| :`'actuators'`: Any actuator elements | |
| :`'sites'`: Any site elements | |
| :`'sensors'`: Any sensor elements | |
| :`'contact_geoms'`: Any geoms used for collision (as specified by group 0 (default group) geoms) | |
| :`'visual_geoms'`: Any geoms used for visual rendering (as specified by group 1 geoms) | |
| Args: | |
| element (ET.Element): Current XML element that we are filtering | |
| parent (ET.Element): Parent XML element for the current element | |
| Returns: | |
| str or None: Assigned filter key for this element. None if no matching filter is found. | |
| """ | |
| # Check for actuator first since this is dependent on the parent element | |
| if parent is not None and parent.tag == "actuator": | |
| return "actuators" | |
| elif element.tag == "joint": | |
| # Make sure this is not a tendon (this should not have a "joint", "joint1", or "joint2" attribute specified) | |
| if element.get("joint") is None and element.get("joint1") is None: | |
| return "joints" | |
| elif element.tag == "body": | |
| # If the parent of this does not have a tag "body", then this is the top-level body element | |
| if parent is None or parent.tag != "body": | |
| return "root_body" | |
| return "bodies" | |
| elif element.tag == "site": | |
| return "sites" | |
| elif element.tag in SENSOR_TYPES: | |
| return "sensors" | |
| elif element.tag == "geom": | |
| # Only get collision and visual geoms (group 0 / None, or 1, respectively) | |
| group = element.get("group") | |
| if group in {None, "0", "1"}: | |
| return "visual_geoms" if group == "1" else "contact_geoms" | |
| else: | |
| # If no condition met, return None | |
| return None | |
| def sort_elements(root, parent=None, element_filter=None, _elements_dict=None): | |
| """ | |
| Utility method to iteratively sort all elements based on @tags. This XML ElementTree will be parsed such that | |
| all elements with the same key as returned by @element_filter will be grouped as a list entry in the returned | |
| dictionary. | |
| Args: | |
| root (ET.Element): Root of the xml element tree to start recursively searching through | |
| parent (ET.Element): Parent of the root node. Default is None (no parent node initially) | |
| element_filter (None or function): Function used to filter the incoming elements. Should take in two | |
| ET.Elements (current_element, parent_element) and return a string filter_key if the element | |
| should be added to the list of values sorted by filter_key, and return None if no value should be added. | |
| If no element_filter is specified, defaults to self._element_filter. | |
| _elements_dict (dict): Dictionary that gets passed to recursive calls. Should not be modified externally by | |
| top-level call. | |
| Returns: | |
| dict: Filtered key-specific lists of the corresponding elements | |
| """ | |
| # Initialize dictionary and element filter if None is set | |
| if _elements_dict is None: | |
| _elements_dict = {} | |
| if element_filter is None: | |
| element_filter = _element_filter | |
| # Parse this element | |
| key = element_filter(root, parent) | |
| if key is not None: | |
| # Initialize new entry in the dict if this is the first time encountering this value, otherwise append | |
| if key not in _elements_dict: | |
| _elements_dict[key] = [root] | |
| else: | |
| _elements_dict[key].append(root) | |
| # Loop through all possible subtrees for this XML recurisvely | |
| for r in root: | |
| _elements_dict = sort_elements( | |
| root=r, parent=root, element_filter=element_filter, _elements_dict=_elements_dict | |
| ) | |
| return _elements_dict | |
| def find_parent(root, child): | |
| """ | |
| Find the parent element of the specified @child node, recurisvely searching through @root. | |
| Args: | |
| root (ET.Element): Root of the xml element tree to start recursively searching through. | |
| child (ET.Element): Child element whose parent is to be found | |
| Returns: | |
| None or ET.Element: Matching parent if found, else None | |
| """ | |
| # Iterate through children (DFS), if the correct child element is found, then return the current root as the parent | |
| for r in root: | |
| if r == child: | |
| return root | |
| parent = find_parent(root=r, child=child) | |
| if parent is not None: | |
| return parent | |
| # If we get here, we didn't find anything ): | |
| return None | |
| def find_elements(root, tags, attribs=None, return_first=True): | |
| """ | |
| Find all element(s) matching the requested @tag and @attributes. If @return_first is True, then will return the | |
| first element found matching the criteria specified. Otherwise, will return a list of elements that match the | |
| criteria. | |
| Args: | |
| root (ET.Element): Root of the xml element tree to start recursively searching through. | |
| tags (str or list of str or set): Tag(s) to search for in this ElementTree. | |
| attribs (None or dict of str): Element attribute(s) to check against for a filtered element. A match is | |
| considered found only if all attributes match. Each attribute key should have a corresponding value with | |
| which to compare against. | |
| return_first (bool): Whether to immediately return once the first matching element is found. | |
| Returns: | |
| None or ET.Element or list of ET.Element: Matching element(s) found. Returns None if there was no match. | |
| """ | |
| # Initialize return value | |
| elements = None if return_first else [] | |
| # Make sure tags is list | |
| tags = [tags] if type(tags) is str else tags | |
| # Check the current element for matching conditions | |
| if root.tag in tags: | |
| matching = True | |
| if attribs is not None: | |
| for k, v in attribs.items(): | |
| if root.get(k) != v: | |
| matching = False | |
| break | |
| # If all criteria were matched, add this to the solution (or return immediately if specified) | |
| if matching: | |
| if return_first: | |
| return root | |
| else: | |
| elements.append(root) | |
| # Continue recursively searching through the element tree | |
| for r in root: | |
| if return_first: | |
| elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first) | |
| if elements is not None: | |
| return elements | |
| else: | |
| found_elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first) | |
| pre_elements = deepcopy(elements) | |
| if found_elements: | |
| elements += found_elements if type(found_elements) is list else [found_elements] | |
| return elements if elements else None | |
| def save_sim_model(sim, fname): | |
| """ | |
| Saves the current model xml from @sim at file location @fname. | |
| Args: | |
| sim (MjSim): XML file to save, in string form | |
| fname (str): Absolute filepath to the location to save the file | |
| """ | |
| with open(fname, "w") as f: | |
| sim.save(file=f, format="xml") | |
| def get_ids(sim, elements, element_type="geom", inplace=False): | |
| """ | |
| Grabs the mujoco IDs for each element in @elements, corresponding to the specified @element_type. | |
| Args: | |
| sim (MjSim): Active mujoco simulation object | |
| elements (str or list or dict): Element(s) to convert into IDs. Note that the return type corresponds to | |
| @elements type, where each element name is replaced with the ID | |
| element_type (str): The type of element to grab ID for. Options are {geom, body, site} | |
| inplace (bool): If False, will create a copy of @elements to prevent overwriting the original data structure | |
| Returns: | |
| str or list or dict: IDs corresponding to @elements. | |
| """ | |
| if not inplace: | |
| # Copy elements first so we don't write to the underlying object | |
| elements = deepcopy(elements) | |
| # Choose what to do based on elements type | |
| if isinstance(elements, str): | |
| # We simply return the value of this single element | |
| assert element_type in { | |
| "geom", | |
| "body", | |
| "site", | |
| }, f"element_type must be either geom, body, or site. Got: {element_type}" | |
| if element_type == "geom": | |
| elements = sim.model.geom_name2id(elements) | |
| elif element_type == "body": | |
| elements = sim.model.body_name2id(elements) | |
| else: # site | |
| elements = sim.model.site_name2id(elements) | |
| elif isinstance(elements, dict): | |
| # Iterate over each element in dict and recursively repeat | |
| for name, ele in elements: | |
| elements[name] = get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True) | |
| else: # We assume this is an iterable array | |
| assert isinstance(elements, Iterable), "Elements must be iterable for get_id!" | |
| elements = [get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True) for ele in elements] | |
| return elements | |