VLAdaptorBench / external /peract_bimanual /helpers /demo_loading_utils.py
lsnu's picture
Add files using upload-large-folder tool
0d89eb9 verified
import logging
from typing import List
import numpy as np
from rlbench.demo import Demo
import omegaconf
def _is_stopped(demo, i, obs, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = i < (len(demo) - 2) and (
obs.gripper_open == demo[i + 1].gripper_open
and obs.gripper_open == demo[i - 1].gripper_open
and demo[i - 2].gripper_open == demo[i - 1].gripper_open
)
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
return small_delta and (not next_is_not_final) and gripper_state_no_change
def _is_stopped_right(demo, i, obs, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = i < (len(demo) - 2) and (
obs.gripper_open == demo[i + 1].right.gripper_open
and obs.gripper_open == demo[i - 1].right.gripper_open
and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open
)
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
return small_delta and (not next_is_not_final) and gripper_state_no_change
def _is_stopped_left(demo, i, obs, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = i < (len(demo) - 2) and (
obs.gripper_open == demo[i + 1].left.gripper_open
and obs.gripper_open == demo[i - 1].left.gripper_open
and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open
)
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
return small_delta and (not next_is_not_final) and gripper_state_no_change
def _keypoint_discovery_bimanual(demo: Demo, stopping_delta=0.1) -> List[int]:
episode_keypoints = []
right_prev_gripper_open = demo[0].right.gripper_open
left_prev_gripper_open = demo[0].left.gripper_open
stopped_buffer = 0
for i, obs in enumerate(demo._observations):
right_stopped = _is_stopped_right(demo, i, obs.right, stopping_delta)
left_stopped = _is_stopped_left(demo, i, obs.left, stopping_delta)
stopped = (stopped_buffer <= 0) and right_stopped and left_stopped
stopped_buffer = 4 if stopped else stopped_buffer - 1
# if change in gripper, or end of episode.
last = i == (len(demo) - 1)
right_state_changed = obs.right.gripper_open != right_prev_gripper_open
left_state_changed = obs.left.gripper_open != left_prev_gripper_open
state_changed = right_state_changed or left_state_changed
if i != 0 and (state_changed or last or stopped):
episode_keypoints.append(i)
right_prev_gripper_open = obs.right.gripper_open
left_prev_gripper_open = obs.left.gripper_open
if (
len(episode_keypoints) > 1
and (episode_keypoints[-1] - 1) == episode_keypoints[-2]
):
episode_keypoints.pop(-2)
print("Found %d keypoints." % len(episode_keypoints), episode_keypoints)
return episode_keypoints
def _keypoint_discovery_unimanual(demo: Demo, stopping_delta=0.1) -> List[int]:
episode_keypoints = []
prev_gripper_open = demo[0].gripper_open
stopped_buffer = 0
for i, obs in enumerate(demo):
stopped = _is_stopped(demo, i, obs, stopping_delta)
stopped = (stopped_buffer <= 0) and stopped
stopped_buffer = 4 if stopped else stopped_buffer - 1
# if change in gripper, or end of episode.
last = i == (len(demo) - 1)
if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped):
episode_keypoints.append(i)
prev_gripper_open = obs.gripper_open
if (
len(episode_keypoints) > 1
and (episode_keypoints[-1] - 1) == episode_keypoints[-2]
):
episode_keypoints.pop(-2)
print("Found %d keypoints." % len(episode_keypoints), episode_keypoints)
return episode_keypoints
def _keypoint_discovery_heuristic(demo: Demo, stopping_delta=0.1) -> List[int]:
if demo[0].is_bimanual:
return _keypoint_discovery_bimanual(demo, stopping_delta)
else:
return _keypoint_discovery_unimanual(demo, stopping_delta)
def keypoint_discovery(demo: Demo, stopping_delta=0.1, method="heuristic") -> List[int]:
episode_keypoints = []
if method == "heuristic":
return _keypoint_discovery_heuristic(demo, stopping_delta)
elif method == "random":
# Randomly select keypoints.
episode_keypoints = np.random.choice(range(len(demo)), size=20, replace=False)
episode_keypoints.sort()
return episode_keypoints
elif method == "fixed_interval":
# Fixed interval.
episode_keypoints = []
segment_length = len(demo) // 20
for i in range(0, len(demo), segment_length):
episode_keypoints.append(i)
return episode_keypoints
elif isinstance(method, omegaconf.listconfig.ListConfig):
return list(method)
else:
raise NotImplementedError
# find minimum difference between any two elements in list
def find_minimum_difference(lst):
minimum = lst[-1]
for i in range(1, len(lst)):
if lst[i] - lst[i - 1] < minimum:
minimum = lst[i] - lst[i - 1]
return minimum