RoboMME_Interactive_Demo_cpu / scripts /dev /eval-dataset-offline-rpy.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
raw
history blame
8.13 kB
import os
import sys
import json
import h5py
import numpy as np
import argparse
from pathlib import Path
from typing import Any
from robomme.robomme_env.utils.rpy_util import summarize_and_print_rpy_sequence
def _write_split_rpy_summaries_json(
path: str,
demo_summaries: list[dict[str, Any]],
non_demo_summaries: list[dict[str, Any]],
) -> None:
"""
Summarize both demo and non-demo parts and write to JSON.
"""
if os.path.dirname(path):
os.makedirs(os.path.dirname(path), exist_ok=True)
payload = {
"demo_summaries": demo_summaries,
"non_demo_summaries": non_demo_summaries,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def _read_is_video_demo(ts_group: h5py.Group) -> bool:
"""Read info/is_video_demo from timestep group, default to False if missing."""
info_grp = ts_group.get("info")
if info_grp is not None and "is_video_demo" in info_grp:
val = info_grp["is_video_demo"][()]
if isinstance(val, (bytes, np.bytes_)):
return val in (b"True", b"true", b"1")
return bool(val)
return False
def _extract_rpy_from_timestep(ts_group: h5py.Group) -> list[np.ndarray]:
"""Extract RPY vector list from a single timestep."""
if (
"action" in ts_group
and "eef_action_raw" in ts_group["action"]
and "rpy" in ts_group["action"]["eef_action_raw"]
):
rpy_data = ts_group["action"]["eef_action_raw"]["rpy"][()]
rpy_arr = np.asarray(rpy_data, dtype=np.float64)
if rpy_arr.ndim == 1:
rpy_arr = rpy_arr.reshape(1, -1)
else:
rpy_arr = rpy_arr.reshape(-1, rpy_arr.shape[-1])
if rpy_arr.shape[-1] == 3:
return [row.copy() for row in rpy_arr]
return []
def main():
# Hardcoded dataset directory as requested
DATASET_DIR = Path("/data/hongzefu/dataset_generate")
parser = argparse.ArgumentParser(description="Read generated HDF5 dataset and verify RPY consistency.")
parser.add_argument("--dataset_path", type=str, default=str(DATASET_DIR), help="Path to the HDF5 file or directory to verify.")
args = parser.parse_args()
input_path = Path(args.dataset_path).resolve()
if not input_path.exists():
print(f"Error: Path not found: {input_path}")
sys.exit(1)
# Determine files to process
files_to_process = []
if input_path.is_file():
if input_path.suffix in ['.h5', '.hdf5']:
files_to_process.append(input_path)
elif input_path.is_dir():
files_to_process.extend(sorted(input_path.glob("*.h5")))
files_to_process.extend(sorted(input_path.glob("*.hdf5")))
if not files_to_process:
print(f"No HDF5 files found in {input_path}")
sys.exit(0)
print(f"Found {len(files_to_process)} files to process in {input_path}")
for dataset_path in files_to_process:
print(f"\n{'='*50}")
print(f"Processing dataset: {dataset_path}")
print(f"{'='*50}")
# Generate output JSON path
output_json_path = Path("/data/hongzefu/dataset_replay") / f"{dataset_path.stem}_rpy_summary.json"
demo_summaries: list[dict[str, Any]] = []
non_demo_summaries: list[dict[str, Any]] = []
try:
with h5py.File(dataset_path, "r") as f:
# Iterate through environments (e.g., env_PickXtimes...)
env_groups = [key for key in f.keys() if key.startswith("env_")]
env_groups.sort()
if not env_groups:
print(f"Warning: No 'env_*' groups found in {dataset_path.name}")
for env_group_name in env_groups:
env_group = f[env_group_name]
print(f"Processing environment group: {env_group_name}")
# Extract env_id from group name (remove 'env_' prefix)
env_id = env_group_name[4:]
# Iterate through episodes
episode_keys = [key for key in env_group.keys() if key.startswith("episode_")]
# Sort numerically by episode ID
episode_keys.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else x)
for episode_key in episode_keys:
print(f" Processing {episode_key}...")
episode_group = env_group[episode_key]
try:
episode_idx = int(episode_key.split('_')[1])
except (IndexError, ValueError):
episode_idx = -1
# Iterate through timesteps to reconstruct sequence
timestep_keys = [key for key in episode_group.keys() if key.startswith("record_timestep_")]
def get_timestep_idx(key):
parts = key.split('_')
try:
return int(parts[2])
except (IndexError, ValueError):
return -1
timestep_keys.sort(key=get_timestep_idx)
# Separate RPY sequences by is_video_demo flag
demo_rpy_seq: list[np.ndarray] = []
non_demo_rpy_seq: list[np.ndarray] = []
for ts_key in timestep_keys:
ts_group = episode_group[ts_key]
rpy_rows = _extract_rpy_from_timestep(ts_group)
if rpy_rows:
if _read_is_video_demo(ts_group):
demo_rpy_seq.extend(rpy_rows)
else:
non_demo_rpy_seq.extend(rpy_rows)
# Summarize demo portion
if demo_rpy_seq:
demo_summary = summarize_and_print_rpy_sequence(
demo_rpy_seq,
label=f"[{env_id}] episode {episode_idx} (demo)",
)
demo_summaries.append({
"order_index": len(demo_summaries),
"env_id": env_id,
"episode": episode_idx,
"action_space": "eef_pose",
"summary": demo_summary,
})
# Summarize non-demo portion
if non_demo_rpy_seq:
non_demo_summary = summarize_and_print_rpy_sequence(
non_demo_rpy_seq,
label=f"[{env_id}] episode {episode_idx} (non-demo)",
)
non_demo_summaries.append({
"order_index": len(non_demo_summaries),
"env_id": env_id,
"episode": episode_idx,
"action_space": "eef_pose",
"summary": non_demo_summary,
})
except Exception as e:
print(f"An error occurred while reading {dataset_path.name}: {e}")
import traceback
traceback.print_exc()
# Write summary to JSON
if demo_summaries or non_demo_summaries:
_write_split_rpy_summaries_json(str(output_json_path), demo_summaries, non_demo_summaries)
print(f"Saved split RPY summaries to: {output_json_path}")
print(f" demo entries: {len(demo_summaries)}, non-demo entries: {len(non_demo_summaries)}")
else:
print(f"No summaries generated for {dataset_path.name}")
if __name__ == "__main__":
main()