| from __future__ import annotations |
|
|
| from collections.abc import Iterable |
| from dataclasses import dataclass |
| from fractions import Fraction |
|
|
| import numpy as np |
|
|
| MOTION_CLASS_NAMES = { |
| 0: "still", |
| 1: "translate", |
| 2: "rotate", |
| 3: "translate_rotate", |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class SpeedTransformConfig: |
| """Speed-augmentation knobs for LIBERO 7D delta actions [dx,dy,dz,droll,dpitch,dyaw,gripper]. |
| |
| Gripper is never used for segmentation; it's copied discretely. |
| |
| Two modes: |
| - chunk_aligned_observation=False (legacy): uniform linspace over the segment. |
| - chunk_aligned_observation=True: speed q/p splits each segment into chunks |
| of q source frames -> p output frames; the leftover < q frames pass |
| through 1:1 verbatim. Only chunk-starts (and phase=0 passthrough rows) |
| are mask=1 in the output. ``chunk_phase`` r in [0, q) shifts chunk |
| starts to r, r+q, r+2q, ... within each segment (used by the online |
| sliding sampler to randomize which source frames become chunk-starts). |
| """ |
|
|
| transl_eps: float = 1e-4 |
| rot_eps: float = 1e-4 |
| |
| |
| clean_transl_eps: float = 0.0 |
| clean_rot_eps: float = 0.0 |
| direction_cos_threshold: float = -0.25 |
| min_segment_len: int = 1 |
| keep_still_segments: bool = True |
| fps: int = 20 |
| chunk_aligned_observation: bool = False |
| chunk_phase: int = 0 |
|
|
|
|
| def _as_float32_2d(values: Iterable[np.ndarray], expected_dim: int) -> np.ndarray: |
| arr = np.asarray(list(values), dtype=np.float32) |
| if arr.ndim != 2 or arr.shape[1] != expected_dim: |
| raise ValueError(f"Expected shape (T, {expected_dim}), got {arr.shape}") |
| return arr |
|
|
|
|
| def clean_near_zero_actions( |
| actions: np.ndarray, |
| transl_eps: float = 1e-4, |
| rot_eps: float = 1e-4, |
| ) -> tuple[np.ndarray, np.ndarray]: |
| """Zero tiny translation/rotation noise; gripper untouched. |
| |
| Returns (cleaned_actions, zeroed_mask of shape (T, 2)) where columns are |
| [translation_zeroed, rotation_zeroed]. |
| """ |
|
|
| cleaned = np.asarray(actions, dtype=np.float32).copy() |
| if cleaned.ndim != 2 or cleaned.shape[1] < 7: |
| raise ValueError(f"Expected 7D actions, got {cleaned.shape}") |
|
|
| transl_norm = np.linalg.norm(cleaned[:, :3], axis=1) |
| rot_norm = np.linalg.norm(cleaned[:, 3:6], axis=1) |
| transl_zeroed = transl_norm < transl_eps |
| rot_zeroed = rot_norm < rot_eps |
| cleaned[transl_zeroed, :3] = 0.0 |
| cleaned[rot_zeroed, 3:6] = 0.0 |
| return cleaned, np.stack([transl_zeroed, rot_zeroed], axis=1) |
|
|
|
|
| def _motion_class(action: np.ndarray, transl_eps: float, rot_eps: float) -> int: |
| has_translation = float(np.linalg.norm(action[:3])) >= transl_eps |
| has_rotation = float(np.linalg.norm(action[3:6])) >= rot_eps |
| if has_translation and has_rotation: |
| return 3 |
| if has_translation: |
| return 1 |
| if has_rotation: |
| return 2 |
| return 0 |
|
|
|
|
| def _cosine(a: np.ndarray, b: np.ndarray) -> float: |
| denom = float(np.linalg.norm(a) * np.linalg.norm(b)) |
| if denom <= 1e-12: |
| return 1.0 |
| return float(np.dot(a, b) / denom) |
|
|
|
|
| def segment_actions( |
| actions: np.ndarray, |
| config: SpeedTransformConfig, |
| ) -> list[tuple[int, int, int]]: |
| """Cut the episode into homogeneous-motion segments. |
| |
| A boundary is placed when (a) motion_class changes (still/translate/rotate/ |
| both) or (b) translation or rotation direction reverses (cosine below |
| ``direction_cos_threshold``). Gripper open/close is NEVER a cut criterion -- |
| it's preserved later as an interval anchor in ``_segment_boundaries``. |
| |
| Returns a list of (start, end_exclusive, motion_class). |
| """ |
|
|
| if len(actions) == 0: |
| return [] |
|
|
| classes = np.asarray( |
| [_motion_class(a, config.transl_eps, config.rot_eps) for a in actions], |
| dtype=np.int32, |
| ) |
|
|
| segments: list[tuple[int, int, int]] = [] |
| start = 0 |
| for i in range(1, len(actions)): |
| boundary = classes[i] != classes[i - 1] |
| if not boundary and classes[i] in (1, 3): |
| boundary = _cosine(actions[i - 1, :3], actions[i, :3]) < config.direction_cos_threshold |
| if not boundary and classes[i] in (2, 3): |
| boundary = _cosine(actions[i - 1, 3:6], actions[i, 3:6]) < config.direction_cos_threshold |
| if boundary: |
| segments.append((start, i, int(classes[i - 1]))) |
| start = i |
| segments.append((start, len(actions), int(classes[-1]))) |
|
|
| if config.min_segment_len <= 1: |
| return segments |
|
|
| merged: list[tuple[int, int, int]] = [] |
| for seg in segments: |
| if not merged or (seg[1] - seg[0]) >= config.min_segment_len: |
| merged.append(seg) |
| continue |
| prev_start, _prev_end, prev_cls = merged[-1] |
| merged[-1] = (prev_start, seg[1], prev_cls) |
| return merged |
|
|
|
|
| def _interp_cumulative(cumulative: np.ndarray, x: float) -> np.ndarray: |
| x = float(np.clip(x, 0.0, cumulative.shape[0] - 1)) |
| left = int(np.floor(x)) |
| right = min(left + 1, cumulative.shape[0] - 1) |
| alpha = np.float32(x - left) |
| return (1.0 - alpha) * cumulative[left] + alpha * cumulative[right] |
|
|
|
|
| _CHUNK_DENOM_LIMIT = 32 |
|
|
|
|
| def _speed_chunk_ratio(speed: float) -> tuple[int, int]: |
| """Decompose speed as q/p with small denominator: q source frames -> p output frames. |
| |
| All ablation speeds (0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 4.0) fit |
| within ``_CHUNK_DENOM_LIMIT=32``. Raises for irrational speeds. |
| """ |
| f = Fraction(speed).limit_denominator(_CHUNK_DENOM_LIMIT) |
| if abs(float(f) - speed) > 1e-6: |
| raise ValueError( |
| f"speed={speed} cannot be expressed as a rational with denominator <= " |
| f"{_CHUNK_DENOM_LIMIT}; chunk_aligned_observation requires a rational speed." |
| ) |
| return f.numerator, f.denominator |
|
|
|
|
| def _segment_boundaries( |
| src_actions: np.ndarray, |
| speed: float, |
| *, |
| chunk_aligned: bool = False, |
| chunk_phase: int = 0, |
| ) -> np.ndarray: |
| """Build resampling boundaries in source-frame coordinates. |
| |
| chunk_aligned=True layout (speed=q/p, phase=r): |
| |
| [0..r] leading passthrough (1:1, only when r > 0) |
| [r, r+q] chunk 0 -> p sub-bins via linspace |
| [r+q, r+2q] chunk 1 |
| ... |
| [r+Nq, n_src] trailing passthrough (1:1) |
| |
| Each full chunk integrates exactly q source frames into p outputs. The |
| leftover frames at both ends are emitted verbatim so total motion is |
| preserved end-to-end. |
| |
| chunk_aligned=False: uniform linspace(0, n_src, n_out+1) over the segment. |
| |
| Gripper open/close indices are added as anchor boundaries afterward so a |
| switch event never gets averaged inside a coarse bin (segmentation never |
| looks at the gripper channel). |
| """ |
| n_src = len(src_actions) |
|
|
| if chunk_aligned: |
| q, p = _speed_chunk_ratio(speed) |
| if chunk_phase < 0 or chunk_phase >= q: |
| raise ValueError(f"chunk_phase must be in [0, {q}), got {chunk_phase}") |
| boundaries: list[float] = [] |
|
|
| |
| if chunk_phase > 0: |
| boundaries.extend(float(i) for i in range(0, min(chunk_phase, n_src) + 1)) |
|
|
| |
| chunk_starts = list(range(chunk_phase, max(n_src - q + 1, chunk_phase), q)) |
| for start in chunk_starts: |
| boundaries.extend(np.linspace(float(start), float(start + q), p + 1).tolist()) |
|
|
| |
| tail_start = chunk_starts[-1] + q if chunk_starts else min(chunk_phase, n_src) |
| if tail_start < n_src: |
| boundaries.extend(float(i) for i in range(tail_start, n_src + 1)) |
| else: |
| n_out = max(1, int(round(n_src / speed))) |
| boundaries = [float(x) for x in np.linspace(0.0, float(n_src), n_out + 1)] |
|
|
| switch_indices = np.flatnonzero(np.abs(np.diff(src_actions[:, 6])) > 0.5) + 1 |
| boundaries.extend(float(i) for i in switch_indices) |
|
|
| deduped = sorted(set(boundaries)) |
| return np.asarray(deduped, dtype=np.float32) |
|
|
|
|
| def _resample_segment( |
| actions: np.ndarray, |
| states: np.ndarray, |
| source_frame_indices: np.ndarray, |
| start: int, |
| end: int, |
| motion_class: int, |
| speed: float, |
| chunk_aligned_observation: bool = False, |
| chunk_phase: int = 0, |
| ) -> dict[str, np.ndarray]: |
| """Atomic chunk resampling: cumulative + linear interpolation. |
| |
| Key invariant: per output bin [left_t, right_t], |
| action[:6] = cumulative(right_t) - cumulative(left_t) |
| where cumulative is piecewise-linear over source indices. So summing |
| output actions over a full chunk reproduces the source's integrated |
| motion EXACTLY (no drift from non-integer speeds). |
| |
| state at output j is the source state at floor(left_t); at chunk-start |
| boundaries (integer left_t) this is exact. gripper at output j is the |
| source gripper at ceil(right_t)-1 -- gripper-switch anchors in |
| ``_segment_boundaries`` ensure a switch is never averaged inside a bin. |
| """ |
| src_actions = actions[start:end] |
| src_states = states[start:end] |
| src_frames = source_frame_indices[start:end] |
| n_src = end - start |
| if n_src <= 0: |
| raise ValueError("Cannot resample an empty segment") |
|
|
| boundaries = _segment_boundaries( |
| src_actions, |
| speed, |
| chunk_aligned=chunk_aligned_observation, |
| chunk_phase=chunk_phase, |
| ) |
| n_out = len(boundaries) - 1 |
| cumulative = np.concatenate( |
| [np.zeros((1, 6), dtype=np.float32), np.cumsum(src_actions[:, :6], axis=0)], |
| axis=0, |
| ) |
|
|
| out_actions = np.zeros((n_out, 7), dtype=np.float32) |
| out_states = np.zeros((n_out, states.shape[1]), dtype=np.float32) |
| out_source_frames = np.zeros(n_out, dtype=np.int64) |
| out_source_steps = np.zeros(n_out, dtype=np.int64) |
| observation_mask = np.ones(n_out, dtype=np.int8) |
|
|
| used_sources: set[int] = set() |
| for j in range(n_out): |
| left_t = float(boundaries[j]) |
| right_t = float(boundaries[j + 1]) |
| out_actions[j, :6] = _interp_cumulative(cumulative, right_t) - _interp_cumulative(cumulative, left_t) |
|
|
| source_local = min(int(np.floor(left_t)), n_src - 1) |
| grip_local = min(max(int(np.ceil(right_t) - 1), 0), n_src - 1) |
| out_actions[j, 6] = src_actions[grip_local, 6] |
| out_states[j] = src_states[source_local] |
| out_source_frames[j] = int(src_frames[source_local]) |
| out_source_steps[j] = start + source_local |
|
|
| |
| |
| if speed < 1.0 and source_local in used_sources: |
| observation_mask[j] = 0 |
| used_sources.add(source_local) |
|
|
| if chunk_aligned_observation: |
| |
| |
| |
| |
| |
| |
| q, _p = _speed_chunk_ratio(speed) |
| if chunk_phase < 0 or chunk_phase >= q: |
| raise ValueError(f"chunk_phase must be in [0, {q}), got {chunk_phase}") |
| new_mask = np.zeros(n_out, dtype=np.int8) |
| n_full = n_src // q |
| tail_threshold = n_full * q |
| for j in range(n_out): |
| left_t = float(boundaries[j]) |
| if chunk_phase == 0: |
| if left_t >= tail_threshold and n_full > 0: |
| new_mask[j] = 1 |
| else: |
| chunk_idx_f = left_t / q |
| chunk_idx = round(chunk_idx_f) |
| if abs(chunk_idx_f - chunk_idx) < 1e-6 and chunk_idx * q < n_src: |
| new_mask[j] = 1 |
| else: |
| chunk_idx_f = (left_t - chunk_phase) / q |
| chunk_idx = round(chunk_idx_f) |
| chunk_start = chunk_phase + chunk_idx * q |
| if abs(chunk_idx_f - chunk_idx) < 1e-6 and chunk_start + q <= n_src: |
| new_mask[j] = 1 |
| observation_mask = new_mask |
|
|
| return { |
| "action": out_actions, |
| "state": out_states, |
| "source_frame_index": out_source_frames, |
| "source_step_index": out_source_steps, |
| "observation_mask": observation_mask, |
| "segment_id": np.full(n_out, -1, dtype=np.int32), |
| "motion_class": np.full(n_out, motion_class, dtype=np.int8), |
| } |
|
|
|
|
| def transform_episode( |
| actions: np.ndarray, |
| states: np.ndarray, |
| source_frame_indices: np.ndarray, |
| speed: float, |
| config: SpeedTransformConfig, |
| ) -> tuple[dict[str, np.ndarray], dict[str, float | int]]: |
| """End-to-end episode resampling: clean -> segment -> resample each segment. |
| |
| Outputs are concatenated across segments into one continuous trajectory. |
| The same ``chunk_phase`` is applied to every segment (one phase per call). |
| """ |
|
|
| if speed <= 0: |
| raise ValueError(f"Speed must be positive, got {speed}") |
| if config.chunk_phase != 0 and not config.chunk_aligned_observation: |
| raise ValueError("chunk_phase is only supported when chunk_aligned_observation=True") |
| actions = _as_float32_2d(actions, 7) |
| states = np.asarray(states, dtype=np.float32) |
| if states.ndim != 2: |
| raise ValueError(f"Expected 2D states, got {states.shape}") |
| source_frame_indices = np.asarray(source_frame_indices, dtype=np.int64) |
| if len(actions) != len(states) or len(actions) != len(source_frame_indices): |
| raise ValueError("actions, states, and source_frame_indices must have equal length") |
|
|
| cleaned_actions, clean_mask = clean_near_zero_actions(actions, config.clean_transl_eps, config.clean_rot_eps) |
| segments = segment_actions(cleaned_actions, config) |
| if not config.keep_still_segments: |
| segments = [s for s in segments if s[2] != 0] |
|
|
| pieces = [] |
| for segment_id, (start, end, motion_class) in enumerate(segments): |
| piece = _resample_segment( |
| cleaned_actions, |
| states, |
| source_frame_indices, |
| start, |
| end, |
| motion_class, |
| speed, |
| chunk_aligned_observation=config.chunk_aligned_observation, |
| chunk_phase=config.chunk_phase, |
| ) |
| piece["segment_id"][:] = segment_id |
| pieces.append(piece) |
|
|
| if not pieces: |
| raise ValueError("Episode produced no segments") |
|
|
| merged = {key: np.concatenate([piece[key] for piece in pieces], axis=0) for key in pieces[0]} |
| source_steps = merged["source_step_index"] |
| merged["cleaned_translation"] = clean_mask[source_steps, 0].astype(np.int8) |
| merged["cleaned_rotation"] = clean_mask[source_steps, 1].astype(np.int8) |
| merged["speed"] = np.full(len(merged["action"]), float(speed), dtype=np.float32) |
| merged["action_mask"] = np.ones(len(merged["action"]), dtype=np.int8) |
| merged["is_padded"] = (1 - merged["observation_mask"]).astype(np.int8) |
|
|
| metrics = compute_replay_metrics(actions, merged["action"], speed) |
| transl_zeroed = clean_mask[:, 0] |
| rot_zeroed = clean_mask[:, 1] |
| any_zeroed = transl_zeroed | rot_zeroed |
| both_zeroed = transl_zeroed & rot_zeroed |
| n_source = int(len(actions)) |
| denom = max(n_source, 1) |
|
|
| |
| |
| seg_lens = np.asarray([end - start for (start, end, _) in segments], dtype=np.int64) |
| seg_classes = np.asarray([cls for (_, _, cls) in segments], dtype=np.int64) |
|
|
| metrics.update( |
| { |
| "segment_count": len(segments), |
| "segments": len(segments), |
| "source_frames": n_source, |
| "output_frames": int(len(merged["action"])), |
| "padded_frames": int(np.sum(merged["is_padded"])), |
| "padded_ratio": float(np.mean(merged["is_padded"])), |
| "cleaned_translation_frames": int(np.sum(transl_zeroed)), |
| "cleaned_rotation_frames": int(np.sum(rot_zeroed)), |
| "cleaned_any_frames": int(np.sum(any_zeroed)), |
| "cleaned_both_frames": int(np.sum(both_zeroed)), |
| "cleaned_translation_ratio": float(np.sum(transl_zeroed) / denom), |
| "cleaned_rotation_ratio": float(np.sum(rot_zeroed) / denom), |
| "cleaned_any_ratio": float(np.sum(any_zeroed) / denom), |
| "cleaned_both_ratio": float(np.sum(both_zeroed) / denom), |
| "segment_len_min": int(seg_lens.min()) if seg_lens.size else 0, |
| "segment_len_max": int(seg_lens.max()) if seg_lens.size else 0, |
| "segment_len_mean": float(seg_lens.mean()) if seg_lens.size else 0.0, |
| "segment_len_median": float(np.median(seg_lens)) if seg_lens.size else 0.0, |
| "segment_len_p10": float(np.percentile(seg_lens, 10)) if seg_lens.size else 0.0, |
| "segment_len_p90": float(np.percentile(seg_lens, 90)) if seg_lens.size else 0.0, |
| "motion_class_still_count": int(np.sum(seg_classes == 0)), |
| "motion_class_translate_count": int(np.sum(seg_classes == 1)), |
| "motion_class_rotate_count": int(np.sum(seg_classes == 2)), |
| "motion_class_translate_rotate_count": int(np.sum(seg_classes == 3)), |
| } |
| ) |
| return merged, metrics |
|
|
|
|
| def compute_replay_metrics( |
| source_actions: np.ndarray, |
| replay_actions: np.ndarray, |
| target_speed: float | None = None, |
| ) -> dict[str, float | int]: |
| """Sanity-check that resampling preserved integrated 6D motion. |
| |
| Reports per-axis path lengths, integrated translation/rotation L2 error, |
| and gripper-switch count delta. NOT a task-success metric. |
| """ |
|
|
| source_actions = _as_float32_2d(source_actions, 7) |
| replay_actions = _as_float32_2d(replay_actions, 7) |
| source_motion = source_actions[:, :6].sum(axis=0) |
| replay_motion = replay_actions[:, :6].sum(axis=0) |
| source_steps = len(source_actions) |
| replay_steps = len(replay_actions) |
| actual_speed = source_steps / max(replay_steps, 1) |
| translation_path_source = float(np.linalg.norm(source_actions[:, :3], axis=1).sum()) |
| translation_path_replay = float(np.linalg.norm(replay_actions[:, :3], axis=1).sum()) |
| rotation_path_source = float(np.linalg.norm(source_actions[:, 3:6], axis=1).sum()) |
| rotation_path_replay = float(np.linalg.norm(replay_actions[:, 3:6], axis=1).sum()) |
| gripper_switches_source = int(np.sum(np.abs(np.diff(source_actions[:, 6])) > 0.5)) |
| gripper_switches_replay = int(np.sum(np.abs(np.diff(replay_actions[:, 6])) > 0.5)) |
|
|
| out: dict[str, float | int] = { |
| "target_speed": float(target_speed) if target_speed is not None else 1.0, |
| "source_steps": int(source_steps), |
| "replay_steps": int(replay_steps), |
| "actual_speed": float(actual_speed), |
| "speed_error": float(abs(actual_speed - target_speed)) if target_speed is not None else 0.0, |
| "translation_path_source": translation_path_source, |
| "translation_path_replay": translation_path_replay, |
| "translation_path_ratio": translation_path_replay / max(translation_path_source, 1e-12), |
| "rotation_path_source": rotation_path_source, |
| "rotation_path_replay": rotation_path_replay, |
| "rotation_path_ratio": rotation_path_replay / max(rotation_path_source, 1e-12), |
| "integrated_translation_l2_error": float(np.linalg.norm(source_motion[:3] - replay_motion[:3])), |
| "integrated_rotation_l2_error": float(np.linalg.norm(source_motion[3:] - replay_motion[3:])), |
| "gripper_switches_source": gripper_switches_source, |
| "gripper_switches_replay": gripper_switches_replay, |
| "gripper_switch_delta": int(gripper_switches_replay - gripper_switches_source), |
| } |
| return out |
|
|