HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
"""
RPY continuous tool: shared by wrapper and public scripts.
"""
from __future__ import annotations
from typing import Any
import numpy as np
import torch
def normalize_quat_wxyz_torch(quat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""
Normalize wxyz quaternion.
Fallback to unit quaternion [1, 0, 0, 0] for invalid input (zero norm/NaN/Inf).
"""
quat = torch.as_tensor(quat)
quat_norm = torch.linalg.norm(quat, dim=-1, keepdim=True)
finite_quat = torch.all(torch.isfinite(quat), dim=-1, keepdim=True)
finite_norm = torch.isfinite(quat_norm)
valid = finite_quat & finite_norm & (quat_norm > eps)
safe_norm = torch.where(valid, quat_norm, torch.ones_like(quat_norm))
normalized = quat / safe_norm
fallback = torch.zeros_like(normalized)
fallback[..., 0] = 1.0
return torch.where(valid.expand_as(normalized), normalized, fallback)
def align_quat_sign_with_prev_torch(quat: torch.Tensor, prev_quat: torch.Tensor | None) -> torch.Tensor:
"""
Align sign with previous frame's quaternion representation.
If dot(quat, prev_quat) < 0, flip current quaternion sign.
"""
if prev_quat is None:
return quat
if prev_quat.shape != quat.shape:
return quat
prev = prev_quat.to(device=quat.device, dtype=quat.dtype)
dot = torch.sum(quat * prev, dim=-1, keepdim=True)
sign = torch.where(dot < 0, -torch.ones_like(dot), torch.ones_like(dot))
return quat * sign
from scipy.spatial.transform import Rotation
def quat_wxyz_to_rpy_xyz_torch(quat: torch.Tensor) -> torch.Tensor:
"""
Convert wxyz quaternion to XYZ order RPY (radians).
Use scipy.spatial.transform.Rotation implementation.
Note: This process blocks gradient propagation and involves CPU/GPU data transfer.
"""
# Keep input tensor device and dtype
device = quat.device
dtype = quat.dtype
# Convert to numpy (CPU)
quat_np = quat.detach().cpu().numpy()
# scipy needs xyzw format, input is wxyz
# If single vector (4,) -> (1, 4) processing, squeeze at the end
is_single = quat_np.ndim == 1
if is_single:
quat_np = quat_np[None, :]
# wxyz -> xyzw
# quat_np: [..., 4] -> w, x, y, z
w = quat_np[..., 0]
x = quat_np[..., 1]
y = quat_np[..., 2]
z = quat_np[..., 3]
# Re-stack as xyzw
quat_xyzw = np.stack([x, y, z, w], axis=-1)
# Create Rotation object
try:
rot = Rotation.from_quat(quat_xyzw)
# Convert to euler 'xyz'
rpy_np = rot.as_euler('xyz', degrees=False)
except ValueError as e:
# Handle all-zero or invalid quaternion, fallback to 0
# scipy is strict, errors on zero norm
# Simple handling: catch exception and return all 0s, or ensure normalization during preprocessing
# Here normalize_quat_wxyz_torch is already called externally,
# but for robustness, return 0 if error occurs
rpy_np = np.zeros((quat_np.shape[0], 3))
if is_single:
rpy_np = rpy_np[0]
return torch.from_numpy(rpy_np).to(device=device, dtype=dtype)
def rpy_xyz_to_quat_wxyz_torch(rpy: torch.Tensor) -> torch.Tensor:
"""
Convert XYZ order RPY (radians) to wxyz quaternion.
Use scipy.spatial.transform.Rotation implementation.
Inverse operation of quat_wxyz_to_rpy_xyz_torch.
Note: This process blocks gradient propagation and involves CPU/GPU data transfer.
"""
device = rpy.device
dtype = rpy.dtype
rpy_np = rpy.detach().cpu().numpy()
is_single = rpy_np.ndim == 1
if is_single:
rpy_np = rpy_np[None, :]
# scipy euler 'xyz' -> quat (xyzw)
rot = Rotation.from_euler('xyz', rpy_np, degrees=False)
quat_xyzw = rot.as_quat()
# xyzw -> wxyz
x = quat_xyzw[..., 0]
y = quat_xyzw[..., 1]
z = quat_xyzw[..., 2]
w = quat_xyzw[..., 3]
quat_wxyz = np.stack([w, x, y, z], axis=-1)
if is_single:
quat_wxyz = quat_wxyz[0]
# Output normalized (scipy default normalized), convert back to tensor directly
return torch.from_numpy(quat_wxyz).to(device=device, dtype=dtype)
def unwrap_rpy_with_prev_torch(rpy: torch.Tensor, prev_rpy: torch.Tensor | None) -> torch.Tensor:
"""
Unwrap RPY relative to previous frame: fold difference into (-pi, pi] then accumulate.
"""
if prev_rpy is None:
return rpy
if prev_rpy.shape != rpy.shape:
return rpy
prev = prev_rpy.to(device=rpy.device, dtype=rpy.dtype)
pi = torch.as_tensor(np.pi, dtype=rpy.dtype, device=rpy.device)
two_pi = torch.as_tensor(2.0 * np.pi, dtype=rpy.dtype, device=rpy.device)
delta = rpy - prev
delta = torch.remainder(delta + pi, two_pi) - pi
return prev + delta
def build_endeffector_pose_dict(
position: torch.Tensor,
quat_wxyz: torch.Tensor,
prev_ee_quat_wxyz: torch.Tensor | None,
prev_ee_rpy_xyz: torch.Tensor | None,
eps: float = 1e-12,
) -> tuple[dict, torch.Tensor, torch.Tensor]:
"""
End-effector pose continuous pipeline.
Pipeline:
1) quat normalization;
2) Align quaternion sign with previous frame;
3) quat -> rpy principal value;
4) Unwrap based on previous frame to get continuous RPY;
5) Update cache (aligned quat + unwrapped rpy);
6) Output {"pose": xyz, "quat": wxyz, "rpy": [roll, pitch, yaw]}.
Input:
- position: xyz position
- quat_wxyz: current frame wxyz quaternion
- prev_ee_quat_wxyz / prev_ee_rpy_xyz: previous frame cache (None = no cache)
Return:
- pose_dict: {"pose": position, "quat": aligned quat, "rpy": continuous RPY}
- new_prev_quat: updated cache quat (detach+clone)
- new_prev_rpy: updated cache rpy (detach+clone)
"""
quat_normalized = normalize_quat_wxyz_torch(quat_wxyz, eps=eps)
quat_aligned = align_quat_sign_with_prev_torch(quat_normalized, prev_ee_quat_wxyz)
rpy_xyz = quat_wxyz_to_rpy_xyz_torch(quat_aligned)
rpy_xyz_unwrapped = unwrap_rpy_with_prev_torch(rpy_xyz, prev_ee_rpy_xyz)
new_prev_quat = quat_aligned.detach().clone()
new_prev_rpy = rpy_xyz_unwrapped.detach().clone()
pose_dict = {
"pose": position, # xyz position
"quat": quat_aligned, # wxyz quaternion (normalized + sign aligned)
"rpy": rpy_xyz_unwrapped, # continuous RPY (roll, pitch, yaw)
}
return pose_dict, new_prev_quat, new_prev_rpy
def summarize_and_print_rpy_sequence(rpy_sequence: Any, label: str = "") -> dict[str, Any]:
"""
Summarize an RPY sequence and print report containing only count and delta.
"""
rpy = np.asarray(rpy_sequence, dtype=np.float64)
if rpy.size == 0:
summary = {
"count": 0,
"axis_max_abs_delta_rad": [0.0, 0.0, 0.0],
"axis_max_abs_delta_deg": [0.0, 0.0, 0.0],
"axis_max_abs_delta_transition": [None, None, None],
}
prefix = f"{label} " if label else ""
logger.debug(f"{prefix}RPY summary: no RPY samples.")
return summary
if rpy.ndim == 1:
if rpy.shape[0] == 3:
rpy = rpy.reshape(1, 3)
elif rpy.shape[0] % 3 == 0:
rpy = rpy.reshape(-1, 3)
else:
raise ValueError(f"Cannot reshape 1D rpy_sequence of shape {rpy.shape} to (*, 3)")
elif rpy.shape[-1] == 3:
rpy = rpy.reshape(-1, 3)
else:
raise ValueError(f"rpy_sequence last dimension must be 3, got shape {rpy.shape}")
count = int(rpy.shape[0])
if count < 2:
axis_max_abs_delta_rad = np.zeros(3, dtype=np.float64)
axis_max_abs_delta_deg = np.zeros(3, dtype=np.float64)
axis_max_abs_delta_transition = [None, None, None]
else:
diff = np.diff(rpy, axis=0)
abs_diff = np.abs(diff)
axis_max_abs_delta_rad = np.max(abs_diff, axis=0)
axis_max_abs_delta_deg = np.rad2deg(axis_max_abs_delta_rad)
peak_indices = np.argmax(abs_diff, axis=0)
axis_max_abs_delta_transition = [[int(i), int(i) + 1] for i in peak_indices]
summary = {
"count": count,
"axis_max_abs_delta_rad": axis_max_abs_delta_rad.tolist(),
"axis_max_abs_delta_deg": axis_max_abs_delta_deg.tolist(),
"axis_max_abs_delta_transition": axis_max_abs_delta_transition,
}
prefix = f"{label} " if label else ""
logger.debug(f"{prefix}RPY summary (rad):")
logger.debug(f" count={count}")
logger.debug(
" axis_max_abs_delta_rad (roll,pitch,yaw)="
f"[{axis_max_abs_delta_rad[0]:.6f}, {axis_max_abs_delta_rad[1]:.6f}, {axis_max_abs_delta_rad[2]:.6f}]"
)
logger.debug(f" transitions={axis_max_abs_delta_transition}")
logger.debug(f"{prefix}RPY summary (deg):")
logger.debug(
" axis_max_abs_delta_deg (roll,pitch,yaw)="
f"[{axis_max_abs_delta_deg[0]:.6f}, {axis_max_abs_delta_deg[1]:.6f}, {axis_max_abs_delta_deg[2]:.6f}]"
)
return summary