rc_inference / utils /geom.py
rxcui's picture
Upload folder using huggingface_hub
cb65626 verified
"""Geometry utilities for SE(3) pose operations used in embodied robot learning.
Pure torch implementation — all functions accept and return torch tensors,
support arbitrary batch dimensions, and work on both CPU and CUDA.
Bimanual EEF state/action format (20-dim):
[left_xyz(3), left_rot6d(6), left_grip(1), right_xyz(3), right_rot6d(6), right_grip(1)]
Rotation formats:
rpy — 3-dim, extrinsic xyz euler angles (roll-pitch-yaw)
xyzw — 4-dim, quaternion [x, y, z, w] (scipy/ROS convention)
wxyz — 4-dim, quaternion [w, x, y, z] (transforms3d/RoboTwin convention)
rot6d — 6-dim, first two columns of rotation matrix (Zhou et al., CVPR 2019)
rvec — 3-dim, axis-angle (Rodrigues vector), scipy `Rotation.as_rotvec()` convention.
rvec = angle * axis, angle in radians, canonical angle ∈ [0, π].
"""
import torch
ROT_DIM = {"rpy": 3, "xyzw": 4, "wxyz": 4, "rot6d": 6, "rvec": 3}
# ---------------------------------------------------------------------------
# Low-level rotation utilities
# ---------------------------------------------------------------------------
def quat_xyzw_to_mat(q: torch.Tensor) -> torch.Tensor:
"""Quaternion [x, y, z, w] → rotation matrix. (..., 4) → (..., 3, 3)."""
q = q.to(torch.float64)
x, y, z, w = q.unbind(-1)
tx, ty, tz = 2 * x, 2 * y, 2 * z
twx, twy, twz = w * tx, w * ty, w * tz
txx, txy, txz = x * tx, x * ty, x * tz
tyy, tyz, tzz = y * ty, y * tz, z * tz
R = torch.stack([
1 - (tyy + tzz), txy - twz, txz + twy,
txy + twz, 1 - (txx + tzz), tyz - twx,
txz - twy, tyz + twx, 1 - (txx + tyy),
], dim=-1).reshape(q.shape[:-1] + (3, 3))
return R
def mat_to_quat_xyzw(R: torch.Tensor) -> torch.Tensor:
"""Rotation matrix → quaternion [x, y, z, w]. (..., 3, 3) → (..., 4)."""
R = R.to(torch.float64)
batch = R.shape[:-2]
R = R.reshape(-1, 3, 3)
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
q = torch.empty(R.shape[0], 4, dtype=R.dtype, device=R.device)
# Case 1: trace > 0
s = torch.sqrt(torch.clamp(trace + 1, min=1e-10)) * 2 # 4w
m1 = trace > 0
q[m1, 3] = s[m1] / 4
q[m1, 0] = (R[m1, 2, 1] - R[m1, 1, 2]) / s[m1]
q[m1, 1] = (R[m1, 0, 2] - R[m1, 2, 0]) / s[m1]
q[m1, 2] = (R[m1, 1, 0] - R[m1, 0, 1]) / s[m1]
# Case 2: R[0,0] largest diagonal
m2 = (~m1) & (R[:, 0, 0] > R[:, 1, 1]) & (R[:, 0, 0] > R[:, 2, 2])
s2 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=1e-10)) * 2
q[m2, 0] = s2[m2] / 4
q[m2, 1] = (R[m2, 0, 1] + R[m2, 1, 0]) / s2[m2]
q[m2, 2] = (R[m2, 0, 2] + R[m2, 2, 0]) / s2[m2]
q[m2, 3] = (R[m2, 2, 1] - R[m2, 1, 2]) / s2[m2]
# Case 3: R[1,1] largest
m3 = (~m1) & (~m2) & (R[:, 1, 1] > R[:, 2, 2])
s3 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=1e-10)) * 2
q[m3, 1] = s3[m3] / 4
q[m3, 0] = (R[m3, 0, 1] + R[m3, 1, 0]) / s3[m3]
q[m3, 2] = (R[m3, 1, 2] + R[m3, 2, 1]) / s3[m3]
q[m3, 3] = (R[m3, 0, 2] - R[m3, 2, 0]) / s3[m3]
# Case 4: R[2,2] largest
m4 = (~m1) & (~m2) & (~m3)
s4 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=1e-10)) * 2
q[m4, 2] = s4[m4] / 4
q[m4, 0] = (R[m4, 0, 2] + R[m4, 2, 0]) / s4[m4]
q[m4, 1] = (R[m4, 1, 2] + R[m4, 2, 1]) / s4[m4]
q[m4, 3] = (R[m4, 1, 0] - R[m4, 0, 1]) / s4[m4]
# Normalize
q = q / q.norm(dim=-1, keepdim=True).clamp(min=1e-10)
return q.reshape(batch + (4,))
def rpy_to_mat(rpy: torch.Tensor) -> torch.Tensor:
"""RPY (extrinsic xyz euler angles) → rotation matrix. (..., 3) → (..., 3, 3).
R = Rz(yaw) @ Ry(pitch) @ Rx(roll)
"""
rpy = rpy.to(torch.float64)
roll, pitch, yaw = rpy.unbind(-1)
cr, sr = torch.cos(roll), torch.sin(roll)
cp, sp = torch.cos(pitch), torch.sin(pitch)
cy, sy = torch.cos(yaw), torch.sin(yaw)
R = torch.stack([
cy * cp, cy * sp * sr - sy * cr, cy * sp * cr + sy * sr,
sy * cp, sy * sp * sr + cy * cr, sy * sp * cr - cy * sr,
-sp, cp * sr, cp * cr,
], dim=-1).reshape(rpy.shape[:-1] + (3, 3))
return R
def mat_to_rpy(R: torch.Tensor) -> torch.Tensor:
"""Rotation matrix → RPY (extrinsic xyz euler angles). (..., 3, 3) → (..., 3)."""
R = R.to(torch.float64)
pitch = torch.atan2(-R[..., 2, 0], torch.sqrt(R[..., 0, 0] ** 2 + R[..., 1, 0] ** 2))
yaw = torch.atan2(R[..., 1, 0], R[..., 0, 0])
roll = torch.atan2(R[..., 2, 1], R[..., 2, 2])
return torch.stack([roll, pitch, yaw], dim=-1)
# ---------------------------------------------------------------------------
# Axis-angle (rvec / Rodrigues vector), scipy convention
# ---------------------------------------------------------------------------
def rvec_to_mat(rvec: torch.Tensor) -> torch.Tensor:
"""Axis-angle vector → rotation matrix. (..., 3) → (..., 3, 3).
rvec = angle * axis. Implemented as rvec → quat_xyzw → mat for numerical
consistency with the rest of the module.
"""
rvec = rvec.to(torch.float64)
angle = rvec.norm(dim=-1, keepdim=True) # (..., 1)
half = 0.5 * angle
# coeff = sin(angle/2) / angle, with stable small-angle Taylor:
# sin(a/2)/a = 1/2 - a^2/48 + O(a^4)
small = angle < 1e-6
coeff = torch.where(small, 0.5 - angle * angle / 48.0, torch.sin(half) / angle.clamp(min=1e-30))
q_xyz = coeff * rvec
q_w = torch.cos(half)
q = torch.cat([q_xyz, q_w], dim=-1)
return quat_xyzw_to_mat(q)
def mat_to_rvec(R: torch.Tensor) -> torch.Tensor:
"""Rotation matrix → axis-angle vector. (..., 3, 3) → (..., 3).
Canonical form: angle ∈ [0, π] (scipy `as_rotvec()` convention). Computed
via mat → quat_xyzw → rvec, hemisphere-canonicalized so that q_w ≥ 0.
"""
q = mat_to_quat_xyzw(R) # (..., 4), already float64 inside
# Canonicalize hemisphere so q_w >= 0 → resulting angle ∈ [0, π].
q = torch.where(q[..., 3:4] < 0, -q, q)
xyz = q[..., :3]
w = q[..., 3]
xyz_norm = xyz.norm(dim=-1) # (...,)
angle = 2.0 * torch.atan2(xyz_norm, w) # ∈ [0, π]
# rvec = angle * axis = (angle / xyz_norm) * xyz.
# Small-angle stable: as ||xyz||→0 with w→1, atan2(s, w) ≈ s/w,
# so angle/s → 2/w. Branch avoids 0/0.
small = xyz_norm < 1e-8
factor = torch.where(small, 2.0 / w.clamp(min=1e-30), angle / xyz_norm.clamp(min=1e-30))
return factor.unsqueeze(-1) * xyz
# ---------------------------------------------------------------------------
# 6D rotation representation (Zhou et al., CVPR 2019)
# ---------------------------------------------------------------------------
def rotmat_to_rot6d(R: torch.Tensor) -> torch.Tensor:
"""First two columns of rotation matrix → 6D. (..., 3, 3) → (..., 6)."""
return torch.cat([R[..., :, 0], R[..., :, 1]], dim=-1)
def rot6d_to_rotmat(r6: torch.Tensor) -> torch.Tensor:
"""6D → rotation matrix via Gram-Schmidt. (..., 6) → (..., 3, 3)."""
r6 = r6.to(torch.float64)
a1 = r6[..., :3]
a2 = r6[..., 3:]
b1 = a1 / a1.norm(dim=-1, keepdim=True).clamp(min=1e-8)
dot = (b1 * a2).sum(dim=-1, keepdim=True)
b2 = a2 - dot * b1
b2 = b2 / b2.norm(dim=-1, keepdim=True).clamp(min=1e-8)
b3 = torch.linalg.cross(b1, b2)
return torch.stack([b1, b2, b3], dim=-1)
# ---------------------------------------------------------------------------
# Convenience conversions
# ---------------------------------------------------------------------------
def quat_to_rot6d(q: torch.Tensor) -> torch.Tensor:
"""Quaternion [x,y,z,w] → 6D rotation."""
return rotmat_to_rot6d(quat_xyzw_to_mat(q))
def rot6d_to_quat(r6: torch.Tensor) -> torch.Tensor:
"""6D rotation → quaternion [x,y,z,w]."""
return mat_to_quat_xyzw(rot6d_to_rotmat(r6))
# ---------------------------------------------------------------------------
# Generic rotation format dispatch
# ---------------------------------------------------------------------------
def rot_to_mat(rot: torch.Tensor, fmt: str) -> torch.Tensor:
"""Convert rotation in any format to rotation matrix. (..., D) → (..., 3, 3)."""
if fmt == "rot6d":
return rot6d_to_rotmat(rot)
if fmt == "xyzw":
return quat_xyzw_to_mat(rot)
if fmt == "wxyz":
return quat_xyzw_to_mat(rot[..., [1, 2, 3, 0]])
if fmt == "rpy":
return rpy_to_mat(rot)
if fmt == "rvec":
return rvec_to_mat(rot)
raise ValueError(f"Unknown rotation format: {fmt!r}")
def mat_to_rot(R: torch.Tensor, fmt: str) -> torch.Tensor:
"""Convert rotation matrix to any format. (..., 3, 3) → (..., D)."""
if fmt == "rot6d":
return rotmat_to_rot6d(R)
if fmt == "xyzw":
return mat_to_quat_xyzw(R)
if fmt == "wxyz":
return mat_to_quat_xyzw(R)[..., [3, 0, 1, 2]]
if fmt == "rpy":
return mat_to_rpy(R)
if fmt == "rvec":
return mat_to_rvec(R)
raise ValueError(f"Unknown rotation format: {fmt!r}")
def convert_rot(rot: torch.Tensor, from_fmt: str, to_fmt: str) -> torch.Tensor:
"""Convert rotation between any two formats. (..., D_in) → (..., D_out)."""
if from_fmt == to_fmt:
return rot
return mat_to_rot(rot_to_mat(rot, from_fmt), to_fmt)
# ---------------------------------------------------------------------------
# Endpose format conversion: 16-dim (quat) → 20-dim (rot6d)
# ---------------------------------------------------------------------------
def eef16_to_eef20(states: torch.Tensor) -> torch.Tensor:
"""Convert 16-dim endpose [xyz(3), xyzw(4), grip(1)] * 2 to 20-dim [xyz(3), rot6d(6), grip(1)] * 2."""
batch = states.shape[:-1]
s = states.reshape(-1, 16)
out = torch.empty(s.shape[0], 20, dtype=states.dtype, device=states.device)
for arm_in, arm_out in [(0, 0), (8, 10)]:
out[:, arm_out:arm_out + 3] = s[:, arm_in:arm_in + 3]
out[:, arm_out + 3:arm_out + 9] = quat_to_rot6d(s[:, arm_in + 3:arm_in + 7]).to(states.dtype)
out[:, arm_out + 9] = s[:, arm_in + 7]
return out.reshape(batch + (20,))
# ---------------------------------------------------------------------------
# High-level bimanual EEF utilities (20-dim rot6d)
# World-frame relative: p_rel = p_target - p_init, R_rel = R_init^T @ R_target
# ---------------------------------------------------------------------------
_ARM_DIM = 10
def eef_relative(actions: torch.Tensor, init_state: torch.Tensor, num_arms: int = 2) -> torch.Tensor:
"""Compute world-frame relative EEF actions.
For each arm:
p_rel = p_target - p_init
R_rel = R_init^T @ R_target
gripper stays absolute
Args:
actions: (N, num_arms*10) or (num_arms*10,) target absolute EEF states
init_state: (num_arms*10,) initial absolute EEF state (condition frame)
num_arms: 1 (single-arm, 10-dim) or 2 (bimanual, 20-dim)
Returns:
same shape as actions, relative EEF actions
"""
scalar = actions.ndim == 1
if scalar:
actions = actions.unsqueeze(0)
result = torch.empty_like(actions)
for i in range(num_arms):
arm_off = i * _ARM_DIM
p_init = init_state[arm_off:arm_off + 3]
R_init = rot6d_to_rotmat(init_state[arm_off + 3:arm_off + 9])
p_target = actions[:, arm_off:arm_off + 3]
r6_target = actions[:, arm_off + 3:arm_off + 9]
result[:, arm_off:arm_off + 3] = p_target - p_init
R_target = rot6d_to_rotmat(r6_target)
R_rel = torch.einsum("ki,nkj->nij", R_init, R_target)
result[:, arm_off + 3:arm_off + 9] = rotmat_to_rot6d(R_rel).to(actions.dtype)
result[:, arm_off + 9] = actions[:, arm_off + 9]
return result.squeeze(0) if scalar else result
def eef_apply_relative(current_state: torch.Tensor, relative_actions: torch.Tensor, num_arms: int = 2) -> torch.Tensor:
"""Convert world-frame relative EEF actions back to absolute (inverse of eef_relative).
Args:
current_state: (num_arms*10,) current absolute EEF state in rot6d
relative_actions: (N, num_arms*10) or (num_arms*10,) relative EEF actions
num_arms: 1 (single-arm, 10-dim) or 2 (bimanual, 20-dim)
Returns:
same shape as relative_actions, absolute EEF actions
"""
scalar = relative_actions.ndim == 1
if scalar:
relative_actions = relative_actions.unsqueeze(0)
result = torch.empty_like(relative_actions)
for i in range(num_arms):
arm_off = i * _ARM_DIM
p_cur = current_state[arm_off:arm_off + 3]
R_cur = rot6d_to_rotmat(current_state[arm_off + 3:arm_off + 9])
p_rel = relative_actions[:, arm_off:arm_off + 3]
r6_rel = relative_actions[:, arm_off + 3:arm_off + 9]
result[:, arm_off:arm_off + 3] = p_rel + p_cur
R_rel = rot6d_to_rotmat(r6_rel)
R_abs = torch.einsum("ik,nkj->nij", R_cur, R_rel)
result[:, arm_off + 3:arm_off + 9] = rotmat_to_rot6d(R_abs).to(relative_actions.dtype)
result[:, arm_off + 9] = relative_actions[:, arm_off + 9]
return result.squeeze(0) if scalar else result