| | from __future__ import annotations |
| |
|
| | import h5py |
| | import numpy as np |
| | import pytest |
| |
|
| | from tests._shared.dataset_generation import DatasetCase |
| |
|
| | pytestmark = pytest.mark.dataset |
| |
|
| |
|
| | def _make_case(env_id: str, base_seed: int) -> DatasetCase: |
| | return DatasetCase( |
| | env_id=env_id, |
| | episode=0, |
| | base_seed=base_seed, |
| | difficulty="easy", |
| | save_video=True, |
| | mode_tag="waypoint_phase_isolation", |
| | ) |
| |
|
| |
|
| | def _decode_h5_string(raw) -> str: |
| | if isinstance(raw, np.ndarray): |
| | raw = raw.flatten()[0] |
| | if isinstance(raw, (bytes, np.bytes_)): |
| | return raw.decode("utf-8") |
| | return str(raw) |
| |
|
| |
|
| | def _collect_records(ep_group: h5py.Group) -> list[dict]: |
| | timestep_keys = sorted( |
| | (k for k in ep_group.keys() if k.startswith("timestep_")), |
| | key=lambda k: int(k.split("_")[1]), |
| | ) |
| | out: list[dict] = [] |
| | for key in timestep_keys: |
| | ts = ep_group[key] |
| | info = ts["info"] |
| | waypoint_action = np.asarray(ts["action"]["waypoint_action"][()]).flatten() |
| | out.append( |
| | { |
| | "timestep": int(key.split("_")[1]), |
| | "is_demo": bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0]), |
| | "subgoal": _decode_h5_string(info["simple_subgoal_online"][()]), |
| | "waypoint_action": waypoint_action, |
| | "is_finite_waypoint": bool( |
| | waypoint_action.shape == (7,) and np.all(np.isfinite(waypoint_action)) |
| | ), |
| | } |
| | ) |
| | return out |
| |
|
| |
|
| | def _find_demo_to_non_demo_boundary(records: list[dict]) -> int | None: |
| | for idx in range(1, len(records)): |
| | if records[idx - 1]["is_demo"] and not records[idx]["is_demo"]: |
| | return idx |
| | return None |
| |
|
| |
|
| | def _last_finite_demo_waypoint(records: list[dict], boundary_idx: int) -> np.ndarray | None: |
| | for idx in range(boundary_idx - 1, -1, -1): |
| | row = records[idx] |
| | if not row["is_demo"]: |
| | continue |
| | if row["is_finite_waypoint"]: |
| | return np.asarray(row["waypoint_action"]).flatten() |
| | return None |
| |
|
| |
|
| | def _unique_finite_waypoints(rows: list[dict]) -> list[np.ndarray]: |
| | uniques: list[np.ndarray] = [] |
| | prev: np.ndarray | None = None |
| | for row in rows: |
| | if not row["is_finite_waypoint"]: |
| | continue |
| | wa = np.asarray(row["waypoint_action"]).flatten() |
| | if prev is None or not np.array_equal(wa, prev): |
| | uniques.append(wa.copy()) |
| | prev = wa.copy() |
| | return uniques |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "env_id,base_seed,assert_first_segment_midpoint", |
| | [ |
| | ("PatternLock", 15001, False), |
| | ("RouteStick", 16000, True), |
| | ], |
| | ) |
| | def test_waypoint_isolation_across_demo_phase( |
| | env_id: str, |
| | base_seed: int, |
| | assert_first_segment_midpoint: bool, |
| | dataset_factory, |
| | ): |
| | generated = dataset_factory(_make_case(env_id, base_seed)) |
| |
|
| | with h5py.File(generated.raw_h5_path, "r") as h5f: |
| | records = _collect_records(h5f["episode_0"]) |
| |
|
| | assert records, f"{env_id}: episode_0 has no recorded timesteps." |
| | boundary_idx = _find_demo_to_non_demo_boundary(records) |
| | assert boundary_idx is not None, f"{env_id}: missing demo->non-demo boundary." |
| |
|
| | first_non_demo = records[boundary_idx] |
| | assert not first_non_demo["is_demo"], ( |
| | f"{env_id}: boundary row should be non-demo, got demo at timestep " |
| | f"{first_non_demo['timestep']}." |
| | ) |
| |
|
| | last_demo_waypoint = _last_finite_demo_waypoint(records, boundary_idx) |
| | if first_non_demo["is_finite_waypoint"] and last_demo_waypoint is not None: |
| | assert not np.array_equal( |
| | np.asarray(first_non_demo["waypoint_action"]).flatten(), |
| | last_demo_waypoint, |
| | ), ( |
| | f"{env_id}: boundary non-demo step consumed demo-phase pending waypoint " |
| | f"(timestep {first_non_demo['timestep']})." |
| | ) |
| |
|
| | non_demo_rows = [row for row in records if not row["is_demo"]] |
| | assert non_demo_rows, f"{env_id}: no non-demo rows." |
| | all_non_demo_unique = _unique_finite_waypoints(non_demo_rows) |
| | assert all_non_demo_unique, f"{env_id}: non-demo phase has no finite waypoints." |
| |
|
| | if assert_first_segment_midpoint: |
| | first_subgoal = non_demo_rows[0]["subgoal"] |
| | first_segment_rows: list[dict] = [] |
| | for row in non_demo_rows: |
| | if row["subgoal"] != first_subgoal: |
| | break |
| | first_segment_rows.append(row) |
| | first_segment_unique = _unique_finite_waypoints(first_segment_rows) |
| | assert len(first_segment_unique) >= 2, ( |
| | f"{env_id}: first non-demo subgoal segment lost midpoint waypoint; " |
| | f"expected >=2 unique finite waypoints, got {len(first_segment_unique)}." |
| | ) |
| |
|