| | |
| | """Compare how v3 and v4 replay pipelines read multi_choice actions. |
| | |
| | v3 source: |
| | - EpisodeDatasetResolver.get_step("multi_choice", step) |
| | |
| | v4-noresolver source: |
| | - scripts.dataset_replay._build_action_sequence(..., "multi_choice") |
| | - then _parse_oracle_command() in replay loop |
| | """ |
| |
|
| | import argparse |
| | import importlib.util |
| | import json |
| | import re |
| | import sys |
| | from pathlib import Path |
| | from typing import Any, Optional |
| |
|
| | import h5py |
| | import numpy as np |
| |
|
| | REPO_ROOT = Path(__file__).resolve().parents[2] |
| | if str(REPO_ROOT) not in sys.path: |
| | sys.path.insert(0, str(REPO_ROOT)) |
| | SRC_ROOT = REPO_ROOT / "src" |
| | if str(SRC_ROOT) not in sys.path: |
| | sys.path.insert(0, str(SRC_ROOT)) |
| |
|
| |
|
| | def _load_episode_dataset_resolver_cls(): |
| | resolver_path = SRC_ROOT / "robomme" / "env_record_wrapper" / "episode_dataset_resolver.py" |
| | spec = importlib.util.spec_from_file_location( |
| | "episode_dataset_resolver_direct", |
| | resolver_path, |
| | ) |
| | if spec is None or spec.loader is None: |
| | raise RuntimeError(f"Failed to load resolver module from {resolver_path}") |
| | module = importlib.util.module_from_spec(spec) |
| | spec.loader.exec_module(module) |
| | resolver_cls = getattr(module, "EpisodeDatasetResolver", None) |
| | if resolver_cls is None: |
| | raise RuntimeError(f"EpisodeDatasetResolver not found in {resolver_path}") |
| | return resolver_cls |
| |
|
| |
|
| | EpisodeDatasetResolver = _load_episode_dataset_resolver_cls() |
| |
|
| | DEFAULT_ENV_ID = "PatternLock" |
| | DEFAULT_DATASET_ROOT = "/data/hongzefu/data_0226-test" |
| |
|
| |
|
| | def _parse_oracle_command_v4(choice_action: Optional[Any]) -> Optional[dict[str, Any]]: |
| | """Exact validation logic used in evaluate_dataset_replay-parallelv4-noresolver.py.""" |
| | if not isinstance(choice_action, dict): |
| | return None |
| | choice = choice_action.get("choice") |
| | if not isinstance(choice, str) or not choice.strip(): |
| | return None |
| | point = choice_action.get("point") |
| | if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2: |
| | return None |
| | return choice_action |
| |
|
| |
|
| | def _is_video_demo_v4(ts: h5py.Group) -> bool: |
| | info = ts.get("info") |
| | if info is None or "is_video_demo" not in info: |
| | return False |
| | return bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0]) |
| |
|
| |
|
| | def _is_subgoal_boundary_v4(ts: h5py.Group) -> bool: |
| | info = ts.get("info") |
| | if info is None or "is_subgoal_boundary" not in info: |
| | return False |
| | return bool(np.reshape(np.asarray(info["is_subgoal_boundary"][()]), -1)[0]) |
| |
|
| |
|
| | def _decode_h5_str_v4(raw: Any) -> str: |
| | if isinstance(raw, np.ndarray): |
| | raw = raw.flatten()[0] |
| | if isinstance(raw, (bytes, np.bytes_)): |
| | raw = raw.decode("utf-8") |
| | return raw |
| |
|
| |
|
| | def _build_multi_choice_sequence_v4(episode_data: h5py.Group) -> list[Any]: |
| | """ |
| | Re-implementation of dataset_replay._build_action_sequence(..., \"multi_choice\") |
| | without importing cv2/imageio/torch dependencies. |
| | """ |
| | timestep_keys = sorted( |
| | (k for k in episode_data.keys() if k.startswith("timestep_")), |
| | key=lambda k: int(k.split("_")[1]), |
| | ) |
| |
|
| | out: list[Any] = [] |
| | for key in timestep_keys: |
| | ts = episode_data[key] |
| | if _is_video_demo_v4(ts): |
| | continue |
| |
|
| | action_grp = ts.get("action") |
| | if action_grp is None: |
| | continue |
| | if not _is_subgoal_boundary_v4(ts): |
| | continue |
| | if "choice_action" not in action_grp: |
| | continue |
| |
|
| | raw = _decode_h5_str_v4(action_grp["choice_action"][()]) |
| | try: |
| | out.append(json.loads(raw)) |
| | except (TypeError, ValueError, json.JSONDecodeError): |
| | continue |
| | return out |
| |
|
| |
|
| | def _resolve_h5_path(env_id: str, dataset_root: Optional[str], h5_path: Optional[str]) -> Path: |
| | if h5_path: |
| | return Path(h5_path) |
| | if not dataset_root: |
| | raise ValueError("Either --h5_path or --dataset_root must be provided") |
| | return Path(dataset_root) / f"record_dataset_{env_id}.h5" |
| |
|
| |
|
| | def _episode_indices(data: h5py.File) -> list[int]: |
| | return sorted( |
| | int(m.group(1)) |
| | for key in data.keys() |
| | for m in [re.match(r"episode_(\d+)$", key)] |
| | if m |
| | ) |
| |
|
| |
|
| | def _parse_episode_filter(raw: Optional[str], all_eps: list[int]) -> list[int]: |
| | if not raw: |
| | return all_eps |
| |
|
| | selected: set[int] = set() |
| | for token in [x.strip() for x in raw.split(",") if x.strip()]: |
| | if "-" in token: |
| | lo_s, hi_s = token.split("-", 1) |
| | lo = int(lo_s) |
| | hi = int(hi_s) |
| | if lo > hi: |
| | lo, hi = hi, lo |
| | selected.update(range(lo, hi + 1)) |
| | else: |
| | selected.add(int(token)) |
| |
|
| | return [ep for ep in all_eps if ep in selected] |
| |
|
| |
|
| | def _canonical_command(cmd: Any) -> str: |
| | """Stable string form for diffing and readable output.""" |
| | try: |
| | return json.dumps(cmd, ensure_ascii=False, sort_keys=True) |
| | except TypeError: |
| | if isinstance(cmd, dict): |
| | safe = { |
| | str(k): (v.tolist() if isinstance(v, np.ndarray) else v) |
| | for k, v in cmd.items() |
| | } |
| | return json.dumps(safe, ensure_ascii=False, sort_keys=True) |
| | return repr(cmd) |
| |
|
| |
|
| | def _read_v4_commands(episode_group: h5py.Group) -> tuple[list[Any], list[dict[str, Any]], int]: |
| | raw_list = _build_multi_choice_sequence_v4(episode_group) |
| | parsed_list: list[dict[str, Any]] = [] |
| | skipped = 0 |
| |
|
| | for item in raw_list: |
| | parsed = _parse_oracle_command_v4(item) |
| | if parsed is None: |
| | skipped += 1 |
| | continue |
| | parsed_list.append(parsed) |
| |
|
| | return raw_list, parsed_list, skipped |
| |
|
| |
|
| | def _read_v3_commands(env_id: str, episode: int, dataset_ref: str) -> list[dict[str, Any]]: |
| | out: list[dict[str, Any]] = [] |
| | with EpisodeDatasetResolver( |
| | env_id=env_id, |
| | episode=episode, |
| | dataset_directory=dataset_ref, |
| | ) as resolver: |
| | step = 0 |
| | while True: |
| | cmd = resolver.get_step("multi_choice", step) |
| | if cmd is None: |
| | break |
| | if isinstance(cmd, dict): |
| | out.append(cmd) |
| | step += 1 |
| | return out |
| |
|
| |
|
| | def compare_episode( |
| | env_id: str, |
| | episode: int, |
| | episode_group: h5py.Group, |
| | dataset_ref: str, |
| | max_show: int, |
| | ) -> None: |
| | v4_raw, v4_effective, v4_skipped = _read_v4_commands(episode_group) |
| | v3_resolver = _read_v3_commands(env_id=env_id, episode=episode, dataset_ref=dataset_ref) |
| |
|
| | print(f"\n=== episode_{episode} ===") |
| | print( |
| | "counts: " |
| | f"v4_raw={len(v4_raw)}, " |
| | f"v4_effective={len(v4_effective)} (skipped_by_parse={v4_skipped}), " |
| | f"v3_resolver={len(v3_resolver)}" |
| | ) |
| |
|
| | v4_effective_c = [_canonical_command(x) for x in v4_effective] |
| | v3_c = [_canonical_command(x) for x in v3_resolver] |
| |
|
| | if v4_effective_c == v3_c: |
| | print("effective sequence compare: SAME") |
| | else: |
| | print("effective sequence compare: DIFFERENT") |
| | max_len = max(len(v4_effective_c), len(v3_c)) |
| | shown = 0 |
| | for idx in range(max_len): |
| | left = v4_effective_c[idx] if idx < len(v4_effective_c) else "<MISSING>" |
| | right = v3_c[idx] if idx < len(v3_c) else "<MISSING>" |
| | if left == right: |
| | continue |
| | print(f" idx={idx}") |
| | print(f" v4_effective: {left}") |
| | print(f" v3_resolver : {right}") |
| | shown += 1 |
| | if shown >= max_show: |
| | remaining = max_len - idx - 1 |
| | if remaining > 0: |
| | print(f" ... more differences omitted ({remaining} remaining positions)") |
| | break |
| |
|
| | print(f"sample v4_raw (first {max_show}):") |
| | for i, item in enumerate(v4_raw[:max_show]): |
| | print(f" [{i}] {_canonical_command(item)}") |
| |
|
| | print(f"sample v4_effective (first {max_show}):") |
| | for i, item in enumerate(v4_effective[:max_show]): |
| | print(f" [{i}] {_canonical_command(item)}") |
| |
|
| | print(f"sample v3_resolver (first {max_show}):") |
| | for i, item in enumerate(v3_resolver[:max_show]): |
| | print(f" [{i}] {_canonical_command(item)}") |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Compare multi_choice read results between " |
| | "evaluate_dataset_replay-parallelv3 and parallelv4-noresolver." |
| | ) |
| | ) |
| | parser.add_argument( |
| | "--env_id", |
| | type=str, |
| | default=DEFAULT_ENV_ID, |
| | help=f"Task/env id. Default: {DEFAULT_ENV_ID}", |
| | ) |
| | parser.add_argument( |
| | "--dataset_root", |
| | type=str, |
| | default=DEFAULT_DATASET_ROOT, |
| | help=( |
| | "Directory that contains record_dataset_<env_id>.h5. " |
| | f"Default: {DEFAULT_DATASET_ROOT}" |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--h5_path", |
| | type=str, |
| | default=None, |
| | help="Direct path to .h5 file (overrides --dataset_root)", |
| | ) |
| | parser.add_argument( |
| | "--episodes", |
| | type=str, |
| | default=0, |
| | help="Episode filter, e.g. '0,3,8-10'. Default: all episodes in h5", |
| | ) |
| | parser.add_argument( |
| | "--max_show", |
| | type=int, |
| | default=50, |
| | help="Max number of diff/sample rows per episode", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | h5_file = _resolve_h5_path(args.env_id, args.dataset_root, args.h5_path) |
| | if not h5_file.exists(): |
| | raise FileNotFoundError(f"h5 file not found: {h5_file}") |
| |
|
| | dataset_ref = str(h5_file) if h5_file.suffix == ".h5" else str(h5_file.parent) |
| |
|
| | print(f"env_id={args.env_id}") |
| | print(f"h5={h5_file}") |
| |
|
| | with h5py.File(h5_file, "r") as data: |
| | all_eps = _episode_indices(data) |
| | selected_eps = _parse_episode_filter(args.episodes, all_eps) |
| |
|
| | if not selected_eps: |
| | print("No episodes selected.") |
| | return |
| |
|
| | print(f"episodes={selected_eps}") |
| | for ep in selected_eps: |
| | key = f"episode_{ep}" |
| | if key not in data: |
| | print(f"\n=== episode_{ep} ===") |
| | print("missing in h5, skip") |
| | continue |
| | compare_episode( |
| | env_id=args.env_id, |
| | episode=ep, |
| | episode_group=data[key], |
| | dataset_ref=dataset_ref, |
| | max_show=args.max_show, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|