import argparse import json import math import os import pickle from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import pandas as pd from natsort import natsorted from PIL import Image from pyrep.const import ConfigurationPathAlgorithms as Algos from pyrep.errors import ConfigurationPathError from pyrep.objects.joint import Joint from pyrep.objects.object import Object from pyrep.objects.shape import Shape from rlbench.action_modes.action_mode import BimanualJointPositionActionMode from rlbench.action_modes.gripper_action_modes import BimanualDiscrete from rlbench.backend.const import DEPTH_SCALE from rlbench.backend.utils import image_to_float_array, rgb_handles_to_mask from rlbench.bimanual_tasks.bimanual_take_tray_out_of_oven import ( BimanualTakeTrayOutOfOven, ) from rlbench.demo import Demo from rlbench.environment import Environment from rlbench.observation_config import CameraConfig, ObservationConfig from sklearn.metrics import ( average_precision_score, f1_score, roc_auc_score, ) FULL_CAMERA_SET = [ "front", "overhead", "wrist_right", "wrist_left", "over_shoulder_left", "over_shoulder_right", ] THREE_VIEW_SET = ["front", "wrist_right", "wrist_left"] DISPLAY = ":99" DEMO_DT = 0.05 DEFAULT_IMAGE_SIZE = (128, 128) DEFAULT_PATH_SCALE = 0.75 DEFAULT_PPRE_TAU = 0.45 DEFAULT_VISIBILITY_TAU = 0.35 DEFAULT_PEXT_TAU = 0.45 DEFAULT_DOOR_SPEED_TAU = 0.08 DEFAULT_PHASE_SCORE_TAU = 0.5 DEFAULT_APPROACH_SPEED_TAU = 0.01 DEFAULT_APPROACH_PROGRESS_TAU = 0.02 DEFAULT_PREGRASP_LABEL_PROGRESS_TAU = 0.60 DEFAULT_APPROACH_ONSET_WINDOW = 96 DEFAULT_PREGRASP_CANDIDATE_COUNT = 4 DEFAULT_PLAN_TRIALS = 48 DEFAULT_PLAN_MAX_CONFIGS = 4 DEFAULT_PLAN_MAX_TIME_MS = 10 DEFAULT_PLAN_TRIALS_PER_GOAL = 4 DEFAULT_PLAN_ATTEMPTS = 2 DEFAULT_PLAN_MIN_SUCCESSES = 2 DEFAULT_READY_PERSISTENCE = 3 DEFAULT_RETRIEVE_PERSISTENCE = 3 DEFAULT_PREGRASP_PERSISTENCE = 3 DEFAULT_MASK_HANDLE_COUNT = 2 @dataclass class SimulatorSnapshot: task_state: Tuple[bytes, int] right_arm_tree: bytes right_gripper_tree: bytes left_arm_tree: bytes left_gripper_tree: bytes right_arm_joints: Tuple[float, ...] left_arm_joints: Tuple[float, ...] right_gripper_joints: Tuple[float, ...] left_gripper_joints: Tuple[float, ...] right_grasped: Tuple[str, ...] left_grasped: Tuple[str, ...] right_grasped_old_parents: Dict[str, Optional[str]] left_grasped_old_parents: Dict[str, Optional[str]] grasped_subtree_poses: Dict[str, Tuple[float, ...]] @dataclass class ReplayState: frame_index: int tray_pose: np.ndarray door_angle: float right_gripper_pose: np.ndarray left_gripper_pose: np.ndarray right_gripper_open: float left_gripper_open: float snapshot: Optional[SimulatorSnapshot] = None @dataclass class MotionTemplates: pregrasp_rel_pose: np.ndarray grasp_rel_pose: np.ndarray retreat_rel_poses: List[np.ndarray] grasp_local_center: np.ndarray grasp_region_extents: np.ndarray hold_open_angle: float open_more_delta: float reference_tray_height: float approach_rel_poses: List[np.ndarray] = field(default_factory=list) mask_handle_ids: List[int] = field(default_factory=list) def to_json(self) -> Dict[str, object]: return { "approach_rel_poses": [pose.tolist() for pose in self.approach_rel_poses], "pregrasp_rel_pose": self.pregrasp_rel_pose.tolist(), "grasp_rel_pose": self.grasp_rel_pose.tolist(), "retreat_rel_poses": [pose.tolist() for pose in self.retreat_rel_poses], "grasp_local_center": self.grasp_local_center.tolist(), "grasp_region_extents": self.grasp_region_extents.tolist(), "hold_open_angle": float(self.hold_open_angle), "open_more_delta": float(self.open_more_delta), "reference_tray_height": float(self.reference_tray_height), "mask_handle_ids": [int(handle) for handle in self.mask_handle_ids], } @classmethod def from_json(cls, payload: Dict[str, object]) -> "MotionTemplates": return cls( pregrasp_rel_pose=np.asarray(payload["pregrasp_rel_pose"], dtype=np.float64), grasp_rel_pose=np.asarray(payload["grasp_rel_pose"], dtype=np.float64), retreat_rel_poses=[ np.asarray(pose, dtype=np.float64) for pose in payload.get("retreat_rel_poses", []) ], grasp_local_center=np.asarray(payload["grasp_local_center"], dtype=np.float64), grasp_region_extents=np.asarray(payload["grasp_region_extents"], dtype=np.float64), hold_open_angle=float(payload["hold_open_angle"]), open_more_delta=float(payload["open_more_delta"]), reference_tray_height=float(payload["reference_tray_height"]), approach_rel_poses=[ np.asarray(pose, dtype=np.float64) for pose in payload.get("approach_rel_poses", []) ], mask_handle_ids=[int(handle) for handle in payload.get("mask_handle_ids", [])], ) @dataclass class EpisodeArtifacts: episode_name: str dense: pd.DataFrame keyframes: pd.DataFrame metrics: Dict[str, object] template_frames: Dict[str, int] def _configure_runtime() -> None: os.environ["DISPLAY"] = os.environ.get("DISPLAY", DISPLAY) os.environ["COPPELIASIM_ROOT"] = "/workspace/coppelia_sim" ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") if "/workspace/coppelia_sim" not in ld_library_path: os.environ["LD_LIBRARY_PATH"] = ( f"{ld_library_path}:/workspace/coppelia_sim" if ld_library_path else "/workspace/coppelia_sim" ) os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = "/workspace/coppelia_sim" os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp/runtime-root") def _minimal_camera_config() -> Dict[str, CameraConfig]: return { "front": CameraConfig( rgb=False, depth=False, point_cloud=False, mask=False, image_size=DEFAULT_IMAGE_SIZE, ) } def _make_observation_config() -> ObservationConfig: return ObservationConfig( camera_configs=_minimal_camera_config(), joint_velocities=True, joint_positions=True, joint_forces=False, gripper_open=True, gripper_pose=True, gripper_matrix=False, gripper_joint_positions=True, gripper_touch_forces=False, task_low_dim_state=False, robot_name="bimanual", ) def _launch_replay_env() -> Environment: _configure_runtime() env = Environment( action_mode=BimanualJointPositionActionMode(), obs_config=_make_observation_config(), headless=True, robot_setup="dual_panda", ) env.launch() return env def _load_demo(episode_dir: Path) -> Demo: with episode_dir.joinpath("low_dim_obs.pkl").open("rb") as handle: return pickle.load(handle) def _load_descriptions(episode_dir: Path) -> List[str]: with episode_dir.joinpath("variation_descriptions.pkl").open("rb") as handle: return pickle.load(handle) def _episode_dirs(dataset_root: Path) -> List[Path]: episodes_dir = dataset_root.joinpath("all_variations", "episodes") return [ episodes_dir.joinpath(name) for name in natsorted(os.listdir(episodes_dir)) if episodes_dir.joinpath(name, "low_dim_obs.pkl").exists() ] def _camera_file(episode_dir: Path, camera_name: str, kind: str, frame_index: int) -> Path: if kind == "rgb": return episode_dir.joinpath(f"{camera_name}_rgb", f"rgb_{frame_index:04d}.png") if kind == "depth": return episode_dir.joinpath( f"{camera_name}_depth", f"depth_{frame_index:04d}.png" ) if kind == "mask": return episode_dir.joinpath(f"{camera_name}_mask", f"mask_{frame_index:04d}.png") raise ValueError(f"unknown kind: {kind}") def _load_depth_meters(episode_dir: Path, demo: Demo, frame_index: int, camera_name: str) -> np.ndarray: image = Image.open(_camera_file(episode_dir, camera_name, "depth", frame_index)) depth = image_to_float_array(image, DEPTH_SCALE) near = demo[frame_index].misc[f"{camera_name}_camera_near"] far = demo[frame_index].misc[f"{camera_name}_camera_far"] return near + depth * (far - near) def _load_mask(episode_dir: Path, frame_index: int, camera_name: str) -> np.ndarray: image = np.asarray( Image.open(_camera_file(episode_dir, camera_name, "mask", frame_index)), dtype=np.float32, ) if image.ndim == 2: image = np.repeat(image[..., None], 3, axis=2) if image.max() > 1.0: image /= 255.0 return rgb_handles_to_mask(image.copy()) def _capture_snapshot(task) -> SimulatorSnapshot: robot = task._scene.robot right_grasped = tuple(robot.right_gripper.get_grasped_objects()) left_grasped = tuple(robot.left_gripper.get_grasped_objects()) grasped_subtree_poses: Dict[str, Tuple[float, ...]] = {} for grasped_object in right_grasped + left_grasped: for subtree_object in grasped_object.get_objects_in_tree(exclude_base=False): grasped_subtree_poses[subtree_object.get_name()] = tuple( float(value) for value in subtree_object.get_pose() ) return SimulatorSnapshot( task_state=task._task.get_state(), right_arm_tree=robot.right_arm.get_configuration_tree(), right_gripper_tree=robot.right_gripper.get_configuration_tree(), left_arm_tree=robot.left_arm.get_configuration_tree(), left_gripper_tree=robot.left_gripper.get_configuration_tree(), right_arm_joints=tuple(float(value) for value in robot.right_arm.get_joint_positions()), left_arm_joints=tuple(float(value) for value in robot.left_arm.get_joint_positions()), right_gripper_joints=tuple(float(value) for value in robot.right_gripper.get_joint_positions()), left_gripper_joints=tuple(float(value) for value in robot.left_gripper.get_joint_positions()), right_grasped=tuple(obj.get_name() for obj in right_grasped), left_grasped=tuple(obj.get_name() for obj in left_grasped), right_grasped_old_parents={ obj.get_name(): ( None if old_parent is None else old_parent.get_name() ) for obj, old_parent in zip(right_grasped, robot.right_gripper._old_parents) }, left_grasped_old_parents={ obj.get_name(): ( None if old_parent is None else old_parent.get_name() ) for obj, old_parent in zip(left_grasped, robot.left_gripper._old_parents) }, grasped_subtree_poses=grasped_subtree_poses, ) def _restore_snapshot(task, snapshot: SimulatorSnapshot) -> None: robot = task._scene.robot snapshot_has_grasp = bool(snapshot.right_grasped or snapshot.left_grasped) if not snapshot_has_grasp: robot.release_gripper() try: task._task.restore_state(snapshot.task_state) except RuntimeError: task._pyrep.set_configuration_tree(snapshot.task_state[0]) task._pyrep.set_configuration_tree(snapshot.right_arm_tree) task._pyrep.set_configuration_tree(snapshot.right_gripper_tree) task._pyrep.set_configuration_tree(snapshot.left_arm_tree) task._pyrep.set_configuration_tree(snapshot.left_gripper_tree) robot.right_arm.set_joint_positions(list(snapshot.right_arm_joints), disable_dynamics=True) robot.left_arm.set_joint_positions(list(snapshot.left_arm_joints), disable_dynamics=True) robot.right_arm.set_joint_target_positions(list(snapshot.right_arm_joints)) robot.left_arm.set_joint_target_positions(list(snapshot.left_arm_joints)) robot.right_gripper.set_joint_positions( list(snapshot.right_gripper_joints), disable_dynamics=True ) robot.left_gripper.set_joint_positions( list(snapshot.left_gripper_joints), disable_dynamics=True ) robot.right_gripper.set_joint_target_positions(list(snapshot.right_gripper_joints)) robot.left_gripper.set_joint_target_positions(list(snapshot.left_gripper_joints)) if snapshot_has_grasp: robot.release_gripper() for name, pose in snapshot.grasped_subtree_poses.items(): Object.get_object(name).set_pose(np.asarray(pose, dtype=np.float64)) task._pyrep.step() if not snapshot_has_grasp: robot.release_gripper() for name in snapshot.right_grasped: _force_attach_grasped_object( robot.right_gripper, Shape(name), snapshot.right_grasped_old_parents.get(name), ) for name in snapshot.left_grasped: _force_attach_grasped_object( robot.left_gripper, Shape(name), snapshot.left_grasped_old_parents.get(name), ) task._pyrep.step() def _force_attach_grasped_object( gripper, obj: Shape, old_parent_name: Optional[str] ) -> None: if any(grasped.get_name() == obj.get_name() for grasped in gripper.get_grasped_objects()): return gripper._grasped_objects.append(obj) old_parent = obj.get_parent() if old_parent_name is None else Object.get_object(old_parent_name) gripper._old_parents.append(old_parent) obj.set_parent(gripper._attach_point, keep_in_place=True) def _build_joint_action(target_obs) -> np.ndarray: def _joint_vector(value, fallback) -> np.ndarray: array = np.asarray(fallback if value is None else value, dtype=np.float64) if array.ndim != 1 or array.shape[0] != 7: array = np.asarray(fallback, dtype=np.float64) if array.ndim != 1 or array.shape[0] != 7: raise ValueError(f"invalid joint vector shape: {array.shape}") return array right = _joint_vector( target_obs.misc.get("right_executed_demo_joint_position_action"), target_obs.right.joint_positions, ) left = _joint_vector( target_obs.misc.get("left_executed_demo_joint_position_action"), target_obs.left.joint_positions, ) return np.concatenate( [ right, np.array([target_obs.right.gripper_open], dtype=np.float64), left, np.array([target_obs.left.gripper_open], dtype=np.float64), ] ) class ReplayCache: def __init__(self, task, demo: Demo, checkpoint_stride: int = 16): self.task = task self.demo = demo self.checkpoint_stride = checkpoint_stride self.current_index = 0 self.current_obs = None self.checkpoints: Dict[int, SimulatorSnapshot] = {} self.discrete_gripper = BimanualDiscrete() def reset(self) -> None: _, self.current_obs = self.task.reset_to_demo(self.demo) self.current_index = 0 self.checkpoints = {0: _capture_snapshot(self.task)} def step_to(self, target_index: int): if target_index < self.current_index: checkpoint_index = max(i for i in self.checkpoints if i <= target_index) _restore_snapshot(self.task, self.checkpoints[checkpoint_index]) self.current_index = checkpoint_index self.current_obs = self._observation_from_scene() while self.current_index < target_index: next_index = self.current_index + 1 self.task._action_mode.action( self.task._scene, _build_joint_action(self.demo[next_index]) ) self._apply_gripper_replay(self.demo[next_index]) self.current_obs = self._observation_from_scene() self.current_index = next_index if self.current_index % self.checkpoint_stride == 0: self.checkpoints[self.current_index] = _capture_snapshot(self.task) return self.current_obs def current_state(self) -> ReplayState: return ReplayState( frame_index=self.current_index, tray_pose=Shape("tray").get_pose(), door_angle=Joint("oven_door_joint").get_joint_position(), right_gripper_pose=self.current_obs.right.gripper_pose.copy(), left_gripper_pose=self.current_obs.left.gripper_pose.copy(), right_gripper_open=float(self.current_obs.right.gripper_open), left_gripper_open=float(self.current_obs.left.gripper_open), snapshot=None, ) def snapshot(self) -> SimulatorSnapshot: return _capture_snapshot(self.task) def restore(self, snapshot: SimulatorSnapshot) -> None: _restore_snapshot(self.task, snapshot) self.current_obs = self._observation_from_scene() def restore_to_index(self, snapshot: SimulatorSnapshot, frame_index: int) -> None: self.restore(snapshot) self.current_index = frame_index def _observation_from_scene(self): return self.task._scene.get_observation() def _apply_gripper_replay(self, target_obs) -> None: desired = np.array( [target_obs.right.gripper_open, target_obs.left.gripper_open], dtype=np.float64, ) current = np.array( [ float(all(x > 0.9 for x in self.task._scene.robot.right_gripper.get_open_amount())), float(all(x > 0.9 for x in self.task._scene.robot.left_gripper.get_open_amount())), ], dtype=np.float64, ) if not np.allclose(current, desired): self.discrete_gripper.action(self.task._scene, desired) self.current_obs = self._observation_from_scene() self._maintain_grasp_state(desired) self.current_obs = self._observation_from_scene() def _maintain_grasp_state(self, desired: np.ndarray) -> None: scene = self.task._scene robot = scene.robot if float(desired[0]) <= 0.5: left_grasped = {obj.get_name() for obj in robot.left_gripper.get_grasped_objects()} for graspable in scene.task.get_graspable_objects(): if graspable.get_name() not in left_grasped: robot.right_gripper.grasp(graspable) elif robot.right_gripper.get_grasped_objects(): robot.right_gripper.release() if float(desired[1]) <= 0.5: right_grasped = {obj.get_name() for obj in robot.right_gripper.get_grasped_objects()} for graspable in scene.task.get_graspable_objects(): if graspable.get_name() not in right_grasped: robot.left_gripper.grasp(graspable) elif robot.left_gripper.get_grasped_objects(): robot.left_gripper.release() def _quat_to_matrix(quat: Sequence[float]) -> np.ndarray: x, y, z, w = quat xx, yy, zz = x * x, y * y, z * z xy, xz, yz = x * y, x * z, y * z wx, wy, wz = w * x, w * y, w * z return np.array( [ [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], ], dtype=np.float64, ) def _matrix_to_quat(matrix: np.ndarray) -> np.ndarray: trace = np.trace(matrix) if trace > 0.0: s = math.sqrt(trace + 1.0) * 2.0 w = 0.25 * s x = (matrix[2, 1] - matrix[1, 2]) / s y = (matrix[0, 2] - matrix[2, 0]) / s z = (matrix[1, 0] - matrix[0, 1]) / s elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]: s = math.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2.0 w = (matrix[2, 1] - matrix[1, 2]) / s x = 0.25 * s y = (matrix[0, 1] + matrix[1, 0]) / s z = (matrix[0, 2] + matrix[2, 0]) / s elif matrix[1, 1] > matrix[2, 2]: s = math.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2.0 w = (matrix[0, 2] - matrix[2, 0]) / s x = (matrix[0, 1] + matrix[1, 0]) / s y = 0.25 * s z = (matrix[1, 2] + matrix[2, 1]) / s else: s = math.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2.0 w = (matrix[1, 0] - matrix[0, 1]) / s x = (matrix[0, 2] + matrix[2, 0]) / s y = (matrix[1, 2] + matrix[2, 1]) / s z = 0.25 * s quat = np.array([x, y, z, w], dtype=np.float64) return quat / np.linalg.norm(quat) def _pose_to_matrix(pose: Sequence[float]) -> np.ndarray: matrix = np.eye(4, dtype=np.float64) matrix[:3, :3] = _quat_to_matrix(pose[3:]) matrix[:3, 3] = pose[:3] return matrix def _matrix_to_pose(matrix: np.ndarray) -> np.ndarray: return np.concatenate([matrix[:3, 3], _matrix_to_quat(matrix[:3, :3])], axis=0) def _relative_pose(reference_pose: Sequence[float], target_pose: Sequence[float]) -> np.ndarray: reference = _pose_to_matrix(reference_pose) target = _pose_to_matrix(target_pose) rel = np.linalg.inv(reference) @ target return _matrix_to_pose(rel) def _apply_relative_pose(reference_pose: Sequence[float], relative_pose: Sequence[float]) -> np.ndarray: reference = _pose_to_matrix(reference_pose) relative = _pose_to_matrix(relative_pose) world = reference @ relative return _matrix_to_pose(world) def _world_to_local(reference_pose: Sequence[float], point_world: Sequence[float]) -> np.ndarray: reference = _pose_to_matrix(reference_pose) point = np.concatenate([np.asarray(point_world, dtype=np.float64), [1.0]]) return (np.linalg.inv(reference) @ point)[:3] def _local_to_world(reference_pose: Sequence[float], point_local: Sequence[float]) -> np.ndarray: reference = _pose_to_matrix(reference_pose) point = np.concatenate([np.asarray(point_local, dtype=np.float64), [1.0]]) return (reference @ point)[:3] def _first_transition(demo: Demo, side: str, open_to_closed: bool) -> int: values = [getattr(demo[i], side).gripper_open for i in range(len(demo))] for i in range(1, len(values)): if open_to_closed and values[i - 1] > 0.5 and values[i] < 0.5: return i if not open_to_closed and values[i - 1] < 0.5 and values[i] > 0.5: return i raise RuntimeError(f"no gripper transition found for {side}") def _detect_pregrasp_approach_onset( cache: ReplayCache, left_close: int, pregrasp_rel_pose: np.ndarray, ) -> int: distances: List[float] = [] for frame_index in range(left_close): cache.step_to(frame_index) tray_pose = Shape("tray").get_pose() target_pose = _apply_relative_pose(tray_pose, pregrasp_rel_pose) current_pose = np.asarray(cache.current_obs.left.gripper_pose, dtype=np.float64) distances.append(float(np.linalg.norm(current_pose[:3] - target_pose[:3]))) distances_arr = np.asarray(distances, dtype=np.float64) if len(distances_arr) < 8: return max(0, left_close - 24) search_start = max(8, left_close - DEFAULT_APPROACH_ONSET_WINDOW) for frame_index in range(search_start, max(search_start, left_close - 6)): window = distances_arr[frame_index : frame_index + 6] if len(window) < 6: break short_drop = float(window[0] - window[-1]) if short_drop < 0.01: continue if np.mean(np.diff(window)) < -0.002: return frame_index return max(search_start, left_close - 24) def _derive_templates(dataset_root: Path, template_episode_dir: Path) -> Tuple[MotionTemplates, Dict[str, int]]: env = _launch_replay_env() try: demo = _load_demo(template_episode_dir) task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=8) cache.reset() base_pose = task._task.get_base().get_pose() left_close = _first_transition(demo, "left", open_to_closed=True) left_open = _first_transition(demo, "left", open_to_closed=False) pregrasp_index = max(0, left_close - 5) right_close = _first_transition(demo, "right", open_to_closed=True) right_open = _first_transition(demo, "right", open_to_closed=False) bootstrap_indices = sorted({pregrasp_index, left_close, right_close, right_open, left_open}) states: Dict[int, ReplayState] = {} for frame_index in bootstrap_indices: cache.step_to(frame_index) states[frame_index] = cache.current_state() pregrasp_rel_pose = _relative_pose( states[pregrasp_index].tray_pose, states[pregrasp_index].left_gripper_pose ) grasp_rel_pose = _relative_pose( states[left_close].tray_pose, states[left_close].left_gripper_pose ) approach_onset = _detect_pregrasp_approach_onset(cache, left_close, pregrasp_rel_pose) approach_indices = sorted( { *np.linspace( approach_onset, max(approach_onset, pregrasp_index), num=min(6, max(2, pregrasp_index - approach_onset + 1)), dtype=int, ).tolist(), pregrasp_index, max(0, left_close - 2), } ) retreat_indices = sorted( { *[ min(len(demo) - 1, max(left_close + 1, left_close + offset)) for offset in (5, 10, 15, 20, 30, 40, 50) ], max(left_close + 1, left_open - 20), max(left_close + 1, left_open - 10), max(left_close + 1, left_open - 5), left_open, } ) interesting = sorted( { *bootstrap_indices, *approach_indices, *retreat_indices, } ) for frame_index in interesting: if frame_index not in states: cache.step_to(frame_index) states[frame_index] = cache.current_state() approach_rel_poses = [ _relative_pose(states[index].tray_pose, states[index].left_gripper_pose) for index in approach_indices if index < left_close ] retreat_rel_poses = [ _relative_pose(base_pose, states[index].left_gripper_pose) for index in retreat_indices if index > left_close ] grasp_local_center = _world_to_local( states[left_close].tray_pose, states[left_close].left_gripper_pose[:3] ) templates = MotionTemplates( pregrasp_rel_pose=pregrasp_rel_pose, grasp_rel_pose=grasp_rel_pose, retreat_rel_poses=retreat_rel_poses, grasp_local_center=grasp_local_center, grasp_region_extents=np.array([0.03, 0.015, 0.004], dtype=np.float64), hold_open_angle=float(states[right_open].door_angle), open_more_delta=max( 0.12, float(states[right_open].door_angle - states[right_close].door_angle) * 0.25, ), reference_tray_height=float(states[left_close].tray_pose[2]), approach_rel_poses=approach_rel_poses, ) templates.mask_handle_ids = _infer_tray_mask_handle_ids( episode_dir=template_episode_dir, demo=demo, cache=cache, templates=templates, reference_frames=approach_indices[-4:] + [left_close], ) template_frames = { "pregrasp": pregrasp_index, "grasp": left_close, "right_close": right_close, "right_open": right_open, "approach": approach_indices, "retreat": retreat_indices, } return templates, template_frames finally: env.shutdown() def _camera_projection(extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: camera_pos = extrinsics[:3, 3:4] rotation = extrinsics[:3, :3] world_to_camera = np.concatenate([rotation.T, -(rotation.T @ camera_pos)], axis=1) projection = intrinsics @ world_to_camera return projection, world_to_camera def _project_points(points_world: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: projection, world_to_camera = _camera_projection(extrinsics, intrinsics) homogeneous = np.concatenate([points_world, np.ones((len(points_world), 1))], axis=1) camera_xyz = (world_to_camera @ homogeneous.T).T image_xyz = (projection @ homogeneous.T).T uv = image_xyz[:, :2] / image_xyz[:, 2:3] return uv, camera_xyz def _sample_grasp_points(templates: MotionTemplates, tray_pose: np.ndarray) -> np.ndarray: center = templates.grasp_local_center extents = templates.grasp_region_extents xs = np.linspace(center[0] - extents[0], center[0] + extents[0], 9) ys = np.linspace(center[1] - extents[1], center[1] + extents[1], 5) zs = np.linspace(center[2] - extents[2], center[2] + extents[2], 3) points_local = np.array([[x, y, z] for x in xs for y in ys for z in zs], dtype=np.float64) return np.array([_local_to_world(tray_pose, point) for point in points_local], dtype=np.float64) def _sample_full_tray_points(tray_pose: np.ndarray) -> np.ndarray: tray = Shape("tray") bbox = np.asarray(tray.get_bounding_box(), dtype=np.float64) xs = np.linspace(bbox[0], bbox[1], 10) ys = np.linspace(bbox[2], bbox[3], 12) zs = np.linspace(bbox[4], bbox[5], 3) points_local = np.array([[x, y, z] for x in xs for y in ys for z in zs], dtype=np.float64) return np.array([_local_to_world(tray_pose, point) for point in points_local], dtype=np.float64) def _infer_tray_mask_handle_ids( episode_dir: Path, demo: Demo, cache: ReplayCache, templates: MotionTemplates, reference_frames: Sequence[int], max_handles: int = DEFAULT_MASK_HANDLE_COUNT, ) -> List[int]: counts: Dict[int, int] = {} unique_frames = sorted({int(frame_index) for frame_index in reference_frames}) for frame_index in unique_frames: cache.step_to(frame_index) grasp_points = _sample_grasp_points(templates, Shape("tray").get_pose()) for camera_name in FULL_CAMERA_SET: mask = _load_mask(episode_dir, frame_index, camera_name) extrinsics = demo[frame_index].misc[f"{camera_name}_camera_extrinsics"] intrinsics = demo[frame_index].misc[f"{camera_name}_camera_intrinsics"] uv, camera_xyz = _project_points(grasp_points, extrinsics, intrinsics) height, width = mask.shape for (u, v), (_, _, camera_depth) in zip(uv, camera_xyz): if camera_depth <= 0 or not (0 <= u < width and 0 <= v < height): continue px = min(max(int(round(float(u))), 0), width - 1) py = min(max(int(round(float(v))), 0), height - 1) handle = int(mask[py, px]) if handle == 0: continue counts[handle] = counts.get(handle, 0) + 1 if not counts: return [] ranked = sorted(counts.items(), key=lambda item: item[1], reverse=True) top_count = ranked[0][1] selected = [ handle for handle, count in ranked if count >= max(4, int(math.ceil(top_count * 0.5))) ][:max_handles] return selected if selected else [ranked[0][0]] def _mask_visibility_ratio( points_world: np.ndarray, mask: np.ndarray, handle_ids: Sequence[int], extrinsics: np.ndarray, intrinsics: np.ndarray, ) -> float: if not handle_ids: return 0.0 handle_set = {int(handle) for handle in handle_ids} uv, camera_xyz = _project_points(points_world, extrinsics, intrinsics) height, width = mask.shape visible = 0 total = 0 for (u, v), (_, _, camera_depth) in zip(uv, camera_xyz): if camera_depth <= 0: continue if not (0 <= u < width and 0 <= v < height): continue total += 1 px = int(round(u)) py = int(round(v)) px = min(max(px, 0), width - 1) py = min(max(py, 0), height - 1) if int(mask[py, px]) in handle_set: visible += 1 return float(visible / total) if total else 0.0 def _union_visibility(values: Iterable[float]) -> float: product = 1.0 for value in values: product *= 1.0 - float(value) return 1.0 - product def _keypoint_discovery(demo: Demo, stopping_delta: float = 0.1) -> List[int]: keypoints: List[int] = [] right_prev = demo[0].right.gripper_open left_prev = demo[0].left.gripper_open stopped_buffer = 0 for i, obs in enumerate(demo._observations): if i < 2 or i >= len(demo) - 1: right_stopped = left_stopped = False else: right_stopped = ( np.allclose(obs.right.joint_velocities, 0, atol=stopping_delta) and obs.right.gripper_open == demo[i + 1].right.gripper_open and obs.right.gripper_open == demo[i - 1].right.gripper_open and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open ) left_stopped = ( np.allclose(obs.left.joint_velocities, 0, atol=stopping_delta) and obs.left.gripper_open == demo[i + 1].left.gripper_open and obs.left.gripper_open == demo[i - 1].left.gripper_open and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open ) stopped = stopped_buffer <= 0 and right_stopped and left_stopped stopped_buffer = 4 if stopped else stopped_buffer - 1 last = i == len(demo) - 1 state_changed = ( obs.right.gripper_open != right_prev or obs.left.gripper_open != left_prev ) if i != 0 and (state_changed or last or stopped): keypoints.append(i) right_prev = obs.right.gripper_open left_prev = obs.left.gripper_open if len(keypoints) > 1 and (keypoints[-1] - 1) == keypoints[-2]: keypoints.pop(-2) return keypoints def _plan_path(scene, arm_name: str, pose: np.ndarray, ignore_collisions: bool = False): arm = scene.robot.left_arm if arm_name == "left" else scene.robot.right_arm try: return arm.get_path( pose[:3], quaternion=pose[3:], ignore_collisions=ignore_collisions, trials=DEFAULT_PLAN_TRIALS, max_configs=DEFAULT_PLAN_MAX_CONFIGS, max_time_ms=DEFAULT_PLAN_MAX_TIME_MS, trials_per_goal=DEFAULT_PLAN_TRIALS_PER_GOAL, algorithm=Algos.RRTConnect, ) except Exception: return None def _stable_plan( scene, arm_name: str, pose: np.ndarray, ignore_collisions: bool = False, ) -> Tuple[Optional[object], float, float]: attempts = max(1, DEFAULT_PLAN_ATTEMPTS) successful_paths: List[Tuple[float, object]] = [] for _ in range(attempts): path = _plan_path(scene, arm_name, pose, ignore_collisions=ignore_collisions) length = _path_length(path) if path is None or not np.isfinite(length): continue successful_paths.append((float(length), path)) if not successful_paths: return None, math.inf, 0.0 successful_paths.sort(key=lambda item: item[0]) stable_length = float(np.median([length for length, _ in successful_paths])) reliability = float(len(successful_paths) / attempts) return successful_paths[0][1], stable_length, reliability def _plan_is_reliable(reliability: float) -> bool: attempts = max(1, DEFAULT_PLAN_ATTEMPTS) required = min(attempts, max(1, DEFAULT_PLAN_MIN_SUCCESSES)) return reliability >= (required / attempts) def _path_length(path) -> float: if path is None: return math.inf try: return float(path._get_path_point_lengths()[-1]) except Exception: return math.inf def _dedupe_pose_list(poses: Iterable[np.ndarray], precision: int = 4) -> List[np.ndarray]: unique: List[np.ndarray] = [] seen = set() for pose in poses: key = tuple(np.round(np.asarray(pose, dtype=np.float64), precision)) if key in seen: continue seen.add(key) unique.append(np.asarray(pose, dtype=np.float64)) return unique def _pregrasp_progress_and_distance( current_pose: np.ndarray, tray_pose: np.ndarray, templates: MotionTemplates, ) -> Tuple[float, float]: goal_pose = _apply_relative_pose(tray_pose, templates.pregrasp_rel_pose) if templates.approach_rel_poses: start_pose = _apply_relative_pose(tray_pose, templates.approach_rel_poses[0]) span = float(np.linalg.norm(start_pose[:3] - goal_pose[:3])) else: span = 0.12 distance = float(np.linalg.norm(current_pose[:3] - goal_pose[:3])) progress = 1.0 - (distance / max(span, 1e-6)) return float(np.clip(progress, 0.0, 1.0)), distance def _pregrasp_candidates(tray_pose: np.ndarray, templates: MotionTemplates) -> List[np.ndarray]: if templates.approach_rel_poses: rel_poses = templates.approach_rel_poses[-min( DEFAULT_PREGRASP_CANDIDATE_COUNT, len(templates.approach_rel_poses) ) :] else: rel_poses = [templates.pregrasp_rel_pose] candidates = [_apply_relative_pose(tray_pose, rel_pose) for rel_pose in rel_poses] for rel_pose in rel_poses[-1:]: base = _apply_relative_pose(tray_pose, rel_pose) for dx in (-0.02, 0.02): perturbed = base.copy() perturbed[0] += dx candidates.append(perturbed) return _dedupe_pose_list(candidates) def _extract_sequence_poses( tray_pose: np.ndarray, task_base_pose: np.ndarray, templates: MotionTemplates ) -> List[np.ndarray]: poses = [ _apply_relative_pose(tray_pose, templates.pregrasp_rel_pose), _apply_relative_pose(tray_pose, templates.grasp_rel_pose), ] if templates.retreat_rel_poses: retreat_indices = np.linspace( 0, len(templates.retreat_rel_poses) - 1, num=min(3, len(templates.retreat_rel_poses)), dtype=int, ) poses.extend( _apply_relative_pose(task_base_pose, templates.retreat_rel_poses[index]) for index in sorted(set(retreat_indices.tolist())) ) return poses def _extract_height_threshold(templates: MotionTemplates) -> float: return templates.reference_tray_height + 0.06 def _extraction_progress_score(current_height: float, templates: MotionTemplates) -> float: baseline = float(templates.reference_tray_height) threshold = _extract_height_threshold(templates) current_height = float(current_height) if current_height <= baseline: return 0.0 if current_height < threshold: lift_fraction = (current_height - baseline) / max(threshold - baseline, 1e-6) return float(np.clip(0.8 * lift_fraction, 0.0, 0.8)) margin = current_height - threshold # Saturate smoothly once the tray is clearly lifted above the oven lip. return float(min(1.0, 0.8 + margin / 0.12)) def _pregrasp_score_and_success(task, templates: MotionTemplates) -> Tuple[float, bool]: tray = Shape("tray") if any( grasped.get_name() == tray.get_name() for grasped in task._scene.robot.left_gripper.get_grasped_objects() ): return 1.0, True tray_pose = Shape("tray").get_pose() current_pose = np.asarray(task._scene.robot.left_gripper.get_pose(), dtype=np.float64) progress, distance_to_pregrasp = _pregrasp_progress_and_distance( current_pose, tray_pose, templates ) best = progress success = False late_approach_poses = [ _apply_relative_pose(tray_pose, rel_pose) for rel_pose in templates.approach_rel_poses[-min(1, len(templates.approach_rel_poses)) :] ] corridor_targets = _dedupe_pose_list(late_approach_poses) if not corridor_targets: corridor_targets = [_apply_relative_pose(tray_pose, templates.pregrasp_rel_pose)] snapshot = _capture_snapshot(task) try: for start_index, start_pose in enumerate(corridor_targets): start_distance = float(np.linalg.norm(current_pose[:3] - start_pose[:3])) if start_index > 0 and start_distance > 0.08: continue _restore_snapshot(task, snapshot) stage_scores: List[float] = [] reliable_stage_count = 0 stage_success = True for target_pose in corridor_targets[start_index:]: live_pose = np.asarray(task._scene.robot.left_gripper.get_pose(), dtype=np.float64) proximity = math.exp(-float(np.linalg.norm(live_pose[:3] - target_pose[:3])) / 0.09) path, length, reliability = _stable_plan( task._scene, "left", target_pose, ignore_collisions=False, ) if path is None or not np.isfinite(length): stage_scores.append(0.25 * proximity) stage_success = False break planner_score = reliability * math.exp(-length / DEFAULT_PATH_SCALE) stage_scores.append(0.7 * planner_score + 0.3 * proximity) if not _plan_is_reliable(reliability): stage_success = False break path.set_to_end(disable_dynamics=True) task._pyrep.step() reliable_stage_count += 1 if stage_scores: normalized_stage_score = float(np.mean(stage_scores)) best = max(best, 0.35 * progress + 0.65 * normalized_stage_score) else: best = max(best, 0.75 * progress) if stage_success and reliable_stage_count == len(corridor_targets[start_index:]): success = True finally: _restore_snapshot(task, snapshot) return best, success def _extract_score_and_success(task, templates: MotionTemplates) -> Tuple[float, bool]: tray = Shape("tray") robot = task._scene.robot snapshot = _capture_snapshot(task) try: total_length = 0.0 current_height = float(tray.get_position()[2]) already_grasped = any( grasped.get_name() == tray.get_name() for grasped in robot.left_gripper.get_grasped_objects() ) if already_grasped and current_height >= _extract_height_threshold(templates): return _extraction_progress_score(current_height, templates), True poses = _extract_sequence_poses( tray.get_pose(), task._task.get_base().get_pose(), templates ) approach_poses = [] if already_grasped else poses[:2] retreat_poses = poses[2:] if already_grasped and retreat_poses: future_retreat_poses = [ pose for pose in retreat_poses if float(pose[2]) > current_height + 0.01 ] if future_retreat_poses: retreat_poses = future_retreat_poses elif current_height < _extract_height_threshold(templates): retreat_poses = [retreat_poses[-1]] else: retreat_poses = [] milestone_poses = approach_poses + retreat_poses milestone_collision = ([False] * len(approach_poses)) + ([True] * len(retreat_poses)) progress = _extraction_progress_score(current_height, templates) * 0.25 milestone_weight = 0.75 / max(1, len(milestone_poses)) for milestone_index, (pose, ignore_collisions) in enumerate( zip(milestone_poses, milestone_collision) ): path, length, reliability = _stable_plan( task._scene, "left", pose, ignore_collisions=ignore_collisions ) if path is None or not np.isfinite(length): return progress, False planner_score = reliability * math.exp(-length / DEFAULT_PATH_SCALE) progress += milestone_weight * planner_score if not _plan_is_reliable(reliability): return progress, False total_length += length path.set_to_end(disable_dynamics=True) task._pyrep.step() current_height = float(tray.get_position()[2]) progress = max(progress, _extraction_progress_score(current_height, templates) * 0.25) if (not already_grasped) and milestone_index == (len(approach_poses) - 1): robot.left_gripper.grasp(tray) already_grasped = True if not already_grasped: robot.left_gripper.grasp(tray) final_height = float(tray.get_position()[2]) success = final_height >= _extract_height_threshold(templates) score = max( progress, math.exp(-total_length / (DEFAULT_PATH_SCALE * 2.5)), _extraction_progress_score(final_height, templates) if success else 0.0, ) return score, bool(success) finally: _restore_snapshot(task, snapshot) def _wait_branch(task, steps: int = 5) -> None: for _ in range(steps): task._scene.step() def _open_more_branch(task, templates: MotionTemplates) -> None: joint = Joint("oven_door_joint") current = joint.get_joint_position() joint.set_joint_position(current - templates.open_more_delta, disable_dynamics=True) for _ in range(3): task._pyrep.step() def _hold_open_branch(task, templates: MotionTemplates) -> None: joint = Joint("oven_door_joint") current = joint.get_joint_position() joint.set_joint_position(min(current, templates.hold_open_angle), disable_dynamics=True) for _ in range(3): task._pyrep.step() def _frame_metrics( episode_dir: Path, demo: Demo, frame_state: ReplayState, templates: MotionTemplates, ) -> Dict[str, float]: grasp_points = _sample_grasp_points(templates, frame_state.tray_pose) full_tray_points = _sample_full_tray_points(frame_state.tray_pose) camera_values: Dict[str, Dict[str, float]] = {} for camera_name in FULL_CAMERA_SET: mask = _load_mask(episode_dir, frame_state.frame_index, camera_name) extrinsics = demo[frame_state.frame_index].misc[f"{camera_name}_camera_extrinsics"] intrinsics = demo[frame_state.frame_index].misc[f"{camera_name}_camera_intrinsics"] camera_values[camera_name] = { "grasp_visibility": _mask_visibility_ratio( grasp_points, mask, templates.mask_handle_ids, extrinsics, intrinsics ), "tray_visibility": _mask_visibility_ratio( full_tray_points, mask, templates.mask_handle_ids, extrinsics, intrinsics ), } values: Dict[str, float] = {} for name, camera_set in {"three_view": THREE_VIEW_SET, "full_view": FULL_CAMERA_SET}.items(): values[f"{name}_visibility"] = _union_visibility( camera_values[camera_name]["grasp_visibility"] for camera_name in camera_set ) values[f"{name}_whole_tray_visibility"] = _union_visibility( camera_values[camera_name]["tray_visibility"] for camera_name in camera_set ) return values def _compute_frame_row_isolated( episode_dir: Path, demo: Demo, templates: MotionTemplates, checkpoint_stride: int, frame_index: int, ) -> Dict[str, float]: rows = _compute_frame_rows_sequential( episode_dir=episode_dir, demo=demo, templates=templates, checkpoint_stride=checkpoint_stride, frame_indices=[frame_index], ) if not rows: raise RuntimeError(f"failed to compute frame {frame_index}") return rows[0] def _compute_frame_rows_sequential( episode_dir: Path, demo: Demo, templates: MotionTemplates, checkpoint_stride: int, frame_indices: Sequence[int], ) -> List[Dict[str, float]]: env = _launch_replay_env() try: task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=checkpoint_stride) cache.reset() rows: List[Dict[str, float]] = [] for frame_index in sorted({int(index) for index in frame_indices}): cache.step_to(frame_index) frame_snapshot = cache.snapshot() state = cache.current_state() visibility = _frame_metrics(episode_dir, demo, state, templates) pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance( np.asarray(state.left_gripper_pose, dtype=np.float64), np.asarray(state.tray_pose, dtype=np.float64), templates, ) p_pre, y_pre = _pregrasp_score_and_success(task, templates) p_ext, y_ext = _extract_score_and_success(task, templates) rows.append( { "frame_index": frame_index, "time_norm": frame_index / max(1, len(demo) - 1), "door_angle": state.door_angle, "right_gripper_open": state.right_gripper_open, "left_gripper_open": state.left_gripper_open, "pregrasp_progress": pregrasp_progress, "pregrasp_distance": pregrasp_distance, "p_pre": p_pre, "p_ext": p_ext, "y_pre_raw": float(bool(y_pre)), "y_ext_raw": float(bool(y_ext)), "y_pre": float(bool(y_pre)), "y_ext": float(bool(y_ext)), **visibility, } ) cache.restore(frame_snapshot) return rows finally: env.shutdown() def _safe_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: if len(np.unique(y_true)) < 2: return float("nan") return float(roc_auc_score(y_true, y_score)) def _safe_auprc(y_true: np.ndarray, y_score: np.ndarray) -> float: if len(np.unique(y_true)) < 2: return float("nan") return float(average_precision_score(y_true, y_score)) def _first_crossing(values: np.ndarray, threshold: float) -> int: above = np.flatnonzero(values >= threshold) return int(above[0]) if len(above) else -1 def _transition_count(binary_values: np.ndarray) -> Tuple[int, int]: diffs = np.diff(binary_values.astype(int)) return int(np.sum(diffs == 1)), int(np.sum(diffs == -1)) def _keyframe_subset(frame_df: pd.DataFrame, keyframes: Sequence[int]) -> pd.DataFrame: key_df = frame_df.iloc[list(keyframes)].copy() key_df["keyframe_ordinal"] = np.arange(len(key_df)) return key_df def _persistent_rise_mask(binary_values: np.ndarray, window: int) -> np.ndarray: binary_values = np.asarray(binary_values, dtype=bool) rises = np.zeros(len(binary_values), dtype=bool) if window <= 0: rises[:] = binary_values return rises for index in range(len(binary_values)): segment = binary_values[index : index + window] if len(segment) == window and np.all(segment): rises[index] = True return rises def _monotone_after_first(binary_values: np.ndarray) -> np.ndarray: binary_values = np.asarray(binary_values, dtype=bool) monotone = np.zeros(len(binary_values), dtype=bool) if np.any(binary_values): first_true = int(np.flatnonzero(binary_values)[0]) monotone[first_true:] = True return monotone def _annotate_phase_columns(frame_df: pd.DataFrame) -> pd.DataFrame: door_speed = np.gradient(frame_df["door_angle"].to_numpy(dtype=float), DEMO_DT) frame_df["door_speed_abs"] = np.abs(door_speed) y_ext_raw = ( frame_df["y_ext_raw"].to_numpy(dtype=bool) if "y_ext_raw" in frame_df else frame_df["y_ext"].to_numpy(dtype=bool) ) pregrasp_progress = ( frame_df["pregrasp_progress"].to_numpy(dtype=float) if "pregrasp_progress" in frame_df else frame_df["p_pre"].to_numpy(dtype=float) ) pregrasp_distance = ( frame_df["pregrasp_distance"].to_numpy(dtype=float) if "pregrasp_distance" in frame_df else 1.0 - pregrasp_progress ) pregrasp_speed = -np.gradient(pregrasp_distance, DEMO_DT) frame_df["pregrasp_speed"] = pregrasp_speed y_pre_seed = pregrasp_progress >= DEFAULT_PREGRASP_LABEL_PROGRESS_TAU y_pre_binary = _monotone_after_first( _persistent_rise_mask(y_pre_seed, DEFAULT_PREGRASP_PERSISTENCE) ) y_ext_binary = _monotone_after_first( _persistent_rise_mask(y_ext_raw, DEFAULT_READY_PERSISTENCE) ) frame_df["y_pre_progress_seed"] = y_pre_seed.astype(float) frame_df["y_pre"] = y_pre_binary.astype(float) frame_df["y_ext"] = y_ext_binary.astype(float) frame_df["phase_score"] = np.clip( 0.7 * pregrasp_progress + 0.3 * frame_df["p_pre"].to_numpy(dtype=float), 0.0, 1.0, ) approach_active = ( (pregrasp_progress >= DEFAULT_APPROACH_PROGRESS_TAU) & (pregrasp_speed >= DEFAULT_APPROACH_SPEED_TAU) ) frame_df["approach_active"] = approach_active.astype(float) retrieve_onset = _persistent_rise_mask( approach_active & y_pre_binary, DEFAULT_RETRIEVE_PERSISTENCE ) frame_df["y_retrieve"] = _monotone_after_first(retrieve_onset).astype(float) ready_seed = np.zeros(len(frame_df), dtype=bool) for index in range(len(frame_df)): window = y_ext_binary[index : index + DEFAULT_READY_PERSISTENCE] if ( len(window) == DEFAULT_READY_PERSISTENCE and np.all(window) and frame_df.iloc[index]["door_speed_abs"] <= DEFAULT_DOOR_SPEED_TAU ): ready_seed[index] = True frame_df["y_ready"] = _monotone_after_first(ready_seed).astype(float) phase_seed = _persistent_rise_mask( frame_df["phase_score"].to_numpy(dtype=float) >= DEFAULT_PHASE_SCORE_TAU, DEFAULT_RETRIEVE_PERSISTENCE, ) frame_df["phase_switch"] = _monotone_after_first(phase_seed).astype(float) return frame_df def _episode_metrics_from_frames( frame_df: pd.DataFrame, key_df: pd.DataFrame, episode_name: str, description: str, interventions: Dict[str, float], ) -> Dict[str, object]: y_pre_arr = frame_df["y_pre"].to_numpy(dtype=int) y_ext_arr = frame_df["y_ext"].to_numpy(dtype=int) y_retrieve_arr = frame_df["y_retrieve"].to_numpy(dtype=int) y_ready_arr = frame_df["y_ready"].to_numpy(dtype=int) p_pre_arr = frame_df["p_pre"].to_numpy(dtype=float) p_ext_arr = frame_df["p_ext"].to_numpy(dtype=float) phase_arr = frame_df["phase_switch"].to_numpy(dtype=int) whole_vis = frame_df["full_view_whole_tray_visibility"].to_numpy(dtype=float) door_angle_arr = frame_df["door_angle"].to_numpy(dtype=float) time_arr = frame_df["time_norm"].to_numpy(dtype=float) ppre_cross = _first_crossing(p_pre_arr, DEFAULT_PPRE_TAU) pext_cross = _first_crossing(p_ext_arr, DEFAULT_PEXT_TAU) phase_cross = _first_crossing(frame_df["phase_switch"].to_numpy(dtype=float), 0.5) retrieve_cross = _first_crossing(y_retrieve_arr.astype(float), 0.5) ready_cross = _first_crossing(y_ready_arr.astype(float), 0.5) phase_rises, phase_falls = _transition_count(phase_arr) key_phase_cross = _first_crossing(key_df["phase_switch"].to_numpy(dtype=float), 0.5) key_retrieve_cross = _first_crossing(key_df["y_retrieve"].to_numpy(dtype=float), 0.5) key_ready_cross = _first_crossing(key_df["y_ready"].to_numpy(dtype=float), 0.5) return { "episode_name": episode_name, "description": description, "num_dense_frames": int(len(frame_df)), "num_keyframes": int(len(key_df)), "phase_switch_rises": int(phase_rises), "phase_switch_falls": int(phase_falls), "ppre_cross_frame": int(ppre_cross), "pext_cross_frame": int(pext_cross), "phase_cross_frame": int(phase_cross), "retrieve_cross_frame": int(retrieve_cross), "ready_cross_frame": int(ready_cross), "ordering_ok": bool(ppre_cross == -1 or pext_cross == -1 or ppre_cross <= pext_cross), "dense_boundary_error_to_retrieve_frames": float(abs(phase_cross - retrieve_cross)) if phase_cross >= 0 and retrieve_cross >= 0 else float("nan"), "dense_boundary_error_frames": float(abs(phase_cross - ready_cross)) if phase_cross >= 0 and ready_cross >= 0 else float("nan"), "dense_boundary_error_fraction": float(abs(phase_cross - ready_cross) / len(frame_df)) if phase_cross >= 0 and ready_cross >= 0 else float("nan"), "key_boundary_error_to_retrieve_keyframes": float(abs(key_phase_cross - key_retrieve_cross)) if key_phase_cross >= 0 and key_retrieve_cross >= 0 else float("nan"), "key_boundary_error_keyframes": float(abs(key_phase_cross - key_ready_cross)) if key_phase_cross >= 0 and key_ready_cross >= 0 else float("nan"), "auroc_vret_ypre_three": _safe_auc(y_pre_arr, frame_df["three_view_visibility"].to_numpy(dtype=float)), "auprc_vret_ypre_three": _safe_auprc(y_pre_arr, frame_df["three_view_visibility"].to_numpy(dtype=float)), "auroc_vret_ypre_full": _safe_auc(y_pre_arr, frame_df["full_view_visibility"].to_numpy(dtype=float)), "auprc_vret_ypre_full": _safe_auprc(y_pre_arr, frame_df["full_view_visibility"].to_numpy(dtype=float)), "auroc_ppre_ypre": _safe_auc(y_pre_arr, p_pre_arr), "auprc_ppre_ypre": _safe_auprc(y_pre_arr, p_pre_arr), "auroc_pext_yext": _safe_auc(y_ext_arr, p_ext_arr), "auprc_pext_yext": _safe_auprc(y_ext_arr, p_ext_arr), "auroc_phase_yretrieve": _safe_auc(y_retrieve_arr, frame_df["phase_score"].to_numpy(dtype=float)), "auprc_phase_yretrieve": _safe_auprc(y_retrieve_arr, frame_df["phase_score"].to_numpy(dtype=float)), "f1_phase_yretrieve": float(f1_score(y_retrieve_arr, phase_arr)) if np.any(y_retrieve_arr) and np.any(phase_arr) else float("nan"), "auroc_phase_yready": _safe_auc(y_ready_arr, frame_df["phase_score"].to_numpy(dtype=float)), "auprc_phase_yready": _safe_auprc(y_ready_arr, frame_df["phase_score"].to_numpy(dtype=float)), "f1_phase_yready": float(f1_score(y_ready_arr, phase_arr)) if np.any(y_ready_arr) and np.any(phase_arr) else float("nan"), "baseline_auroc_door_yext": _safe_auc(y_ext_arr, door_angle_arr), "baseline_auprc_door_yext": _safe_auprc(y_ext_arr, door_angle_arr), "baseline_auroc_time_yext": _safe_auc(y_ext_arr, time_arr), "baseline_auprc_time_yext": _safe_auprc(y_ext_arr, time_arr), "baseline_auroc_whole_vis_yext": _safe_auc(y_ext_arr, whole_vis), "baseline_auprc_whole_vis_yext": _safe_auprc(y_ext_arr, whole_vis), **interventions, } def _isolated_intervention_outcomes( demo: Demo, templates: MotionTemplates, frame_index: int, checkpoint_stride: int, ) -> Dict[str, Tuple[float, bool]]: env = _launch_replay_env() try: task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=checkpoint_stride) cache.reset() cache.step_to(frame_index) snapshot = cache.snapshot() base = _extract_score_and_success(task, templates) _restore_snapshot(task, snapshot) _open_more_branch(task, templates) open_more = _extract_score_and_success(task, templates) _restore_snapshot(task, snapshot) _hold_open_branch(task, templates) hold_open = _extract_score_and_success(task, templates) _restore_snapshot(task, snapshot) _wait_branch(task) wait = _extract_score_and_success(task, templates) return { "base": base, "open_more": open_more, "hold_open": hold_open, "wait": wait, } finally: env.shutdown() def _interventional_validity( demo: Demo, templates: MotionTemplates, frame_df: pd.DataFrame, checkpoint_stride: int, ) -> Dict[str, float]: ready_indices = np.flatnonzero(frame_df["y_ready"].to_numpy(dtype=bool)) ready_onset = int(ready_indices[0]) if len(ready_indices) else len(frame_df) // 2 pre_ready_indices = sorted( { max(0, ready_onset - 20), max(0, ready_onset - 10), } ) post_ready_indices = sorted( { ready_onset, min(len(frame_df) - 1, ready_onset + 20), } ) stats = { "pre_ready_open_more_increases_pext": 0, "pre_ready_open_more_trials": 0, "pre_ready_hold_open_increases_pext": 0, "pre_ready_hold_open_trials": 0, "pre_ready_extract_success": 0, "pre_ready_extract_trials": 0, "pre_ready_wait_extract_success": 0, "pre_ready_wait_trials": 0, "post_ready_extract_success": 0, "post_ready_extract_trials": 0, "post_ready_open_more_low_gain": 0, "post_ready_open_more_trials": 0, "post_ready_hold_open_low_gain": 0, "post_ready_hold_open_trials": 0, } for frame_index in [*pre_ready_indices, *post_ready_indices]: outcomes = _isolated_intervention_outcomes( demo=demo, templates=templates, frame_index=frame_index, checkpoint_stride=checkpoint_stride, ) base_pext, base_extract_success = outcomes["base"] pre_ready = frame_index in pre_ready_indices open_pext, _ = outcomes["open_more"] hold_pext, _ = outcomes["hold_open"] _, wait_extract_success = outcomes["wait"] if pre_ready: stats["pre_ready_open_more_trials"] += 1 stats["pre_ready_hold_open_trials"] += 1 stats["pre_ready_extract_trials"] += 1 stats["pre_ready_wait_trials"] += 1 if open_pext > base_pext: stats["pre_ready_open_more_increases_pext"] += 1 if hold_pext > base_pext: stats["pre_ready_hold_open_increases_pext"] += 1 if base_extract_success: stats["pre_ready_extract_success"] += 1 if wait_extract_success: stats["pre_ready_wait_extract_success"] += 1 else: stats["post_ready_extract_trials"] += 1 stats["post_ready_open_more_trials"] += 1 stats["post_ready_hold_open_trials"] += 1 if base_extract_success: stats["post_ready_extract_success"] += 1 if (open_pext - base_pext) <= 0.05: stats["post_ready_open_more_low_gain"] += 1 if (hold_pext - base_pext) <= 0.05: stats["post_ready_hold_open_low_gain"] += 1 return { key: float(value) for key, value in stats.items() } def _analyze_episode( dataset_root: Path, episode_dir: Path, templates: MotionTemplates, template_frames: Dict[str, int], checkpoint_stride: int = 16, max_frames: Optional[int] = None, independent_replay: bool = False, ) -> EpisodeArtifacts: demo = _load_demo(episode_dir) descriptions = _load_descriptions(episode_dir) env = _launch_replay_env() try: task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=checkpoint_stride) cache.reset() num_frames = len(demo) if max_frames is None else min(len(demo), max_frames) rows: List[Dict[str, float]] = [] initial_snapshot = cache.checkpoints[0] if independent_replay else None for frame_index in range(num_frames): if independent_replay: cache.restore_to_index(initial_snapshot, 0) cache.step_to(frame_index) frame_snapshot = cache.snapshot() if not independent_replay else None state = cache.current_state() visibility = _frame_metrics(episode_dir, demo, state, templates) pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance( np.asarray(state.left_gripper_pose, dtype=np.float64), np.asarray(state.tray_pose, dtype=np.float64), templates, ) p_pre, y_pre = _pregrasp_score_and_success(task, templates) p_ext, y_ext = _extract_score_and_success(task, templates) rows.append( { "frame_index": frame_index, "time_norm": frame_index / max(1, num_frames - 1), "door_angle": state.door_angle, "right_gripper_open": state.right_gripper_open, "left_gripper_open": state.left_gripper_open, "pregrasp_progress": pregrasp_progress, "pregrasp_distance": pregrasp_distance, "p_pre": p_pre, "p_ext": p_ext, "y_pre_raw": float(bool(y_pre)), "y_ext_raw": float(bool(y_ext)), "y_pre": float(bool(y_pre)), "y_ext": float(bool(y_ext)), **visibility, } ) if frame_snapshot is not None: cache.restore(frame_snapshot) if (frame_index + 1) % 25 == 0 or (frame_index + 1) == num_frames: print( f"[{episode_dir.name}] analyzed {frame_index + 1}/{num_frames} dense frames", flush=True, ) frame_df = pd.DataFrame(rows) frame_df = _annotate_phase_columns(frame_df) keyframes = [index for index in _keypoint_discovery(demo) if index < num_frames] key_df = _keyframe_subset(frame_df, keyframes) finally: env.shutdown() interventions = _interventional_validity( demo=demo, templates=templates, frame_df=frame_df, checkpoint_stride=checkpoint_stride, ) metrics = _episode_metrics_from_frames( frame_df=frame_df, key_df=key_df, episode_name=episode_dir.name, description=descriptions[0], interventions=interventions, ) return EpisodeArtifacts( episode_name=episode_dir.name, dense=frame_df, keyframes=key_df, metrics=metrics, template_frames=template_frames, ) def _aggregate_summary(episode_metrics: List[Dict[str, object]]) -> Dict[str, object]: frame = pd.DataFrame(episode_metrics) numeric = frame.select_dtypes(include=[np.number]) summary = { "num_episodes": int(len(frame)), "mean_metrics": numeric.mean(numeric_only=True).to_dict(), "median_metrics": numeric.median(numeric_only=True).to_dict(), "single_switch_rate": float((frame["phase_switch_rises"] == 1).mean()) if len(frame) else float("nan"), "reversion_rate": float((frame["phase_switch_falls"] > 0).mean()) if len(frame) else float("nan"), "ordering_ok_rate": float(frame["ordering_ok"].mean()) if len(frame) else float("nan"), } return summary def run_study( dataset_root: str, result_dir: str, max_episodes: Optional[int] = None, checkpoint_stride: int = 16, max_frames: Optional[int] = None, episode_offset: int = 0, template_episode_index: int = 0, episode_indices: Optional[Sequence[int]] = None, independent_replay: bool = False, ) -> Dict[str, object]: dataset_path = Path(dataset_root) result_path = Path(result_dir) result_path.mkdir(parents=True, exist_ok=True) all_episode_dirs = _episode_dirs(dataset_path) if not all_episode_dirs: raise RuntimeError(f"no episodes available under {dataset_root}") if not (0 <= template_episode_index < len(all_episode_dirs)): raise ValueError( f"template_episode_index {template_episode_index} outside available range 0..{len(all_episode_dirs) - 1}" ) selected_episode_indices: List[int] if episode_indices is not None: selected_episode_indices = [] seen_episode_indices = set() for raw_index in episode_indices: episode_index = int(raw_index) if not (0 <= episode_index < len(all_episode_dirs)): raise ValueError( f"episode index {episode_index} outside available range 0..{len(all_episode_dirs) - 1}" ) if episode_index in seen_episode_indices: continue selected_episode_indices.append(episode_index) seen_episode_indices.add(episode_index) episode_dirs = [all_episode_dirs[index] for index in selected_episode_indices] else: episode_dirs = all_episode_dirs[episode_offset:] if max_episodes is not None: episode_dirs = episode_dirs[:max_episodes] selected_episode_indices = [ all_episode_dirs.index(episode_dir) for episode_dir in episode_dirs ] if not episode_dirs: raise RuntimeError( f"no episodes selected under {dataset_root} with offset={episode_offset} max_episodes={max_episodes} episode_indices={episode_indices}" ) template_episode_dir = all_episode_dirs[template_episode_index] templates, template_frames = _derive_templates(dataset_path, template_episode_dir) with result_path.joinpath("templates.json").open("w", encoding="utf-8") as handle: json.dump( { "templates": templates.to_json(), "template_episode": template_episode_dir.name, "template_frames": template_frames, "episode_offset": episode_offset, "selected_episode_indices": selected_episode_indices, }, handle, indent=2, ) episode_metrics: List[Dict[str, object]] = [] for episode_dir in episode_dirs: artifacts = _analyze_episode( dataset_path, episode_dir, templates, template_frames, checkpoint_stride=checkpoint_stride, max_frames=max_frames, independent_replay=independent_replay, ) artifacts.dense.to_csv(result_path.joinpath(f"{episode_dir.name}.dense.csv"), index=False) artifacts.keyframes.to_csv( result_path.joinpath(f"{episode_dir.name}.keyframes.csv"), index=False ) with result_path.joinpath(f"{episode_dir.name}.metrics.json").open( "w", encoding="utf-8" ) as handle: json.dump(artifacts.metrics, handle, indent=2) episode_metrics.append(artifacts.metrics) summary = _aggregate_summary(episode_metrics) with result_path.joinpath("summary.json").open("w", encoding="utf-8") as handle: json.dump(summary, handle, indent=2) return summary def main(argv: Optional[Sequence[str]] = None) -> int: def _parse_episode_indices(value: str) -> List[int]: indices: List[int] = [] for chunk in value.split(","): chunk = chunk.strip() if not chunk: continue indices.append(int(chunk)) if not indices: raise argparse.ArgumentTypeError("episode-indices must contain at least one integer") return indices parser = argparse.ArgumentParser() parser.add_argument( "--dataset-root", default="/workspace/data/bimanual_take_tray_out_of_oven_train_128", ) parser.add_argument( "--result-dir", default="/workspace/reveal_retrieve_label_study/results/oven_first_pass", ) parser.add_argument("--max-episodes", type=int, default=1) parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--max-frames", type=int) parser.add_argument("--episode-offset", type=int, default=0) parser.add_argument("--template-episode-index", type=int, default=0) parser.add_argument("--episode-indices", type=_parse_episode_indices) parser.add_argument("--independent-replay", action="store_true") args = parser.parse_args(argv) summary = run_study( dataset_root=args.dataset_root, result_dir=args.result_dir, max_episodes=args.max_episodes, checkpoint_stride=args.checkpoint_stride, max_frames=args.max_frames, episode_offset=args.episode_offset, template_episode_index=args.template_episode_index, episode_indices=args.episode_indices, independent_replay=args.independent_replay, ) print(json.dumps(summary, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())