File size: 10,530 Bytes
fc36e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
Base Case Handler Template
Abstract base class for all simulation case handlers.
"""

from abc import ABC, abstractmethod
import numpy as np
import torch
import gstaichi as ti
import genesis as gs
import sys

CASE_REGISTRY = {}

def register_case(case_name: str):
    """
    A decorator to automatically register the CaseHandler subclass to CASE_REGISTRY.
    """
    def decorator(cls):
        if case_name in CASE_REGISTRY:
            raise ValueError(f"Case name '{case_name}' already registered!")
        
        # Register: map the string case_name to the actual Class Object
        CASE_REGISTRY[case_name] = cls
        print(f"Registered Case: '{case_name}' -> {cls.__name__}")
        return cls # Return the unmodified class
    return decorator

class CaseHandler(ABC):
    """
    Abstract base class for handling case-specific simulation logic.
    Each simulation case should inherit from this class.
    """
    
    def __init__(self, config, all_obj_info: list[dict], device: torch.device):
        self.config = config
        self.all_obj_info = all_obj_info
        self.device = device

    def set_simulation_bounds(self, all_obj_occupied_lower_bound, all_obj_occupied_upper_bound):
        self.all_obj_occupied_lower_bound = all_obj_occupied_lower_bound
        self.all_obj_occupied_upper_bound = all_obj_occupied_upper_bound
        self.all_obj_occupied_size = self.all_obj_occupied_upper_bound - self.all_obj_occupied_lower_bound
        self.simulation_lower_bound = self.all_obj_occupied_lower_bound - 3 * self.all_obj_occupied_size
        self.simulation_upper_bound = self.all_obj_occupied_upper_bound + 3 * self.all_obj_occupied_size

    def get_simulation_bounds(self):
        return self.simulation_lower_bound.cpu().numpy(), self.simulation_upper_bound.cpu().numpy()
    

    def add_entities_to_scene(self, scene, obj_materials, obj_vis_modes):
        self.obj_materials = obj_materials
        self.obj_vis_modes = obj_vis_modes
        self.scene = scene
        self.objs = []
        if 'is_obj_fixed' not in self.config:
            is_obj_fixed = [False] * len(self.all_obj_info)
        else:
            is_obj_fixed = self.config['is_obj_fixed']
        for idx, per_obj_info in enumerate(self.all_obj_info):
            if "use_primitive" in self.config and self.config['use_primitive']:

                primitive_morhph = gs.morphs.Box(
                        pos=self.all_obj_info[idx]['center'].cpu().numpy().astype(np.float64),
                        size=self.all_obj_info[idx]['size'].cpu().numpy().astype(np.float64),
                        visualization=True,
                        collision=True,
                        fixed=False,
                    )
                per_obj = self.scene.add_entity(
                    material = self.obj_materials[idx],
                    morph = primitive_morhph,
                    surface = gs.surfaces.Default(
                        color = tuple(np.random.rand(3).tolist() + [1.0]),
                        vis_mode = self.obj_vis_modes[idx],
                    ),
                )
            else:
                try:
                    morph = gs.morphs.Mesh(
                            file = per_obj_info['mesh_path'],
                            scale = 1.0,
                            pos = tuple(per_obj_info['center'].cpu().numpy().astype(np.float64)),
                            euler = (0.0, 0.0, 0.0),
                            fixed = is_obj_fixed[idx],
                            # decimate = self.config['decimate'],
                            # convexify = self.config['convexify'],
                        )
                    per_obj = self.scene.add_entity(
                        material = self.obj_materials[idx],
                        morph = morph,
                        # morph = gs.morphs.Box(
                        #     pos = per_obj_info['center'].cpu().numpy(),
                        #     size = per_obj_info['size'].cpu().numpy(),
                        # ),
                        surface = gs.surfaces.Default(
                            color = tuple(np.random.rand(3).tolist() + [1.0]),
                            vis_mode = self.obj_vis_modes[idx],
                        ),
                    )
                except Exception as e:
                    print(e)
                    print("trying to add primitive mesh for object", idx)
                    primitive_morhph = gs.morphs.Box(
                        pos=self.all_obj_info[idx]['center'].cpu().numpy().astype(np.float64),
                        size=self.all_obj_info[idx]['size'].cpu().numpy().astype(np.float64),
                        visualization=True,
                        collision=True,
                        fixed=False,
                    )
                    per_obj = self.scene.add_entity(
                        material = self.obj_materials[idx],
                        morph = primitive_morhph,
                        surface = gs.surfaces.Default(
                            color = tuple(np.random.rand(3).tolist() + [1.0]),
                            vis_mode = self.obj_vis_modes[idx],
                        ),
                    )
            self.objs.append(per_obj)
    
        return self.objs



    
    def before_scene_building(self, scene, all_objs, ground_plane):
        self.scene = scene
        self.all_objs = all_objs
        self.detect_ground_plane(ground_plane)
        self.create_force_fields()
        self.add_robots()
        self.custom_setup()
        self.add_emitters()
    
    def after_scene_building(self):
        self.init_robots_pose()
        self.fix_particles()

    def custom_simulation(self, sid):
        pass

    def after_simulation_step(self, svr):
        pass

    def add_emitters(self):
        """Add emitters if needed for this case."""
        pass

    ## before scene building
    def detect_ground_plane(self, ground_plane):
        """Detect ground plane specific to this case."""
        self.ground_anchor = self.all_obj_occupied_lower_bound.cpu().numpy()
        self.ground_anchor[2] = self.ground_anchor[2]
        self.normal = np.array([0, 0, 1])
        self.scene.add_entity(
            material = gs.materials.Rigid(
                rho = 1000.0 if 'plane_rho' not in self.config else self.config['plane_rho'],
                friction = 5 if 'plane_friction' not in self.config else self.config['plane_friction'],
                coup_friction = 5.0 if 'plane_coup_friction' not in self.config else self.config['plane_coup_friction'],
                coup_softness = 0.002 if 'plane_coup_softness' not in self.config else self.config['plane_coup_softness'],
            ),
            morph = gs.morphs.Plane(pos=(self.ground_anchor[0], self.ground_anchor[1], self.ground_anchor[2]), normal=self.normal)
        )
    
    def create_force_fields(self):
        """Create case-specific force fields."""
        pass
    
    def custom_setup(self):
        """Custom setup for this case."""
        pass
    
    def add_robots(self):
        """Setup robots if needed for this case."""
        pass
    

    ## after scene building
    def init_robots_pose(self):
        """Initialize robots pose if needed for this case."""
        pass

    def fix_particles(self):
        """Fix particles if needed for this case."""
        pass



    def extract_franka_mesh_data_combined(self, target_franka):
        """
        Extract and combine all mesh data into single arrays with transformations applied.
        
        Returns:
            vertices: torch tensor of all transformed vertices
            faces: torch tensor of all faces (with proper indexing)
            colors: torch tensor of per-vertex colors
        """
        
        all_vertices = []
        all_faces = []
        all_colors = []
        
        vertex_offset = 0
        sim_vgeoms_render_T = target_franka.solver._vgeoms_render_T
        
        for vgeom in target_franka.vgeoms:
            verts = vgeom.vmesh.verts  # shape: (N, 3)
            faces = vgeom.vmesh.faces
            
            # Get transformation matrix for this vgeom
            cur_render_T = sim_vgeoms_render_T[vgeom.idx][0]  # shape: (4, 4), remove batch dim
            
            # Apply transformation to vertices
            # Convert vertices to homogeneous coordinates (N, 4)
            verts_homogeneous = np.concatenate([verts, np.ones((len(verts), 1))], axis=1)
            
            # Apply transformation: (N, 4) @ (4, 4)^T = (N, 4)
            verts_transformed = verts_homogeneous @ cur_render_T.T
            
            # Convert back to 3D coordinates (N, 3)
            verts_transformed = verts_transformed[:, :3]
            
            # Get color from surface
            surface = vgeom.vmesh.surface
            if hasattr(surface, 'diffuse_texture') and surface.diffuse_texture is not None:
                color = surface.diffuse_texture.color
            elif surface.color is not None:
                color = surface.color
            else:
                color = (0.5, 0.5, 0.5)
            
            # Offset faces by current vertex count
            faces_offset = faces + vertex_offset
            
            # Create per-vertex colors
            vertex_colors = np.tile(color, (len(verts), 1))
            
            all_vertices.append(verts_transformed)
            all_faces.append(faces_offset)
            all_colors.append(vertex_colors)
            
            vertex_offset += len(verts)
        
        vertices = torch.from_numpy(np.vstack(all_vertices)).to(self.device, dtype=torch.float32) # + self.franka_pos
        faces = torch.from_numpy(np.vstack(all_faces)).to(self.device, dtype=torch.int32)
        colors = torch.from_numpy(np.vstack(all_colors)).to(self.device, dtype=torch.float32)
        
        return vertices, faces, colors

def get_case_handler(case_name: str, config, all_obj_info, device) -> CaseHandler:
    """
    Factory function to return the corresponding CaseHandler instance based on the case name.
    """
    if case_name not in CASE_REGISTRY:
        raise ValueError(f"Unknown case name: '{case_name}'. Available cases: {list(CASE_REGISTRY.keys())}")
        
    # Dynamically get the class object
    CaseClass = CASE_REGISTRY[case_name]
    
    # Instantiate the class object and return
    # Pass all the parameters required by CaseHandler.__init__
    return CaseClass(config, all_obj_info, device)