Timsty's picture
Upload folder using huggingface_hub
e94400c verified
import json
from accelerate.logging import get_logger
import numpy as np
from torch.utils.data import DataLoader
import torch.distributed as dist
from pathlib import Path
from starVLA.dataloader.vlm_datasets import make_vlm_dataloader
logger = get_logger(__name__)
def _is_main_process() -> bool:
return (not dist.is_initialized()) or dist.get_rank() == 0
def save_dataset_statistics(dataset_statistics, run_dir):
"""Saves a `dataset_statistics.json` file."""
out_path = run_dir / "dataset_statistics.json"
with open(out_path, "w") as f_json:
for _, stats in dataset_statistics.items():
for k in stats["action"].keys():
if isinstance(stats["action"][k], np.ndarray):
stats["action"][k] = stats["action"][k].tolist()
if "proprio" in stats:
for k in stats["proprio"].keys():
if isinstance(stats["proprio"][k], np.ndarray):
stats["proprio"][k] = stats["proprio"][k].tolist()
if "num_trajectories" in stats:
if isinstance(stats["num_trajectories"], np.ndarray):
stats["num_trajectories"] = stats["num_trajectories"].item()
if "num_transitions" in stats:
if isinstance(stats["num_transitions"], np.ndarray):
stats["num_transitions"] = stats["num_transitions"].item()
json.dump(dataset_statistics, f_json, indent=2)
logger.info(f"Saved dataset statistics file at path {out_path}")
def build_dataloader(cfg, dataset_py="lerobot_datasets_oxe"): # TODO now here only is get dataset, we need mv dataloader to here
if dataset_py == "lerobot_datasets":
from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn
vla_dataset_cfg = cfg.datasets.vla_data
vla_dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
vla_train_dataloader = DataLoader(
vla_dataset,
batch_size=cfg.datasets.vla_data.per_device_batch_size,
collate_fn=collate_fn,
num_workers=16,
prefetch_factor=20,
shuffle=True,
persistent_workers=True, # 保持 worker 存活,避免重启开销
pin_memory=True, # 加速 GPU 传输
drop_last=True, # 丢弃最后不完整的 batch,避免等待
timeout=30, # 设置超时,避免 worker 阻塞导致长时间等待
)
if _is_main_process():
output_dir = Path(cfg.output_dir)
vla_dataset.save_dataset_statistics(output_dir / "dataset_statistics.json")
return vla_train_dataloader
if dataset_py == "vlm_datasets":
vlm_data_module = make_vlm_dataloader(cfg)
vlm_train_dataloader = vlm_data_module["train_dataloader"]
return vlm_train_dataloader
raise ValueError(
f"Unsupported dataset builder `{dataset_py}`. "
"Expected one of: `lerobot_datasets`, `vlm_datasets`."
)