RoboMME / tests /dataset /test_waypoint_phase_isolation.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
from __future__ import annotations
import h5py
import numpy as np
import pytest
from tests._shared.dataset_generation import DatasetCase
pytestmark = pytest.mark.dataset
def _make_case(env_id: str, base_seed: int) -> DatasetCase:
return DatasetCase(
env_id=env_id,
episode=0,
base_seed=base_seed,
difficulty="easy",
save_video=True,
mode_tag="waypoint_phase_isolation",
)
def _decode_h5_string(raw) -> str:
if isinstance(raw, np.ndarray):
raw = raw.flatten()[0]
if isinstance(raw, (bytes, np.bytes_)):
return raw.decode("utf-8")
return str(raw)
def _collect_records(ep_group: h5py.Group) -> list[dict]:
timestep_keys = sorted(
(k for k in ep_group.keys() if k.startswith("timestep_")),
key=lambda k: int(k.split("_")[1]),
)
out: list[dict] = []
for key in timestep_keys:
ts = ep_group[key]
info = ts["info"]
waypoint_action = np.asarray(ts["action"]["waypoint_action"][()]).flatten()
out.append(
{
"timestep": int(key.split("_")[1]),
"is_demo": bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0]),
"subgoal": _decode_h5_string(info["simple_subgoal_online"][()]),
"waypoint_action": waypoint_action,
"is_finite_waypoint": bool(
waypoint_action.shape == (7,) and np.all(np.isfinite(waypoint_action))
),
}
)
return out
def _find_demo_to_non_demo_boundary(records: list[dict]) -> int | None:
for idx in range(1, len(records)):
if records[idx - 1]["is_demo"] and not records[idx]["is_demo"]:
return idx
return None
def _last_finite_demo_waypoint(records: list[dict], boundary_idx: int) -> np.ndarray | None:
for idx in range(boundary_idx - 1, -1, -1):
row = records[idx]
if not row["is_demo"]:
continue
if row["is_finite_waypoint"]:
return np.asarray(row["waypoint_action"]).flatten()
return None
def _unique_finite_waypoints(rows: list[dict]) -> list[np.ndarray]:
uniques: list[np.ndarray] = []
prev: np.ndarray | None = None
for row in rows:
if not row["is_finite_waypoint"]:
continue
wa = np.asarray(row["waypoint_action"]).flatten()
if prev is None or not np.array_equal(wa, prev):
uniques.append(wa.copy())
prev = wa.copy()
return uniques
@pytest.mark.parametrize(
"env_id,base_seed,assert_first_segment_midpoint",
[
("PatternLock", 15001, False),
("RouteStick", 16000, True),
],
)
def test_waypoint_isolation_across_demo_phase(
env_id: str,
base_seed: int,
assert_first_segment_midpoint: bool,
dataset_factory,
):
generated = dataset_factory(_make_case(env_id, base_seed))
with h5py.File(generated.raw_h5_path, "r") as h5f:
records = _collect_records(h5f["episode_0"])
assert records, f"{env_id}: episode_0 has no recorded timesteps."
boundary_idx = _find_demo_to_non_demo_boundary(records)
assert boundary_idx is not None, f"{env_id}: missing demo->non-demo boundary."
first_non_demo = records[boundary_idx]
assert not first_non_demo["is_demo"], (
f"{env_id}: boundary row should be non-demo, got demo at timestep "
f"{first_non_demo['timestep']}."
)
last_demo_waypoint = _last_finite_demo_waypoint(records, boundary_idx)
if first_non_demo["is_finite_waypoint"] and last_demo_waypoint is not None:
assert not np.array_equal(
np.asarray(first_non_demo["waypoint_action"]).flatten(),
last_demo_waypoint,
), (
f"{env_id}: boundary non-demo step consumed demo-phase pending waypoint "
f"(timestep {first_non_demo['timestep']})."
)
non_demo_rows = [row for row in records if not row["is_demo"]]
assert non_demo_rows, f"{env_id}: no non-demo rows."
all_non_demo_unique = _unique_finite_waypoints(non_demo_rows)
assert all_non_demo_unique, f"{env_id}: non-demo phase has no finite waypoints."
if assert_first_segment_midpoint:
first_subgoal = non_demo_rows[0]["subgoal"]
first_segment_rows: list[dict] = []
for row in non_demo_rows:
if row["subgoal"] != first_subgoal:
break
first_segment_rows.append(row)
first_segment_unique = _unique_finite_waypoints(first_segment_rows)
assert len(first_segment_unique) >= 2, (
f"{env_id}: first non-demo subgoal segment lost midpoint waypoint; "
f"expected >=2 unique finite waypoints, got {len(first_segment_unique)}."
)