|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Run this command to interactively debug: |
|
|
PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py |
|
|
|
|
|
Adapted from: |
|
|
https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
import pickle |
|
|
import random |
|
|
import traceback |
|
|
import warnings |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
import imageio |
|
|
import numpy as np |
|
|
import torch |
|
|
from decord import VideoReader, cpu |
|
|
from einops import rearrange |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms as T |
|
|
from tqdm import tqdm |
|
|
|
|
|
from cosmos_predict1.diffusion.training.datasets.dataset_utils import ( |
|
|
Resize_Preprocess, |
|
|
ToTensorVideo, |
|
|
euler2rotm, |
|
|
rotm2euler, |
|
|
) |
|
|
|
|
|
|
|
|
class Dataset_3D(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
train_annotation_path, |
|
|
val_annotation_path, |
|
|
test_annotation_path, |
|
|
video_path, |
|
|
sequence_interval, |
|
|
num_frames, |
|
|
cam_ids, |
|
|
accumulate_action, |
|
|
video_size, |
|
|
val_start_frame_interval, |
|
|
debug=False, |
|
|
normalize=False, |
|
|
pre_encode=False, |
|
|
do_evaluate=False, |
|
|
load_t5_embeddings=False, |
|
|
load_action=True, |
|
|
mode="train", |
|
|
): |
|
|
"""Dataset class for loading 3D robot action-conditional data. |
|
|
|
|
|
This dataset loads robot trajectories consisting of RGB video frames, robot states (arm positions and gripper states), |
|
|
and computes relative actions between consecutive frames. |
|
|
|
|
|
Args: |
|
|
train_annotation_path (str): Path to training annotation files |
|
|
val_annotation_path (str): Path to validation annotation files |
|
|
test_annotation_path (str): Path to test annotation files |
|
|
video_path (str): Base path to video files |
|
|
sequence_interval (int): Interval between sampled frames in a sequence |
|
|
num_frames (int): Number of frames to load per sequence |
|
|
cam_ids (list): List of camera IDs to sample from |
|
|
accumulate_action (bool): Whether to accumulate actions relative to first frame |
|
|
video_size (list): Target size [H,W] for video frames |
|
|
val_start_frame_interval (int): Frame sampling interval for validation/test |
|
|
debug (bool, optional): If True, only loads subset of data. Defaults to False. |
|
|
normalize (bool, optional): Whether to normalize video frames. Defaults to False. |
|
|
pre_encode (bool, optional): Whether to pre-encode video frames. Defaults to False. |
|
|
do_evaluate (bool, optional): Whether in evaluation mode. Defaults to False. |
|
|
load_t5_embeddings (bool, optional): Whether to load T5 embeddings. Defaults to False. |
|
|
load_action (bool, optional): Whether to load actions. Defaults to True. |
|
|
mode (str, optional): Dataset mode - 'train', 'val' or 'test'. Defaults to 'train'. |
|
|
|
|
|
The dataset loads robot trajectories and computes: |
|
|
- RGB video frames from specified camera views |
|
|
- Robot arm states (xyz position + euler angles) |
|
|
- Gripper states (binary open/closed) |
|
|
- Relative actions between consecutive frames |
|
|
|
|
|
Actions are computed as relative transforms between frames: |
|
|
- Translation: xyz offset in previous frame's coordinate frame |
|
|
- Rotation: euler angles of relative rotation |
|
|
- Gripper: binary gripper state |
|
|
|
|
|
Returns dict with: |
|
|
- video: RGB frames tensor [T,C,H,W] |
|
|
- action: Action tensor [T-1,7] |
|
|
- video_name: Dict with episode/frame metadata |
|
|
- latent: Pre-encoded video features if pre_encode=True |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
if mode == "train": |
|
|
self.data_path = train_annotation_path |
|
|
self.start_frame_interval = 1 |
|
|
elif mode == "val": |
|
|
self.data_path = val_annotation_path |
|
|
self.start_frame_interval = val_start_frame_interval |
|
|
elif mode == "test": |
|
|
self.data_path = test_annotation_path |
|
|
self.start_frame_interval = val_start_frame_interval |
|
|
self.video_path = video_path |
|
|
self.sequence_interval = sequence_interval |
|
|
self.mode = mode |
|
|
self.sequence_length = num_frames |
|
|
self.normalize = normalize |
|
|
self.pre_encode = pre_encode |
|
|
self.load_t5_embeddings = load_t5_embeddings |
|
|
self.load_action = load_action |
|
|
|
|
|
self.cam_ids = cam_ids |
|
|
self.accumulate_action = accumulate_action |
|
|
|
|
|
self.action_dim = 7 |
|
|
self.c_act_scaler = [20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 1.0] |
|
|
self.c_act_scaler = np.array(self.c_act_scaler, dtype=float) |
|
|
self.ann_files = self._init_anns(self.data_path) |
|
|
|
|
|
self.samples = self._init_sequences(self.ann_files) |
|
|
|
|
|
self.samples = sorted(self.samples, key=lambda x: (x["ann_file"], x["frame_ids"][0])) |
|
|
if debug and not do_evaluate: |
|
|
self.samples = self.samples[0:10] |
|
|
self.wrong_number = 0 |
|
|
self.transform = T.Compose([T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) |
|
|
self.training = False |
|
|
self.preprocess = T.Compose( |
|
|
[ |
|
|
ToTensorVideo(), |
|
|
Resize_Preprocess(tuple(video_size)), |
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
] |
|
|
) |
|
|
self.not_norm_preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) |
|
|
|
|
|
def __str__(self): |
|
|
return f"{len(self.ann_files)} samples from {self.data_path}" |
|
|
|
|
|
def _init_anns(self, data_dir): |
|
|
ann_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")] |
|
|
return ann_files |
|
|
|
|
|
def _init_sequences(self, ann_files): |
|
|
samples = [] |
|
|
with ThreadPoolExecutor(32) as executor: |
|
|
future_to_ann_file = { |
|
|
executor.submit(self._load_and_process_ann_file, ann_file): ann_file for ann_file in ann_files |
|
|
} |
|
|
for future in tqdm(as_completed(future_to_ann_file), total=len(ann_files)): |
|
|
samples.extend(future.result()) |
|
|
return samples |
|
|
|
|
|
def _load_and_process_ann_file(self, ann_file): |
|
|
samples = [] |
|
|
with open(ann_file, "r") as f: |
|
|
ann = json.load(f) |
|
|
|
|
|
n_frames = len(ann["state"]) |
|
|
for frame_i in range(0, n_frames, self.start_frame_interval): |
|
|
sample = dict() |
|
|
sample["ann_file"] = ann_file |
|
|
sample["frame_ids"] = [] |
|
|
curr_frame_i = frame_i |
|
|
while True: |
|
|
if curr_frame_i > (n_frames - 1): |
|
|
break |
|
|
sample["frame_ids"].append(curr_frame_i) |
|
|
if len(sample["frame_ids"]) == self.sequence_length: |
|
|
break |
|
|
curr_frame_i += self.sequence_interval |
|
|
|
|
|
if len(sample["frame_ids"]) == self.sequence_length: |
|
|
samples.append(sample) |
|
|
return samples |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def _load_video(self, video_path, frame_ids): |
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) |
|
|
assert (np.array(frame_ids) < len(vr)).all() |
|
|
assert (np.array(frame_ids) >= 0).all() |
|
|
vr.seek(0) |
|
|
frame_data = vr.get_batch(frame_ids).asnumpy() |
|
|
return frame_data |
|
|
|
|
|
def _get_frames(self, label, frame_ids, cam_id, pre_encode): |
|
|
if pre_encode: |
|
|
raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") |
|
|
else: |
|
|
video_path = label["videos"][cam_id]["video_path"] |
|
|
video_path = os.path.join(self.video_path, video_path) |
|
|
frames = self._load_video(video_path, frame_ids) |
|
|
frames = frames.astype(np.uint8) |
|
|
frames = torch.from_numpy(frames).permute(0, 3, 1, 2) |
|
|
|
|
|
def printvideo(videos, filename): |
|
|
t_videos = rearrange(videos, "f c h w -> f h w c") |
|
|
t_videos = ( |
|
|
((t_videos / 2.0 + 0.5).clamp(0, 1) * 255).detach().to(dtype=torch.uint8).cpu().contiguous().numpy() |
|
|
) |
|
|
print(t_videos.shape) |
|
|
writer = imageio.get_writer(filename, fps=4) |
|
|
for frame in t_videos: |
|
|
writer.append_data(frame) |
|
|
|
|
|
if self.normalize: |
|
|
frames = self.preprocess(frames) |
|
|
else: |
|
|
frames = self.not_norm_preprocess(frames) |
|
|
frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) |
|
|
return frames |
|
|
|
|
|
def _get_obs(self, label, frame_ids, cam_id, pre_encode): |
|
|
if cam_id is None: |
|
|
temp_cam_id = random.choice(self.cam_ids) |
|
|
else: |
|
|
temp_cam_id = cam_id |
|
|
frames = self._get_frames(label, frame_ids, cam_id=temp_cam_id, pre_encode=pre_encode) |
|
|
return frames, temp_cam_id |
|
|
|
|
|
def _get_robot_states(self, label, frame_ids): |
|
|
all_states = np.array(label["state"]) |
|
|
all_cont_gripper_states = np.array(label["continuous_gripper_state"]) |
|
|
states = all_states[frame_ids] |
|
|
cont_gripper_states = all_cont_gripper_states[frame_ids] |
|
|
arm_states = states[:, :6] |
|
|
assert arm_states.shape[0] == self.sequence_length |
|
|
assert cont_gripper_states.shape[0] == self.sequence_length |
|
|
return arm_states, cont_gripper_states |
|
|
|
|
|
def _get_all_robot_states(self, label, frame_ids): |
|
|
all_states = np.array(label["state"]) |
|
|
all_cont_gripper_states = np.array(label["continuous_gripper_state"]) |
|
|
states = all_states[frame_ids] |
|
|
cont_gripper_states = all_cont_gripper_states[frame_ids] |
|
|
arm_states = states[:, :6] |
|
|
return arm_states, cont_gripper_states |
|
|
|
|
|
def _get_all_actions(self, arm_states, gripper_states, accumulate_action): |
|
|
action_num = arm_states.shape[0] - 1 |
|
|
action = np.zeros((action_num, self.action_dim)) |
|
|
if accumulate_action: |
|
|
first_xyz = arm_states[0, 0:3] |
|
|
first_rpy = arm_states[0, 3:6] |
|
|
first_rotm = euler2rotm(first_rpy) |
|
|
for k in range(1, action_num + 1): |
|
|
curr_xyz = arm_states[k, 0:3] |
|
|
curr_rpy = arm_states[k, 3:6] |
|
|
curr_gripper = gripper_states[k] |
|
|
curr_rotm = euler2rotm(curr_rpy) |
|
|
rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) |
|
|
rel_rotm = first_rotm.T @ curr_rotm |
|
|
rel_rpy = rotm2euler(rel_rotm) |
|
|
action[k - 1, 0:3] = rel_xyz |
|
|
action[k - 1, 3:6] = rel_rpy |
|
|
action[k - 1, 6] = curr_gripper |
|
|
else: |
|
|
for k in range(1, action_num + 1): |
|
|
prev_xyz = arm_states[k - 1, 0:3] |
|
|
prev_rpy = arm_states[k - 1, 3:6] |
|
|
prev_rotm = euler2rotm(prev_rpy) |
|
|
curr_xyz = arm_states[k, 0:3] |
|
|
curr_rpy = arm_states[k, 3:6] |
|
|
curr_gripper = gripper_states[k] |
|
|
curr_rotm = euler2rotm(curr_rpy) |
|
|
rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) |
|
|
rel_rotm = prev_rotm.T @ curr_rotm |
|
|
rel_rpy = rotm2euler(rel_rotm) |
|
|
action[k - 1, 0:3] = rel_xyz |
|
|
action[k - 1, 3:6] = rel_rpy |
|
|
action[k - 1, 6] = curr_gripper |
|
|
return torch.from_numpy(action) |
|
|
|
|
|
def _get_actions(self, arm_states, gripper_states, accumulate_action): |
|
|
action = np.zeros((self.sequence_length - 1, self.action_dim)) |
|
|
if accumulate_action: |
|
|
first_xyz = arm_states[0, 0:3] |
|
|
first_rpy = arm_states[0, 3:6] |
|
|
first_rotm = euler2rotm(first_rpy) |
|
|
for k in range(1, self.sequence_length): |
|
|
curr_xyz = arm_states[k, 0:3] |
|
|
curr_rpy = arm_states[k, 3:6] |
|
|
curr_gripper = gripper_states[k] |
|
|
curr_rotm = euler2rotm(curr_rpy) |
|
|
rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) |
|
|
rel_rotm = first_rotm.T @ curr_rotm |
|
|
rel_rpy = rotm2euler(rel_rotm) |
|
|
action[k - 1, 0:3] = rel_xyz |
|
|
action[k - 1, 3:6] = rel_rpy |
|
|
action[k - 1, 6] = curr_gripper |
|
|
else: |
|
|
for k in range(1, self.sequence_length): |
|
|
prev_xyz = arm_states[k - 1, 0:3] |
|
|
prev_rpy = arm_states[k - 1, 3:6] |
|
|
prev_rotm = euler2rotm(prev_rpy) |
|
|
curr_xyz = arm_states[k, 0:3] |
|
|
curr_rpy = arm_states[k, 3:6] |
|
|
curr_gripper = gripper_states[k] |
|
|
curr_rotm = euler2rotm(curr_rpy) |
|
|
rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) |
|
|
rel_rotm = prev_rotm.T @ curr_rotm |
|
|
rel_rpy = rotm2euler(rel_rotm) |
|
|
action[k - 1, 0:3] = rel_xyz |
|
|
action[k - 1, 3:6] = rel_rpy |
|
|
action[k - 1, 6] = curr_gripper |
|
|
return torch.from_numpy(action) |
|
|
|
|
|
def __getitem__(self, index, cam_id=None, return_video=False): |
|
|
if self.mode != "train": |
|
|
np.random.seed(index) |
|
|
random.seed(index) |
|
|
|
|
|
try: |
|
|
sample = self.samples[index] |
|
|
ann_file = sample["ann_file"] |
|
|
frame_ids = sample["frame_ids"] |
|
|
with open(ann_file, "r") as f: |
|
|
label = json.load(f) |
|
|
arm_states, gripper_states = self._get_robot_states(label, frame_ids) |
|
|
actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) |
|
|
actions *= self.c_act_scaler |
|
|
|
|
|
data = dict() |
|
|
if self.load_action: |
|
|
data["action"] = actions.float() |
|
|
|
|
|
if self.pre_encode: |
|
|
raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") |
|
|
else: |
|
|
video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) |
|
|
video = video.permute(1, 0, 2, 3) |
|
|
data["video"] = video.to(dtype=torch.uint8) |
|
|
|
|
|
data["annotation_file"] = ann_file |
|
|
|
|
|
|
|
|
if "episode_id" in label: |
|
|
data["__key__"] = label["episode_id"] |
|
|
else: |
|
|
data["__key__"] = label["original_path"] |
|
|
|
|
|
|
|
|
if self.load_t5_embeddings: |
|
|
t5_embedding_path = ann_file.replace(".json", ".pickle") |
|
|
with open(t5_embedding_path, "rb") as f: |
|
|
data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) |
|
|
else: |
|
|
data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) |
|
|
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) |
|
|
data["fps"] = 4 |
|
|
data["image_size"] = 256 * torch.ones(4) |
|
|
data["num_frames"] = self.sequence_length |
|
|
data["padding_mask"] = torch.zeros(1, 256, 256) |
|
|
|
|
|
return data |
|
|
except Exception: |
|
|
warnings.warn( |
|
|
f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " |
|
|
f"(by randomly sampling another sample in the same dataset)." |
|
|
) |
|
|
warnings.warn("FULL TRACEBACK:") |
|
|
warnings.warn(traceback.format_exc()) |
|
|
self.wrong_number += 1 |
|
|
print(self.wrong_number) |
|
|
return self[np.random.randint(len(self.samples))] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dataset = Dataset_3D( |
|
|
train_annotation_path="datasets/bridge/annotation/train", |
|
|
val_annotation_path="datasets/bridge/annotation/val", |
|
|
test_annotation_path="datasets/bridge/annotation/test", |
|
|
video_path="datasets/bridge/", |
|
|
sequence_interval=1, |
|
|
num_frames=2, |
|
|
cam_ids=[0], |
|
|
accumulate_action=False, |
|
|
video_size=[256, 320], |
|
|
val_start_frame_interval=1, |
|
|
mode="train", |
|
|
load_t5_embeddings=True, |
|
|
) |
|
|
|
|
|
indices = [0, 13, 200, -1] |
|
|
for idx in indices: |
|
|
print( |
|
|
( |
|
|
f"{idx=} " |
|
|
f"{dataset[idx]['video'].sum()=}\n" |
|
|
f"{dataset[idx]['video'].shape=}\n" |
|
|
f"{dataset[idx]['video_name']=}\n" |
|
|
f"{dataset[idx]['action'].sum()=}\n" |
|
|
"---" |
|
|
) |
|
|
) |
|
|
|
|
|
from IPython import embed |
|
|
|
|
|
embed() |
|
|
|