File size: 8,990 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""
RPY continuous tool: shared by wrapper and public scripts.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import torch


def normalize_quat_wxyz_torch(quat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """
    Normalize wxyz quaternion.

    Fallback to unit quaternion [1, 0, 0, 0] for invalid input (zero norm/NaN/Inf).
    """
    quat = torch.as_tensor(quat)
    quat_norm = torch.linalg.norm(quat, dim=-1, keepdim=True)
    finite_quat = torch.all(torch.isfinite(quat), dim=-1, keepdim=True)
    finite_norm = torch.isfinite(quat_norm)
    valid = finite_quat & finite_norm & (quat_norm > eps)

    safe_norm = torch.where(valid, quat_norm, torch.ones_like(quat_norm))
    normalized = quat / safe_norm
    fallback = torch.zeros_like(normalized)
    fallback[..., 0] = 1.0
    return torch.where(valid.expand_as(normalized), normalized, fallback)


def align_quat_sign_with_prev_torch(quat: torch.Tensor, prev_quat: torch.Tensor | None) -> torch.Tensor:
    """
    Align sign with previous frame's quaternion representation.

    If dot(quat, prev_quat) < 0, flip current quaternion sign.
    """
    if prev_quat is None:
        return quat
    if prev_quat.shape != quat.shape:
        return quat

    prev = prev_quat.to(device=quat.device, dtype=quat.dtype)
    dot = torch.sum(quat * prev, dim=-1, keepdim=True)
    sign = torch.where(dot < 0, -torch.ones_like(dot), torch.ones_like(dot))
    return quat * sign


from scipy.spatial.transform import Rotation


def quat_wxyz_to_rpy_xyz_torch(quat: torch.Tensor) -> torch.Tensor:
    """
    Convert wxyz quaternion to XYZ order RPY (radians).
    Use scipy.spatial.transform.Rotation implementation.
    Note: This process blocks gradient propagation and involves CPU/GPU data transfer.
    """
    # Keep input tensor device and dtype
    device = quat.device
    dtype = quat.dtype
    
    # Convert to numpy (CPU)
    quat_np = quat.detach().cpu().numpy()
    
    # scipy needs xyzw format, input is wxyz
    # If single vector (4,) -> (1, 4) processing, squeeze at the end
    is_single = quat_np.ndim == 1
    if is_single:
        quat_np = quat_np[None, :]
        
    # wxyz -> xyzw
    # quat_np: [..., 4] -> w, x, y, z
    w = quat_np[..., 0]
    x = quat_np[..., 1]
    y = quat_np[..., 2]
    z = quat_np[..., 3]
    
    # Re-stack as xyzw
    quat_xyzw = np.stack([x, y, z, w], axis=-1)
    
    # Create Rotation object
    try:
        rot = Rotation.from_quat(quat_xyzw)
        # Convert to euler 'xyz'
        rpy_np = rot.as_euler('xyz', degrees=False)
    except ValueError as e:
        # Handle all-zero or invalid quaternion, fallback to 0
        # scipy is strict, errors on zero norm
        # Simple handling: catch exception and return all 0s, or ensure normalization during preprocessing
        # Here normalize_quat_wxyz_torch is already called externally,
        # but for robustness, return 0 if error occurs
        rpy_np = np.zeros((quat_np.shape[0], 3))

    if is_single:
        rpy_np = rpy_np[0]
        
    return torch.from_numpy(rpy_np).to(device=device, dtype=dtype)


def rpy_xyz_to_quat_wxyz_torch(rpy: torch.Tensor) -> torch.Tensor:
    """
    Convert XYZ order RPY (radians) to wxyz quaternion.
    Use scipy.spatial.transform.Rotation implementation.
    Inverse operation of quat_wxyz_to_rpy_xyz_torch.
    Note: This process blocks gradient propagation and involves CPU/GPU data transfer.
    """
    device = rpy.device
    dtype = rpy.dtype
    
    rpy_np = rpy.detach().cpu().numpy()
    
    is_single = rpy_np.ndim == 1
    if is_single:
        rpy_np = rpy_np[None, :]
        
    # scipy euler 'xyz' -> quat (xyzw)
    rot = Rotation.from_euler('xyz', rpy_np, degrees=False)
    quat_xyzw = rot.as_quat()
    
    # xyzw -> wxyz
    x = quat_xyzw[..., 0]
    y = quat_xyzw[..., 1]
    z = quat_xyzw[..., 2]
    w = quat_xyzw[..., 3]
    
    quat_wxyz = np.stack([w, x, y, z], axis=-1)
    
    if is_single:
        quat_wxyz = quat_wxyz[0]

    # Output normalized (scipy default normalized), convert back to tensor directly
    return torch.from_numpy(quat_wxyz).to(device=device, dtype=dtype)


def unwrap_rpy_with_prev_torch(rpy: torch.Tensor, prev_rpy: torch.Tensor | None) -> torch.Tensor:
    """
    Unwrap RPY relative to previous frame: fold difference into (-pi, pi] then accumulate.
    """
    if prev_rpy is None:
        return rpy
    if prev_rpy.shape != rpy.shape:
        return rpy

    prev = prev_rpy.to(device=rpy.device, dtype=rpy.dtype)
    pi = torch.as_tensor(np.pi, dtype=rpy.dtype, device=rpy.device)
    two_pi = torch.as_tensor(2.0 * np.pi, dtype=rpy.dtype, device=rpy.device)
    delta = rpy - prev
    delta = torch.remainder(delta + pi, two_pi) - pi
    return prev + delta


def build_endeffector_pose_dict(
    position: torch.Tensor,
    quat_wxyz: torch.Tensor,
    prev_ee_quat_wxyz: torch.Tensor | None,
    prev_ee_rpy_xyz: torch.Tensor | None,
    eps: float = 1e-12,
) -> tuple[dict, torch.Tensor, torch.Tensor]:
    """
    End-effector pose continuous pipeline.

    Pipeline:
    1) quat normalization;
    2) Align quaternion sign with previous frame;
    3) quat -> rpy principal value;
    4) Unwrap based on previous frame to get continuous RPY;
    5) Update cache (aligned quat + unwrapped rpy);
    6) Output {"pose": xyz, "quat": wxyz, "rpy": [roll, pitch, yaw]}.

    Input:
      - position: xyz position
      - quat_wxyz: current frame wxyz quaternion
      - prev_ee_quat_wxyz / prev_ee_rpy_xyz: previous frame cache (None = no cache)

    Return:
      - pose_dict: {"pose": position, "quat": aligned quat, "rpy": continuous RPY}
      - new_prev_quat: updated cache quat (detach+clone)
      - new_prev_rpy: updated cache rpy (detach+clone)
    """
    quat_normalized = normalize_quat_wxyz_torch(quat_wxyz, eps=eps)
    quat_aligned = align_quat_sign_with_prev_torch(quat_normalized, prev_ee_quat_wxyz)
    rpy_xyz = quat_wxyz_to_rpy_xyz_torch(quat_aligned)
    rpy_xyz_unwrapped = unwrap_rpy_with_prev_torch(rpy_xyz, prev_ee_rpy_xyz)

    new_prev_quat = quat_aligned.detach().clone()
    new_prev_rpy = rpy_xyz_unwrapped.detach().clone()

    pose_dict = {
        "pose": position,          # xyz position
        "quat": quat_aligned,      # wxyz quaternion (normalized + sign aligned)
        "rpy": rpy_xyz_unwrapped,  # continuous RPY (roll, pitch, yaw)
    }
    return pose_dict, new_prev_quat, new_prev_rpy


def summarize_and_print_rpy_sequence(rpy_sequence: Any, label: str = "") -> dict[str, Any]:
    """
    Summarize an RPY sequence and print report containing only count and delta.
    """
    rpy = np.asarray(rpy_sequence, dtype=np.float64)
    if rpy.size == 0:
        summary = {
            "count": 0,
            "axis_max_abs_delta_rad": [0.0, 0.0, 0.0],
            "axis_max_abs_delta_deg": [0.0, 0.0, 0.0],
            "axis_max_abs_delta_transition": [None, None, None],
        }
        prefix = f"{label} " if label else ""
        logger.debug(f"{prefix}RPY summary: no RPY samples.")
        return summary

    if rpy.ndim == 1:
        if rpy.shape[0] == 3:
            rpy = rpy.reshape(1, 3)
        elif rpy.shape[0] % 3 == 0:
            rpy = rpy.reshape(-1, 3)
        else:
            raise ValueError(f"Cannot reshape 1D rpy_sequence of shape {rpy.shape} to (*, 3)")
    elif rpy.shape[-1] == 3:
        rpy = rpy.reshape(-1, 3)
    else:
        raise ValueError(f"rpy_sequence last dimension must be 3, got shape {rpy.shape}")

    count = int(rpy.shape[0])

    if count < 2:
        axis_max_abs_delta_rad = np.zeros(3, dtype=np.float64)
        axis_max_abs_delta_deg = np.zeros(3, dtype=np.float64)
        axis_max_abs_delta_transition = [None, None, None]
    else:
        diff = np.diff(rpy, axis=0)
        abs_diff = np.abs(diff)
        axis_max_abs_delta_rad = np.max(abs_diff, axis=0)
        axis_max_abs_delta_deg = np.rad2deg(axis_max_abs_delta_rad)

        peak_indices = np.argmax(abs_diff, axis=0)
        axis_max_abs_delta_transition = [[int(i), int(i) + 1] for i in peak_indices]

    summary = {
        "count": count,
        "axis_max_abs_delta_rad": axis_max_abs_delta_rad.tolist(),
        "axis_max_abs_delta_deg": axis_max_abs_delta_deg.tolist(),
        "axis_max_abs_delta_transition": axis_max_abs_delta_transition,
    }

    prefix = f"{label} " if label else ""
    logger.debug(f"{prefix}RPY summary (rad):")
    logger.debug(f"  count={count}")
    logger.debug(
        "  axis_max_abs_delta_rad (roll,pitch,yaw)="
        f"[{axis_max_abs_delta_rad[0]:.6f}, {axis_max_abs_delta_rad[1]:.6f}, {axis_max_abs_delta_rad[2]:.6f}]"
    )
    logger.debug(f"  transitions={axis_max_abs_delta_transition}")
    logger.debug(f"{prefix}RPY summary (deg):")
    logger.debug(
        "  axis_max_abs_delta_deg (roll,pitch,yaw)="
        f"[{axis_max_abs_delta_deg[0]:.6f}, {axis_max_abs_delta_deg[1]:.6f}, {axis_max_abs_delta_deg[2]:.6f}]"
    )

    return summary