from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass from fractions import Fraction import numpy as np MOTION_CLASS_NAMES = { 0: "still", 1: "translate", 2: "rotate", 3: "translate_rotate", } @dataclass(frozen=True) class SpeedTransformConfig: """Speed-augmentation knobs for LIBERO 7D delta actions [dx,dy,dz,droll,dpitch,dyaw,gripper]. Gripper is never used for segmentation; it's copied discretely. Two modes: - chunk_aligned_observation=False (legacy): uniform linspace over the segment. - chunk_aligned_observation=True: speed q/p splits each segment into chunks of q source frames -> p output frames; the leftover < q frames pass through 1:1 verbatim. Only chunk-starts (and phase=0 passthrough rows) are mask=1 in the output. ``chunk_phase`` r in [0, q) shifts chunk starts to r, r+q, r+2q, ... within each segment (used by the online sliding sampler to randomize which source frames become chunk-starts). """ transl_eps: float = 1e-4 rot_eps: float = 1e-4 # OFF by default: LIBERO demos have no frame below 1e-4 in practice. # Set 1e-4 explicitly for noisier datasets via the build CLI. clean_transl_eps: float = 0.0 clean_rot_eps: float = 0.0 direction_cos_threshold: float = -0.25 min_segment_len: int = 1 keep_still_segments: bool = True fps: int = 20 chunk_aligned_observation: bool = False chunk_phase: int = 0 def _as_float32_2d(values: Iterable[np.ndarray], expected_dim: int) -> np.ndarray: arr = np.asarray(list(values), dtype=np.float32) if arr.ndim != 2 or arr.shape[1] != expected_dim: raise ValueError(f"Expected shape (T, {expected_dim}), got {arr.shape}") return arr def clean_near_zero_actions( actions: np.ndarray, transl_eps: float = 1e-4, rot_eps: float = 1e-4, ) -> tuple[np.ndarray, np.ndarray]: """Zero tiny translation/rotation noise; gripper untouched. Returns (cleaned_actions, zeroed_mask of shape (T, 2)) where columns are [translation_zeroed, rotation_zeroed]. """ cleaned = np.asarray(actions, dtype=np.float32).copy() if cleaned.ndim != 2 or cleaned.shape[1] < 7: raise ValueError(f"Expected 7D actions, got {cleaned.shape}") transl_norm = np.linalg.norm(cleaned[:, :3], axis=1) rot_norm = np.linalg.norm(cleaned[:, 3:6], axis=1) transl_zeroed = transl_norm < transl_eps rot_zeroed = rot_norm < rot_eps cleaned[transl_zeroed, :3] = 0.0 cleaned[rot_zeroed, 3:6] = 0.0 return cleaned, np.stack([transl_zeroed, rot_zeroed], axis=1) def _motion_class(action: np.ndarray, transl_eps: float, rot_eps: float) -> int: has_translation = float(np.linalg.norm(action[:3])) >= transl_eps has_rotation = float(np.linalg.norm(action[3:6])) >= rot_eps if has_translation and has_rotation: return 3 if has_translation: return 1 if has_rotation: return 2 return 0 def _cosine(a: np.ndarray, b: np.ndarray) -> float: denom = float(np.linalg.norm(a) * np.linalg.norm(b)) if denom <= 1e-12: return 1.0 return float(np.dot(a, b) / denom) def segment_actions( actions: np.ndarray, config: SpeedTransformConfig, ) -> list[tuple[int, int, int]]: """Cut the episode into homogeneous-motion segments. A boundary is placed when (a) motion_class changes (still/translate/rotate/ both) or (b) translation or rotation direction reverses (cosine below ``direction_cos_threshold``). Gripper open/close is NEVER a cut criterion -- it's preserved later as an interval anchor in ``_segment_boundaries``. Returns a list of (start, end_exclusive, motion_class). """ if len(actions) == 0: return [] classes = np.asarray( [_motion_class(a, config.transl_eps, config.rot_eps) for a in actions], dtype=np.int32, ) segments: list[tuple[int, int, int]] = [] start = 0 for i in range(1, len(actions)): boundary = classes[i] != classes[i - 1] if not boundary and classes[i] in (1, 3): boundary = _cosine(actions[i - 1, :3], actions[i, :3]) < config.direction_cos_threshold if not boundary and classes[i] in (2, 3): boundary = _cosine(actions[i - 1, 3:6], actions[i, 3:6]) < config.direction_cos_threshold if boundary: segments.append((start, i, int(classes[i - 1]))) start = i segments.append((start, len(actions), int(classes[-1]))) if config.min_segment_len <= 1: return segments merged: list[tuple[int, int, int]] = [] for seg in segments: if not merged or (seg[1] - seg[0]) >= config.min_segment_len: merged.append(seg) continue prev_start, _prev_end, prev_cls = merged[-1] merged[-1] = (prev_start, seg[1], prev_cls) return merged def _interp_cumulative(cumulative: np.ndarray, x: float) -> np.ndarray: x = float(np.clip(x, 0.0, cumulative.shape[0] - 1)) left = int(np.floor(x)) right = min(left + 1, cumulative.shape[0] - 1) alpha = np.float32(x - left) return (1.0 - alpha) * cumulative[left] + alpha * cumulative[right] _CHUNK_DENOM_LIMIT = 32 def _speed_chunk_ratio(speed: float) -> tuple[int, int]: """Decompose speed as q/p with small denominator: q source frames -> p output frames. All ablation speeds (0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 4.0) fit within ``_CHUNK_DENOM_LIMIT=32``. Raises for irrational speeds. """ f = Fraction(speed).limit_denominator(_CHUNK_DENOM_LIMIT) if abs(float(f) - speed) > 1e-6: raise ValueError( f"speed={speed} cannot be expressed as a rational with denominator <= " f"{_CHUNK_DENOM_LIMIT}; chunk_aligned_observation requires a rational speed." ) return f.numerator, f.denominator def _segment_boundaries( src_actions: np.ndarray, speed: float, *, chunk_aligned: bool = False, chunk_phase: int = 0, ) -> np.ndarray: """Build resampling boundaries in source-frame coordinates. chunk_aligned=True layout (speed=q/p, phase=r): [0..r] leading passthrough (1:1, only when r > 0) [r, r+q] chunk 0 -> p sub-bins via linspace [r+q, r+2q] chunk 1 ... [r+Nq, n_src] trailing passthrough (1:1) Each full chunk integrates exactly q source frames into p outputs. The leftover frames at both ends are emitted verbatim so total motion is preserved end-to-end. chunk_aligned=False: uniform linspace(0, n_src, n_out+1) over the segment. Gripper open/close indices are added as anchor boundaries afterward so a switch event never gets averaged inside a coarse bin (segmentation never looks at the gripper channel). """ n_src = len(src_actions) if chunk_aligned: q, p = _speed_chunk_ratio(speed) if chunk_phase < 0 or chunk_phase >= q: raise ValueError(f"chunk_phase must be in [0, {q}), got {chunk_phase}") boundaries: list[float] = [] # Leading passthrough [0..r]: integer 1:1 boundaries. if chunk_phase > 0: boundaries.extend(float(i) for i in range(0, min(chunk_phase, n_src) + 1)) # Full chunks at r, r+q, r+2q, ... while the chunk fits. chunk_starts = list(range(chunk_phase, max(n_src - q + 1, chunk_phase), q)) for start in chunk_starts: boundaries.extend(np.linspace(float(start), float(start + q), p + 1).tolist()) # Trailing passthrough: each remaining source frame -> one verbatim output. tail_start = chunk_starts[-1] + q if chunk_starts else min(chunk_phase, n_src) if tail_start < n_src: boundaries.extend(float(i) for i in range(tail_start, n_src + 1)) else: n_out = max(1, int(round(n_src / speed))) boundaries = [float(x) for x in np.linspace(0.0, float(n_src), n_out + 1)] switch_indices = np.flatnonzero(np.abs(np.diff(src_actions[:, 6])) > 0.5) + 1 boundaries.extend(float(i) for i in switch_indices) deduped = sorted(set(boundaries)) return np.asarray(deduped, dtype=np.float32) def _resample_segment( actions: np.ndarray, states: np.ndarray, source_frame_indices: np.ndarray, start: int, end: int, motion_class: int, speed: float, chunk_aligned_observation: bool = False, chunk_phase: int = 0, ) -> dict[str, np.ndarray]: """Atomic chunk resampling: cumulative + linear interpolation. Key invariant: per output bin [left_t, right_t], action[:6] = cumulative(right_t) - cumulative(left_t) where cumulative is piecewise-linear over source indices. So summing output actions over a full chunk reproduces the source's integrated motion EXACTLY (no drift from non-integer speeds). state at output j is the source state at floor(left_t); at chunk-start boundaries (integer left_t) this is exact. gripper at output j is the source gripper at ceil(right_t)-1 -- gripper-switch anchors in ``_segment_boundaries`` ensure a switch is never averaged inside a bin. """ src_actions = actions[start:end] src_states = states[start:end] src_frames = source_frame_indices[start:end] n_src = end - start if n_src <= 0: raise ValueError("Cannot resample an empty segment") boundaries = _segment_boundaries( src_actions, speed, chunk_aligned=chunk_aligned_observation, chunk_phase=chunk_phase, ) n_out = len(boundaries) - 1 cumulative = np.concatenate( [np.zeros((1, 6), dtype=np.float32), np.cumsum(src_actions[:, :6], axis=0)], axis=0, ) out_actions = np.zeros((n_out, 7), dtype=np.float32) out_states = np.zeros((n_out, states.shape[1]), dtype=np.float32) out_source_frames = np.zeros(n_out, dtype=np.int64) out_source_steps = np.zeros(n_out, dtype=np.int64) observation_mask = np.ones(n_out, dtype=np.int8) used_sources: set[int] = set() for j in range(n_out): left_t = float(boundaries[j]) right_t = float(boundaries[j + 1]) out_actions[j, :6] = _interp_cumulative(cumulative, right_t) - _interp_cumulative(cumulative, left_t) source_local = min(int(np.floor(left_t)), n_src - 1) grip_local = min(max(int(np.ceil(right_t) - 1), 0), n_src - 1) out_actions[j, 6] = src_actions[grip_local, 6] out_states[j] = src_states[source_local] out_source_frames[j] = int(src_frames[source_local]) out_source_steps[j] = start + source_local # Legacy (non-chunk-aligned) slow-speed mask: when the same source frame # gets duplicated across consecutive bins, only the first copy is valid. if speed < 1.0 and source_local in used_sources: observation_mask[j] = 0 used_sources.add(source_local) if chunk_aligned_observation: # mask=1 only at chunk-aligned outputs (and -- for phase=0 -- trailing # passthrough rows). NOTE the phase asymmetry: # phase=0 : full chunk starts + ALL trailing-passthrough rows. # phase>0 : full chunk starts ONLY (leading + trailing passthrough # rows are mask=0). Random phase per access in the online # sampler rotates which source frames become valid. q, _p = _speed_chunk_ratio(speed) if chunk_phase < 0 or chunk_phase >= q: raise ValueError(f"chunk_phase must be in [0, {q}), got {chunk_phase}") new_mask = np.zeros(n_out, dtype=np.int8) n_full = n_src // q tail_threshold = n_full * q # phase=0 trailing passthrough starts here for j in range(n_out): left_t = float(boundaries[j]) if chunk_phase == 0: if left_t >= tail_threshold and n_full > 0: new_mask[j] = 1 else: chunk_idx_f = left_t / q chunk_idx = round(chunk_idx_f) if abs(chunk_idx_f - chunk_idx) < 1e-6 and chunk_idx * q < n_src: new_mask[j] = 1 else: chunk_idx_f = (left_t - chunk_phase) / q chunk_idx = round(chunk_idx_f) chunk_start = chunk_phase + chunk_idx * q if abs(chunk_idx_f - chunk_idx) < 1e-6 and chunk_start + q <= n_src: new_mask[j] = 1 observation_mask = new_mask return { "action": out_actions, "state": out_states, "source_frame_index": out_source_frames, "source_step_index": out_source_steps, "observation_mask": observation_mask, "segment_id": np.full(n_out, -1, dtype=np.int32), "motion_class": np.full(n_out, motion_class, dtype=np.int8), } def transform_episode( actions: np.ndarray, states: np.ndarray, source_frame_indices: np.ndarray, speed: float, config: SpeedTransformConfig, ) -> tuple[dict[str, np.ndarray], dict[str, float | int]]: """End-to-end episode resampling: clean -> segment -> resample each segment. Outputs are concatenated across segments into one continuous trajectory. The same ``chunk_phase`` is applied to every segment (one phase per call). """ if speed <= 0: raise ValueError(f"Speed must be positive, got {speed}") if config.chunk_phase != 0 and not config.chunk_aligned_observation: raise ValueError("chunk_phase is only supported when chunk_aligned_observation=True") actions = _as_float32_2d(actions, 7) states = np.asarray(states, dtype=np.float32) if states.ndim != 2: raise ValueError(f"Expected 2D states, got {states.shape}") source_frame_indices = np.asarray(source_frame_indices, dtype=np.int64) if len(actions) != len(states) or len(actions) != len(source_frame_indices): raise ValueError("actions, states, and source_frame_indices must have equal length") cleaned_actions, clean_mask = clean_near_zero_actions(actions, config.clean_transl_eps, config.clean_rot_eps) segments = segment_actions(cleaned_actions, config) if not config.keep_still_segments: segments = [s for s in segments if s[2] != 0] pieces = [] for segment_id, (start, end, motion_class) in enumerate(segments): piece = _resample_segment( cleaned_actions, states, source_frame_indices, start, end, motion_class, speed, chunk_aligned_observation=config.chunk_aligned_observation, chunk_phase=config.chunk_phase, ) piece["segment_id"][:] = segment_id pieces.append(piece) if not pieces: raise ValueError("Episode produced no segments") merged = {key: np.concatenate([piece[key] for piece in pieces], axis=0) for key in pieces[0]} source_steps = merged["source_step_index"] merged["cleaned_translation"] = clean_mask[source_steps, 0].astype(np.int8) merged["cleaned_rotation"] = clean_mask[source_steps, 1].astype(np.int8) merged["speed"] = np.full(len(merged["action"]), float(speed), dtype=np.float32) merged["action_mask"] = np.ones(len(merged["action"]), dtype=np.int8) merged["is_padded"] = (1 - merged["observation_mask"]).astype(np.int8) metrics = compute_replay_metrics(actions, merged["action"], speed) transl_zeroed = clean_mask[:, 0] rot_zeroed = clean_mask[:, 1] any_zeroed = transl_zeroed | rot_zeroed both_zeroed = transl_zeroed & rot_zeroed n_source = int(len(actions)) denom = max(n_source, 1) # Short segments => fewer full chunks => more 1:1 passthrough leftover at # non-integer speeds. seg_len distribution is the key diagnostic. seg_lens = np.asarray([end - start for (start, end, _) in segments], dtype=np.int64) seg_classes = np.asarray([cls for (_, _, cls) in segments], dtype=np.int64) metrics.update( { "segment_count": len(segments), "segments": len(segments), "source_frames": n_source, "output_frames": int(len(merged["action"])), "padded_frames": int(np.sum(merged["is_padded"])), "padded_ratio": float(np.mean(merged["is_padded"])), "cleaned_translation_frames": int(np.sum(transl_zeroed)), "cleaned_rotation_frames": int(np.sum(rot_zeroed)), "cleaned_any_frames": int(np.sum(any_zeroed)), "cleaned_both_frames": int(np.sum(both_zeroed)), "cleaned_translation_ratio": float(np.sum(transl_zeroed) / denom), "cleaned_rotation_ratio": float(np.sum(rot_zeroed) / denom), "cleaned_any_ratio": float(np.sum(any_zeroed) / denom), "cleaned_both_ratio": float(np.sum(both_zeroed) / denom), "segment_len_min": int(seg_lens.min()) if seg_lens.size else 0, "segment_len_max": int(seg_lens.max()) if seg_lens.size else 0, "segment_len_mean": float(seg_lens.mean()) if seg_lens.size else 0.0, "segment_len_median": float(np.median(seg_lens)) if seg_lens.size else 0.0, "segment_len_p10": float(np.percentile(seg_lens, 10)) if seg_lens.size else 0.0, "segment_len_p90": float(np.percentile(seg_lens, 90)) if seg_lens.size else 0.0, "motion_class_still_count": int(np.sum(seg_classes == 0)), "motion_class_translate_count": int(np.sum(seg_classes == 1)), "motion_class_rotate_count": int(np.sum(seg_classes == 2)), "motion_class_translate_rotate_count": int(np.sum(seg_classes == 3)), } ) return merged, metrics def compute_replay_metrics( source_actions: np.ndarray, replay_actions: np.ndarray, target_speed: float | None = None, ) -> dict[str, float | int]: """Sanity-check that resampling preserved integrated 6D motion. Reports per-axis path lengths, integrated translation/rotation L2 error, and gripper-switch count delta. NOT a task-success metric. """ source_actions = _as_float32_2d(source_actions, 7) replay_actions = _as_float32_2d(replay_actions, 7) source_motion = source_actions[:, :6].sum(axis=0) replay_motion = replay_actions[:, :6].sum(axis=0) source_steps = len(source_actions) replay_steps = len(replay_actions) actual_speed = source_steps / max(replay_steps, 1) translation_path_source = float(np.linalg.norm(source_actions[:, :3], axis=1).sum()) translation_path_replay = float(np.linalg.norm(replay_actions[:, :3], axis=1).sum()) rotation_path_source = float(np.linalg.norm(source_actions[:, 3:6], axis=1).sum()) rotation_path_replay = float(np.linalg.norm(replay_actions[:, 3:6], axis=1).sum()) gripper_switches_source = int(np.sum(np.abs(np.diff(source_actions[:, 6])) > 0.5)) gripper_switches_replay = int(np.sum(np.abs(np.diff(replay_actions[:, 6])) > 0.5)) out: dict[str, float | int] = { "target_speed": float(target_speed) if target_speed is not None else 1.0, "source_steps": int(source_steps), "replay_steps": int(replay_steps), "actual_speed": float(actual_speed), "speed_error": float(abs(actual_speed - target_speed)) if target_speed is not None else 0.0, "translation_path_source": translation_path_source, "translation_path_replay": translation_path_replay, "translation_path_ratio": translation_path_replay / max(translation_path_source, 1e-12), "rotation_path_source": rotation_path_source, "rotation_path_replay": rotation_path_replay, "rotation_path_ratio": rotation_path_replay / max(rotation_path_source, 1e-12), "integrated_translation_l2_error": float(np.linalg.norm(source_motion[:3] - replay_motion[:3])), "integrated_rotation_l2_error": float(np.linalg.norm(source_motion[3:] - replay_motion[3:])), "gripper_switches_source": gripper_switches_source, "gripper_switches_replay": gripper_switches_replay, "gripper_switch_delta": int(gripper_switches_replay - gripper_switches_source), } return out