File size: 3,578 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import importlib.util
import json
from pathlib import Path

import h5py
import pytest

from tests._shared.repo_paths import find_repo_root

pytestmark = [pytest.mark.lightweight, pytest.mark.gpu]


def _load_module(module_name: str, relative_path: str):
    repo_root = find_repo_root(__file__)
    module_path = repo_root / relative_path
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(module)
    return module


matcher_mod = _load_module(
    "oracle_action_matcher_under_test",
    "src/robomme/robomme_env/utils/oracle_action_matcher.py",
)
resolver_mod = _load_module(
    "episode_dataset_resolver_under_test",
    "src/robomme/env_record_wrapper/episode_dataset_resolver.py",
)


def test_find_exact_label_option_index_matches_label_only():
    options = [
        {"label": "a", "action": "pick up the cube"},
        {"label": "b", "action": "put it down"},
    ]

    assert matcher_mod.find_exact_label_option_index("a", options) == 0
    assert matcher_mod.find_exact_label_option_index("b", options) == 1
    assert matcher_mod.find_exact_label_option_index("pick up the cube", options) == -1
    assert matcher_mod.find_exact_label_option_index(1, options) == -1


def test_map_action_text_to_option_label_strict_exact():
    options = [
        {"label": "a", "action": "pick up the cube"},
        {"label": "b", "action": "put it down"},
    ]

    assert (
        matcher_mod.map_action_text_to_option_label("pick up the cube", options) == "a"
    )
    assert matcher_mod.map_action_text_to_option_label("unknown action", options) is None
    assert matcher_mod.map_action_text_to_option_label(None, options) is None


def test_episode_dataset_resolver_extracts_choice_command_and_ignores_empty_choice(tmp_path):
    h5_path = tmp_path / "choice_oracle_commands.h5"

    with h5py.File(h5_path, "w") as h5:
        episode_group = h5.create_group("episode_0")

        ts0 = episode_group.create_group("timestep_0")
        ts0_action = ts0.create_group("action")
        ts0_action.create_dataset(
            "choice_action",
            data=json.dumps(
                {
                    "choice": "B",
                    "point": [34, 12],
                }
            ),
            dtype=h5py.string_dtype(encoding="utf-8"),
        )
        ts0_info = ts0.create_group("info")
        ts0_info.create_dataset("is_video_demo", data=False)
        ts0_info.create_dataset("is_subgoal_boundary", data=True)

        ts1 = episode_group.create_group("timestep_1")
        ts1_action = ts1.create_group("action")
        ts1_action.create_dataset(
            "choice_action",
            data=json.dumps(
                {
                    "choice": "",
                    "point": [30, 20],
                }
            ),
            dtype=h5py.string_dtype(encoding="utf-8"),
        )
        ts1_info = ts1.create_group("info")
        ts1_info.create_dataset("is_video_demo", data=False)
        ts1_info.create_dataset("is_subgoal_boundary", data=True)

    resolver = resolver_mod.EpisodeDatasetResolver(
        env_id="DummyEnv",
        episode=0,
        dataset_directory=h5_path,
    )
    try:
        command0 = resolver.get_step("multi_choice", 0)
        assert command0 == {"choice": "B", "point": [34, 12]}
        assert "position_3d" not in command0

        command1 = resolver.get_step("multi_choice", 1)
        assert command1 is None
    finally:
        resolver.close()