File size: 31,636 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
# utility functions for manipulating MJCF XML models

import os
import xml.etree.ElementTree as ET
from collections.abc import Iterable
from copy import deepcopy
from pathlib import Path

import numpy as np
from PIL import Image

import robosuite

RED = [1, 0, 0, 1]
GREEN = [0, 1, 0, 1]
BLUE = [0, 0, 1, 1]
CYAN = [0, 1, 1, 1]
ROBOT_COLLISION_COLOR = [0, 0.5, 0, 1]
MOUNT_COLLISION_COLOR = [0.5, 0.5, 0, 1]
GRIPPER_COLLISION_COLOR = [0, 0, 0.5, 1]
OBJECT_COLLISION_COLOR = [0.5, 0, 0, 1]
ENVIRONMENT_COLLISION_COLOR = [0.5, 0.5, 0, 1]
SENSOR_TYPES = {
    "touch",
    "accelerometer",
    "velocimeter",
    "gyro",
    "force",
    "torque",
    "magnetometer",
    "rangefinder",
    "jointpos",
    "jointvel",
    "tendonpos",
    "tendonvel",
    "actuatorpos",
    "actuatorvel",
    "actuatorfrc",
    "ballangvel",
    "jointlimitpos",
    "jointlimitvel",
    "jointlimitfrc",
    "tendonlimitpos",
    "tendonlimitvel",
    "tendonlimitfrc",
    "framepos",
    "framequat",
    "framexaxis",
    "frameyaxis",
    "framezaxis",
    "framelinvel",
    "frameangvel",
    "framelinacc",
    "frameangacc",
    "subtreecom",
    "subtreelinvel",
    "subtreeangmom",
    "user",
}

MUJOCO_NAMED_ATTRIBUTES = {
    "class",
    "childclass",
    "name",
    "objname",
    "material",
    "texture",
    "joint",
    "joint1",
    "joint2",
    "jointinparent",
    "geom",
    "geom1",
    "geom2",
    "mesh",
    "fixed",
    "actuator",
    "objname",
    "tendon",
    "tendon1",
    "tendon2",
    "slidesite",
    "cranksite",
    "body",
    "body1",
    "body2",
    "hfield",
    "target",
    "prefix",
    "site",
}

IMAGE_CONVENTION_MAPPING = {
    "opengl": 1,
    "opencv": -1,
}

TEXTURE_FILES = {
    "WoodRed": "red-wood.png",
    "WoodGreen": "green-wood.png",
    "WoodBlue": "blue-wood.png",
    "WoodLight": "light-wood.png",
    "WoodDark": "dark-wood.png",
    "WoodTiles": "wood-tiles.png",
    "WoodPanels": "wood-varnished-panels.png",
    "WoodgrainGray": "gray-woodgrain.png",
    "PlasterCream": "cream-plaster.png",
    "PlasterPink": "pink-plaster.png",
    "PlasterYellow": "yellow-plaster.png",
    "PlasterGray": "gray-plaster.png",
    "PlasterWhite": "white-plaster.png",
    "BricksWhite": "white-bricks.png",
    "Metal": "metal.png",
    "SteelBrushed": "steel-brushed.png",
    "SteelScratched": "steel-scratched.png",
    "Brass": "brass-ambra.png",
    "Bread": "bread.png",
    "Can": "can.png",
    "Ceramic": "ceramic.png",
    "Cereal": "cereal.png",
    "Clay": "clay.png",
    "Dirt": "dirt.png",
    "Glass": "glass.png",
    "FeltGray": "gray-felt.png",
    "Lemon": "lemon.png",
}

TEXTURES = {
    texture_name: os.path.join("textures", texture_file) for (texture_name, texture_file) in TEXTURE_FILES.items()
}

ALL_TEXTURES = TEXTURES.keys()


class CustomMaterial(object):
    """
    Simple class to instantiate the necessary parameters to define an appropriate texture / material combo

    Instantiates a nested dict holding necessary components for procedurally generating a texture / material combo

    Please see http://www.mujoco.org/book/XMLreference.html#asset for specific details on
        attributes expected for Mujoco texture / material tags, respectively

    Note that the values in @tex_attrib and @mat_attrib can be in string or array / numerical form.

    Args:
        texture (None or str or 4-array): Name of texture file to be imported. If a string, should be part of
            ALL_TEXTURES. If texture is a 4-array, then this argument will be interpreted as an rgba tuple value and
            a template png will be procedurally generated during object instantiation, with any additional
            texture / material attributes specified. If None, no file will be linked and no rgba value will be set
            Note, if specified, the RGBA values are expected to be floats between 0 and 1

        tex_name (str): Name to reference the imported texture

        mat_name (str): Name to reference the imported material

        tex_attrib (dict): Any other optional mujoco texture specifications.

        mat_attrib (dict): Any other optional mujoco material specifications.

        shared (bool): If True, this material should not have any naming prefixes added to all names

    Raises:
        AssertionError: [Invalid texture]
    """

    def __init__(
        self,
        texture,
        tex_name,
        mat_name,
        tex_attrib=None,
        mat_attrib=None,
        shared=False,
    ):
        # Check if the desired texture is an rgba value
        if type(texture) is str:
            default = False
            # Verify that requested texture is valid
            assert texture in ALL_TEXTURES, "Error: Requested invalid texture. Got {}. Valid options are:\n{}".format(
                texture, ALL_TEXTURES
            )
        else:
            default = True
            # If specified, this is an rgba value and a default texture is desired; make sure length of rgba array is 4
            if texture is not None:
                assert len(texture) == 4, (
                    "Error: Requested default texture. Got array of length {}."
                    "Expected rgba array of length 4.".format(len(texture))
                )

        # Setup the texture and material attributes
        self.tex_attrib = {} if tex_attrib is None else tex_attrib.copy()
        self.mat_attrib = {} if mat_attrib is None else mat_attrib.copy()

        # Add in name values
        self.name = mat_name
        self.shared = shared
        self.tex_attrib["name"] = tex_name
        self.mat_attrib["name"] = mat_name
        self.mat_attrib["texture"] = tex_name

        # Loop through all attributes and convert all non-string values into strings
        for attrib in (self.tex_attrib, self.mat_attrib):
            for k, v in attrib.items():
                if type(v) is not str:
                    if isinstance(v, Iterable):
                        attrib[k] = array_to_string(v)
                    else:
                        attrib[k] = str(v)

        # Handle default and non-default cases separately for linking texture patch file locations
        if not default:
            # Add in the filepath to texture patch
            self.tex_attrib["file"] = xml_path_completion(TEXTURES[texture])
        else:
            if texture is not None:
                # Create a texture patch
                tex = Image.new("RGBA", (100, 100), tuple((np.array(texture) * 255).astype("int")))
                # Create temp directory if it does not exist
                save_dir = "/tmp/robosuite_temp_tex"
                Path(save_dir).mkdir(parents=True, exist_ok=True)
                # Save this texture patch to the temp directory on disk (MacOS / Linux)
                fpath = save_dir + "/{}.png".format(tex_name)
                tex.save(fpath, "PNG")
                # Link this texture file to the default texture dict
                self.tex_attrib["file"] = fpath


def xml_path_completion(xml_path):
    """
    Takes in a local xml path and returns a full path.
        if @xml_path is absolute, do nothing
        if @xml_path is not absolute, load xml that is shipped by the package

    Args:
        xml_path (str): local xml path

    Returns:
        str: Full (absolute) xml path
    """
    if xml_path.startswith("/"):
        full_path = xml_path
    else:
        full_path = os.path.join(robosuite.models.assets_root, xml_path)
    return full_path


def array_to_string(array):
    """
    Converts a numeric array into the string format in mujoco.

    Examples:
        [0, 1, 2] => "0 1 2"

    Args:
        array (n-array): Array to convert to a string

    Returns:
        str: String equivalent of @array
    """
    return " ".join(["{}".format(x) for x in array])


def string_to_array(string):
    """
    Converts a array string in mujoco xml to np.array.

    Examples:
        "0 1 2" => [0, 1, 2]

    Args:
        string (str): String to convert to an array

    Returns:
        np.array: Numerical array equivalent of @string
    """
    return np.array([float(x) for x in string.strip().split(" ")])


def convert_to_string(inp):
    """
    Converts any type of {bool, int, float, list, tuple, array, string, np.str_} into an mujoco-xml compatible string.
        Note that an input string / np.str_ results in a no-op action.

    Args:
        inp: Input to convert to string

    Returns:
        str: String equivalent of @inp
    """
    if type(inp) in {list, tuple, np.ndarray}:
        return array_to_string(inp)
    elif type(inp) in {int, float, bool}:
        return str(inp).lower()
    elif type(inp) in {str, np.str_}:
        return inp
    else:
        raise ValueError("Unsupported type received: got {}".format(type(inp)))


def set_alpha(node, alpha=0.1):
    """
    Sets all a(lpha) field of the rgba attribute to be @alpha
    for @node and all subnodes
    used for managing display

    Args:
        node (ET.Element): Specific node element within XML tree
        alpha (float): Value to set alpha value of rgba tuple
    """
    for child_node in node.findall(".//*[@rgba]"):
        rgba_orig = string_to_array(child_node.get("rgba"))
        child_node.set("rgba", array_to_string(list(rgba_orig[0:3]) + [alpha]))


def new_element(tag, name, **kwargs):
    """
    Creates a new @tag element with attributes specified by @**kwargs.

    Args:
        tag (str): Type of element to create
        name (None or str): Name for this element. Should only be None for elements that do not have an explicit
            name attribute (e.g.: inertial elements)
        **kwargs: Specified attributes for the new joint

    Returns:
        ET.Element: new specified xml element
    """
    # Name will be set if it's not None
    if name is not None:
        kwargs["name"] = name
    # Loop through all attributes and pop any that are None, otherwise convert them to strings
    for k, v in kwargs.copy().items():
        if v is None:
            kwargs.pop(k)
        else:
            kwargs[k] = convert_to_string(v)
    element = ET.Element(tag, attrib=kwargs)
    return element


def new_joint(name, **kwargs):
    """
    Creates a joint tag with attributes specified by @**kwargs.

    Args:
        name (str): Name for this joint
        **kwargs: Specified attributes for the new joint

    Returns:
        ET.Element: new joint xml element
    """
    return new_element(tag="joint", name=name, **kwargs)


def new_actuator(name, joint, act_type="actuator", **kwargs):
    """
    Creates an actuator tag with attributes specified by @**kwargs.

    Args:
        name (str): Name for this actuator
        joint (str): type of actuator transmission.
            see all types here: http://mujoco.org/book/modeling.html#actuator
        act_type (str): actuator type. Defaults to "actuator"
        **kwargs: Any additional specified attributes for the new joint

    Returns:
        ET.Element: new actuator xml element
    """
    element = new_element(tag=act_type, name=name, **kwargs)
    element.set("joint", joint)
    return element


def new_site(name, rgba=RED, pos=(0, 0, 0), size=(0.005,), **kwargs):
    """
    Creates a site element with attributes specified by @**kwargs.

    NOTE: With the exception of @name, @pos, and @size, if any arg is set to
        None, the value will automatically be popped before passing the values
        to create the appropriate XML

    Args:
        name (str): Name for this site
        rgba (4-array): (r,g,b,a) color and transparency. Defaults to solid red.
        pos (3-array): (x,y,z) 3d position of the site.
        size (n-array of float): site size (sites are spherical by default).
        **kwargs: Any additional specified attributes for the new site

    Returns:
        ET.Element: new site xml element
    """
    kwargs["pos"] = pos
    kwargs["size"] = size
    kwargs["rgba"] = rgba if rgba is not None else None
    return new_element(tag="site", name=name, **kwargs)


def new_geom(name, type, size, pos=(0, 0, 0), group=0, **kwargs):
    """
    Creates a geom element with attributes specified by @**kwargs.

    NOTE: With the exception of @geom_type, @size, and @pos, if any arg is set to
        None, the value will automatically be popped before passing the values
        to create the appropriate XML

    Args:
        name (str): Name for this geom
        type (str): type of the geom.
            see all types here: http://mujoco.org/book/modeling.html#geom
        size (n-array of float): geom size parameters.
        pos (3-array): (x,y,z) 3d position of the site.
        group (int): the integrer group that the geom belongs to. useful for
            separating visual and physical elements.
        **kwargs: Any additional specified attributes for the new geom

    Returns:
        ET.Element: new geom xml element
    """
    kwargs["type"] = type
    kwargs["size"] = size
    kwargs["pos"] = pos
    kwargs["group"] = group if group is not None else None
    return new_element(tag="geom", name=name, **kwargs)


def new_body(name, pos=(0, 0, 0), **kwargs):
    """
    Creates a body element with attributes specified by @**kwargs.

    Args:
        name (str): Name for this body
        pos (3-array): (x,y,z) 3d position of the body frame.
        **kwargs: Any additional specified attributes for the new body

    Returns:
        ET.Element: new body xml element
    """
    kwargs["pos"] = pos
    return new_element(tag="body", name=name, **kwargs)


def new_inertial(pos=(0, 0, 0), mass=None, **kwargs):
    """
    Creates a inertial element with attributes specified by @**kwargs.

    Args:
        pos (3-array): (x,y,z) 3d position of the inertial frame.
        mass (float): The mass of inertial
        **kwargs: Any additional specified attributes for the new inertial element

    Returns:
        ET.Element: new inertial xml element
    """
    kwargs["mass"] = mass if mass is not None else None
    kwargs["pos"] = pos
    return new_element(tag="inertial", name=None, **kwargs)


def get_size(size, size_max, size_min, default_max, default_min):
    """
    Helper method for providing a size, or a range to randomize from

    Args:
        size (n-array): Array of numbers that explicitly define the size
        size_max (n-array): Array of numbers that define the custom max size from which to randomly sample
        size_min (n-array): Array of numbers that define the custom min size from which to randomly sample
        default_max (n-array): Array of numbers that define the default max size from which to randomly sample
        default_min (n-array): Array of numbers that define the default min size from which to randomly sample

    Returns:
        np.array: size generated

    Raises:
        ValueError: [Inconsistent array sizes]
    """
    if len(default_max) != len(default_min):
        raise ValueError(
            "default_max = {} and default_min = {}".format(str(default_max), str(default_min))
            + " have different lengths"
        )
    if size is not None:
        if (size_max is not None) or (size_min is not None):
            raise ValueError("size = {} overrides size_max = {}, size_min = {}".format(size, size_max, size_min))
    else:
        if size_max is None:
            size_max = default_max
        if size_min is None:
            size_min = default_min
        size = np.array([np.random.uniform(size_min[i], size_max[i]) for i in range(len(default_max))])
    return np.array(size)


def add_to_dict(dic, fill_in_defaults=True, default_value=None, **kwargs):
    """
    Helper function to add key-values to dictionary @dic where each entry is its own array (list).
    Args:
        dic (dict): Dictionary to which new key / value pairs will be added. If the key already exists,
            will append the value to that key entry
        fill_in_defaults (bool): If True, will automatically add @default_value to all dictionary entries that are
            not explicitly specified in @kwargs
        default_value (any): Default value to fill (None by default)

    Returns:
        dict: Modified dictionary
    """
    # Get keys and length of array for a given entry in dic
    keys = set(dic.keys())
    n = len(list(keys)[0]) if keys else 0
    for k, v in kwargs.items():
        if k in dic:
            dic[k].append(v)
            keys.remove(k)
        else:
            dic[k] = [default_value] * n + [v] if fill_in_defaults else [v]
    # If filling in defaults, fill in remaining default values
    if fill_in_defaults:
        for k in keys:
            dic[k].append(default_value)
    return dic


def add_prefix(
    root,
    prefix,
    tags="default",
    attribs="default",
    exclude=None,
):
    """
    Find all element(s) matching the requested @tag, and appends @prefix to all @attributes if they exist.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        prefix (str): Prefix to add to all specified attributes
        tags (str or list of str or set): Tag(s) to search for in this ElementTree. "Default" corresponds to all tags
        attribs (str or list of str or set): Element attribute(s) to append prefix to. "Default" corresponds
            to all attributes that reference names
        exclude (None or function): Filtering function that should take in an ET.Element or a string (attribute) and
            return True if we should exclude the given element / attribute from having any prefixes added
    """
    # Standardize tags and attributes to be a set
    if tags != "default":
        tags = {tags} if type(tags) is str else set(tags)
    if attribs == "default":
        attribs = MUJOCO_NAMED_ATTRIBUTES
    attribs = {attribs} if type(attribs) is str else set(attribs)

    # Check the current element for matching conditions
    if (tags == "default" or root.tag in tags) and (exclude is None or not exclude(root)):
        for attrib in attribs:
            v = root.get(attrib, None)
            # Only add prefix if the attribute exist, the current attribute doesn't already begin with prefix,
            # and the @exclude filter is either None or returns False
            if v is not None and not v.startswith(prefix) and (exclude is None or not exclude(v)):
                root.set(attrib, prefix + v)
    # Continue recursively searching through the element tree
    for r in root:
        add_prefix(root=r, prefix=prefix, tags=tags, attribs=attribs, exclude=exclude)


def add_material(root, naming_prefix="", custom_material=None):
    """
    Iterates through all element(s) in @root recursively and adds a material / texture to all visual geoms that don't
    already have a material specified.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        naming_prefix (str): Adds this prefix to all material and texture names
        custom_material (None or CustomMaterial): If specified, will add this material to all visual geoms.
            Else, will add a default "no-change" material.

    Returns:
        4-tuple: (ET.Element, ET.Element, CustomMaterial, bool) (tex_element, mat_element, material, used)
            corresponding to the added material and whether the material was actually used or not.
    """
    # Initialize used as False
    used = False
    # First, make sure material is specified
    if custom_material is None:
        custom_material = CustomMaterial(
            texture=None,
            tex_name="default_tex",
            mat_name="default_mat",
            tex_attrib={
                "type": "cube",
                "builtin": "flat",
                "width": 100,
                "height": 100,
                "rgb1": np.ones(3),
                "rgb2": np.ones(3),
            },
        )
    # Else, check to make sure the custom material begins with the specified prefix and that it's unique
    if not custom_material.name.startswith(naming_prefix) and not custom_material.shared:
        custom_material.name = naming_prefix + custom_material.name
        custom_material.tex_attrib["name"] = naming_prefix + custom_material.tex_attrib["name"]
        custom_material.mat_attrib["name"] = naming_prefix + custom_material.mat_attrib["name"]
        custom_material.mat_attrib["texture"] = naming_prefix + custom_material.mat_attrib["texture"]

    # Check the current element for matching conditions
    if root.tag == "geom" and root.get("group", None) == "1" and root.get("material", None) is None:
        # Add a new material attribute to this geom
        root.set("material", custom_material.name)
        # Set used to True
        used = True
    # Continue recursively searching through the element tree
    for r in root:
        _, _, _, _used = add_material(root=r, naming_prefix=naming_prefix, custom_material=custom_material)
        # Update used
        used = used or _used
    # Lastly, return the new texture and material elements
    tex_element = new_element(tag="texture", **custom_material.tex_attrib)
    mat_element = new_element(tag="material", **custom_material.mat_attrib)
    return tex_element, mat_element, custom_material, used


def recolor_collision_geoms(root, rgba, exclude=None):
    """
    Iteratively searches through all elements starting with @root to find all geoms belonging to group 0 and set
    the corresponding rgba value to the specified @rgba argument. Note: also removes any material values for these
    elements.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through
        rgba (4-array): (R, G, B, A) values to assign to all geoms with this group.
        exclude (None or function): Filtering function that should take in an ET.Element and
            return True if we should exclude the given element / attribute from having its collision geom impacted.
    """
    # Check this body
    if root.tag == "geom" and root.get("group") in {None, "0"} and (exclude is None or not exclude(root)):
        root.set("rgba", array_to_string(rgba))
        root.attrib.pop("material", None)

    # Iterate through all children elements
    for r in root:
        recolor_collision_geoms(root=r, rgba=rgba, exclude=exclude)


def _element_filter(element, parent):
    """
    Default element filter to be used in sort_elements. This will filter for the following groups:

        :`'root_body'`: Top-level body element
        :`'bodies'`: Any body elements
        :`'joints'`: Any joint elements
        :`'actuators'`: Any actuator elements
        :`'sites'`: Any site elements
        :`'sensors'`: Any sensor elements
        :`'contact_geoms'`: Any geoms used for collision (as specified by group 0 (default group) geoms)
        :`'visual_geoms'`: Any geoms used for visual rendering (as specified by group 1 geoms)

    Args:
        element (ET.Element): Current XML element that we are filtering
        parent (ET.Element): Parent XML element for the current element

    Returns:
        str or None: Assigned filter key for this element. None if no matching filter is found.
    """
    # Check for actuator first since this is dependent on the parent element
    if parent is not None and parent.tag == "actuator":
        return "actuators"
    elif element.tag == "joint":
        # Make sure this is not a tendon (this should not have a "joint", "joint1", or "joint2" attribute specified)
        if element.get("joint") is None and element.get("joint1") is None:
            return "joints"
    elif element.tag == "body":
        # If the parent of this does not have a tag "body", then this is the top-level body element
        if parent is None or parent.tag != "body":
            return "root_body"
        return "bodies"
    elif element.tag == "site":
        return "sites"
    elif element.tag in SENSOR_TYPES:
        return "sensors"
    elif element.tag == "geom":
        # Only get collision and visual geoms (group 0 / None, or 1, respectively)
        group = element.get("group")
        if group in {None, "0", "1"}:
            return "visual_geoms" if group == "1" else "contact_geoms"
    else:
        # If no condition met, return None
        return None


def sort_elements(root, parent=None, element_filter=None, _elements_dict=None):
    """
    Utility method to iteratively sort all elements based on @tags. This XML ElementTree will be parsed such that
    all elements with the same key as returned by @element_filter will be grouped as a list entry in the returned
    dictionary.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through
        parent (ET.Element): Parent of the root node. Default is None (no parent node initially)
        element_filter (None or function): Function used to filter the incoming elements. Should take in two
            ET.Elements (current_element, parent_element) and return a string filter_key if the element
            should be added to the list of values sorted by filter_key, and return None if no value should be added.
            If no element_filter is specified, defaults to self._element_filter.
        _elements_dict (dict): Dictionary that gets passed to recursive calls. Should not be modified externally by
            top-level call.

    Returns:
        dict: Filtered key-specific lists of the corresponding elements
    """
    # Initialize dictionary and element filter if None is set
    if _elements_dict is None:
        _elements_dict = {}
    if element_filter is None:
        element_filter = _element_filter

    # Parse this element
    key = element_filter(root, parent)
    if key is not None:
        # Initialize new entry in the dict if this is the first time encountering this value, otherwise append
        if key not in _elements_dict:
            _elements_dict[key] = [root]
        else:
            _elements_dict[key].append(root)

    # Loop through all possible subtrees for this XML recurisvely
    for r in root:
        _elements_dict = sort_elements(
            root=r, parent=root, element_filter=element_filter, _elements_dict=_elements_dict
        )

    return _elements_dict


def find_parent(root, child):
    """
    Find the parent element of the specified @child node, recurisvely searching through @root.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        child (ET.Element): Child element whose parent is to be found

    Returns:
        None or ET.Element: Matching parent if found, else None
    """
    # Iterate through children (DFS), if the correct child element is found, then return the current root as the parent
    for r in root:
        if r == child:
            return root
        parent = find_parent(root=r, child=child)
        if parent is not None:
            return parent
    # If we get here, we didn't find anything ):
    return None


def find_elements(root, tags, attribs=None, return_first=True):
    """
    Find all element(s) matching the requested @tag and @attributes. If @return_first is True, then will return the
    first element found matching the criteria specified. Otherwise, will return a list of elements that match the
    criteria.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        tags (str or list of str or set): Tag(s) to search for in this ElementTree.
        attribs (None or dict of str): Element attribute(s) to check against for a filtered element. A match is
            considered found only if all attributes match. Each attribute key should have a corresponding value with
            which to compare against.
        return_first (bool): Whether to immediately return once the first matching element is found.

    Returns:
        None or ET.Element or list of ET.Element: Matching element(s) found. Returns None if there was no match.
    """
    # Initialize return value
    elements = None if return_first else []

    # Make sure tags is list
    tags = [tags] if type(tags) is str else tags

    # Check the current element for matching conditions
    if root.tag in tags:
        matching = True
        if attribs is not None:
            for k, v in attribs.items():
                if root.get(k) != v:
                    matching = False
                    break
        # If all criteria were matched, add this to the solution (or return immediately if specified)
        if matching:
            if return_first:
                return root
            else:
                elements.append(root)
    # Continue recursively searching through the element tree
    for r in root:
        if return_first:
            elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first)
            if elements is not None:
                return elements
        else:
            found_elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first)
            pre_elements = deepcopy(elements)
            if found_elements:
                elements += found_elements if type(found_elements) is list else [found_elements]

    return elements if elements else None


def save_sim_model(sim, fname):
    """
    Saves the current model xml from @sim at file location @fname.

    Args:
        sim (MjSim): XML file to save, in string form
        fname (str): Absolute filepath to the location to save the file
    """
    with open(fname, "w") as f:
        sim.save(file=f, format="xml")


def get_ids(sim, elements, element_type="geom", inplace=False):
    """
    Grabs the mujoco IDs for each element in @elements, corresponding to the specified @element_type.

    Args:
        sim (MjSim): Active mujoco simulation object
        elements (str or list or dict): Element(s) to convert into IDs. Note that the return type corresponds to
            @elements type, where each element name is replaced with the ID
        element_type (str): The type of element to grab ID for. Options are {geom, body, site}
        inplace (bool): If False, will create a copy of @elements to prevent overwriting the original data structure

    Returns:
        str or list or dict: IDs corresponding to @elements.
    """
    if not inplace:
        # Copy elements first so we don't write to the underlying object
        elements = deepcopy(elements)
    # Choose what to do based on elements type
    if isinstance(elements, str):
        # We simply return the value of this single element
        assert element_type in {
            "geom",
            "body",
            "site",
        }, f"element_type must be either geom, body, or site. Got: {element_type}"
        if element_type == "geom":
            elements = sim.model.geom_name2id(elements)
        elif element_type == "body":
            elements = sim.model.body_name2id(elements)
        else:  # site
            elements = sim.model.site_name2id(elements)
    elif isinstance(elements, dict):
        # Iterate over each element in dict and recursively repeat
        for name, ele in elements:
            elements[name] = get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True)
    else:  # We assume this is an iterable array
        assert isinstance(elements, Iterable), "Elements must be iterable for get_id!"
        elements = [get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True) for ele in elements]

    return elements