lsnu's picture
Add files using upload-large-folder tool
912c7e2 verified
from __future__ import annotations
import trimesh
import numpy as np
import open3d as o3d
from yourdfpy.urdf import URDF
from pfp.common.se3_utils import pfp_to_pose_np
try:
import rerun as rr
except ImportError:
print("WARNING: Rerun not installed. Visualization will not work.")
class RerunViewer:
def __init__(self, name: str, addr: str = None):
rr.init(name)
if addr is None:
addr = "127.0.0.1"
port = ":9876"
rr.connect(addr + port)
RerunViewer.clear()
return
@staticmethod
def add_obs_dict(obs_dict: dict, timestep: int = None):
if timestep is not None:
rr.set_time_sequence("timestep", timestep)
RerunViewer.add_rgb("rgb", obs_dict["image"])
RerunViewer.add_depth("depth", obs_dict["depth"])
RerunViewer.add_np_pointcloud(
"vis/pointcloud",
points=obs_dict["point_cloud"][:, :3],
colors_uint8=obs_dict["point_cloud"][:, 3:],
)
return
@staticmethod
def add_o3d_pointcloud(name: str, pointcloud: o3d.geometry.PointCloud, radii: float = None):
points = np.asanyarray(pointcloud.points)
colors = np.asanyarray(pointcloud.colors) if pointcloud.has_colors() else None
colors_uint8 = (colors * 255).astype(np.uint8) if pointcloud.has_colors() else None
RerunViewer.add_np_pointcloud(name, points, colors_uint8, radii)
return
@staticmethod
def add_np_pointcloud(
name: str, points: np.ndarray, colors_uint8: np.ndarray = None, radii: float = None
):
rr_points = rr.Points3D(positions=points, colors=colors_uint8, radii=radii)
rr.log(name, rr_points)
return
@staticmethod
def add_axis(name: str, pose: np.ndarray, size: float = 0.004, timeless: bool = False):
mesh = trimesh.creation.axis(origin_size=size, transform=pose)
RerunViewer.add_mesh_trimesh(name, mesh, timeless)
return
@staticmethod
def add_aabb(name: str, centers: np.ndarray, extents: np.ndarray, timeless=False):
rr.log(name, rr.Boxes3D(centers=centers, sizes=extents), timeless=timeless)
return
@staticmethod
def add_mesh_trimesh(name: str, mesh: trimesh.Trimesh, timeless: bool = False):
# Handle colors
if mesh.visual.kind in ["vertex", "face"]:
vertex_colors = mesh.visual.vertex_colors
elif mesh.visual.kind == "texture":
vertex_colors = mesh.visual.to_color().vertex_colors
else:
vertex_colors = None
# Log mesh
rr_mesh = rr.Mesh3D(
vertex_positions=mesh.vertices,
vertex_colors=vertex_colors,
vertex_normals=mesh.vertex_normals,
indices=mesh.faces,
)
rr.log(name, rr_mesh, timeless=timeless)
return
@staticmethod
def add_mesh_list_trimesh(name: str, meshes: list[trimesh.Trimesh]):
for i, mesh in enumerate(meshes):
RerunViewer.add_mesh_trimesh(name + f"/{i}", mesh)
return
@staticmethod
def add_rgb(name: str, rgb_uint8: np.ndarray):
if rgb_uint8.shape[0] == 3:
# CHW -> HWC
rgb_uint8 = np.transpose(rgb_uint8, (1, 2, 0))
rr.log(name, rr.Image(rgb_uint8))
@staticmethod
def add_depth(name: str, detph: np.ndarray):
rr.log(name, rr.DepthImage(detph))
@staticmethod
def add_traj(name: str, traj: np.ndarray):
"""
name: str
traj: np.ndarray (T, 10)
"""
poses = pfp_to_pose_np(traj)
for i, pose in enumerate(poses):
RerunViewer.add_axis(name + f"/{i}t", pose)
return
@staticmethod
def clear():
rr.log("vis", rr.Clear(recursive=True))
return
class RerunTraj:
def __init__(self) -> None:
self.traj_shape = None
return
def add_traj(self, name: str, traj: np.ndarray, size: float = 0.004):
"""
name: str
traj: np.ndarray (T, 10)
"""
if self.traj_shape is None or self.traj_shape != traj.shape:
self.traj_shape = traj.shape
for i in range(traj.shape[0]):
RerunViewer.add_axis(name + f"/{i}t", np.eye(4), size)
poses = pfp_to_pose_np(traj)
for i, pose in enumerate(poses):
rr.log(
name + f"/{i}t",
rr.Transform3D(mat3x3=pose[:3, :3], translation=pose[:3, 3]),
)
return
class RerunURDF:
def __init__(self, name: str, urdf_path: str, meshes_root: str):
self.name = name
self.urdf: URDF = URDF.load(urdf_path, mesh_dir=meshes_root)
return
def update_vis(
self,
joint_state: list | np.ndarray,
root_pose: np.ndarray = np.eye(4),
name_suffix: str = "",
):
self._update_joints(joint_state)
scene = self.urdf.scene
trimeshes = self._scene_to_trimeshes(scene)
trimeshes = [t.apply_transform(root_pose) for t in trimeshes]
RerunViewer.add_mesh_list_trimesh(self.name + name_suffix, trimeshes)
return
def _update_joints(self, joint_state: list | np.ndarray):
assert len(joint_state) == len(self.urdf.actuated_joints), "Wrong number of joint values."
self.urdf.update_cfg(joint_state)
return
def _scene_to_trimeshes(self, scene: trimesh.Scene) -> list[trimesh.Trimesh]:
"""
Convert a trimesh.Scene to a list of trimesh.Trimesh.
Skips objects that are not an instance of trimesh.Trimesh.
"""
trimeshes = []
scene_dump = scene.dump()
geometries = [scene_dump] if not isinstance(scene_dump, list) else scene_dump
for geometry in geometries:
if isinstance(geometry, trimesh.Trimesh):
trimeshes.append(geometry)
elif isinstance(geometry, trimesh.Scene):
trimeshes.extend(self._scene_to_trimeshes(geometry))
return trimeshes