|
|
""" |
|
|
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format. |
|
|
|
|
|
Usage: |
|
|
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data |
|
|
|
|
|
If you want to push your dataset to the Hugging Face Hub, you can use the following command: |
|
|
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub |
|
|
|
|
|
The resulting dataset will get saved to the $LEROBOT_HOME directory. |
|
|
""" |
|
|
|
|
|
from collections import defaultdict |
|
|
import copy |
|
|
import glob |
|
|
import json |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
|
|
|
import cv2 |
|
|
import h5py |
|
|
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME |
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
import tyro |
|
|
|
|
|
REPO_NAME = "your_hf_username/my_droid_dataset" |
|
|
|
|
|
|
|
|
def resize_image(image, size): |
|
|
image = Image.fromarray(image) |
|
|
return np.array(image.resize(size, resample=Image.BICUBIC)) |
|
|
|
|
|
|
|
|
def main(data_dir: str, *, push_to_hub: bool = False): |
|
|
|
|
|
output_path = HF_LEROBOT_HOME / REPO_NAME |
|
|
if output_path.exists(): |
|
|
shutil.rmtree(output_path) |
|
|
data_dir = Path(data_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = LeRobotDataset.create( |
|
|
repo_id=REPO_NAME, |
|
|
robot_type="panda", |
|
|
fps=15, |
|
|
features={ |
|
|
|
|
|
"exterior_image_1_left": { |
|
|
"dtype": "image", |
|
|
"shape": (180, 320, 3), |
|
|
"names": ["height", "width", "channel"], |
|
|
}, |
|
|
"exterior_image_2_left": { |
|
|
"dtype": "image", |
|
|
"shape": (180, 320, 3), |
|
|
"names": ["height", "width", "channel"], |
|
|
}, |
|
|
"wrist_image_left": { |
|
|
"dtype": "image", |
|
|
"shape": (180, 320, 3), |
|
|
"names": ["height", "width", "channel"], |
|
|
}, |
|
|
"joint_position": { |
|
|
"dtype": "float32", |
|
|
"shape": (7,), |
|
|
"names": ["joint_position"], |
|
|
}, |
|
|
"gripper_position": { |
|
|
"dtype": "float32", |
|
|
"shape": (1,), |
|
|
"names": ["gripper_position"], |
|
|
}, |
|
|
"actions": { |
|
|
"dtype": "float32", |
|
|
"shape": (8,), |
|
|
"names": ["actions"], |
|
|
}, |
|
|
}, |
|
|
image_writer_threads=10, |
|
|
image_writer_processes=5, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with (data_dir / "aggregated-annotations-030724.json").open() as f: |
|
|
language_annotations = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
episode_paths = list(data_dir.glob("**/trajectory.h5")) |
|
|
print(f"Found {len(episode_paths)} episodes for conversion") |
|
|
|
|
|
|
|
|
for episode_path in tqdm(episode_paths, desc="Converting episodes"): |
|
|
|
|
|
recording_folderpath = episode_path.parent / "recordings" / "MP4" |
|
|
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath)) |
|
|
|
|
|
|
|
|
|
|
|
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json"))) |
|
|
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1] |
|
|
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[ |
|
|
"language_instruction1" |
|
|
] |
|
|
print(f"Converting episode with language instruction: {language_instruction}") |
|
|
|
|
|
|
|
|
for step in trajectory: |
|
|
camera_type_dict = step["observation"]["camera_type"] |
|
|
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0] |
|
|
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0] |
|
|
dataset.add_frame( |
|
|
{ |
|
|
|
|
|
"exterior_image_1_left": resize_image( |
|
|
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180) |
|
|
), |
|
|
"exterior_image_2_left": resize_image( |
|
|
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180) |
|
|
), |
|
|
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)), |
|
|
"joint_position": np.asarray( |
|
|
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32 |
|
|
), |
|
|
"gripper_position": np.asarray( |
|
|
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32 |
|
|
), |
|
|
|
|
|
"actions": np.concatenate( |
|
|
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32 |
|
|
), |
|
|
"task": language_instruction, |
|
|
} |
|
|
) |
|
|
dataset.save_episode() |
|
|
|
|
|
|
|
|
if push_to_hub: |
|
|
dataset.push_to_hub( |
|
|
tags=["libero", "panda", "rlds"], |
|
|
private=False, |
|
|
push_videos=True, |
|
|
license="apache-2.0", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
camera_type_dict = { |
|
|
"hand_camera_id": 0, |
|
|
"varied_camera_1_id": 1, |
|
|
"varied_camera_2_id": 1, |
|
|
} |
|
|
|
|
|
camera_type_to_string_dict = { |
|
|
0: "hand_camera", |
|
|
1: "varied_camera", |
|
|
2: "fixed_camera", |
|
|
} |
|
|
|
|
|
|
|
|
def get_camera_type(cam_id): |
|
|
if cam_id not in camera_type_dict: |
|
|
return None |
|
|
type_int = camera_type_dict[cam_id] |
|
|
return camera_type_to_string_dict[type_int] |
|
|
|
|
|
|
|
|
class MP4Reader: |
|
|
def __init__(self, filepath, serial_number): |
|
|
|
|
|
self.serial_number = serial_number |
|
|
self._index = 0 |
|
|
|
|
|
|
|
|
self._mp4_reader = cv2.VideoCapture(filepath) |
|
|
if not self._mp4_reader.isOpened(): |
|
|
raise RuntimeError("Corrupted MP4 File") |
|
|
|
|
|
def set_reading_parameters( |
|
|
self, |
|
|
image=True, |
|
|
concatenate_images=False, |
|
|
resolution=(0, 0), |
|
|
resize_func=None, |
|
|
): |
|
|
|
|
|
self.image = image |
|
|
self.concatenate_images = concatenate_images |
|
|
self.resolution = resolution |
|
|
self.resize_func = cv2.resize |
|
|
self.skip_reading = not image |
|
|
if self.skip_reading: |
|
|
return |
|
|
|
|
|
def get_frame_resolution(self): |
|
|
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) |
|
|
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) |
|
|
return (width, height) |
|
|
|
|
|
def get_frame_count(self): |
|
|
if self.skip_reading: |
|
|
return 0 |
|
|
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
def set_frame_index(self, index): |
|
|
if self.skip_reading: |
|
|
return |
|
|
|
|
|
if index < self._index: |
|
|
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1) |
|
|
self._index = index |
|
|
|
|
|
while self._index < index: |
|
|
self.read_camera(ignore_data=True) |
|
|
|
|
|
def _process_frame(self, frame): |
|
|
frame = copy.deepcopy(frame) |
|
|
if self.resolution == (0, 0): |
|
|
return frame |
|
|
return self.resize_func(frame, self.resolution) |
|
|
|
|
|
def read_camera(self, ignore_data=False, correct_timestamp=None): |
|
|
|
|
|
if self.skip_reading: |
|
|
return {} |
|
|
|
|
|
|
|
|
success, frame = self._mp4_reader.read() |
|
|
|
|
|
self._index += 1 |
|
|
if not success: |
|
|
return None |
|
|
if ignore_data: |
|
|
return None |
|
|
|
|
|
|
|
|
data_dict = {} |
|
|
|
|
|
if self.concatenate_images or "stereo" not in self.serial_number: |
|
|
data_dict["image"] = {self.serial_number: self._process_frame(frame)} |
|
|
else: |
|
|
single_width = frame.shape[1] // 2 |
|
|
data_dict["image"] = { |
|
|
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]), |
|
|
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]), |
|
|
} |
|
|
|
|
|
return data_dict |
|
|
|
|
|
def disable_camera(self): |
|
|
if hasattr(self, "_mp4_reader"): |
|
|
self._mp4_reader.release() |
|
|
|
|
|
|
|
|
class RecordedMultiCameraWrapper: |
|
|
def __init__(self, recording_folderpath, camera_kwargs={}): |
|
|
|
|
|
self.camera_kwargs = camera_kwargs |
|
|
|
|
|
|
|
|
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4") |
|
|
all_filepaths = mp4_filepaths |
|
|
|
|
|
self.camera_dict = {} |
|
|
for f in all_filepaths: |
|
|
serial_number = f.split("/")[-1][:-4] |
|
|
cam_type = get_camera_type(serial_number) |
|
|
camera_kwargs.get(cam_type, {}) |
|
|
|
|
|
if f.endswith(".mp4"): |
|
|
Reader = MP4Reader |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
self.camera_dict[serial_number] = Reader(f, serial_number) |
|
|
|
|
|
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): |
|
|
full_obs_dict = defaultdict(dict) |
|
|
|
|
|
|
|
|
all_cam_ids = list(self.camera_dict.keys()) |
|
|
|
|
|
|
|
|
for cam_id in all_cam_ids: |
|
|
if "stereo" in cam_id: |
|
|
continue |
|
|
try: |
|
|
cam_type = camera_type_dict[cam_id] |
|
|
except KeyError: |
|
|
print(f"{self.camera_dict} -- {camera_type_dict}") |
|
|
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") |
|
|
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {}) |
|
|
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs) |
|
|
|
|
|
timestamp = timestamp_dict.get(cam_id + "_frame_received", None) |
|
|
if index is not None: |
|
|
self.camera_dict[cam_id].set_frame_index(index) |
|
|
|
|
|
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp) |
|
|
|
|
|
|
|
|
if data_dict is None: |
|
|
return None |
|
|
for key in data_dict: |
|
|
full_obs_dict[key].update(data_dict[key]) |
|
|
|
|
|
return full_obs_dict |
|
|
|
|
|
|
|
|
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): |
|
|
length = None |
|
|
|
|
|
for key in hdf5_file: |
|
|
if key in keys_to_ignore: |
|
|
continue |
|
|
|
|
|
curr_data = hdf5_file[key] |
|
|
if isinstance(curr_data, h5py.Group): |
|
|
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore) |
|
|
elif isinstance(curr_data, h5py.Dataset): |
|
|
curr_length = len(curr_data) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
if length is None: |
|
|
length = curr_length |
|
|
assert curr_length == length |
|
|
|
|
|
return length |
|
|
|
|
|
|
|
|
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): |
|
|
data_dict = {} |
|
|
|
|
|
for key in hdf5_file: |
|
|
if key in keys_to_ignore: |
|
|
continue |
|
|
|
|
|
curr_data = hdf5_file[key] |
|
|
if isinstance(curr_data, h5py.Group): |
|
|
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore) |
|
|
elif isinstance(curr_data, h5py.Dataset): |
|
|
data_dict[key] = curr_data[index] |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
return data_dict |
|
|
|
|
|
|
|
|
class TrajectoryReader: |
|
|
def __init__(self, filepath, read_images=True): |
|
|
self._hdf5_file = h5py.File(filepath, "r") |
|
|
is_video_folder = "observations/videos" in self._hdf5_file |
|
|
self._read_images = read_images and is_video_folder |
|
|
self._length = get_hdf5_length(self._hdf5_file) |
|
|
self._video_readers = {} |
|
|
self._index = 0 |
|
|
|
|
|
def length(self): |
|
|
return self._length |
|
|
|
|
|
def read_timestep(self, index=None, keys_to_ignore=[]): |
|
|
|
|
|
if index is None: |
|
|
index = self._index |
|
|
else: |
|
|
assert not self._read_images |
|
|
self._index = index |
|
|
assert index < self._length |
|
|
|
|
|
|
|
|
keys_to_ignore = [*keys_to_ignore.copy(), "videos"] |
|
|
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore) |
|
|
|
|
|
|
|
|
self._index += 1 |
|
|
|
|
|
|
|
|
return timestep |
|
|
|
|
|
def close(self): |
|
|
self._hdf5_file.close() |
|
|
|
|
|
|
|
|
def load_trajectory( |
|
|
filepath=None, |
|
|
read_cameras=True, |
|
|
recording_folderpath=None, |
|
|
camera_kwargs={}, |
|
|
remove_skipped_steps=False, |
|
|
num_samples_per_traj=None, |
|
|
num_samples_per_traj_coeff=1.5, |
|
|
): |
|
|
read_recording_folderpath = read_cameras and (recording_folderpath is not None) |
|
|
|
|
|
traj_reader = TrajectoryReader(filepath) |
|
|
if read_recording_folderpath: |
|
|
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs) |
|
|
|
|
|
horizon = traj_reader.length() |
|
|
timestep_list = [] |
|
|
|
|
|
|
|
|
if num_samples_per_traj: |
|
|
num_to_save = num_samples_per_traj |
|
|
if remove_skipped_steps: |
|
|
num_to_save = int(num_to_save * num_samples_per_traj_coeff) |
|
|
max_size = min(num_to_save, horizon) |
|
|
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False)) |
|
|
else: |
|
|
indices_to_save = np.arange(horizon) |
|
|
|
|
|
|
|
|
for i in indices_to_save: |
|
|
|
|
|
timestep = traj_reader.read_timestep(index=i) |
|
|
|
|
|
|
|
|
if read_recording_folderpath: |
|
|
timestamp_dict = timestep["observation"]["timestamp"]["cameras"] |
|
|
camera_type_dict = { |
|
|
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items() |
|
|
} |
|
|
camera_obs = camera_reader.read_cameras( |
|
|
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict |
|
|
) |
|
|
camera_failed = camera_obs is None |
|
|
|
|
|
|
|
|
if camera_failed: |
|
|
break |
|
|
timestep["observation"].update(camera_obs) |
|
|
|
|
|
|
|
|
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True) |
|
|
delete_skipped_step = step_skipped and remove_skipped_steps |
|
|
|
|
|
|
|
|
if delete_skipped_step: |
|
|
del timestep |
|
|
else: |
|
|
timestep_list.append(timestep) |
|
|
|
|
|
|
|
|
timestep_list = np.array(timestep_list) |
|
|
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj): |
|
|
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False) |
|
|
timestep_list = timestep_list[ind_to_keep] |
|
|
|
|
|
|
|
|
traj_reader.close() |
|
|
|
|
|
|
|
|
return timestep_list |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
tyro.cli(main) |
|
|
|