File size: 5,495 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | # Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved.
# Modified by [Fangjing Wang/ SUST University] in [2025].
# Modification: [return raw data and suport multi-dataset mixture].
# Modified by [Jinhui YE/ HKUST University] in [2025].
# Modification: [suport topdowm processing, suport param from config].
from pathlib import Path
from typing import Sequence
from omegaconf import OmegaConf
from starVLA.dataloader.gr00t_lerobot.datasets import LeRobotSingleDataset, LeRobotMixtureDataset
from starVLA.dataloader.gr00t_lerobot.mixtures import DATASET_NAMED_MIXTURES
from starVLA.dataloader.gr00t_lerobot.data_config import get_robot_type_config_map
from starVLA.dataloader.gr00t_lerobot.embodiment_tags import ROBOT_TYPE_TO_EMBODIMENT_TAG, EmbodimentTag
def collate_fn(batch):
return batch
def make_LeRobotSingleDataset(
data_root_dir: Path | str,
data_name: str,
robot_type: str,
delete_pause_frame: bool = False,
data_cfg: dict | None = None,
) -> LeRobotSingleDataset:
"""
Make a LeRobotSingleDataset object.
:param data_root_dir: The root directory of the dataset.
:param data_name: The name of the dataset.
:param robot_type: The robot type config to use.
:param crop_obs_camera: Whether to crop the observation camera images.
:return: A LeRobotSingleDataset object.
"""
chunk_size = data_cfg.get("chunk_size")
state_use_action_chunk = data_cfg.get("state_use_action_chunk")
num_history_steps = data_cfg.get("num_history_steps", 0)
data_config = get_robot_type_config_map(
chunk_size=chunk_size,
state_use_action_chunk=state_use_action_chunk,
num_history_steps=num_history_steps,
)[robot_type]
modality_config = data_config.modality_config()
transforms = data_config.transform()
dataset_path = data_root_dir / data_name
if robot_type not in ROBOT_TYPE_TO_EMBODIMENT_TAG:
print(f"Warning: Robot type {robot_type} not found in ROBOT_TYPE_TO_EMBODIMENT_TAG, using {EmbodimentTag.NEW_EMBODIMENT} as default")
embodiment_tag = EmbodimentTag.NEW_EMBODIMENT
else:
embodiment_tag = ROBOT_TYPE_TO_EMBODIMENT_TAG[robot_type]
video_backend = data_cfg.get("video_backend", "decord") if data_cfg else "decord"
return LeRobotSingleDataset(
dataset_path=dataset_path,
modality_configs=modality_config,
transforms=transforms,
embodiment_tag=embodiment_tag,
video_backend=video_backend, # decord is more efficiency | torchvision_av for video.av1
delete_pause_frame=delete_pause_frame,
data_cfg=data_cfg,
)
def get_vla_dataset(
data_cfg: dict,
mode: str = "train",
balance_dataset_weights: bool = False,
balance_trajectory_weights: bool = False,
seed: int = 42,
delete_pause_frame: bool = True,
**kwargs: dict,
) -> LeRobotMixtureDataset:
"""
Get a LeRobotMixtureDataset object.
"""
data_root_dir = data_cfg.data_root_dir
data_mix = data_cfg.data_mix
mixture_spec = DATASET_NAMED_MIXTURES[data_mix]
included_datasets, filtered_mixture_spec = set(), []
for d_name, d_weight, robot_type in mixture_spec:
dataset_key = (d_name, robot_type)
if dataset_key in included_datasets:
print(f"Skipping Duplicate Dataset: `{(d_name, d_weight, robot_type)}`")
continue
included_datasets.add(dataset_key)
filtered_mixture_spec.append((d_name, d_weight, robot_type))
dataset_mixture = []
for d_name, d_weight, robot_type in filtered_mixture_spec:
dataset_mixture.append((make_LeRobotSingleDataset(Path(data_root_dir), d_name, robot_type, delete_pause_frame=delete_pause_frame, data_cfg=data_cfg), d_weight))
return LeRobotMixtureDataset(
dataset_mixture,
mode=mode,
balance_dataset_weights=balance_dataset_weights,
balance_trajectory_weights=balance_trajectory_weights,
seed=seed,
data_cfg=data_cfg,
**kwargs,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_behavior.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
args.config_yaml = "examples/LIBERO/train_files/starvla_cotrain_libero.yaml"
cfg = OmegaConf.load(args.config_yaml)
vla_dataset_cfg = cfg.datasets.vla_data
# vla_dataset_cfg.data_root_dir = "./playground/Datasets/behavior-1k"
# vla_dataset_cfg.include_state = True
# vla_dataset_cfg.data_mix = "BEHAVIOR_dual_base_depth"
vla_dataset_cfg.task_id = 1
for task_id in ["all"]:
# 11,26,36,37
# 5,11,13,26,36,27,43,44,45,46
# 2,3,5,11,13,25,26,27,
# 3,5,11,13, / 14,15,16,17, / 19,20,23,25, / 26,27,30,34, / 36,37,38,39, 41,42,43,44,45,46,47,49
vla_dataset_cfg.task_id = task_id
print(f"Testing Task ID: {task_id}")
dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
# dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=1, # For Debug
collate_fn=collate_fn,
)
from tqdm import tqdm
count = 1
for batch in tqdm(train_dataloader, desc="Processing Batches"):
# print(batch)
# print(1)
if count > 1:
break
count += 1
pass |