Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Run this command to interactively debug: | |
| PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py | |
| Adapted from: | |
| https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py | |
| """ | |
| import json | |
| import pickle | |
| import random | |
| import traceback | |
| import warnings | |
| import numpy as np | |
| import torch | |
| from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D | |
| from cosmos_predict1.utils import log | |
| class Dataset_3DBinary(Dataset_3D): | |
| def __init__( | |
| self, | |
| train_annotation_path, | |
| val_annotation_path, | |
| test_annotation_path, | |
| video_path, | |
| sequence_interval, | |
| num_frames, | |
| cam_ids, | |
| accumulate_action, | |
| video_size, | |
| val_start_frame_interval, | |
| debug=False, | |
| normalize=False, | |
| pre_encode=False, | |
| do_evaluate=False, | |
| load_t5_embeddings=False, | |
| load_action=True, | |
| mode="train", | |
| ): | |
| """Dataset class for loading 3D robot action-conditional data. | |
| This dataset loads robot trajectories consisting of RGB video frames, robot states | |
| (arm positions and binary gripper states), and computes relative actions between | |
| consecutive frames. | |
| """ | |
| super().__init__( | |
| train_annotation_path=train_annotation_path, | |
| val_annotation_path=val_annotation_path, | |
| test_annotation_path=test_annotation_path, | |
| video_path=video_path, | |
| sequence_interval=sequence_interval, | |
| num_frames=num_frames, | |
| cam_ids=cam_ids, | |
| accumulate_action=accumulate_action, | |
| video_size=video_size, | |
| val_start_frame_interval=val_start_frame_interval, | |
| debug=debug, | |
| normalize=normalize, | |
| pre_encode=pre_encode, | |
| do_evaluate=do_evaluate, | |
| load_t5_embeddings=load_t5_embeddings, | |
| load_action=load_action, | |
| mode=mode, | |
| ) | |
| log.info("Dataset_3DBinary: in this dataset, we binarize the gripper state to 0 or 1.") | |
| def _get_json_action(self, label, frame_ids): | |
| all_action = np.array(label["action"]) | |
| actions = all_action[frame_ids[:-1]] | |
| return torch.from_numpy(actions) | |
| def __getitem__(self, index, cam_id=None, return_video=False): | |
| if self.mode != "train": | |
| np.random.seed(index) | |
| random.seed(index) | |
| try: | |
| sample = self.samples[index] | |
| ann_file = sample["ann_file"] | |
| frame_ids = sample["frame_ids"] | |
| with open(ann_file, "r") as f: | |
| label = json.load(f) | |
| arm_states, gripper_states = self._get_robot_states(label, frame_ids) | |
| actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) | |
| actions *= self.c_act_scaler | |
| data = dict() | |
| if self.load_action: | |
| data["action"] = actions.float() | |
| json_action = self._get_json_action(label, frame_ids).float() | |
| json_action[:, :6] = data["action"][:, :6] | |
| data["action"] = json_action | |
| if self.pre_encode: | |
| raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") | |
| else: | |
| video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) | |
| video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] | |
| data["video"] = video.to(dtype=torch.uint8) | |
| data["annotation_file"] = ann_file | |
| if "episode_id" in label: | |
| data["__key__"] = label["episode_id"] | |
| else: | |
| data["__key__"] = label["original_path"] | |
| # Just add these to fit the interface | |
| if self.load_t5_embeddings: | |
| t5_embedding_path = ann_file.replace(".json", ".pickle") | |
| with open(t5_embedding_path, "rb") as f: | |
| data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) | |
| else: | |
| data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) | |
| data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) | |
| data["fps"] = 4 | |
| data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? | |
| data["num_frames"] = self.sequence_length | |
| data["padding_mask"] = torch.zeros(1, 256, 256) | |
| return data | |
| except Exception: | |
| warnings.warn( | |
| f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " | |
| f"(by randomly sampling another sample in the same dataset)." | |
| ) | |
| warnings.warn("FULL TRACEBACK:") | |
| warnings.warn(traceback.format_exc()) | |
| self.wrong_number += 1 | |
| print(self.wrong_number) | |
| return self[np.random.randint(len(self.samples))] | |
| if __name__ == "__main__": | |
| dataset = Dataset_3DBinary( | |
| train_annotation_path="datasets/bridge/annotation/train", | |
| val_annotation_path="datasets/bridge/annotation/val", | |
| test_annotation_path="datasets/bridge/annotation/test", | |
| video_path="datasets/bridge/", | |
| sequence_interval=1, | |
| num_frames=2, | |
| cam_ids=[0], | |
| accumulate_action=False, | |
| video_size=[256, 320], | |
| val_start_frame_interval=1, | |
| mode="train", | |
| load_t5_embeddings=True, | |
| ) | |
| indices = [0, 13, 200, -1] | |
| for idx in indices: | |
| print( | |
| ( | |
| f"{idx=} " | |
| f"{dataset[idx]['video'].sum()=}\n" | |
| f"{dataset[idx]['video'].shape=}\n" | |
| f"{dataset[idx]['video_name']=}\n" | |
| f"{dataset[idx]['action'].sum()=}\n" | |
| f"{dataset[idx]['json_action'].sum()=}\n" | |
| "---" | |
| ) | |
| ) | |
| from IPython import embed | |
| embed() | |