RoboMME / tests /lightweight /test_ChoiceLabel.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
import importlib.util
import json
from pathlib import Path
import h5py
import pytest
from tests._shared.repo_paths import find_repo_root
pytestmark = [pytest.mark.lightweight, pytest.mark.gpu]
def _load_module(module_name: str, relative_path: str):
repo_root = find_repo_root(__file__)
module_path = repo_root / relative_path
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
matcher_mod = _load_module(
"oracle_action_matcher_under_test",
"src/robomme/robomme_env/utils/oracle_action_matcher.py",
)
resolver_mod = _load_module(
"episode_dataset_resolver_under_test",
"src/robomme/env_record_wrapper/episode_dataset_resolver.py",
)
def test_find_exact_label_option_index_matches_label_only():
options = [
{"label": "a", "action": "pick up the cube"},
{"label": "b", "action": "put it down"},
]
assert matcher_mod.find_exact_label_option_index("a", options) == 0
assert matcher_mod.find_exact_label_option_index("b", options) == 1
assert matcher_mod.find_exact_label_option_index("pick up the cube", options) == -1
assert matcher_mod.find_exact_label_option_index(1, options) == -1
def test_map_action_text_to_option_label_strict_exact():
options = [
{"label": "a", "action": "pick up the cube"},
{"label": "b", "action": "put it down"},
]
assert (
matcher_mod.map_action_text_to_option_label("pick up the cube", options) == "a"
)
assert matcher_mod.map_action_text_to_option_label("unknown action", options) is None
assert matcher_mod.map_action_text_to_option_label(None, options) is None
def test_episode_dataset_resolver_extracts_choice_command_and_ignores_empty_choice(tmp_path):
h5_path = tmp_path / "choice_oracle_commands.h5"
with h5py.File(h5_path, "w") as h5:
episode_group = h5.create_group("episode_0")
ts0 = episode_group.create_group("timestep_0")
ts0_action = ts0.create_group("action")
ts0_action.create_dataset(
"choice_action",
data=json.dumps(
{
"choice": "B",
"point": [34, 12],
}
),
dtype=h5py.string_dtype(encoding="utf-8"),
)
ts0_info = ts0.create_group("info")
ts0_info.create_dataset("is_video_demo", data=False)
ts0_info.create_dataset("is_subgoal_boundary", data=True)
ts1 = episode_group.create_group("timestep_1")
ts1_action = ts1.create_group("action")
ts1_action.create_dataset(
"choice_action",
data=json.dumps(
{
"choice": "",
"point": [30, 20],
}
),
dtype=h5py.string_dtype(encoding="utf-8"),
)
ts1_info = ts1.create_group("info")
ts1_info.create_dataset("is_video_demo", data=False)
ts1_info.create_dataset("is_subgoal_boundary", data=True)
resolver = resolver_mod.EpisodeDatasetResolver(
env_id="DummyEnv",
episode=0,
dataset_directory=h5_path,
)
try:
command0 = resolver.get_step("multi_choice", 0)
assert command0 == {"choice": "B", "point": [34, 12]}
assert "position_3d" not in command0
command1 = resolver.get_step("multi_choice", 1)
assert command1 is None
finally:
resolver.close()