Wei Liu
init huggingface deployment
fc36e06
"""
Base Case Handler Template
Abstract base class for all simulation case handlers.
"""
from abc import ABC, abstractmethod
import numpy as np
import torch
import gstaichi as ti
import genesis as gs
import sys
CASE_REGISTRY = {}
def register_case(case_name: str):
"""
A decorator to automatically register the CaseHandler subclass to CASE_REGISTRY.
"""
def decorator(cls):
if case_name in CASE_REGISTRY:
raise ValueError(f"Case name '{case_name}' already registered!")
# Register: map the string case_name to the actual Class Object
CASE_REGISTRY[case_name] = cls
print(f"Registered Case: '{case_name}' -> {cls.__name__}")
return cls # Return the unmodified class
return decorator
class CaseHandler(ABC):
"""
Abstract base class for handling case-specific simulation logic.
Each simulation case should inherit from this class.
"""
def __init__(self, config, all_obj_info: list[dict], device: torch.device):
self.config = config
self.all_obj_info = all_obj_info
self.device = device
def set_simulation_bounds(self, all_obj_occupied_lower_bound, all_obj_occupied_upper_bound):
self.all_obj_occupied_lower_bound = all_obj_occupied_lower_bound
self.all_obj_occupied_upper_bound = all_obj_occupied_upper_bound
self.all_obj_occupied_size = self.all_obj_occupied_upper_bound - self.all_obj_occupied_lower_bound
self.simulation_lower_bound = self.all_obj_occupied_lower_bound - 3 * self.all_obj_occupied_size
self.simulation_upper_bound = self.all_obj_occupied_upper_bound + 3 * self.all_obj_occupied_size
def get_simulation_bounds(self):
return self.simulation_lower_bound.cpu().numpy(), self.simulation_upper_bound.cpu().numpy()
def add_entities_to_scene(self, scene, obj_materials, obj_vis_modes):
self.obj_materials = obj_materials
self.obj_vis_modes = obj_vis_modes
self.scene = scene
self.objs = []
if 'is_obj_fixed' not in self.config:
is_obj_fixed = [False] * len(self.all_obj_info)
else:
is_obj_fixed = self.config['is_obj_fixed']
for idx, per_obj_info in enumerate(self.all_obj_info):
if "use_primitive" in self.config and self.config['use_primitive']:
primitive_morhph = gs.morphs.Box(
pos=self.all_obj_info[idx]['center'].cpu().numpy().astype(np.float64),
size=self.all_obj_info[idx]['size'].cpu().numpy().astype(np.float64),
visualization=True,
collision=True,
fixed=False,
)
per_obj = self.scene.add_entity(
material = self.obj_materials[idx],
morph = primitive_morhph,
surface = gs.surfaces.Default(
color = tuple(np.random.rand(3).tolist() + [1.0]),
vis_mode = self.obj_vis_modes[idx],
),
)
else:
try:
morph = gs.morphs.Mesh(
file = per_obj_info['mesh_path'],
scale = 1.0,
pos = tuple(per_obj_info['center'].cpu().numpy().astype(np.float64)),
euler = (0.0, 0.0, 0.0),
fixed = is_obj_fixed[idx],
# decimate = self.config['decimate'],
# convexify = self.config['convexify'],
)
per_obj = self.scene.add_entity(
material = self.obj_materials[idx],
morph = morph,
# morph = gs.morphs.Box(
# pos = per_obj_info['center'].cpu().numpy(),
# size = per_obj_info['size'].cpu().numpy(),
# ),
surface = gs.surfaces.Default(
color = tuple(np.random.rand(3).tolist() + [1.0]),
vis_mode = self.obj_vis_modes[idx],
),
)
except Exception as e:
print(e)
print("trying to add primitive mesh for object", idx)
primitive_morhph = gs.morphs.Box(
pos=self.all_obj_info[idx]['center'].cpu().numpy().astype(np.float64),
size=self.all_obj_info[idx]['size'].cpu().numpy().astype(np.float64),
visualization=True,
collision=True,
fixed=False,
)
per_obj = self.scene.add_entity(
material = self.obj_materials[idx],
morph = primitive_morhph,
surface = gs.surfaces.Default(
color = tuple(np.random.rand(3).tolist() + [1.0]),
vis_mode = self.obj_vis_modes[idx],
),
)
self.objs.append(per_obj)
return self.objs
def before_scene_building(self, scene, all_objs, ground_plane):
self.scene = scene
self.all_objs = all_objs
self.detect_ground_plane(ground_plane)
self.create_force_fields()
self.add_robots()
self.custom_setup()
self.add_emitters()
def after_scene_building(self):
self.init_robots_pose()
self.fix_particles()
def custom_simulation(self, sid):
pass
def after_simulation_step(self, svr):
pass
def add_emitters(self):
"""Add emitters if needed for this case."""
pass
## before scene building
def detect_ground_plane(self, ground_plane):
"""Detect ground plane specific to this case."""
self.ground_anchor = self.all_obj_occupied_lower_bound.cpu().numpy()
self.ground_anchor[2] = self.ground_anchor[2]
self.normal = np.array([0, 0, 1])
self.scene.add_entity(
material = gs.materials.Rigid(
rho = 1000.0 if 'plane_rho' not in self.config else self.config['plane_rho'],
friction = 5 if 'plane_friction' not in self.config else self.config['plane_friction'],
coup_friction = 5.0 if 'plane_coup_friction' not in self.config else self.config['plane_coup_friction'],
coup_softness = 0.002 if 'plane_coup_softness' not in self.config else self.config['plane_coup_softness'],
),
morph = gs.morphs.Plane(pos=(self.ground_anchor[0], self.ground_anchor[1], self.ground_anchor[2]), normal=self.normal)
)
def create_force_fields(self):
"""Create case-specific force fields."""
pass
def custom_setup(self):
"""Custom setup for this case."""
pass
def add_robots(self):
"""Setup robots if needed for this case."""
pass
## after scene building
def init_robots_pose(self):
"""Initialize robots pose if needed for this case."""
pass
def fix_particles(self):
"""Fix particles if needed for this case."""
pass
def extract_franka_mesh_data_combined(self, target_franka):
"""
Extract and combine all mesh data into single arrays with transformations applied.
Returns:
vertices: torch tensor of all transformed vertices
faces: torch tensor of all faces (with proper indexing)
colors: torch tensor of per-vertex colors
"""
all_vertices = []
all_faces = []
all_colors = []
vertex_offset = 0
sim_vgeoms_render_T = target_franka.solver._vgeoms_render_T
for vgeom in target_franka.vgeoms:
verts = vgeom.vmesh.verts # shape: (N, 3)
faces = vgeom.vmesh.faces
# Get transformation matrix for this vgeom
cur_render_T = sim_vgeoms_render_T[vgeom.idx][0] # shape: (4, 4), remove batch dim
# Apply transformation to vertices
# Convert vertices to homogeneous coordinates (N, 4)
verts_homogeneous = np.concatenate([verts, np.ones((len(verts), 1))], axis=1)
# Apply transformation: (N, 4) @ (4, 4)^T = (N, 4)
verts_transformed = verts_homogeneous @ cur_render_T.T
# Convert back to 3D coordinates (N, 3)
verts_transformed = verts_transformed[:, :3]
# Get color from surface
surface = vgeom.vmesh.surface
if hasattr(surface, 'diffuse_texture') and surface.diffuse_texture is not None:
color = surface.diffuse_texture.color
elif surface.color is not None:
color = surface.color
else:
color = (0.5, 0.5, 0.5)
# Offset faces by current vertex count
faces_offset = faces + vertex_offset
# Create per-vertex colors
vertex_colors = np.tile(color, (len(verts), 1))
all_vertices.append(verts_transformed)
all_faces.append(faces_offset)
all_colors.append(vertex_colors)
vertex_offset += len(verts)
vertices = torch.from_numpy(np.vstack(all_vertices)).to(self.device, dtype=torch.float32) # + self.franka_pos
faces = torch.from_numpy(np.vstack(all_faces)).to(self.device, dtype=torch.int32)
colors = torch.from_numpy(np.vstack(all_colors)).to(self.device, dtype=torch.float32)
return vertices, faces, colors
def get_case_handler(case_name: str, config, all_obj_info, device) -> CaseHandler:
"""
Factory function to return the corresponding CaseHandler instance based on the case name.
"""
if case_name not in CASE_REGISTRY:
raise ValueError(f"Unknown case name: '{case_name}'. Available cases: {list(CASE_REGISTRY.keys())}")
# Dynamically get the class object
CaseClass = CASE_REGISTRY[case_name]
# Instantiate the class object and return
# Pass all the parameters required by CaseHandler.__init__
return CaseClass(config, all_obj_info, device)