""" Base Case Handler Template Abstract base class for all simulation case handlers. """ from abc import ABC, abstractmethod import numpy as np import torch import gstaichi as ti import genesis as gs import sys CASE_REGISTRY = {} def register_case(case_name: str): """ A decorator to automatically register the CaseHandler subclass to CASE_REGISTRY. """ def decorator(cls): if case_name in CASE_REGISTRY: raise ValueError(f"Case name '{case_name}' already registered!") # Register: map the string case_name to the actual Class Object CASE_REGISTRY[case_name] = cls print(f"Registered Case: '{case_name}' -> {cls.__name__}") return cls # Return the unmodified class return decorator class CaseHandler(ABC): """ Abstract base class for handling case-specific simulation logic. Each simulation case should inherit from this class. """ def __init__(self, config, all_obj_info: list[dict], device: torch.device): self.config = config self.all_obj_info = all_obj_info self.device = device def set_simulation_bounds(self, all_obj_occupied_lower_bound, all_obj_occupied_upper_bound): self.all_obj_occupied_lower_bound = all_obj_occupied_lower_bound self.all_obj_occupied_upper_bound = all_obj_occupied_upper_bound self.all_obj_occupied_size = self.all_obj_occupied_upper_bound - self.all_obj_occupied_lower_bound self.simulation_lower_bound = self.all_obj_occupied_lower_bound - 3 * self.all_obj_occupied_size self.simulation_upper_bound = self.all_obj_occupied_upper_bound + 3 * self.all_obj_occupied_size def get_simulation_bounds(self): return self.simulation_lower_bound.cpu().numpy(), self.simulation_upper_bound.cpu().numpy() def add_entities_to_scene(self, scene, obj_materials, obj_vis_modes): self.obj_materials = obj_materials self.obj_vis_modes = obj_vis_modes self.scene = scene self.objs = [] if 'is_obj_fixed' not in self.config: is_obj_fixed = [False] * len(self.all_obj_info) else: is_obj_fixed = self.config['is_obj_fixed'] for idx, per_obj_info in enumerate(self.all_obj_info): if "use_primitive" in self.config and self.config['use_primitive']: primitive_morhph = gs.morphs.Box( pos=self.all_obj_info[idx]['center'].cpu().numpy().astype(np.float64), size=self.all_obj_info[idx]['size'].cpu().numpy().astype(np.float64), visualization=True, collision=True, fixed=False, ) per_obj = self.scene.add_entity( material = self.obj_materials[idx], morph = primitive_morhph, surface = gs.surfaces.Default( color = tuple(np.random.rand(3).tolist() + [1.0]), vis_mode = self.obj_vis_modes[idx], ), ) else: try: morph = gs.morphs.Mesh( file = per_obj_info['mesh_path'], scale = 1.0, pos = tuple(per_obj_info['center'].cpu().numpy().astype(np.float64)), euler = (0.0, 0.0, 0.0), fixed = is_obj_fixed[idx], # decimate = self.config['decimate'], # convexify = self.config['convexify'], ) per_obj = self.scene.add_entity( material = self.obj_materials[idx], morph = morph, # morph = gs.morphs.Box( # pos = per_obj_info['center'].cpu().numpy(), # size = per_obj_info['size'].cpu().numpy(), # ), surface = gs.surfaces.Default( color = tuple(np.random.rand(3).tolist() + [1.0]), vis_mode = self.obj_vis_modes[idx], ), ) except Exception as e: print(e) print("trying to add primitive mesh for object", idx) primitive_morhph = gs.morphs.Box( pos=self.all_obj_info[idx]['center'].cpu().numpy().astype(np.float64), size=self.all_obj_info[idx]['size'].cpu().numpy().astype(np.float64), visualization=True, collision=True, fixed=False, ) per_obj = self.scene.add_entity( material = self.obj_materials[idx], morph = primitive_morhph, surface = gs.surfaces.Default( color = tuple(np.random.rand(3).tolist() + [1.0]), vis_mode = self.obj_vis_modes[idx], ), ) self.objs.append(per_obj) return self.objs def before_scene_building(self, scene, all_objs, ground_plane): self.scene = scene self.all_objs = all_objs self.detect_ground_plane(ground_plane) self.create_force_fields() self.add_robots() self.custom_setup() self.add_emitters() def after_scene_building(self): self.init_robots_pose() self.fix_particles() def custom_simulation(self, sid): pass def after_simulation_step(self, svr): pass def add_emitters(self): """Add emitters if needed for this case.""" pass ## before scene building def detect_ground_plane(self, ground_plane): """Detect ground plane specific to this case.""" self.ground_anchor = self.all_obj_occupied_lower_bound.cpu().numpy() self.ground_anchor[2] = self.ground_anchor[2] self.normal = np.array([0, 0, 1]) self.scene.add_entity( material = gs.materials.Rigid( rho = 1000.0 if 'plane_rho' not in self.config else self.config['plane_rho'], friction = 5 if 'plane_friction' not in self.config else self.config['plane_friction'], coup_friction = 5.0 if 'plane_coup_friction' not in self.config else self.config['plane_coup_friction'], coup_softness = 0.002 if 'plane_coup_softness' not in self.config else self.config['plane_coup_softness'], ), morph = gs.morphs.Plane(pos=(self.ground_anchor[0], self.ground_anchor[1], self.ground_anchor[2]), normal=self.normal) ) def create_force_fields(self): """Create case-specific force fields.""" pass def custom_setup(self): """Custom setup for this case.""" pass def add_robots(self): """Setup robots if needed for this case.""" pass ## after scene building def init_robots_pose(self): """Initialize robots pose if needed for this case.""" pass def fix_particles(self): """Fix particles if needed for this case.""" pass def extract_franka_mesh_data_combined(self, target_franka): """ Extract and combine all mesh data into single arrays with transformations applied. Returns: vertices: torch tensor of all transformed vertices faces: torch tensor of all faces (with proper indexing) colors: torch tensor of per-vertex colors """ all_vertices = [] all_faces = [] all_colors = [] vertex_offset = 0 sim_vgeoms_render_T = target_franka.solver._vgeoms_render_T for vgeom in target_franka.vgeoms: verts = vgeom.vmesh.verts # shape: (N, 3) faces = vgeom.vmesh.faces # Get transformation matrix for this vgeom cur_render_T = sim_vgeoms_render_T[vgeom.idx][0] # shape: (4, 4), remove batch dim # Apply transformation to vertices # Convert vertices to homogeneous coordinates (N, 4) verts_homogeneous = np.concatenate([verts, np.ones((len(verts), 1))], axis=1) # Apply transformation: (N, 4) @ (4, 4)^T = (N, 4) verts_transformed = verts_homogeneous @ cur_render_T.T # Convert back to 3D coordinates (N, 3) verts_transformed = verts_transformed[:, :3] # Get color from surface surface = vgeom.vmesh.surface if hasattr(surface, 'diffuse_texture') and surface.diffuse_texture is not None: color = surface.diffuse_texture.color elif surface.color is not None: color = surface.color else: color = (0.5, 0.5, 0.5) # Offset faces by current vertex count faces_offset = faces + vertex_offset # Create per-vertex colors vertex_colors = np.tile(color, (len(verts), 1)) all_vertices.append(verts_transformed) all_faces.append(faces_offset) all_colors.append(vertex_colors) vertex_offset += len(verts) vertices = torch.from_numpy(np.vstack(all_vertices)).to(self.device, dtype=torch.float32) # + self.franka_pos faces = torch.from_numpy(np.vstack(all_faces)).to(self.device, dtype=torch.int32) colors = torch.from_numpy(np.vstack(all_colors)).to(self.device, dtype=torch.float32) return vertices, faces, colors def get_case_handler(case_name: str, config, all_obj_info, device) -> CaseHandler: """ Factory function to return the corresponding CaseHandler instance based on the case name. """ if case_name not in CASE_REGISTRY: raise ValueError(f"Unknown case name: '{case_name}'. Available cases: {list(CASE_REGISTRY.keys())}") # Dynamically get the class object CaseClass = CASE_REGISTRY[case_name] # Instantiate the class object and return # Pass all the parameters required by CaseHandler.__init__ return CaseClass(config, all_obj_info, device)