import torch import numpy as np from spatialmath.base import r2q from spatialmath.base.transforms3d import isrot try: from pytorch3d.ops import corresponding_points_alignment except ImportError: print("pytorch3d not installed") from pfp import DEVICE def transform_th(transform: torch.Tensor, points: torch.Tensor) -> torch.Tensor: """Apply a 4x4 transformation matrix to a set of points.""" new_points = points @ transform[..., :3, :3].mT + transform[..., :3, 3] return new_points def vec_projection_np(v: np.ndarray, e: np.ndarray) -> np.ndarray: """Project vector v onto unit vector e.""" proj = np.sum(v * e, axis=-1, keepdims=True) * e return proj def vec_projection_th(v: torch.Tensor, e: torch.Tensor) -> torch.Tensor: """Project vector v onto unit vector e.""" proj = torch.sum(v * e, dim=-1, keepdim=True) * e return proj def grahm_schmidt_np(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: """Compute orthonormal basis from two vectors.""" v1 = v1.astype(np.float64) v2 = v2.astype(np.float64) u1 = v1 e1 = u1 / np.linalg.norm(u1, axis=-1, keepdims=True) u2 = v2 - vec_projection_np(v2, e1) e2 = u2 / np.linalg.norm(u2, axis=-1, keepdims=True) e3 = np.cross(e1, e2, axis=-1) rot_matrix = np.concatenate([e1[..., None], e2[..., None], e3[..., None]], axis=-1) return rot_matrix def grahm_schmidt_th(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor: """Compute orthonormal basis from two vectors.""" u1 = v1 e1 = u1 / torch.norm(u1, dim=-1, keepdim=True) u2 = v2 - vec_projection_th(v2, e1) e2 = u2 / torch.norm(u2, dim=-1, keepdim=True) e3 = torch.cross(e1, e2, dim=-1) rot_matrix = torch.cat( [e1.unsqueeze(dim=-1), e2.unsqueeze(dim=-1), e3.unsqueeze(dim=-1)], dim=-1 ) return rot_matrix def pfp_to_pose_np(robot_states: np.ndarray) -> np.ndarray: """Convert pfp state (T, 10) to 4x4 poses (T, 4, 4).""" T = robot_states.shape[0] poses = np.eye(4)[np.newaxis, ...] poses = np.tile(poses, (T, 1, 1)) poses[:, :3, 3] = robot_states[:, :3] poses[:, :3, :3] = grahm_schmidt_np(robot_states[:, 3:6], robot_states[:, 6:9]) return poses def pfp_to_pose_th(robot_states: torch.Tensor) -> torch.Tensor: """Convert pfp state (B, T, 10) to 4x4 poses (B, T, 4, 4) and gripper (B, T, 1).""" B = robot_states.shape[0] T = robot_states.shape[1] poses = ( torch.eye(4, device=robot_states.device) .unsqueeze(0) .unsqueeze(0) .expand(B, T, 4, 4) .contiguous() ) poses[..., :3, 3] = robot_states[..., :3] poses[..., :3, :3] = grahm_schmidt_th(robot_states[..., 3:6], robot_states[..., 6:9]) gripper = robot_states[..., -1:] return poses, gripper def rot6d_to_quat_np(rot6d: np.ndarray, order: str = "xyzs") -> np.ndarray: """Convert 6d rotation matrix to quaternion.""" rot = grahm_schmidt_np(rot6d[:3], rot6d[3:]) quat = r2q(rot, order=order) return quat def rot6d_to_rot_np(rot6d: np.ndarray) -> np.ndarray: """Convert 6d rotation matrix to 3x3 rotation matrix.""" rot = grahm_schmidt_np(rot6d[:3], rot6d[3:]) return rot def check_valid_rot(rot: np.ndarray) -> bool: """Check if the 3x3 rotation matrix is valid.""" valid = isrot(rot, check=True, tol=1e10) return valid def get_canonical_5p_th() -> torch.Tensor: """Return the (5,3) canonical 5points representation of the franka hand.""" gripper_width = 0.08 left_y = 0.5 * gripper_width right_y = -0.5 * gripper_width mid_z = -0.041 top_z = -0.1034 a = [0, 0, top_z] b = [0, left_y, mid_z] c = [0, right_y, mid_z] d = [0, left_y, 0] e = [0, right_y, 0] pose_5p = torch.tensor([a, b, c, d, e]) return pose_5p def pfp_to_state5p_th(robot_states: torch.Tensor) -> torch.Tensor: """ Convert pfp state (B, T, 10) to 5points representation (B, T, 16). 5p: [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4, gripper] """ device = robot_states.device poses, gripper = pfp_to_pose_th(robot_states) canonical_5p = get_canonical_5p_th().to(device) canonical_5p_homog = torch.cat([canonical_5p, torch.ones(5, 1, device=device)], dim=-1) poses_5p_homog = (poses @ canonical_5p_homog.mT).mT poses_5p = poses_5p_homog[..., :3].contiguous().flatten(start_dim=-2) state5p = torch.cat([poses_5p, gripper], dim=-1) return state5p def state5p_to_pfp_th(state5p: torch.Tensor) -> torch.Tensor: """ Convert 5points representation (B, T, 16) to pfp state (B, T, 10) using svd projection. """ device = state5p.device leading_dims = state5p.shape[0:2] # Flatten the batch and time dimensions state5p = state5p.reshape(-1, *state5p.shape[2:]) poses_5p, gripper = state5p[..., :-1], state5p[..., -1:] poses_5p = poses_5p.reshape(-1, 5, 3) canonical_5p = get_canonical_5p_th().expand(poses_5p.shape[0], 5, 3).to(device) with torch.cuda.amp.autocast(enabled=False): result = corresponding_points_alignment(canonical_5p, poses_5p) rotations = result.R.mT translations = result.T pfp_state = torch.cat([translations, rotations[..., 0], rotations[..., 1], gripper], dim=-1) # Reshape back to the batch and time dimensions pfp_state = pfp_state.reshape(*leading_dims, -1) return pfp_state def init_random_traj_th(B: int, T: int, noise_scale: float) -> torch.Tensor: """ B: batch size T: number of time steps """ # Position random_xyz = torch.randn((B, 1, 3), device=DEVICE) * noise_scale direction = torch.randn((B, 1, 3), device=DEVICE) direction = direction / torch.norm(direction, dim=-1, keepdim=True) t = torch.linspace(0, 1, T, device=DEVICE).unsqueeze(0).unsqueeze(-1) random_xyz = random_xyz + t * direction # Rotation 6d random_r1 = torch.randn((B, 1, 3), device=DEVICE) random_r1 = random_r1 / torch.norm(random_r1, dim=-1, keepdim=True) random_r2 = torch.randn((B, 1, 3), device=DEVICE) random_r2 = random_r2 - vec_projection_th(random_r2, random_r1) random_r2 = random_r2 / torch.norm(random_r2, dim=-1, keepdim=True) random_r6d = torch.cat([random_r1, random_r2], dim=-1) random_r6d = random_r6d.expand(B, T, 6) # Gripper gripper = torch.ones((B, T, 1), device=DEVICE) random_traj = torch.cat([random_xyz, random_r6d, gripper], dim=-1) return random_traj