from __future__ import annotations from pathlib import Path class _DummyPlanner: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs class _FakeRobot: def __init__(self): self.pose = object() class _FakeAgent: def __init__(self): self.robot = _FakeRobot() class _FakeUnwrapped: def __init__(self): self.agent = _FakeAgent() self.segmentation_id_map = {} def evaluate(self, solve_complete_eval=False): return {"success": False, "fail": False} class _FakeEnv: def __init__(self): self.unwrapped = _FakeUnwrapped() self.demonstration_data = {"language goal": "test goal", "frames": ["f1", "f2"]} self.non_demonstration_task_length = 7 self.frames = [] self.wrist_frames = [] self.closed = False def reset(self): return None def close(self): self.closed = True class _FakeEnvTupleDemo(_FakeEnv): def __init__(self): super().__init__() self.demonstration_data = ( {"front_rgb_list": ["tuple_f1", "tuple_f2"]}, None, None, None, {"task_goal": ["tuple goal", "backup goal"]}, ) class _BuilderSuccess: last_init_kwargs = None def __init__(self, **kwargs): type(self).last_init_kwargs = kwargs def get_episode_num(self): return 3 def resolve_episode(self, episode_idx): return 123, "hard" def make_env_for_episode(self, episode_idx): return _FakeEnv() class _BuilderTupleDemo(_BuilderSuccess): def make_env_for_episode(self, episode_idx): return _FakeEnvTupleDemo() class _BuilderNoMetadata: def __init__(self, **kwargs): self.kwargs = kwargs def get_episode_num(self): return 0 class _BuilderRaiseOnMake: def __init__(self, **kwargs): self.kwargs = kwargs def get_episode_num(self): return 1 def resolve_episode(self, episode_idx): return None, None def make_env_for_episode(self, episode_idx): raise RuntimeError("boom") def test_load_episode_uses_benchmark_builder(monkeypatch, reload_module): oracle_logic = reload_module("oracle_logic") monkeypatch.setenv("ROBOMME_METADATA_ROOT", "/tmp/meta-root") monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderSuccess) monkeypatch.setattr(oracle_logic, "FailAwarePandaArmMotionPlanningSolver", _DummyPlanner) monkeypatch.setattr(oracle_logic, "FailAwarePandaStickMotionPlanningSolver", _DummyPlanner) monkeypatch.setattr(oracle_logic.OracleSession, "update_observation", lambda self: ("IMG", "Ready")) session = oracle_logic.OracleSession(dataset_root=None, gui_render=False) img, msg = session.load_episode("BinFill", 1) assert img == "IMG" assert msg == "Ready" assert session.env_id == "BinFill" assert session.episode_idx == 1 assert session.seed == 123 assert session.difficulty == "hard" assert session.language_goal == "test goal" assert session.demonstration_frames == ["f1", "f2"] init_kwargs = _BuilderSuccess.last_init_kwargs assert init_kwargs["dataset"] == "train" assert init_kwargs["action_space"] == "joint_angle" assert init_kwargs["gui_render"] is False assert init_kwargs["max_steps"] == 3000 assert init_kwargs["override_metadata_path"] == Path("/tmp/meta-root") def test_load_episode_metadata_missing_returns_stable_error(monkeypatch, reload_module): oracle_logic = reload_module("oracle_logic") monkeypatch.setenv("ROBOMME_METADATA_ROOT", "/tmp/custom-metadata") monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderNoMetadata) session = oracle_logic.OracleSession(dataset_root=None, gui_render=False) img, msg = session.load_episode("RouteStick", 0) assert img is None assert "Dataset metadata not found or empty" in msg assert "record_dataset_RouteStick_metadata.json" in msg def test_load_episode_out_of_range_returns_stable_error(monkeypatch, reload_module): oracle_logic = reload_module("oracle_logic") monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderSuccess) session = oracle_logic.OracleSession(dataset_root=None, gui_render=False) img, msg = session.load_episode("BinFill", 99) assert img is None assert "Episode index out of range" in msg assert "valid 0-2" in msg def test_load_episode_init_failure_is_caught(monkeypatch, reload_module): oracle_logic = reload_module("oracle_logic") monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderRaiseOnMake) session = oracle_logic.OracleSession(dataset_root=None, gui_render=False) img, msg = session.load_episode("BinFill", 0) assert img is None assert msg.startswith("Error initializing episode:") def test_load_episode_supports_tuple_demonstration_data(monkeypatch, reload_module): oracle_logic = reload_module("oracle_logic") monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderTupleDemo) monkeypatch.setattr(oracle_logic, "FailAwarePandaArmMotionPlanningSolver", _DummyPlanner) monkeypatch.setattr(oracle_logic, "FailAwarePandaStickMotionPlanningSolver", _DummyPlanner) monkeypatch.setattr(oracle_logic.OracleSession, "update_observation", lambda self: ("IMG", "Ready")) session = oracle_logic.OracleSession(dataset_root=None, gui_render=False) img, msg = session.load_episode("BinFill", 0) assert img == "IMG" assert msg == "Ready" assert session.language_goal == "backup goal" assert session.demonstration_frames == ["tuple_f1", "tuple_f2"]