import importlib.util import json from pathlib import Path import h5py import pytest from tests._shared.repo_paths import find_repo_root pytestmark = [pytest.mark.lightweight, pytest.mark.gpu] def _load_module(module_name: str, relative_path: str): repo_root = find_repo_root(__file__) module_path = repo_root / relative_path spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module matcher_mod = _load_module( "oracle_action_matcher_under_test", "src/robomme/robomme_env/utils/oracle_action_matcher.py", ) resolver_mod = _load_module( "episode_dataset_resolver_under_test", "src/robomme/env_record_wrapper/episode_dataset_resolver.py", ) def test_find_exact_label_option_index_matches_label_only(): options = [ {"label": "a", "action": "pick up the cube"}, {"label": "b", "action": "put it down"}, ] assert matcher_mod.find_exact_label_option_index("a", options) == 0 assert matcher_mod.find_exact_label_option_index("b", options) == 1 assert matcher_mod.find_exact_label_option_index("pick up the cube", options) == -1 assert matcher_mod.find_exact_label_option_index(1, options) == -1 def test_map_action_text_to_option_label_strict_exact(): options = [ {"label": "a", "action": "pick up the cube"}, {"label": "b", "action": "put it down"}, ] assert ( matcher_mod.map_action_text_to_option_label("pick up the cube", options) == "a" ) assert matcher_mod.map_action_text_to_option_label("unknown action", options) is None assert matcher_mod.map_action_text_to_option_label(None, options) is None def test_episode_dataset_resolver_extracts_choice_command_and_ignores_empty_choice(tmp_path): h5_path = tmp_path / "choice_oracle_commands.h5" with h5py.File(h5_path, "w") as h5: episode_group = h5.create_group("episode_0") ts0 = episode_group.create_group("timestep_0") ts0_action = ts0.create_group("action") ts0_action.create_dataset( "choice_action", data=json.dumps( { "choice": "B", "point": [34, 12], } ), dtype=h5py.string_dtype(encoding="utf-8"), ) ts0_info = ts0.create_group("info") ts0_info.create_dataset("is_video_demo", data=False) ts0_info.create_dataset("is_subgoal_boundary", data=True) ts1 = episode_group.create_group("timestep_1") ts1_action = ts1.create_group("action") ts1_action.create_dataset( "choice_action", data=json.dumps( { "choice": "", "point": [30, 20], } ), dtype=h5py.string_dtype(encoding="utf-8"), ) ts1_info = ts1.create_group("info") ts1_info.create_dataset("is_video_demo", data=False) ts1_info.create_dataset("is_subgoal_boundary", data=True) resolver = resolver_mod.EpisodeDatasetResolver( env_id="DummyEnv", episode=0, dataset_directory=h5_path, ) try: command0 = resolver.get_step("multi_choice", 0) assert command0 == {"choice": "B", "point": [34, 12]} assert "position_3d" not in command0 command1 = resolver.get_step("multi_choice", 1) assert command1 is None finally: resolver.close()