File size: 2,590 Bytes
5ce8761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from scipy.spatial.transform import Rotation as R

from pyrep.objects.shape import Shape
from rlbench.environment import Environment
from rlbench.backend.scene import Scene
from rlbench.sim2real.domain_randomization_scene import DomainRandomizationScene


class CustomizedScene(Scene):

    def get_observation(self):
        obs = super().get_observation()

        mesh_points = {}
        # self._robot_shapes
        joints = self.robot.arm.get_visuals()
        # add self._workspace to the list of objects
        # base_object = self.task._base_object
        joints.append(Shape('diningTable_visible'))#self._workspace)
        for curr_obj in joints:
            obj_name = curr_obj.get_name()
            obj_vertices, _, _ = curr_obj.get_mesh_data()
            obj_pose = curr_obj.get_pose()
            position = obj_pose[:3]
            quaternion = obj_pose[3:7]
            rotation_matrix = R.from_quat(quaternion).as_matrix()
            rotated_vertices = (rotation_matrix @ obj_vertices.T).T
            curr_points = rotated_vertices + position.reshape(1, 3)
            mesh_points[obj_name] = curr_points
        for (curr_obj, obj_type) in self.task._initial_objs_in_scene:
            if isinstance(curr_obj, Shape):
                obj_name = curr_obj.get_name()
                obj_color = str(curr_obj.get_color())
                obj_vertices, _, _ = curr_obj.get_mesh_data()
                obj_pose = curr_obj.get_pose()
                position = obj_pose[:3]
                quaternion = obj_pose[3:7]
                rotation_matrix = R.from_quat(quaternion).as_matrix()
                rotated_vertices = (rotation_matrix @ obj_vertices.T).T
                curr_points = rotated_vertices + position.reshape(1, 3)
                mesh_points[obj_name+"_"+obj_color] = curr_points
        obs.mesh_points = mesh_points
        obs = self.task.decorate_observation(obs)
        return obs


class CustomizedDomainRandomizationScene(CustomizedScene, DomainRandomizationScene):
    pass


class CustomizedEnvironment(Environment):

    def launch(self):
        super().launch()
        if self._randomize_every is None:
            self._scene = CustomizedScene(
                self._pyrep, self._robot, self._obs_config, self._robot_setup)
        else:
            self._scene = CustomizedDomainRandomizationScene(
                self._pyrep, self._robot, self._obs_config, self._robot_setup,
                self._randomize_every, self._frequency,
                self._visual_randomization_config,
                self._dynamics_randomization_config)