lsnu's picture
Add files using upload-large-folder tool
912c7e2 verified
import torch
import numpy as np
from spatialmath.base import r2q
from spatialmath.base.transforms3d import isrot
try:
from pytorch3d.ops import corresponding_points_alignment
except ImportError:
print("pytorch3d not installed")
from pfp import DEVICE
def transform_th(transform: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
"""Apply a 4x4 transformation matrix to a set of points."""
new_points = points @ transform[..., :3, :3].mT + transform[..., :3, 3]
return new_points
def vec_projection_np(v: np.ndarray, e: np.ndarray) -> np.ndarray:
"""Project vector v onto unit vector e."""
proj = np.sum(v * e, axis=-1, keepdims=True) * e
return proj
def vec_projection_th(v: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
"""Project vector v onto unit vector e."""
proj = torch.sum(v * e, dim=-1, keepdim=True) * e
return proj
def grahm_schmidt_np(v1: np.ndarray, v2: np.ndarray) -> np.ndarray:
"""Compute orthonormal basis from two vectors."""
v1 = v1.astype(np.float64)
v2 = v2.astype(np.float64)
u1 = v1
e1 = u1 / np.linalg.norm(u1, axis=-1, keepdims=True)
u2 = v2 - vec_projection_np(v2, e1)
e2 = u2 / np.linalg.norm(u2, axis=-1, keepdims=True)
e3 = np.cross(e1, e2, axis=-1)
rot_matrix = np.concatenate([e1[..., None], e2[..., None], e3[..., None]], axis=-1)
return rot_matrix
def grahm_schmidt_th(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
"""Compute orthonormal basis from two vectors."""
u1 = v1
e1 = u1 / torch.norm(u1, dim=-1, keepdim=True)
u2 = v2 - vec_projection_th(v2, e1)
e2 = u2 / torch.norm(u2, dim=-1, keepdim=True)
e3 = torch.cross(e1, e2, dim=-1)
rot_matrix = torch.cat(
[e1.unsqueeze(dim=-1), e2.unsqueeze(dim=-1), e3.unsqueeze(dim=-1)], dim=-1
)
return rot_matrix
def pfp_to_pose_np(robot_states: np.ndarray) -> np.ndarray:
"""Convert pfp state (T, 10) to 4x4 poses (T, 4, 4)."""
T = robot_states.shape[0]
poses = np.eye(4)[np.newaxis, ...]
poses = np.tile(poses, (T, 1, 1))
poses[:, :3, 3] = robot_states[:, :3]
poses[:, :3, :3] = grahm_schmidt_np(robot_states[:, 3:6], robot_states[:, 6:9])
return poses
def pfp_to_pose_th(robot_states: torch.Tensor) -> torch.Tensor:
"""Convert pfp state (B, T, 10) to 4x4 poses (B, T, 4, 4) and gripper (B, T, 1)."""
B = robot_states.shape[0]
T = robot_states.shape[1]
poses = (
torch.eye(4, device=robot_states.device)
.unsqueeze(0)
.unsqueeze(0)
.expand(B, T, 4, 4)
.contiguous()
)
poses[..., :3, 3] = robot_states[..., :3]
poses[..., :3, :3] = grahm_schmidt_th(robot_states[..., 3:6], robot_states[..., 6:9])
gripper = robot_states[..., -1:]
return poses, gripper
def rot6d_to_quat_np(rot6d: np.ndarray, order: str = "xyzs") -> np.ndarray:
"""Convert 6d rotation matrix to quaternion."""
rot = grahm_schmidt_np(rot6d[:3], rot6d[3:])
quat = r2q(rot, order=order)
return quat
def rot6d_to_rot_np(rot6d: np.ndarray) -> np.ndarray:
"""Convert 6d rotation matrix to 3x3 rotation matrix."""
rot = grahm_schmidt_np(rot6d[:3], rot6d[3:])
return rot
def check_valid_rot(rot: np.ndarray) -> bool:
"""Check if the 3x3 rotation matrix is valid."""
valid = isrot(rot, check=True, tol=1e10)
return valid
def get_canonical_5p_th() -> torch.Tensor:
"""Return the (5,3) canonical 5points representation of the franka hand."""
gripper_width = 0.08
left_y = 0.5 * gripper_width
right_y = -0.5 * gripper_width
mid_z = -0.041
top_z = -0.1034
a = [0, 0, top_z]
b = [0, left_y, mid_z]
c = [0, right_y, mid_z]
d = [0, left_y, 0]
e = [0, right_y, 0]
pose_5p = torch.tensor([a, b, c, d, e])
return pose_5p
def pfp_to_state5p_th(robot_states: torch.Tensor) -> torch.Tensor:
"""
Convert pfp state (B, T, 10) to 5points representation (B, T, 16).
5p: [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4, gripper]
"""
device = robot_states.device
poses, gripper = pfp_to_pose_th(robot_states)
canonical_5p = get_canonical_5p_th().to(device)
canonical_5p_homog = torch.cat([canonical_5p, torch.ones(5, 1, device=device)], dim=-1)
poses_5p_homog = (poses @ canonical_5p_homog.mT).mT
poses_5p = poses_5p_homog[..., :3].contiguous().flatten(start_dim=-2)
state5p = torch.cat([poses_5p, gripper], dim=-1)
return state5p
def state5p_to_pfp_th(state5p: torch.Tensor) -> torch.Tensor:
"""
Convert 5points representation (B, T, 16) to pfp state (B, T, 10) using svd projection.
"""
device = state5p.device
leading_dims = state5p.shape[0:2]
# Flatten the batch and time dimensions
state5p = state5p.reshape(-1, *state5p.shape[2:])
poses_5p, gripper = state5p[..., :-1], state5p[..., -1:]
poses_5p = poses_5p.reshape(-1, 5, 3)
canonical_5p = get_canonical_5p_th().expand(poses_5p.shape[0], 5, 3).to(device)
with torch.cuda.amp.autocast(enabled=False):
result = corresponding_points_alignment(canonical_5p, poses_5p)
rotations = result.R.mT
translations = result.T
pfp_state = torch.cat([translations, rotations[..., 0], rotations[..., 1], gripper], dim=-1)
# Reshape back to the batch and time dimensions
pfp_state = pfp_state.reshape(*leading_dims, -1)
return pfp_state
def init_random_traj_th(B: int, T: int, noise_scale: float) -> torch.Tensor:
"""
B: batch size
T: number of time steps
"""
# Position
random_xyz = torch.randn((B, 1, 3), device=DEVICE) * noise_scale
direction = torch.randn((B, 1, 3), device=DEVICE)
direction = direction / torch.norm(direction, dim=-1, keepdim=True)
t = torch.linspace(0, 1, T, device=DEVICE).unsqueeze(0).unsqueeze(-1)
random_xyz = random_xyz + t * direction
# Rotation 6d
random_r1 = torch.randn((B, 1, 3), device=DEVICE)
random_r1 = random_r1 / torch.norm(random_r1, dim=-1, keepdim=True)
random_r2 = torch.randn((B, 1, 3), device=DEVICE)
random_r2 = random_r2 - vec_projection_th(random_r2, random_r1)
random_r2 = random_r2 / torch.norm(random_r2, dim=-1, keepdim=True)
random_r6d = torch.cat([random_r1, random_r2], dim=-1)
random_r6d = random_r6d.expand(B, T, 6)
# Gripper
gripper = torch.ones((B, T, 1), device=DEVICE)
random_traj = torch.cat([random_xyz, random_r6d, gripper], dim=-1)
return random_traj