cross13tasks / code /dataloader_bak /lerobot_datasets.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved.
# Modified by [Fangjing Wang/ SUST University] in [2025].
# Modification: [return raw data and suport multi-dataset mixture].
# Modified by [Jinhui YE/ HKUST University] in [2025].
# Modification: [suport topdowm processing, suport param from config].
from pathlib import Path
from typing import Sequence
from omegaconf import OmegaConf
from starVLA.dataloader.gr00t_lerobot.datasets import LeRobotSingleDataset, LeRobotMixtureDataset
from starVLA.dataloader.gr00t_lerobot.mixtures import DATASET_NAMED_MIXTURES
from starVLA.dataloader.gr00t_lerobot.data_config import get_robot_type_config_map
from starVLA.dataloader.gr00t_lerobot.embodiment_tags import ROBOT_TYPE_TO_EMBODIMENT_TAG, EmbodimentTag
def collate_fn(batch):
return batch
def make_LeRobotSingleDataset(
data_root_dir: Path | str,
data_name: str,
robot_type: str,
delete_pause_frame: bool = False,
data_cfg: dict | None = None,
) -> LeRobotSingleDataset:
"""
Make a LeRobotSingleDataset object.
:param data_root_dir: The root directory of the dataset.
:param data_name: The name of the dataset.
:param robot_type: The robot type config to use.
:param crop_obs_camera: Whether to crop the observation camera images.
:return: A LeRobotSingleDataset object.
"""
chunk_size = data_cfg.get("chunk_size")
state_use_action_chunk = data_cfg.get("state_use_action_chunk")
data_config = get_robot_type_config_map(chunk_size=chunk_size, state_use_action_chunk=state_use_action_chunk)[robot_type]
modality_config = data_config.modality_config()
transforms = data_config.transform()
dataset_path = data_root_dir / data_name
if robot_type not in ROBOT_TYPE_TO_EMBODIMENT_TAG:
print(f"Warning: Robot type {robot_type} not found in ROBOT_TYPE_TO_EMBODIMENT_TAG, using {EmbodimentTag.NEW_EMBODIMENT} as default")
embodiment_tag = EmbodimentTag.NEW_EMBODIMENT
else:
embodiment_tag = ROBOT_TYPE_TO_EMBODIMENT_TAG[robot_type]
video_backend = data_cfg.get("video_backend", "decord") if data_cfg else "decord"
return LeRobotSingleDataset(
dataset_path=dataset_path,
modality_configs=modality_config,
transforms=transforms,
embodiment_tag=embodiment_tag,
video_backend=video_backend, # decord is more efficiency | torchvision_av for video.av1
delete_pause_frame=delete_pause_frame,
data_cfg=data_cfg,
)
def get_vla_dataset(
data_cfg: dict,
mode: str = "train",
balance_dataset_weights: bool = False,
balance_trajectory_weights: bool = False,
seed: int = 42,
delete_pause_frame: bool = True,
**kwargs: dict,
) -> LeRobotMixtureDataset:
"""
Get a LeRobotMixtureDataset object.
"""
data_root_dir = data_cfg.data_root_dir
data_mix = data_cfg.data_mix
mixture_spec = DATASET_NAMED_MIXTURES[data_mix]
included_datasets, filtered_mixture_spec = set(), []
for d_name, d_weight, robot_type in mixture_spec:
dataset_key = (d_name, robot_type)
if dataset_key in included_datasets:
print(f"Skipping Duplicate Dataset: `{(d_name, d_weight, robot_type)}`")
continue
included_datasets.add(dataset_key)
filtered_mixture_spec.append((d_name, d_weight, robot_type))
dataset_mixture = []
for d_name, d_weight, robot_type in filtered_mixture_spec:
dataset_mixture.append((make_LeRobotSingleDataset(Path(data_root_dir), d_name, robot_type, delete_pause_frame=delete_pause_frame, data_cfg=data_cfg), d_weight))
return LeRobotMixtureDataset(
dataset_mixture,
mode=mode,
balance_dataset_weights=balance_dataset_weights,
balance_trajectory_weights=balance_trajectory_weights,
seed=seed,
data_cfg=data_cfg,
**kwargs,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_behavior.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
args.config_yaml = "examples/LIBERO/train_files/starvla_cotrain_libero.yaml"
cfg = OmegaConf.load(args.config_yaml)
vla_dataset_cfg = cfg.datasets.vla_data
# vla_dataset_cfg.data_root_dir = "./playground/Datasets/behavior-1k"
# vla_dataset_cfg.include_state = True
# vla_dataset_cfg.data_mix = "BEHAVIOR_dual_base_depth"
vla_dataset_cfg.task_id = 1
for task_id in ["all"]:
# 11,26,36,37
# 5,11,13,26,36,27,43,44,45,46
# 2,3,5,11,13,25,26,27,
# 3,5,11,13, / 14,15,16,17, / 19,20,23,25, / 26,27,30,34, / 36,37,38,39, 41,42,43,44,45,46,47,49
vla_dataset_cfg.task_id = task_id
print(f"Testing Task ID: {task_id}")
dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
# dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=1, # For Debug
collate_fn=collate_fn,
)
from tqdm import tqdm
count = 1
for batch in tqdm(train_dataloader, desc="Processing Batches"):
# print(batch)
# print(1)
if count > 1:
break
count += 1
pass