| |
| """ |
| Build a filtered training index from community_dataset_v3 on disk. |
| |
| Applies: |
| - Robot type filter (so100/so101 variants only) |
| - Schema filter (2 cameras, 6-DOF, 30fps) |
| - Episode length filter (5s-60s) |
| - Per-task cap (default 200) |
| - Per-contributor cap (default 200) |
| - Excludes datasets with file count mismatches |
| |
| Outputs filtered_index.json with all info needed to train. |
| """ |
|
|
| import argparse |
| import glob |
| import json |
| import random |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
|
|
| def load_dataset_meta(dataset_root: Path) -> dict | None: |
| """Load and validate a single dataset's metadata.""" |
| info_path = dataset_root / "meta" / "info.json" |
| if not info_path.exists(): |
| return None |
|
|
| info = json.load(open(info_path)) |
|
|
| |
| robot = info.get("robot_type", "") |
| if robot not in ("so100", "so101", "so100_follower", "so101_follower"): |
| return None |
|
|
| |
| features = info.get("features", {}) |
| expected_keys = { |
| "action", "episode_index", "frame_index", "index", |
| "observation.images.image", "observation.images.image2", |
| "observation.state", "task_index", "timestamp", |
| } |
| if set(features.keys()) != expected_keys: |
| return None |
|
|
| |
| if features.get("action", {}).get("shape") != [6]: |
| return None |
| if features.get("observation.state", {}).get("shape") != [6]: |
| return None |
|
|
| |
| if info.get("fps") != 30: |
| return None |
|
|
| |
| for cam_key in ("observation.images.image", "observation.images.image2"): |
| shape = features.get(cam_key, {}).get("shape", []) |
| if len(shape) < 2 or shape[0] != 480 or shape[1] != 640: |
| return None |
|
|
| |
| tasks_path = dataset_root / "meta" / "tasks.jsonl" |
| tasks = {} |
| if tasks_path.exists(): |
| for line in open(tasks_path): |
| line = line.strip() |
| if line: |
| t = json.loads(line) |
| tasks[t["task_index"]] = t["task"] |
|
|
| |
| total_eps = info.get("total_episodes", 0) |
| vids = glob.glob(str(dataset_root / "videos" / "**" / "*.mp4"), recursive=True) |
| parquets = glob.glob(str(dataset_root / "data" / "**" / "*.parquet"), recursive=True) |
| expected_vids = total_eps * 2 |
| if len(vids) != expected_vids or len(parquets) != total_eps: |
| return None |
|
|
| |
| episodes = [] |
| ep_jsonl = dataset_root / "meta" / "episodes.jsonl" |
| if ep_jsonl.exists(): |
| for line in open(ep_jsonl): |
| line = line.strip() |
| if line: |
| episodes.append(json.loads(line)) |
|
|
| return { |
| "robot_type": robot, |
| "total_episodes": total_eps, |
| "total_frames": info.get("total_frames", 0), |
| "fps": info["fps"], |
| "tasks": tasks, |
| "episodes": episodes, |
| "features": {k: v.get("shape") for k, v in features.items()}, |
| } |
|
|
|
|
| def build_index( |
| data_root: Path, |
| max_per_task: int = 200, |
| max_per_contributor: int = 200, |
| min_episode_frames: int = 150, |
| max_episode_frames: int = 1800, |
| seed: int = 42, |
| ) -> dict: |
| """Build filtered training index.""" |
| rng = random.Random(seed) |
|
|
| |
| contributors = sorted([ |
| d for d in data_root.iterdir() |
| if d.is_dir() and not d.name.startswith(".") |
| ]) |
|
|
| |
| all_episodes = [] |
| datasets_passed = 0 |
| datasets_rejected = 0 |
| skipped_missing = 0 |
|
|
| for contrib_dir in contributors: |
| if not contrib_dir.is_dir(): |
| continue |
| contributor = contrib_dir.name |
|
|
| for ds_dir in sorted(contrib_dir.iterdir()): |
| if not ds_dir.is_dir(): |
| continue |
|
|
| meta = load_dataset_meta(ds_dir) |
| if meta is None: |
| datasets_rejected += 1 |
| continue |
|
|
| datasets_passed += 1 |
| dataset_name = f"{contributor}/{ds_dir.name}" |
|
|
| |
| if not meta["tasks"]: |
| meta["tasks"] = {0: "(no task)"} |
|
|
| |
| |
| for ep_idx in range(meta["total_episodes"]): |
| parquet_path = ds_dir / f"data/chunk-000/episode_{ep_idx:06d}.parquet" |
| if not parquet_path.exists(): |
| skipped_missing += 1 |
| continue |
|
|
| |
| pf = pd.read_parquet(parquet_path, columns=["frame_index"]) |
| actual_length = len(pf) |
|
|
| if actual_length < min_episode_frames or actual_length > max_episode_frames: |
| continue |
|
|
| |
| vid1 = ds_dir / f"videos/chunk-000/observation.images.image/episode_{ep_idx:06d}.mp4" |
| vid2 = ds_dir / f"videos/chunk-000/observation.images.image2/episode_{ep_idx:06d}.mp4" |
| if not vid1.exists() or not vid2.exists(): |
| skipped_missing += 1 |
| continue |
|
|
| |
| task_idx = 0 |
| if meta["episodes"]: |
| for ep_meta in meta["episodes"]: |
| if ep_meta.get("episode_index") == ep_idx: |
| task_idx = ep_meta.get("task_index", 0) |
| break |
|
|
| task = meta["tasks"].get(task_idx, "(no task)") |
| all_episodes.append((contributor, dataset_name, ep_idx, task, actual_length)) |
|
|
| print(f"Datasets: {datasets_passed} passed, {datasets_rejected} rejected") |
| print(f"Episodes verified: {len(all_episodes)}, skipped (missing files): {skipped_missing}") |
| print(f"Episodes before caps: {len(all_episodes)}") |
|
|
| |
| task_buckets = defaultdict(list) |
| for ep in all_episodes: |
| task_buckets[ep[3]].append(ep) |
|
|
| after_task_cap = [] |
| tasks_capped = 0 |
| for task, eps in task_buckets.items(): |
| rng.shuffle(eps) |
| if len(eps) > max_per_task: |
| tasks_capped += 1 |
| after_task_cap.extend(eps[:max_per_task]) |
|
|
| print(f"Episodes after per-task cap ({max_per_task}): {len(after_task_cap)} ({tasks_capped} tasks capped)") |
|
|
| |
| contrib_buckets = defaultdict(list) |
| for ep in after_task_cap: |
| contrib_buckets[ep[0]].append(ep) |
|
|
| final_episodes = [] |
| contribs_capped = 0 |
| for contributor, eps in contrib_buckets.items(): |
| rng.shuffle(eps) |
| if len(eps) > max_per_contributor: |
| contribs_capped += 1 |
| final_episodes.extend(eps[:max_per_contributor]) |
|
|
| print(f"Episodes after per-contributor cap ({max_per_contributor}): {len(final_episodes)} ({contribs_capped} contributors capped)") |
|
|
| |
| |
| final_episodes.sort(key=lambda x: (x[1], x[2])) |
|
|
| |
| unique_tasks = sorted(set(ep[3] for ep in final_episodes)) |
| task_to_idx = {t: i for i, t in enumerate(unique_tasks)} |
|
|
| |
| datasets_used = sorted(set(ep[1] for ep in final_episodes)) |
|
|
| |
| entries = [] |
| total_frames = 0 |
| for contributor, dataset_name, ep_idx, task, num_frames in final_episodes: |
| entries.append({ |
| "dataset": dataset_name, |
| "episode_index": ep_idx, |
| "task": task, |
| "task_index": task_to_idx[task], |
| "num_frames": num_frames, |
| }) |
| total_frames += num_frames |
|
|
| index = { |
| "source_repo": "HuggingFaceVLA/community_dataset_v3", |
| "filters": { |
| "max_per_task": max_per_task, |
| "max_per_contributor": max_per_contributor, |
| "min_episode_frames": min_episode_frames, |
| "max_episode_frames": max_episode_frames, |
| "seed": seed, |
| }, |
| "summary": { |
| "datasets": len(datasets_used), |
| "episodes": len(entries), |
| "unique_tasks": len(unique_tasks), |
| "total_frames": total_frames, |
| "est_hours": total_frames / 30 / 3600, |
| }, |
| "tasks": unique_tasks, |
| "datasets_used": datasets_used, |
| "episodes": entries, |
| } |
|
|
| return index |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-root", type=Path, default=Path.home() / "lap" / "community_dataset_v3") |
| parser.add_argument("--output", type=Path, default=Path(__file__).parent / "filtered_index.json") |
| parser.add_argument("--max-per-task", type=int, default=200) |
| parser.add_argument("--max-per-contributor", type=int, default=200) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| index = build_index( |
| args.data_root, |
| max_per_task=args.max_per_task, |
| max_per_contributor=args.max_per_contributor, |
| seed=args.seed, |
| ) |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| with open(args.output, "w") as f: |
| json.dump(index, f, indent=2) |
|
|
| print(f"\nSaved to {args.output}") |
| print(f" Datasets: {index['summary']['datasets']}") |
| print(f" Episodes: {index['summary']['episodes']}") |
| print(f" Tasks: {index['summary']['unique_tasks']}") |
| print(f" Frames: {index['summary']['total_frames']:,}") |
| print(f" Est. hours: {index['summary']['est_hours']:.1f}") |
|
|