| """Geometry utilities for SE(3) pose operations used in embodied robot learning. |
| |
| Pure torch implementation — all functions accept and return torch tensors, |
| support arbitrary batch dimensions, and work on both CPU and CUDA. |
| |
| Bimanual EEF state/action format (20-dim): |
| [left_xyz(3), left_rot6d(6), left_grip(1), right_xyz(3), right_rot6d(6), right_grip(1)] |
| |
| Rotation formats: |
| rpy — 3-dim, extrinsic xyz euler angles (roll-pitch-yaw) |
| xyzw — 4-dim, quaternion [x, y, z, w] (scipy/ROS convention) |
| wxyz — 4-dim, quaternion [w, x, y, z] (transforms3d/RoboTwin convention) |
| rot6d — 6-dim, first two columns of rotation matrix (Zhou et al., CVPR 2019) |
| rvec — 3-dim, axis-angle (Rodrigues vector), scipy `Rotation.as_rotvec()` convention. |
| rvec = angle * axis, angle in radians, canonical angle ∈ [0, π]. |
| """ |
|
|
| import torch |
|
|
| ROT_DIM = {"rpy": 3, "xyzw": 4, "wxyz": 4, "rot6d": 6, "rvec": 3} |
|
|
|
|
| |
| |
| |
|
|
|
|
| def quat_xyzw_to_mat(q: torch.Tensor) -> torch.Tensor: |
| """Quaternion [x, y, z, w] → rotation matrix. (..., 4) → (..., 3, 3).""" |
| q = q.to(torch.float64) |
| x, y, z, w = q.unbind(-1) |
| tx, ty, tz = 2 * x, 2 * y, 2 * z |
| twx, twy, twz = w * tx, w * ty, w * tz |
| txx, txy, txz = x * tx, x * ty, x * tz |
| tyy, tyz, tzz = y * ty, y * tz, z * tz |
| R = torch.stack([ |
| 1 - (tyy + tzz), txy - twz, txz + twy, |
| txy + twz, 1 - (txx + tzz), tyz - twx, |
| txz - twy, tyz + twx, 1 - (txx + tyy), |
| ], dim=-1).reshape(q.shape[:-1] + (3, 3)) |
| return R |
|
|
|
|
| def mat_to_quat_xyzw(R: torch.Tensor) -> torch.Tensor: |
| """Rotation matrix → quaternion [x, y, z, w]. (..., 3, 3) → (..., 4).""" |
| R = R.to(torch.float64) |
| batch = R.shape[:-2] |
| R = R.reshape(-1, 3, 3) |
| trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] |
| q = torch.empty(R.shape[0], 4, dtype=R.dtype, device=R.device) |
| |
| s = torch.sqrt(torch.clamp(trace + 1, min=1e-10)) * 2 |
| m1 = trace > 0 |
| q[m1, 3] = s[m1] / 4 |
| q[m1, 0] = (R[m1, 2, 1] - R[m1, 1, 2]) / s[m1] |
| q[m1, 1] = (R[m1, 0, 2] - R[m1, 2, 0]) / s[m1] |
| q[m1, 2] = (R[m1, 1, 0] - R[m1, 0, 1]) / s[m1] |
| |
| m2 = (~m1) & (R[:, 0, 0] > R[:, 1, 1]) & (R[:, 0, 0] > R[:, 2, 2]) |
| s2 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=1e-10)) * 2 |
| q[m2, 0] = s2[m2] / 4 |
| q[m2, 1] = (R[m2, 0, 1] + R[m2, 1, 0]) / s2[m2] |
| q[m2, 2] = (R[m2, 0, 2] + R[m2, 2, 0]) / s2[m2] |
| q[m2, 3] = (R[m2, 2, 1] - R[m2, 1, 2]) / s2[m2] |
| |
| m3 = (~m1) & (~m2) & (R[:, 1, 1] > R[:, 2, 2]) |
| s3 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=1e-10)) * 2 |
| q[m3, 1] = s3[m3] / 4 |
| q[m3, 0] = (R[m3, 0, 1] + R[m3, 1, 0]) / s3[m3] |
| q[m3, 2] = (R[m3, 1, 2] + R[m3, 2, 1]) / s3[m3] |
| q[m3, 3] = (R[m3, 0, 2] - R[m3, 2, 0]) / s3[m3] |
| |
| m4 = (~m1) & (~m2) & (~m3) |
| s4 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=1e-10)) * 2 |
| q[m4, 2] = s4[m4] / 4 |
| q[m4, 0] = (R[m4, 0, 2] + R[m4, 2, 0]) / s4[m4] |
| q[m4, 1] = (R[m4, 1, 2] + R[m4, 2, 1]) / s4[m4] |
| q[m4, 3] = (R[m4, 1, 0] - R[m4, 0, 1]) / s4[m4] |
| |
| q = q / q.norm(dim=-1, keepdim=True).clamp(min=1e-10) |
| return q.reshape(batch + (4,)) |
|
|
|
|
| def rpy_to_mat(rpy: torch.Tensor) -> torch.Tensor: |
| """RPY (extrinsic xyz euler angles) → rotation matrix. (..., 3) → (..., 3, 3). |
| |
| R = Rz(yaw) @ Ry(pitch) @ Rx(roll) |
| """ |
| rpy = rpy.to(torch.float64) |
| roll, pitch, yaw = rpy.unbind(-1) |
| cr, sr = torch.cos(roll), torch.sin(roll) |
| cp, sp = torch.cos(pitch), torch.sin(pitch) |
| cy, sy = torch.cos(yaw), torch.sin(yaw) |
| R = torch.stack([ |
| cy * cp, cy * sp * sr - sy * cr, cy * sp * cr + sy * sr, |
| sy * cp, sy * sp * sr + cy * cr, sy * sp * cr - cy * sr, |
| -sp, cp * sr, cp * cr, |
| ], dim=-1).reshape(rpy.shape[:-1] + (3, 3)) |
| return R |
|
|
|
|
| def mat_to_rpy(R: torch.Tensor) -> torch.Tensor: |
| """Rotation matrix → RPY (extrinsic xyz euler angles). (..., 3, 3) → (..., 3).""" |
| R = R.to(torch.float64) |
| pitch = torch.atan2(-R[..., 2, 0], torch.sqrt(R[..., 0, 0] ** 2 + R[..., 1, 0] ** 2)) |
| yaw = torch.atan2(R[..., 1, 0], R[..., 0, 0]) |
| roll = torch.atan2(R[..., 2, 1], R[..., 2, 2]) |
| return torch.stack([roll, pitch, yaw], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def rvec_to_mat(rvec: torch.Tensor) -> torch.Tensor: |
| """Axis-angle vector → rotation matrix. (..., 3) → (..., 3, 3). |
| |
| rvec = angle * axis. Implemented as rvec → quat_xyzw → mat for numerical |
| consistency with the rest of the module. |
| """ |
| rvec = rvec.to(torch.float64) |
| angle = rvec.norm(dim=-1, keepdim=True) |
| half = 0.5 * angle |
| |
| |
| small = angle < 1e-6 |
| coeff = torch.where(small, 0.5 - angle * angle / 48.0, torch.sin(half) / angle.clamp(min=1e-30)) |
| q_xyz = coeff * rvec |
| q_w = torch.cos(half) |
| q = torch.cat([q_xyz, q_w], dim=-1) |
| return quat_xyzw_to_mat(q) |
|
|
|
|
| def mat_to_rvec(R: torch.Tensor) -> torch.Tensor: |
| """Rotation matrix → axis-angle vector. (..., 3, 3) → (..., 3). |
| |
| Canonical form: angle ∈ [0, π] (scipy `as_rotvec()` convention). Computed |
| via mat → quat_xyzw → rvec, hemisphere-canonicalized so that q_w ≥ 0. |
| """ |
| q = mat_to_quat_xyzw(R) |
| |
| q = torch.where(q[..., 3:4] < 0, -q, q) |
| xyz = q[..., :3] |
| w = q[..., 3] |
| xyz_norm = xyz.norm(dim=-1) |
| angle = 2.0 * torch.atan2(xyz_norm, w) |
| |
| |
| |
| small = xyz_norm < 1e-8 |
| factor = torch.where(small, 2.0 / w.clamp(min=1e-30), angle / xyz_norm.clamp(min=1e-30)) |
| return factor.unsqueeze(-1) * xyz |
|
|
|
|
| |
| |
| |
|
|
|
|
| def rotmat_to_rot6d(R: torch.Tensor) -> torch.Tensor: |
| """First two columns of rotation matrix → 6D. (..., 3, 3) → (..., 6).""" |
| return torch.cat([R[..., :, 0], R[..., :, 1]], dim=-1) |
|
|
|
|
| def rot6d_to_rotmat(r6: torch.Tensor) -> torch.Tensor: |
| """6D → rotation matrix via Gram-Schmidt. (..., 6) → (..., 3, 3).""" |
| r6 = r6.to(torch.float64) |
| a1 = r6[..., :3] |
| a2 = r6[..., 3:] |
| b1 = a1 / a1.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
| dot = (b1 * a2).sum(dim=-1, keepdim=True) |
| b2 = a2 - dot * b1 |
| b2 = b2 / b2.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
| b3 = torch.linalg.cross(b1, b2) |
| return torch.stack([b1, b2, b3], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def quat_to_rot6d(q: torch.Tensor) -> torch.Tensor: |
| """Quaternion [x,y,z,w] → 6D rotation.""" |
| return rotmat_to_rot6d(quat_xyzw_to_mat(q)) |
|
|
|
|
| def rot6d_to_quat(r6: torch.Tensor) -> torch.Tensor: |
| """6D rotation → quaternion [x,y,z,w].""" |
| return mat_to_quat_xyzw(rot6d_to_rotmat(r6)) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def rot_to_mat(rot: torch.Tensor, fmt: str) -> torch.Tensor: |
| """Convert rotation in any format to rotation matrix. (..., D) → (..., 3, 3).""" |
| if fmt == "rot6d": |
| return rot6d_to_rotmat(rot) |
| if fmt == "xyzw": |
| return quat_xyzw_to_mat(rot) |
| if fmt == "wxyz": |
| return quat_xyzw_to_mat(rot[..., [1, 2, 3, 0]]) |
| if fmt == "rpy": |
| return rpy_to_mat(rot) |
| if fmt == "rvec": |
| return rvec_to_mat(rot) |
| raise ValueError(f"Unknown rotation format: {fmt!r}") |
|
|
|
|
| def mat_to_rot(R: torch.Tensor, fmt: str) -> torch.Tensor: |
| """Convert rotation matrix to any format. (..., 3, 3) → (..., D).""" |
| if fmt == "rot6d": |
| return rotmat_to_rot6d(R) |
| if fmt == "xyzw": |
| return mat_to_quat_xyzw(R) |
| if fmt == "wxyz": |
| return mat_to_quat_xyzw(R)[..., [3, 0, 1, 2]] |
| if fmt == "rpy": |
| return mat_to_rpy(R) |
| if fmt == "rvec": |
| return mat_to_rvec(R) |
| raise ValueError(f"Unknown rotation format: {fmt!r}") |
|
|
|
|
| def convert_rot(rot: torch.Tensor, from_fmt: str, to_fmt: str) -> torch.Tensor: |
| """Convert rotation between any two formats. (..., D_in) → (..., D_out).""" |
| if from_fmt == to_fmt: |
| return rot |
| return mat_to_rot(rot_to_mat(rot, from_fmt), to_fmt) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def eef16_to_eef20(states: torch.Tensor) -> torch.Tensor: |
| """Convert 16-dim endpose [xyz(3), xyzw(4), grip(1)] * 2 to 20-dim [xyz(3), rot6d(6), grip(1)] * 2.""" |
| batch = states.shape[:-1] |
| s = states.reshape(-1, 16) |
| out = torch.empty(s.shape[0], 20, dtype=states.dtype, device=states.device) |
| for arm_in, arm_out in [(0, 0), (8, 10)]: |
| out[:, arm_out:arm_out + 3] = s[:, arm_in:arm_in + 3] |
| out[:, arm_out + 3:arm_out + 9] = quat_to_rot6d(s[:, arm_in + 3:arm_in + 7]).to(states.dtype) |
| out[:, arm_out + 9] = s[:, arm_in + 7] |
| return out.reshape(batch + (20,)) |
|
|
|
|
| |
| |
| |
| |
|
|
| _ARM_DIM = 10 |
|
|
|
|
| def eef_relative(actions: torch.Tensor, init_state: torch.Tensor, num_arms: int = 2) -> torch.Tensor: |
| """Compute world-frame relative EEF actions. |
| |
| For each arm: |
| p_rel = p_target - p_init |
| R_rel = R_init^T @ R_target |
| gripper stays absolute |
| |
| Args: |
| actions: (N, num_arms*10) or (num_arms*10,) target absolute EEF states |
| init_state: (num_arms*10,) initial absolute EEF state (condition frame) |
| num_arms: 1 (single-arm, 10-dim) or 2 (bimanual, 20-dim) |
| Returns: |
| same shape as actions, relative EEF actions |
| """ |
| scalar = actions.ndim == 1 |
| if scalar: |
| actions = actions.unsqueeze(0) |
|
|
| result = torch.empty_like(actions) |
| for i in range(num_arms): |
| arm_off = i * _ARM_DIM |
| p_init = init_state[arm_off:arm_off + 3] |
| R_init = rot6d_to_rotmat(init_state[arm_off + 3:arm_off + 9]) |
|
|
| p_target = actions[:, arm_off:arm_off + 3] |
| r6_target = actions[:, arm_off + 3:arm_off + 9] |
|
|
| result[:, arm_off:arm_off + 3] = p_target - p_init |
| R_target = rot6d_to_rotmat(r6_target) |
| R_rel = torch.einsum("ki,nkj->nij", R_init, R_target) |
| result[:, arm_off + 3:arm_off + 9] = rotmat_to_rot6d(R_rel).to(actions.dtype) |
| result[:, arm_off + 9] = actions[:, arm_off + 9] |
|
|
| return result.squeeze(0) if scalar else result |
|
|
|
|
| def eef_apply_relative(current_state: torch.Tensor, relative_actions: torch.Tensor, num_arms: int = 2) -> torch.Tensor: |
| """Convert world-frame relative EEF actions back to absolute (inverse of eef_relative). |
| |
| Args: |
| current_state: (num_arms*10,) current absolute EEF state in rot6d |
| relative_actions: (N, num_arms*10) or (num_arms*10,) relative EEF actions |
| num_arms: 1 (single-arm, 10-dim) or 2 (bimanual, 20-dim) |
| Returns: |
| same shape as relative_actions, absolute EEF actions |
| """ |
| scalar = relative_actions.ndim == 1 |
| if scalar: |
| relative_actions = relative_actions.unsqueeze(0) |
|
|
| result = torch.empty_like(relative_actions) |
| for i in range(num_arms): |
| arm_off = i * _ARM_DIM |
| p_cur = current_state[arm_off:arm_off + 3] |
| R_cur = rot6d_to_rotmat(current_state[arm_off + 3:arm_off + 9]) |
|
|
| p_rel = relative_actions[:, arm_off:arm_off + 3] |
| r6_rel = relative_actions[:, arm_off + 3:arm_off + 9] |
|
|
| result[:, arm_off:arm_off + 3] = p_rel + p_cur |
| R_rel = rot6d_to_rotmat(r6_rel) |
| R_abs = torch.einsum("ik,nkj->nij", R_cur, R_rel) |
| result[:, arm_off + 3:arm_off + 9] = rotmat_to_rot6d(R_abs).to(relative_actions.dtype) |
| result[:, arm_off + 9] = relative_actions[:, arm_off + 9] |
|
|
| return result.squeeze(0) if scalar else result |
|
|