File size: 8,125 Bytes
06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | import os
import sys
import json
import h5py
import numpy as np
import argparse
from pathlib import Path
from typing import Any
from robomme.robomme_env.utils.rpy_util import summarize_and_print_rpy_sequence
def _write_split_rpy_summaries_json(
path: str,
demo_summaries: list[dict[str, Any]],
non_demo_summaries: list[dict[str, Any]],
) -> None:
"""
Summarize both demo and non-demo parts and write to JSON.
"""
if os.path.dirname(path):
os.makedirs(os.path.dirname(path), exist_ok=True)
payload = {
"demo_summaries": demo_summaries,
"non_demo_summaries": non_demo_summaries,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def _read_is_video_demo(ts_group: h5py.Group) -> bool:
"""Read info/is_video_demo from timestep group, default to False if missing."""
info_grp = ts_group.get("info")
if info_grp is not None and "is_video_demo" in info_grp:
val = info_grp["is_video_demo"][()]
if isinstance(val, (bytes, np.bytes_)):
return val in (b"True", b"true", b"1")
return bool(val)
return False
def _extract_rpy_from_timestep(ts_group: h5py.Group) -> list[np.ndarray]:
"""Extract RPY vector list from a single timestep."""
if (
"action" in ts_group
and "eef_action_raw" in ts_group["action"]
and "rpy" in ts_group["action"]["eef_action_raw"]
):
rpy_data = ts_group["action"]["eef_action_raw"]["rpy"][()]
rpy_arr = np.asarray(rpy_data, dtype=np.float64)
if rpy_arr.ndim == 1:
rpy_arr = rpy_arr.reshape(1, -1)
else:
rpy_arr = rpy_arr.reshape(-1, rpy_arr.shape[-1])
if rpy_arr.shape[-1] == 3:
return [row.copy() for row in rpy_arr]
return []
def main():
# Hardcoded dataset directory as requested
DATASET_DIR = Path("/data/hongzefu/dataset_generate")
parser = argparse.ArgumentParser(description="Read generated HDF5 dataset and verify RPY consistency.")
parser.add_argument("--dataset_path", type=str, default=str(DATASET_DIR), help="Path to the HDF5 file or directory to verify.")
args = parser.parse_args()
input_path = Path(args.dataset_path).resolve()
if not input_path.exists():
print(f"Error: Path not found: {input_path}")
sys.exit(1)
# Determine files to process
files_to_process = []
if input_path.is_file():
if input_path.suffix in ['.h5', '.hdf5']:
files_to_process.append(input_path)
elif input_path.is_dir():
files_to_process.extend(sorted(input_path.glob("*.h5")))
files_to_process.extend(sorted(input_path.glob("*.hdf5")))
if not files_to_process:
print(f"No HDF5 files found in {input_path}")
sys.exit(0)
print(f"Found {len(files_to_process)} files to process in {input_path}")
for dataset_path in files_to_process:
print(f"\n{'='*50}")
print(f"Processing dataset: {dataset_path}")
print(f"{'='*50}")
# Generate output JSON path
output_json_path = Path("/data/hongzefu/dataset_replay") / f"{dataset_path.stem}_rpy_summary.json"
demo_summaries: list[dict[str, Any]] = []
non_demo_summaries: list[dict[str, Any]] = []
try:
with h5py.File(dataset_path, "r") as f:
# Iterate through environments (e.g., env_PickXtimes...)
env_groups = [key for key in f.keys() if key.startswith("env_")]
env_groups.sort()
if not env_groups:
print(f"Warning: No 'env_*' groups found in {dataset_path.name}")
for env_group_name in env_groups:
env_group = f[env_group_name]
print(f"Processing environment group: {env_group_name}")
# Extract env_id from group name (remove 'env_' prefix)
env_id = env_group_name[4:]
# Iterate through episodes
episode_keys = [key for key in env_group.keys() if key.startswith("episode_")]
# Sort numerically by episode ID
episode_keys.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else x)
for episode_key in episode_keys:
print(f" Processing {episode_key}...")
episode_group = env_group[episode_key]
try:
episode_idx = int(episode_key.split('_')[1])
except (IndexError, ValueError):
episode_idx = -1
# Iterate through timesteps to reconstruct sequence
timestep_keys = [key for key in episode_group.keys() if key.startswith("record_timestep_")]
def get_timestep_idx(key):
parts = key.split('_')
try:
return int(parts[2])
except (IndexError, ValueError):
return -1
timestep_keys.sort(key=get_timestep_idx)
# Separate RPY sequences by is_video_demo flag
demo_rpy_seq: list[np.ndarray] = []
non_demo_rpy_seq: list[np.ndarray] = []
for ts_key in timestep_keys:
ts_group = episode_group[ts_key]
rpy_rows = _extract_rpy_from_timestep(ts_group)
if rpy_rows:
if _read_is_video_demo(ts_group):
demo_rpy_seq.extend(rpy_rows)
else:
non_demo_rpy_seq.extend(rpy_rows)
# Summarize demo portion
if demo_rpy_seq:
demo_summary = summarize_and_print_rpy_sequence(
demo_rpy_seq,
label=f"[{env_id}] episode {episode_idx} (demo)",
)
demo_summaries.append({
"order_index": len(demo_summaries),
"env_id": env_id,
"episode": episode_idx,
"action_space": "eef_pose",
"summary": demo_summary,
})
# Summarize non-demo portion
if non_demo_rpy_seq:
non_demo_summary = summarize_and_print_rpy_sequence(
non_demo_rpy_seq,
label=f"[{env_id}] episode {episode_idx} (non-demo)",
)
non_demo_summaries.append({
"order_index": len(non_demo_summaries),
"env_id": env_id,
"episode": episode_idx,
"action_space": "eef_pose",
"summary": non_demo_summary,
})
except Exception as e:
print(f"An error occurred while reading {dataset_path.name}: {e}")
import traceback
traceback.print_exc()
# Write summary to JSON
if demo_summaries or non_demo_summaries:
_write_split_rpy_summaries_json(str(output_json_path), demo_summaries, non_demo_summaries)
print(f"Saved split RPY summaries to: {output_json_path}")
print(f" demo entries: {len(demo_summaries)}, non-demo entries: {len(non_demo_summaries)}")
else:
print(f"No summaries generated for {dataset_path.name}")
if __name__ == "__main__":
main()
|