File size: 7,042 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a64124d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from __future__ import annotations

import numpy as np


class _FakeUnwrapped:
    def __init__(self):
        self.segmentation_id_map = {}


class _FakeEnv:
    def __init__(self):
        self.unwrapped = _FakeUnwrapped()
        self.frames = [np.zeros((8, 8, 3), dtype=np.uint8)]
        self.wrist_frames = []


class _FakeObsWrapperEnv:
    def __init__(self, front_rgb_list, wrist_rgb_list):
        self.unwrapped = _FakeUnwrapped()
        self._last_obs = {
            "front_rgb_list": front_rgb_list,
            "wrist_rgb_list": wrist_rgb_list,
        }



def test_available_options_use_label_plus_action(monkeypatch, reload_module):
    oracle_logic = reload_module("oracle_logic")

    monkeypatch.setattr(
        oracle_logic,
        "_fetch_segmentation",
        lambda env: np.zeros((1, 8, 8), dtype=np.int64),
    )
    monkeypatch.setattr(
        oracle_logic,
        "_build_solve_options",
        lambda env, planner, selected_target, env_id: [
            {"label": "a", "action": "pick up the cube", "available": [1]},
            {"label": "b", "action": "put it down", "available": []},
        ],
    )

    session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
    session.env = _FakeEnv()
    session.planner = object()
    session.env_id = "BinFill"
    session.color_map = {}

    _img, msg = session.update_observation()

    assert msg == "Ready"
    assert session.available_options == [
        ("a. pick up the cube", 0),
        ("b. put it down", 1),
    ]
    assert session.raw_solve_options[0]["label"] == "a"


def test_build_solve_options_filters_press_button_for_video_place_envs(monkeypatch, reload_module):
    oracle_logic = reload_module("oracle_logic")

    base_options = [
        {"label": "a", "action": "pick up the cube", "available": [1]},
        {"label": "b", "action": "drop onto", "available": [2]},
        {"label": "c", "action": "press the button"},
    ]
    monkeypatch.setattr(
        oracle_logic,
        "get_vqa_options",
        lambda env, planner, selected_target, env_id: list(base_options),
    )

    filtered_button = oracle_logic._build_solve_options(None, None, {}, "VideoPlaceButton")
    filtered_order = oracle_logic._build_solve_options(None, None, {}, "VideoPlaceOrder")
    unfiltered_other = oracle_logic._build_solve_options(None, None, {}, "ButtonUnmask")

    assert [opt["label"] for opt in filtered_button] == ["a", "b"]
    assert [opt["action"] for opt in filtered_button] == ["pick up the cube", "drop onto"]
    assert [opt["label"] for opt in filtered_order] == ["a", "b"]
    assert [opt["action"] for opt in unfiltered_other] == [
        "pick up the cube",
        "drop onto",
        "press the button",
    ]


def test_update_observation_no_seg_vis_base_fallback(monkeypatch, reload_module):
    oracle_logic = reload_module("oracle_logic")

    seg_vis = np.zeros((6, 6, 3), dtype=np.uint8)
    seg_vis[:, :, 0] = 10  # B
    seg_vis[:, :, 1] = 20  # G
    seg_vis[:, :, 2] = 30  # R

    monkeypatch.setattr(
        oracle_logic,
        "_fetch_segmentation",
        lambda env: np.zeros((1, 6, 6), dtype=np.int64),
    )
    monkeypatch.setattr(
        oracle_logic,
        "_prepare_segmentation_visual",
        lambda seg, color_map, hw: (seg_vis, np.zeros((6, 6), dtype=np.int64)),
    )
    monkeypatch.setattr(
        oracle_logic,
        "_build_solve_options",
        lambda env, planner, selected_target, env_id: [],
    )

    session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
    session.env = type(
        "_NoFrameEnv",
        (),
        {"unwrapped": _FakeUnwrapped(), "frames": [], "wrist_frames": []},
    )()
    session.planner = object()
    session.env_id = "BinFill"
    session.color_map = {}

    _img, msg = session.update_observation(use_segmentation=False)

    assert msg == "Ready"
    assert len(session.base_frames) == 0

    pil_img = session.get_pil_image(use_segmented=False)
    assert pil_img.size == (255, 255)


def test_update_observation_uses_only_front_rgb_list(monkeypatch, reload_module):
    oracle_logic = reload_module("oracle_logic")

    monkeypatch.setattr(
        oracle_logic,
        "_fetch_segmentation",
        lambda env: np.zeros((1, 8, 8), dtype=np.int64),
    )
    monkeypatch.setattr(
        oracle_logic,
        "_build_solve_options",
        lambda env, planner, selected_target, env_id: [],
    )

    f1 = np.full((8, 8, 3), 11, dtype=np.uint8)
    f2 = np.full((8, 8, 3), 22, dtype=np.uint8)

    session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
    session.env = _FakeObsWrapperEnv(front_rgb_list=[f1, f2], wrist_rgb_list=[])
    session.planner = object()
    session.env_id = "BinFill"
    session.color_map = {}

    _img, msg = session.update_observation(use_segmentation=False)

    assert msg == "Ready"
    assert len(session.base_frames) == 2
    assert len(session.wrist_frames) == 0
    assert session.base_frames[-1][0, 0, 0] == 22


def test_update_observation_does_not_duplicate_same_last_obs(monkeypatch, reload_module):
    oracle_logic = reload_module("oracle_logic")

    monkeypatch.setattr(
        oracle_logic,
        "_fetch_segmentation",
        lambda env: np.zeros((1, 8, 8), dtype=np.int64),
    )
    monkeypatch.setattr(
        oracle_logic,
        "_build_solve_options",
        lambda env, planner, selected_target, env_id: [],
    )

    f1 = np.full((8, 8, 3), 10, dtype=np.uint8)
    f2 = np.full((8, 8, 3), 20, dtype=np.uint8)
    env = _FakeObsWrapperEnv(front_rgb_list=[f1, f2], wrist_rgb_list=[])

    session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
    session.env = env
    session.planner = object()
    session.env_id = "BinFill"
    session.color_map = {}

    session.update_observation(use_segmentation=False)
    session.update_observation(use_segmentation=False)
    assert len(session.base_frames) == 2

    f3 = np.full((8, 8, 3), 30, dtype=np.uint8)
    env._last_obs = {"front_rgb_list": [f3], "wrist_rgb_list": []}
    session.update_observation(use_segmentation=False)
    assert len(session.base_frames) == 3
    assert session.base_frames[-1][0, 0, 0] == 30


def test_update_observation_does_not_fallback_to_env_frames(monkeypatch, reload_module):
    oracle_logic = reload_module("oracle_logic")

    monkeypatch.setattr(
        oracle_logic,
        "_fetch_segmentation",
        lambda env: np.zeros((1, 8, 8), dtype=np.int64),
    )
    monkeypatch.setattr(
        oracle_logic,
        "_build_solve_options",
        lambda env, planner, selected_target, env_id: [],
    )

    env = _FakeEnv()
    env.frames = [np.full((8, 8, 3), 99, dtype=np.uint8)]

    session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
    session.env = env
    session.planner = object()
    session.env_id = "BinFill"
    session.color_map = {}

    _img, msg = session.update_observation(use_segmentation=False)

    assert msg == "Ready"
    assert session.base_frames == []