Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Run this command to interactively debug: | |
| PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py | |
| Adapted from: | |
| https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py | |
| """ | |
| import json | |
| import os | |
| import pickle | |
| import random | |
| import traceback | |
| import warnings | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from decord import VideoReader, cpu | |
| from einops import rearrange | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms as T | |
| from tqdm import tqdm | |
| from cosmos_predict1.diffusion.training.datasets.dataset_utils import ( | |
| Resize_Preprocess, | |
| ToTensorVideo, | |
| euler2rotm, | |
| rotm2euler, | |
| ) | |
| class Dataset_3D(Dataset): | |
| def __init__( | |
| self, | |
| train_annotation_path, | |
| val_annotation_path, | |
| test_annotation_path, | |
| video_path, | |
| sequence_interval, | |
| num_frames, | |
| cam_ids, | |
| accumulate_action, | |
| video_size, | |
| val_start_frame_interval, | |
| debug=False, | |
| normalize=False, | |
| pre_encode=False, | |
| do_evaluate=False, | |
| load_t5_embeddings=False, | |
| load_action=True, | |
| mode="train", | |
| ): | |
| """Dataset class for loading 3D robot action-conditional data. | |
| This dataset loads robot trajectories consisting of RGB video frames, robot states (arm positions and gripper states), | |
| and computes relative actions between consecutive frames. | |
| Args: | |
| train_annotation_path (str): Path to training annotation files | |
| val_annotation_path (str): Path to validation annotation files | |
| test_annotation_path (str): Path to test annotation files | |
| video_path (str): Base path to video files | |
| sequence_interval (int): Interval between sampled frames in a sequence | |
| num_frames (int): Number of frames to load per sequence | |
| cam_ids (list): List of camera IDs to sample from | |
| accumulate_action (bool): Whether to accumulate actions relative to first frame | |
| video_size (list): Target size [H,W] for video frames | |
| val_start_frame_interval (int): Frame sampling interval for validation/test | |
| debug (bool, optional): If True, only loads subset of data. Defaults to False. | |
| normalize (bool, optional): Whether to normalize video frames. Defaults to False. | |
| pre_encode (bool, optional): Whether to pre-encode video frames. Defaults to False. | |
| do_evaluate (bool, optional): Whether in evaluation mode. Defaults to False. | |
| load_t5_embeddings (bool, optional): Whether to load T5 embeddings. Defaults to False. | |
| load_action (bool, optional): Whether to load actions. Defaults to True. | |
| mode (str, optional): Dataset mode - 'train', 'val' or 'test'. Defaults to 'train'. | |
| The dataset loads robot trajectories and computes: | |
| - RGB video frames from specified camera views | |
| - Robot arm states (xyz position + euler angles) | |
| - Gripper states (binary open/closed) | |
| - Relative actions between consecutive frames | |
| Actions are computed as relative transforms between frames: | |
| - Translation: xyz offset in previous frame's coordinate frame | |
| - Rotation: euler angles of relative rotation | |
| - Gripper: binary gripper state | |
| Returns dict with: | |
| - video: RGB frames tensor [T,C,H,W] | |
| - action: Action tensor [T-1,7] | |
| - video_name: Dict with episode/frame metadata | |
| - latent: Pre-encoded video features if pre_encode=True | |
| """ | |
| super().__init__() | |
| if mode == "train": | |
| self.data_path = train_annotation_path | |
| self.start_frame_interval = 1 | |
| elif mode == "val": | |
| self.data_path = val_annotation_path | |
| self.start_frame_interval = val_start_frame_interval | |
| elif mode == "test": | |
| self.data_path = test_annotation_path | |
| self.start_frame_interval = val_start_frame_interval | |
| self.video_path = video_path | |
| self.sequence_interval = sequence_interval | |
| self.mode = mode | |
| self.sequence_length = num_frames | |
| self.normalize = normalize | |
| self.pre_encode = pre_encode | |
| self.load_t5_embeddings = load_t5_embeddings | |
| self.load_action = load_action | |
| self.cam_ids = cam_ids | |
| self.accumulate_action = accumulate_action | |
| self.action_dim = 7 # ee xyz (3) + ee euler (3) + gripper(1) | |
| self.c_act_scaler = [20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 1.0] | |
| self.c_act_scaler = np.array(self.c_act_scaler, dtype=float) | |
| self.ann_files = self._init_anns(self.data_path) | |
| self.samples = self._init_sequences(self.ann_files) | |
| self.samples = sorted(self.samples, key=lambda x: (x["ann_file"], x["frame_ids"][0])) | |
| if debug and not do_evaluate: | |
| self.samples = self.samples[0:10] | |
| self.wrong_number = 0 | |
| self.transform = T.Compose([T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) | |
| self.training = False | |
| self.preprocess = T.Compose( | |
| [ | |
| ToTensorVideo(), | |
| Resize_Preprocess(tuple(video_size)), # 288 512 | |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ] | |
| ) | |
| self.not_norm_preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) | |
| def __str__(self): | |
| return f"{len(self.ann_files)} samples from {self.data_path}" | |
| def _init_anns(self, data_dir): | |
| ann_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")] | |
| return ann_files | |
| def _init_sequences(self, ann_files): | |
| samples = [] | |
| with ThreadPoolExecutor(32) as executor: | |
| future_to_ann_file = { | |
| executor.submit(self._load_and_process_ann_file, ann_file): ann_file for ann_file in ann_files | |
| } | |
| for future in tqdm(as_completed(future_to_ann_file), total=len(ann_files)): | |
| samples.extend(future.result()) | |
| return samples | |
| def _load_and_process_ann_file(self, ann_file): | |
| samples = [] | |
| with open(ann_file, "r") as f: | |
| ann = json.load(f) | |
| n_frames = len(ann["state"]) | |
| for frame_i in range(0, n_frames, self.start_frame_interval): | |
| sample = dict() | |
| sample["ann_file"] = ann_file | |
| sample["frame_ids"] = [] | |
| curr_frame_i = frame_i | |
| while True: | |
| if curr_frame_i > (n_frames - 1): | |
| break | |
| sample["frame_ids"].append(curr_frame_i) | |
| if len(sample["frame_ids"]) == self.sequence_length: | |
| break | |
| curr_frame_i += self.sequence_interval | |
| # make sure there are sequence_length number of frames | |
| if len(sample["frame_ids"]) == self.sequence_length: | |
| samples.append(sample) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def _load_video(self, video_path, frame_ids): | |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) | |
| assert (np.array(frame_ids) < len(vr)).all() | |
| assert (np.array(frame_ids) >= 0).all() | |
| vr.seek(0) | |
| frame_data = vr.get_batch(frame_ids).asnumpy() | |
| return frame_data | |
| def _get_frames(self, label, frame_ids, cam_id, pre_encode): | |
| if pre_encode: | |
| raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") | |
| else: | |
| video_path = label["videos"][cam_id]["video_path"] | |
| video_path = os.path.join(self.video_path, video_path) | |
| frames = self._load_video(video_path, frame_ids) | |
| frames = frames.astype(np.uint8) | |
| frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) | |
| def printvideo(videos, filename): | |
| t_videos = rearrange(videos, "f c h w -> f h w c") | |
| t_videos = ( | |
| ((t_videos / 2.0 + 0.5).clamp(0, 1) * 255).detach().to(dtype=torch.uint8).cpu().contiguous().numpy() | |
| ) | |
| print(t_videos.shape) | |
| writer = imageio.get_writer(filename, fps=4) # fps 是帧率 | |
| for frame in t_videos: | |
| writer.append_data(frame) # 1 4 13 23 # fp16 24 76 456 688 | |
| if self.normalize: | |
| frames = self.preprocess(frames) | |
| else: | |
| frames = self.not_norm_preprocess(frames) | |
| frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) | |
| return frames | |
| def _get_obs(self, label, frame_ids, cam_id, pre_encode): | |
| if cam_id is None: | |
| temp_cam_id = random.choice(self.cam_ids) | |
| else: | |
| temp_cam_id = cam_id | |
| frames = self._get_frames(label, frame_ids, cam_id=temp_cam_id, pre_encode=pre_encode) | |
| return frames, temp_cam_id | |
| def _get_robot_states(self, label, frame_ids): | |
| all_states = np.array(label["state"]) | |
| all_cont_gripper_states = np.array(label["continuous_gripper_state"]) | |
| states = all_states[frame_ids] | |
| cont_gripper_states = all_cont_gripper_states[frame_ids] | |
| arm_states = states[:, :6] | |
| assert arm_states.shape[0] == self.sequence_length | |
| assert cont_gripper_states.shape[0] == self.sequence_length | |
| return arm_states, cont_gripper_states | |
| def _get_all_robot_states(self, label, frame_ids): | |
| all_states = np.array(label["state"]) | |
| all_cont_gripper_states = np.array(label["continuous_gripper_state"]) | |
| states = all_states[frame_ids] | |
| cont_gripper_states = all_cont_gripper_states[frame_ids] | |
| arm_states = states[:, :6] | |
| return arm_states, cont_gripper_states | |
| def _get_all_actions(self, arm_states, gripper_states, accumulate_action): | |
| action_num = arm_states.shape[0] - 1 | |
| action = np.zeros((action_num, self.action_dim)) | |
| if accumulate_action: | |
| first_xyz = arm_states[0, 0:3] | |
| first_rpy = arm_states[0, 3:6] | |
| first_rotm = euler2rotm(first_rpy) | |
| for k in range(1, action_num + 1): | |
| curr_xyz = arm_states[k, 0:3] | |
| curr_rpy = arm_states[k, 3:6] | |
| curr_gripper = gripper_states[k] | |
| curr_rotm = euler2rotm(curr_rpy) | |
| rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) | |
| rel_rotm = first_rotm.T @ curr_rotm | |
| rel_rpy = rotm2euler(rel_rotm) | |
| action[k - 1, 0:3] = rel_xyz | |
| action[k - 1, 3:6] = rel_rpy | |
| action[k - 1, 6] = curr_gripper | |
| else: | |
| for k in range(1, action_num + 1): | |
| prev_xyz = arm_states[k - 1, 0:3] | |
| prev_rpy = arm_states[k - 1, 3:6] | |
| prev_rotm = euler2rotm(prev_rpy) | |
| curr_xyz = arm_states[k, 0:3] | |
| curr_rpy = arm_states[k, 3:6] | |
| curr_gripper = gripper_states[k] | |
| curr_rotm = euler2rotm(curr_rpy) | |
| rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) | |
| rel_rotm = prev_rotm.T @ curr_rotm | |
| rel_rpy = rotm2euler(rel_rotm) | |
| action[k - 1, 0:3] = rel_xyz | |
| action[k - 1, 3:6] = rel_rpy | |
| action[k - 1, 6] = curr_gripper | |
| return torch.from_numpy(action) # (l - 1, act_dim) | |
| def _get_actions(self, arm_states, gripper_states, accumulate_action): | |
| action = np.zeros((self.sequence_length - 1, self.action_dim)) | |
| if accumulate_action: | |
| first_xyz = arm_states[0, 0:3] | |
| first_rpy = arm_states[0, 3:6] | |
| first_rotm = euler2rotm(first_rpy) | |
| for k in range(1, self.sequence_length): | |
| curr_xyz = arm_states[k, 0:3] | |
| curr_rpy = arm_states[k, 3:6] | |
| curr_gripper = gripper_states[k] | |
| curr_rotm = euler2rotm(curr_rpy) | |
| rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) | |
| rel_rotm = first_rotm.T @ curr_rotm | |
| rel_rpy = rotm2euler(rel_rotm) | |
| action[k - 1, 0:3] = rel_xyz | |
| action[k - 1, 3:6] = rel_rpy | |
| action[k - 1, 6] = curr_gripper | |
| else: | |
| for k in range(1, self.sequence_length): | |
| prev_xyz = arm_states[k - 1, 0:3] | |
| prev_rpy = arm_states[k - 1, 3:6] | |
| prev_rotm = euler2rotm(prev_rpy) | |
| curr_xyz = arm_states[k, 0:3] | |
| curr_rpy = arm_states[k, 3:6] | |
| curr_gripper = gripper_states[k] | |
| curr_rotm = euler2rotm(curr_rpy) | |
| rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) | |
| rel_rotm = prev_rotm.T @ curr_rotm | |
| rel_rpy = rotm2euler(rel_rotm) | |
| action[k - 1, 0:3] = rel_xyz | |
| action[k - 1, 3:6] = rel_rpy | |
| action[k - 1, 6] = curr_gripper | |
| return torch.from_numpy(action) # (l - 1, act_dim) | |
| def __getitem__(self, index, cam_id=None, return_video=False): | |
| if self.mode != "train": | |
| np.random.seed(index) | |
| random.seed(index) | |
| try: | |
| sample = self.samples[index] | |
| ann_file = sample["ann_file"] | |
| frame_ids = sample["frame_ids"] | |
| with open(ann_file, "r") as f: | |
| label = json.load(f) | |
| arm_states, gripper_states = self._get_robot_states(label, frame_ids) | |
| actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) | |
| actions *= self.c_act_scaler | |
| data = dict() | |
| if self.load_action: | |
| data["action"] = actions.float() | |
| if self.pre_encode: | |
| raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") | |
| else: | |
| video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) | |
| video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] | |
| data["video"] = video.to(dtype=torch.uint8) | |
| data["annotation_file"] = ann_file | |
| # NOTE: __key__ is used to uniquely identify the sample, required for callback functions | |
| if "episode_id" in label: | |
| data["__key__"] = label["episode_id"] | |
| else: | |
| data["__key__"] = label["original_path"] | |
| # Just add these to fit the interface | |
| if self.load_t5_embeddings: | |
| t5_embedding_path = ann_file.replace(".json", ".pickle") | |
| with open(t5_embedding_path, "rb") as f: | |
| data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) | |
| else: | |
| data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) | |
| data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) | |
| data["fps"] = 4 | |
| data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? | |
| data["num_frames"] = self.sequence_length | |
| data["padding_mask"] = torch.zeros(1, 256, 256) | |
| return data | |
| except Exception: | |
| warnings.warn( | |
| f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " | |
| f"(by randomly sampling another sample in the same dataset)." | |
| ) | |
| warnings.warn("FULL TRACEBACK:") | |
| warnings.warn(traceback.format_exc()) | |
| self.wrong_number += 1 | |
| print(self.wrong_number) | |
| return self[np.random.randint(len(self.samples))] | |
| if __name__ == "__main__": | |
| dataset = Dataset_3D( | |
| train_annotation_path="datasets/bridge/annotation/train", | |
| val_annotation_path="datasets/bridge/annotation/val", | |
| test_annotation_path="datasets/bridge/annotation/test", | |
| video_path="datasets/bridge/", | |
| sequence_interval=1, | |
| num_frames=2, | |
| cam_ids=[0], | |
| accumulate_action=False, | |
| video_size=[256, 320], | |
| val_start_frame_interval=1, | |
| mode="train", | |
| load_t5_embeddings=True, | |
| ) | |
| indices = [0, 13, 200, -1] | |
| for idx in indices: | |
| print( | |
| ( | |
| f"{idx=} " | |
| f"{dataset[idx]['video'].sum()=}\n" | |
| f"{dataset[idx]['video'].shape=}\n" | |
| f"{dataset[idx]['video_name']=}\n" | |
| f"{dataset[idx]['action'].sum()=}\n" | |
| "---" | |
| ) | |
| ) | |
| from IPython import embed | |
| embed() | |