import json from accelerate.logging import get_logger import numpy as np from torch.utils.data import DataLoader import torch.distributed as dist from pathlib import Path from starVLA.dataloader.vlm_datasets import make_vlm_dataloader logger = get_logger(__name__) def _is_main_process() -> bool: return (not dist.is_initialized()) or dist.get_rank() == 0 def save_dataset_statistics(dataset_statistics, run_dir): """Saves a `dataset_statistics.json` file.""" out_path = run_dir / "dataset_statistics.json" with open(out_path, "w") as f_json: for _, stats in dataset_statistics.items(): for k in stats["action"].keys(): if isinstance(stats["action"][k], np.ndarray): stats["action"][k] = stats["action"][k].tolist() if "proprio" in stats: for k in stats["proprio"].keys(): if isinstance(stats["proprio"][k], np.ndarray): stats["proprio"][k] = stats["proprio"][k].tolist() if "num_trajectories" in stats: if isinstance(stats["num_trajectories"], np.ndarray): stats["num_trajectories"] = stats["num_trajectories"].item() if "num_transitions" in stats: if isinstance(stats["num_transitions"], np.ndarray): stats["num_transitions"] = stats["num_transitions"].item() json.dump(dataset_statistics, f_json, indent=2) logger.info(f"Saved dataset statistics file at path {out_path}") def build_dataloader(cfg, dataset_py="lerobot_datasets_oxe"): # TODO now here only is get dataset, we need mv dataloader to here if dataset_py == "lerobot_datasets": from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn vla_dataset_cfg = cfg.datasets.vla_data vla_dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) vla_train_dataloader = DataLoader( vla_dataset, batch_size=cfg.datasets.vla_data.per_device_batch_size, collate_fn=collate_fn, num_workers=16, prefetch_factor=20, shuffle=True, persistent_workers=True, # 保持 worker 存活,避免重启开销 pin_memory=True, # 加速 GPU 传输 drop_last=True, # 丢弃最后不完整的 batch,避免等待 timeout=30, # 设置超时,避免 worker 阻塞导致长时间等待 ) if _is_main_process(): output_dir = Path(cfg.output_dir) vla_dataset.save_dataset_statistics(output_dir / "dataset_statistics.json") return vla_train_dataloader if dataset_py == "vlm_datasets": vlm_data_module = make_vlm_dataloader(cfg) vlm_train_dataloader = vlm_data_module["train_dataloader"] return vlm_train_dataloader raise ValueError( f"Unsupported dataset builder `{dataset_py}`. " "Expected one of: `lerobot_datasets`, `vlm_datasets`." )