| import logging |
| from typing import List |
|
|
| import numpy as np |
| from rlbench.demo import Demo |
| import omegaconf |
|
|
|
|
| def _is_stopped(demo, i, obs, delta=0.1): |
| next_is_not_final = i == (len(demo) - 2) |
| gripper_state_no_change = i < (len(demo) - 2) and ( |
| obs.gripper_open == demo[i + 1].gripper_open |
| and obs.gripper_open == demo[i - 1].gripper_open |
| and demo[i - 2].gripper_open == demo[i - 1].gripper_open |
| ) |
| small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) |
| return small_delta and (not next_is_not_final) and gripper_state_no_change |
|
|
|
|
| def _is_stopped_right(demo, i, obs, delta=0.1): |
| next_is_not_final = i == (len(demo) - 2) |
| gripper_state_no_change = i < (len(demo) - 2) and ( |
| obs.gripper_open == demo[i + 1].right.gripper_open |
| and obs.gripper_open == demo[i - 1].right.gripper_open |
| and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open |
| ) |
| small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) |
| return small_delta and (not next_is_not_final) and gripper_state_no_change |
|
|
|
|
| def _is_stopped_left(demo, i, obs, delta=0.1): |
| next_is_not_final = i == (len(demo) - 2) |
| gripper_state_no_change = i < (len(demo) - 2) and ( |
| obs.gripper_open == demo[i + 1].left.gripper_open |
| and obs.gripper_open == demo[i - 1].left.gripper_open |
| and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open |
| ) |
| small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) |
| return small_delta and (not next_is_not_final) and gripper_state_no_change |
|
|
|
|
| def _keypoint_discovery_bimanual(demo: Demo, stopping_delta=0.1) -> List[int]: |
| episode_keypoints = [] |
| right_prev_gripper_open = demo[0].right.gripper_open |
| left_prev_gripper_open = demo[0].left.gripper_open |
| stopped_buffer = 0 |
| for i, obs in enumerate(demo._observations): |
| right_stopped = _is_stopped_right(demo, i, obs.right, stopping_delta) |
| left_stopped = _is_stopped_left(demo, i, obs.left, stopping_delta) |
| stopped = (stopped_buffer <= 0) and right_stopped and left_stopped |
| stopped_buffer = 4 if stopped else stopped_buffer - 1 |
| |
| last = i == (len(demo) - 1) |
| right_state_changed = obs.right.gripper_open != right_prev_gripper_open |
| left_state_changed = obs.left.gripper_open != left_prev_gripper_open |
| state_changed = right_state_changed or left_state_changed |
| if i != 0 and (state_changed or last or stopped): |
| episode_keypoints.append(i) |
|
|
| right_prev_gripper_open = obs.right.gripper_open |
| left_prev_gripper_open = obs.left.gripper_open |
| if ( |
| len(episode_keypoints) > 1 |
| and (episode_keypoints[-1] - 1) == episode_keypoints[-2] |
| ): |
| episode_keypoints.pop(-2) |
| print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) |
| return episode_keypoints |
|
|
|
|
| def _keypoint_discovery_unimanual(demo: Demo, stopping_delta=0.1) -> List[int]: |
| episode_keypoints = [] |
| prev_gripper_open = demo[0].gripper_open |
| stopped_buffer = 0 |
| for i, obs in enumerate(demo): |
| stopped = _is_stopped(demo, i, obs, stopping_delta) |
| stopped = (stopped_buffer <= 0) and stopped |
| stopped_buffer = 4 if stopped else stopped_buffer - 1 |
| |
| last = i == (len(demo) - 1) |
| if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped): |
| episode_keypoints.append(i) |
| prev_gripper_open = obs.gripper_open |
| if ( |
| len(episode_keypoints) > 1 |
| and (episode_keypoints[-1] - 1) == episode_keypoints[-2] |
| ): |
| episode_keypoints.pop(-2) |
| print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) |
| return episode_keypoints |
|
|
|
|
| def _keypoint_discovery_heuristic(demo: Demo, stopping_delta=0.1) -> List[int]: |
| if demo[0].is_bimanual: |
| return _keypoint_discovery_bimanual(demo, stopping_delta) |
| else: |
| return _keypoint_discovery_unimanual(demo, stopping_delta) |
|
|
|
|
| def keypoint_discovery(demo: Demo, stopping_delta=0.1, method="heuristic") -> List[int]: |
| episode_keypoints = [] |
| if method == "heuristic": |
| return _keypoint_discovery_heuristic(demo, stopping_delta) |
|
|
| elif method == "random": |
| |
| episode_keypoints = np.random.choice(range(len(demo)), size=20, replace=False) |
| episode_keypoints.sort() |
| return episode_keypoints |
|
|
| elif method == "fixed_interval": |
| |
| episode_keypoints = [] |
| segment_length = len(demo) // 20 |
| for i in range(0, len(demo), segment_length): |
| episode_keypoints.append(i) |
| return episode_keypoints |
| elif isinstance(method, omegaconf.listconfig.ListConfig): |
| return list(method) |
| else: |
| raise NotImplementedError |
|
|
|
|
| |
| def find_minimum_difference(lst): |
| minimum = lst[-1] |
| for i in range(1, len(lst)): |
| if lst[i] - lst[i - 1] < minimum: |
| minimum = lst[i] - lst[i - 1] |
| return minimum |
|
|