| | import os |
| | import sys |
| | import json |
| | import h5py |
| | import numpy as np |
| | import argparse |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | from robomme.robomme_env.utils.rpy_util import summarize_and_print_rpy_sequence |
| |
|
| | def _write_split_rpy_summaries_json( |
| | path: str, |
| | demo_summaries: list[dict[str, Any]], |
| | non_demo_summaries: list[dict[str, Any]], |
| | ) -> None: |
| | """ |
| | Summarize both demo and non-demo parts and write to JSON. |
| | """ |
| | if os.path.dirname(path): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | payload = { |
| | "demo_summaries": demo_summaries, |
| | "non_demo_summaries": non_demo_summaries, |
| | } |
| | with open(path, "w", encoding="utf-8") as f: |
| | json.dump(payload, f, ensure_ascii=False, indent=2) |
| |
|
| |
|
| | def _read_is_video_demo(ts_group: h5py.Group) -> bool: |
| | """Read info/is_video_demo from timestep group, default to False if missing.""" |
| | info_grp = ts_group.get("info") |
| | if info_grp is not None and "is_video_demo" in info_grp: |
| | val = info_grp["is_video_demo"][()] |
| | if isinstance(val, (bytes, np.bytes_)): |
| | return val in (b"True", b"true", b"1") |
| | return bool(val) |
| | return False |
| |
|
| |
|
| | def _extract_rpy_from_timestep(ts_group: h5py.Group) -> list[np.ndarray]: |
| | """Extract RPY vector list from a single timestep.""" |
| | if ( |
| | "action" in ts_group |
| | and "eef_action_raw" in ts_group["action"] |
| | and "rpy" in ts_group["action"]["eef_action_raw"] |
| | ): |
| | rpy_data = ts_group["action"]["eef_action_raw"]["rpy"][()] |
| | rpy_arr = np.asarray(rpy_data, dtype=np.float64) |
| | if rpy_arr.ndim == 1: |
| | rpy_arr = rpy_arr.reshape(1, -1) |
| | else: |
| | rpy_arr = rpy_arr.reshape(-1, rpy_arr.shape[-1]) |
| | if rpy_arr.shape[-1] == 3: |
| | return [row.copy() for row in rpy_arr] |
| | return [] |
| |
|
| |
|
| | def main(): |
| | |
| | DATASET_DIR = Path("/data/hongzefu/dataset_generate") |
| | |
| | parser = argparse.ArgumentParser(description="Read generated HDF5 dataset and verify RPY consistency.") |
| | parser.add_argument("--dataset_path", type=str, default=str(DATASET_DIR), help="Path to the HDF5 file or directory to verify.") |
| | args = parser.parse_args() |
| |
|
| | input_path = Path(args.dataset_path).resolve() |
| | |
| | if not input_path.exists(): |
| | print(f"Error: Path not found: {input_path}") |
| | sys.exit(1) |
| |
|
| | |
| | files_to_process = [] |
| | if input_path.is_file(): |
| | if input_path.suffix in ['.h5', '.hdf5']: |
| | files_to_process.append(input_path) |
| | elif input_path.is_dir(): |
| | files_to_process.extend(sorted(input_path.glob("*.h5"))) |
| | files_to_process.extend(sorted(input_path.glob("*.hdf5"))) |
| | |
| | if not files_to_process: |
| | print(f"No HDF5 files found in {input_path}") |
| | sys.exit(0) |
| |
|
| | print(f"Found {len(files_to_process)} files to process in {input_path}") |
| |
|
| | for dataset_path in files_to_process: |
| | print(f"\n{'='*50}") |
| | print(f"Processing dataset: {dataset_path}") |
| | print(f"{'='*50}") |
| | |
| | |
| | output_json_path = Path("/data/hongzefu/dataset_replay") / f"{dataset_path.stem}_rpy_summary.json" |
| |
|
| | demo_summaries: list[dict[str, Any]] = [] |
| | non_demo_summaries: list[dict[str, Any]] = [] |
| |
|
| | try: |
| | with h5py.File(dataset_path, "r") as f: |
| | |
| | env_groups = [key for key in f.keys() if key.startswith("env_")] |
| | env_groups.sort() |
| |
|
| | if not env_groups: |
| | print(f"Warning: No 'env_*' groups found in {dataset_path.name}") |
| |
|
| | for env_group_name in env_groups: |
| | env_group = f[env_group_name] |
| | print(f"Processing environment group: {env_group_name}") |
| | |
| | |
| | env_id = env_group_name[4:] |
| |
|
| | |
| | episode_keys = [key for key in env_group.keys() if key.startswith("episode_")] |
| | |
| | episode_keys.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else x) |
| |
|
| | for episode_key in episode_keys: |
| | print(f" Processing {episode_key}...") |
| | episode_group = env_group[episode_key] |
| | try: |
| | episode_idx = int(episode_key.split('_')[1]) |
| | except (IndexError, ValueError): |
| | episode_idx = -1 |
| | |
| | |
| | timestep_keys = [key for key in episode_group.keys() if key.startswith("record_timestep_")] |
| | |
| | def get_timestep_idx(key): |
| | parts = key.split('_') |
| | try: |
| | return int(parts[2]) |
| | except (IndexError, ValueError): |
| | return -1 |
| | |
| | timestep_keys.sort(key=get_timestep_idx) |
| |
|
| | |
| | demo_rpy_seq: list[np.ndarray] = [] |
| | non_demo_rpy_seq: list[np.ndarray] = [] |
| |
|
| | for ts_key in timestep_keys: |
| | ts_group = episode_group[ts_key] |
| | rpy_rows = _extract_rpy_from_timestep(ts_group) |
| | if rpy_rows: |
| | if _read_is_video_demo(ts_group): |
| | demo_rpy_seq.extend(rpy_rows) |
| | else: |
| | non_demo_rpy_seq.extend(rpy_rows) |
| |
|
| | |
| | if demo_rpy_seq: |
| | demo_summary = summarize_and_print_rpy_sequence( |
| | demo_rpy_seq, |
| | label=f"[{env_id}] episode {episode_idx} (demo)", |
| | ) |
| | demo_summaries.append({ |
| | "order_index": len(demo_summaries), |
| | "env_id": env_id, |
| | "episode": episode_idx, |
| | "action_space": "eef_pose", |
| | "summary": demo_summary, |
| | }) |
| |
|
| | |
| | if non_demo_rpy_seq: |
| | non_demo_summary = summarize_and_print_rpy_sequence( |
| | non_demo_rpy_seq, |
| | label=f"[{env_id}] episode {episode_idx} (non-demo)", |
| | ) |
| | non_demo_summaries.append({ |
| | "order_index": len(non_demo_summaries), |
| | "env_id": env_id, |
| | "episode": episode_idx, |
| | "action_space": "eef_pose", |
| | "summary": non_demo_summary, |
| | }) |
| |
|
| | except Exception as e: |
| | print(f"An error occurred while reading {dataset_path.name}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | |
| | if demo_summaries or non_demo_summaries: |
| | _write_split_rpy_summaries_json(str(output_json_path), demo_summaries, non_demo_summaries) |
| | print(f"Saved split RPY summaries to: {output_json_path}") |
| | print(f" demo entries: {len(demo_summaries)}, non-demo entries: {len(non_demo_summaries)}") |
| | else: |
| | print(f"No summaries generated for {dataset_path.name}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|