| from __future__ import annotations |
|
|
| import h5py |
| import numpy as np |
| import pytest |
|
|
| from tests._shared.dataset_generation import DatasetCase |
|
|
| pytestmark = pytest.mark.dataset |
|
|
|
|
| def _make_case(env_id: str, base_seed: int) -> DatasetCase: |
| return DatasetCase( |
| env_id=env_id, |
| episode=0, |
| base_seed=base_seed, |
| difficulty="easy", |
| save_video=True, |
| mode_tag="waypoint_phase_isolation", |
| ) |
|
|
|
|
| def _decode_h5_string(raw) -> str: |
| if isinstance(raw, np.ndarray): |
| raw = raw.flatten()[0] |
| if isinstance(raw, (bytes, np.bytes_)): |
| return raw.decode("utf-8") |
| return str(raw) |
|
|
|
|
| def _collect_records(ep_group: h5py.Group) -> list[dict]: |
| timestep_keys = sorted( |
| (k for k in ep_group.keys() if k.startswith("timestep_")), |
| key=lambda k: int(k.split("_")[1]), |
| ) |
| out: list[dict] = [] |
| for key in timestep_keys: |
| ts = ep_group[key] |
| info = ts["info"] |
| waypoint_action = np.asarray(ts["action"]["waypoint_action"][()]).flatten() |
| out.append( |
| { |
| "timestep": int(key.split("_")[1]), |
| "is_demo": bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0]), |
| "subgoal": _decode_h5_string(info["simple_subgoal_online"][()]), |
| "waypoint_action": waypoint_action, |
| "is_finite_waypoint": bool( |
| waypoint_action.shape == (7,) and np.all(np.isfinite(waypoint_action)) |
| ), |
| } |
| ) |
| return out |
|
|
|
|
| def _find_demo_to_non_demo_boundary(records: list[dict]) -> int | None: |
| for idx in range(1, len(records)): |
| if records[idx - 1]["is_demo"] and not records[idx]["is_demo"]: |
| return idx |
| return None |
|
|
|
|
| def _last_finite_demo_waypoint(records: list[dict], boundary_idx: int) -> np.ndarray | None: |
| for idx in range(boundary_idx - 1, -1, -1): |
| row = records[idx] |
| if not row["is_demo"]: |
| continue |
| if row["is_finite_waypoint"]: |
| return np.asarray(row["waypoint_action"]).flatten() |
| return None |
|
|
|
|
| def _unique_finite_waypoints(rows: list[dict]) -> list[np.ndarray]: |
| uniques: list[np.ndarray] = [] |
| prev: np.ndarray | None = None |
| for row in rows: |
| if not row["is_finite_waypoint"]: |
| continue |
| wa = np.asarray(row["waypoint_action"]).flatten() |
| if prev is None or not np.array_equal(wa, prev): |
| uniques.append(wa.copy()) |
| prev = wa.copy() |
| return uniques |
|
|
|
|
| @pytest.mark.parametrize( |
| "env_id,base_seed,assert_first_segment_midpoint", |
| [ |
| ("PatternLock", 15001, False), |
| ("RouteStick", 16000, True), |
| ], |
| ) |
| def test_waypoint_isolation_across_demo_phase( |
| env_id: str, |
| base_seed: int, |
| assert_first_segment_midpoint: bool, |
| dataset_factory, |
| ): |
| generated = dataset_factory(_make_case(env_id, base_seed)) |
|
|
| with h5py.File(generated.raw_h5_path, "r") as h5f: |
| records = _collect_records(h5f["episode_0"]) |
|
|
| assert records, f"{env_id}: episode_0 has no recorded timesteps." |
| boundary_idx = _find_demo_to_non_demo_boundary(records) |
| assert boundary_idx is not None, f"{env_id}: missing demo->non-demo boundary." |
|
|
| first_non_demo = records[boundary_idx] |
| assert not first_non_demo["is_demo"], ( |
| f"{env_id}: boundary row should be non-demo, got demo at timestep " |
| f"{first_non_demo['timestep']}." |
| ) |
|
|
| last_demo_waypoint = _last_finite_demo_waypoint(records, boundary_idx) |
| if first_non_demo["is_finite_waypoint"] and last_demo_waypoint is not None: |
| assert not np.array_equal( |
| np.asarray(first_non_demo["waypoint_action"]).flatten(), |
| last_demo_waypoint, |
| ), ( |
| f"{env_id}: boundary non-demo step consumed demo-phase pending waypoint " |
| f"(timestep {first_non_demo['timestep']})." |
| ) |
|
|
| non_demo_rows = [row for row in records if not row["is_demo"]] |
| assert non_demo_rows, f"{env_id}: no non-demo rows." |
| all_non_demo_unique = _unique_finite_waypoints(non_demo_rows) |
| assert all_non_demo_unique, f"{env_id}: non-demo phase has no finite waypoints." |
|
|
| if assert_first_segment_midpoint: |
| first_subgoal = non_demo_rows[0]["subgoal"] |
| first_segment_rows: list[dict] = [] |
| for row in non_demo_rows: |
| if row["subgoal"] != first_subgoal: |
| break |
| first_segment_rows.append(row) |
| first_segment_unique = _unique_finite_waypoints(first_segment_rows) |
| assert len(first_segment_unique) >= 2, ( |
| f"{env_id}: first non-demo subgoal segment lost midpoint waypoint; " |
| f"expected >=2 unique finite waypoints, got {len(first_segment_unique)}." |
| ) |
|
|