# Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved. # Modified by [Fangjing Wang/ SUST University] in [2025]. # Modification: [return raw data and suport multi-dataset mixture]. # Modified by [Jinhui YE/ HKUST University] in [2025]. # Modification: [suport topdowm processing, suport param from config]. from pathlib import Path from typing import Sequence from omegaconf import OmegaConf from starVLA.dataloader.gr00t_lerobot.datasets import LeRobotSingleDataset, LeRobotMixtureDataset from starVLA.dataloader.gr00t_lerobot.mixtures import DATASET_NAMED_MIXTURES from starVLA.dataloader.gr00t_lerobot.data_config import get_robot_type_config_map from starVLA.dataloader.gr00t_lerobot.embodiment_tags import ROBOT_TYPE_TO_EMBODIMENT_TAG, EmbodimentTag def collate_fn(batch): return batch def make_LeRobotSingleDataset( data_root_dir: Path | str, data_name: str, robot_type: str, delete_pause_frame: bool = False, data_cfg: dict | None = None, ) -> LeRobotSingleDataset: """ Make a LeRobotSingleDataset object. :param data_root_dir: The root directory of the dataset. :param data_name: The name of the dataset. :param robot_type: The robot type config to use. :param crop_obs_camera: Whether to crop the observation camera images. :return: A LeRobotSingleDataset object. """ chunk_size = data_cfg.get("chunk_size") state_use_action_chunk = data_cfg.get("state_use_action_chunk") num_history_steps = data_cfg.get("num_history_steps", 0) data_config = get_robot_type_config_map( chunk_size=chunk_size, state_use_action_chunk=state_use_action_chunk, num_history_steps=num_history_steps, )[robot_type] modality_config = data_config.modality_config() transforms = data_config.transform() dataset_path = data_root_dir / data_name if robot_type not in ROBOT_TYPE_TO_EMBODIMENT_TAG: print(f"Warning: Robot type {robot_type} not found in ROBOT_TYPE_TO_EMBODIMENT_TAG, using {EmbodimentTag.NEW_EMBODIMENT} as default") embodiment_tag = EmbodimentTag.NEW_EMBODIMENT else: embodiment_tag = ROBOT_TYPE_TO_EMBODIMENT_TAG[robot_type] video_backend = data_cfg.get("video_backend", "decord") if data_cfg else "decord" return LeRobotSingleDataset( dataset_path=dataset_path, modality_configs=modality_config, transforms=transforms, embodiment_tag=embodiment_tag, video_backend=video_backend, # decord is more efficiency | torchvision_av for video.av1 delete_pause_frame=delete_pause_frame, data_cfg=data_cfg, ) def get_vla_dataset( data_cfg: dict, mode: str = "train", balance_dataset_weights: bool = False, balance_trajectory_weights: bool = False, seed: int = 42, delete_pause_frame: bool = True, **kwargs: dict, ) -> LeRobotMixtureDataset: """ Get a LeRobotMixtureDataset object. """ data_root_dir = data_cfg.data_root_dir data_mix = data_cfg.data_mix mixture_spec = DATASET_NAMED_MIXTURES[data_mix] included_datasets, filtered_mixture_spec = set(), [] for d_name, d_weight, robot_type in mixture_spec: dataset_key = (d_name, robot_type) if dataset_key in included_datasets: print(f"Skipping Duplicate Dataset: `{(d_name, d_weight, robot_type)}`") continue included_datasets.add(dataset_key) filtered_mixture_spec.append((d_name, d_weight, robot_type)) dataset_mixture = [] for d_name, d_weight, robot_type in filtered_mixture_spec: dataset_mixture.append((make_LeRobotSingleDataset(Path(data_root_dir), d_name, robot_type, delete_pause_frame=delete_pause_frame, data_cfg=data_cfg), d_weight)) return LeRobotMixtureDataset( dataset_mixture, mode=mode, balance_dataset_weights=balance_dataset_weights, balance_trajectory_weights=balance_trajectory_weights, seed=seed, data_cfg=data_cfg, **kwargs, ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_behavior.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() args.config_yaml = "examples/LIBERO/train_files/starvla_cotrain_libero.yaml" cfg = OmegaConf.load(args.config_yaml) vla_dataset_cfg = cfg.datasets.vla_data # vla_dataset_cfg.data_root_dir = "./playground/Datasets/behavior-1k" # vla_dataset_cfg.include_state = True # vla_dataset_cfg.data_mix = "BEHAVIOR_dual_base_depth" vla_dataset_cfg.task_id = 1 for task_id in ["all"]: # 11,26,36,37 # 5,11,13,26,36,27,43,44,45,46 # 2,3,5,11,13,25,26,27, # 3,5,11,13, / 14,15,16,17, / 19,20,23,25, / 26,27,30,34, / 36,37,38,39, 41,42,43,44,45,46,47,49 vla_dataset_cfg.task_id = task_id print(f"Testing Task ID: {task_id}") dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) # dataset from torch.utils.data import DataLoader train_dataloader = DataLoader( dataset, batch_size=2, num_workers=1, # For Debug collate_fn=collate_fn, ) from tqdm import tqdm count = 1 for batch in tqdm(train_dataloader, desc="Processing Batches"): # print(batch) # print(1) if count > 1: break count += 1 pass