stopcube remain static vqa override
Browse files
gradio-web/oracle_logic.py
CHANGED
|
@@ -31,7 +31,7 @@ except Exception as e:
|
|
| 31 |
# --- Project Imports ---
|
| 32 |
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
| 33 |
from robomme.robomme_env import * # noqa: F401,F403; ensure gym envs are registered
|
| 34 |
-
from
|
| 35 |
from robomme.robomme_env.utils.oracle_action_matcher import (
|
| 36 |
find_exact_label_option_index,
|
| 37 |
map_action_text_to_option_label,
|
|
|
|
| 31 |
# --- Project Imports ---
|
| 32 |
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
| 33 |
from robomme.robomme_env import * # noqa: F401,F403; ensure gym envs are registered
|
| 34 |
+
from vqa_options_override import get_vqa_options
|
| 35 |
from robomme.robomme_env.utils.oracle_action_matcher import (
|
| 36 |
find_exact_label_option_index,
|
| 37 |
map_action_text_to_option_label,
|
gradio-web/test/test_stopcube_vqa_override.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from robomme.robomme_env.utils import vqa_options as upstream_vqa_options
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class _DummyBase:
|
| 7 |
+
def __init__(self, steps_press, interval=30):
|
| 8 |
+
self.steps_press = steps_press
|
| 9 |
+
self.interval = interval
|
| 10 |
+
self.button = object()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _DummyEnv:
|
| 14 |
+
def __init__(self, base, elapsed_steps=0):
|
| 15 |
+
self.unwrapped = base
|
| 16 |
+
self.elapsed_steps = elapsed_steps
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _get_stopcube_options(module, env):
|
| 20 |
+
return module.get_vqa_options(env, planner=None, selected_target={"obj": None}, env_id="StopCube")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _get_remain_static_solver(options):
|
| 24 |
+
for option in options:
|
| 25 |
+
if option.get("action") == "remain static":
|
| 26 |
+
return option["solve"]
|
| 27 |
+
raise AssertionError("Missing 'remain static' option")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_stopcube_remain_static_merges_short_tail(monkeypatch, reload_module):
|
| 31 |
+
override = reload_module("vqa_options_override")
|
| 32 |
+
|
| 33 |
+
hold_calls = []
|
| 34 |
+
|
| 35 |
+
def _hold_spy(env, planner, absTimestep):
|
| 36 |
+
_ = planner
|
| 37 |
+
hold_calls.append(int(absTimestep))
|
| 38 |
+
env.elapsed_steps = int(absTimestep)
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
monkeypatch.setattr(override, "solve_hold_obj_absTimestep", _hold_spy)
|
| 42 |
+
|
| 43 |
+
base = _DummyBase(steps_press=270, interval=30)
|
| 44 |
+
env = _DummyEnv(base, elapsed_steps=0)
|
| 45 |
+
options = _get_stopcube_options(override, env)
|
| 46 |
+
|
| 47 |
+
actions = [option.get("action") for option in options]
|
| 48 |
+
assert actions == [
|
| 49 |
+
"move to the top of the button to prepare",
|
| 50 |
+
"remain static",
|
| 51 |
+
"press button to stop the cube",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
solve_remain_static = _get_remain_static_solver(options)
|
| 55 |
+
for _ in range(3):
|
| 56 |
+
solve_remain_static()
|
| 57 |
+
|
| 58 |
+
assert hold_calls == [100, 240, 240]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_stopcube_remain_static_keeps_boundary_tail(monkeypatch, reload_module):
|
| 62 |
+
override = reload_module("vqa_options_override")
|
| 63 |
+
|
| 64 |
+
hold_calls = []
|
| 65 |
+
|
| 66 |
+
def _hold_spy(env, planner, absTimestep):
|
| 67 |
+
_ = planner
|
| 68 |
+
hold_calls.append(int(absTimestep))
|
| 69 |
+
env.elapsed_steps = int(absTimestep)
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
monkeypatch.setattr(override, "solve_hold_obj_absTimestep", _hold_spy)
|
| 73 |
+
|
| 74 |
+
base = _DummyBase(steps_press=280, interval=30)
|
| 75 |
+
env = _DummyEnv(base, elapsed_steps=0)
|
| 76 |
+
solve_remain_static = _get_remain_static_solver(_get_stopcube_options(override, env))
|
| 77 |
+
|
| 78 |
+
for _ in range(4):
|
| 79 |
+
solve_remain_static()
|
| 80 |
+
|
| 81 |
+
assert hold_calls == [100, 200, 250, 250]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_stopcube_remain_static_resets_after_elapsed_steps_go_back(monkeypatch, reload_module):
|
| 85 |
+
override = reload_module("vqa_options_override")
|
| 86 |
+
|
| 87 |
+
hold_calls = []
|
| 88 |
+
|
| 89 |
+
def _hold_spy(env, planner, absTimestep):
|
| 90 |
+
_ = planner
|
| 91 |
+
hold_calls.append(int(absTimestep))
|
| 92 |
+
env.elapsed_steps = int(absTimestep)
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
monkeypatch.setattr(override, "solve_hold_obj_absTimestep", _hold_spy)
|
| 96 |
+
|
| 97 |
+
base = _DummyBase(steps_press=270, interval=30)
|
| 98 |
+
env = _DummyEnv(base, elapsed_steps=0)
|
| 99 |
+
solve_remain_static = _get_remain_static_solver(_get_stopcube_options(override, env))
|
| 100 |
+
|
| 101 |
+
solve_remain_static()
|
| 102 |
+
solve_remain_static()
|
| 103 |
+
env.elapsed_steps = 0
|
| 104 |
+
solve_remain_static()
|
| 105 |
+
|
| 106 |
+
assert hold_calls == [100, 240, 100]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def test_non_stopcube_builders_passthrough_to_upstream(reload_module):
|
| 110 |
+
override = reload_module("vqa_options_override")
|
| 111 |
+
|
| 112 |
+
assert override.OPTION_BUILDERS["StopCube"] is override._options_stopcube_override
|
| 113 |
+
assert override.OPTION_BUILDERS["StopCube"] is not upstream_vqa_options.OPTION_BUILDERS["StopCube"]
|
| 114 |
+
assert override.OPTION_BUILDERS["BinFill"] is upstream_vqa_options.OPTION_BUILDERS["BinFill"]
|
gradio-web/test/test_ui_phase_machine_runtime_e2e.py
CHANGED
|
@@ -2940,3 +2940,297 @@ def test_phase_machine_runtime_local_video_path_end_transition_terminal_failed()
|
|
| 2940 |
expect_terminal_buttons_disabled=True,
|
| 2941 |
expected_terminal_log="episode failed",
|
| 2942 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2940 |
expect_terminal_buttons_disabled=True,
|
| 2941 |
expected_terminal_log="episode failed",
|
| 2942 |
)
|
| 2943 |
+
|
| 2944 |
+
|
| 2945 |
+
def test_phase_machine_runtime_stopcube_remain_static_merges_short_tail(monkeypatch):
|
| 2946 |
+
import gradio_callbacks as cb
|
| 2947 |
+
import config as config_module
|
| 2948 |
+
import vqa_options_override as override
|
| 2949 |
+
|
| 2950 |
+
demo_video_path = gr.get_video("world.mp4")
|
| 2951 |
+
fake_obs = np.zeros((24, 24, 3), dtype=np.uint8)
|
| 2952 |
+
hold_calls = []
|
| 2953 |
+
|
| 2954 |
+
def _hold_spy(env, planner, absTimestep):
|
| 2955 |
+
_ = planner
|
| 2956 |
+
hold_calls.append(int(absTimestep))
|
| 2957 |
+
env.elapsed_steps = int(absTimestep)
|
| 2958 |
+
return None
|
| 2959 |
+
|
| 2960 |
+
monkeypatch.setattr(override, "solve_hold_obj_absTimestep", _hold_spy)
|
| 2961 |
+
|
| 2962 |
+
class FakeBase:
|
| 2963 |
+
def __init__(self):
|
| 2964 |
+
self.steps_press = 270
|
| 2965 |
+
self.interval = 30
|
| 2966 |
+
self.button = object()
|
| 2967 |
+
|
| 2968 |
+
class FakeEnv:
|
| 2969 |
+
def __init__(self):
|
| 2970 |
+
self.unwrapped = FakeBase()
|
| 2971 |
+
self.elapsed_steps = 0
|
| 2972 |
+
|
| 2973 |
+
class FakeSession:
|
| 2974 |
+
def __init__(self):
|
| 2975 |
+
self.env_id = "StopCube"
|
| 2976 |
+
self.episode_idx = 1
|
| 2977 |
+
self.language_goal = "stop the moving cube"
|
| 2978 |
+
self.difficulty = "easy"
|
| 2979 |
+
self.seed = 123
|
| 2980 |
+
self.non_demonstration_task_length = None
|
| 2981 |
+
self.demonstration_frames = []
|
| 2982 |
+
self.last_execution_frames = []
|
| 2983 |
+
self.base_frames = [fake_obs.copy()]
|
| 2984 |
+
self.env = FakeEnv()
|
| 2985 |
+
self.planner = object()
|
| 2986 |
+
self.raw_solve_options = override.get_vqa_options(
|
| 2987 |
+
self.env,
|
| 2988 |
+
self.planner,
|
| 2989 |
+
{"obj": None},
|
| 2990 |
+
self.env_id,
|
| 2991 |
+
)
|
| 2992 |
+
self.available_options = [
|
| 2993 |
+
(f"{opt['label']}. {opt['action']}", idx)
|
| 2994 |
+
for idx, opt in enumerate(self.raw_solve_options)
|
| 2995 |
+
]
|
| 2996 |
+
|
| 2997 |
+
def get_pil_image(self, use_segmented=False):
|
| 2998 |
+
_ = use_segmented
|
| 2999 |
+
return fake_obs.copy()
|
| 3000 |
+
|
| 3001 |
+
def update_observation(self, use_segmentation=False):
|
| 3002 |
+
_ = use_segmentation
|
| 3003 |
+
return None
|
| 3004 |
+
|
| 3005 |
+
def execute_action(self, option_idx, click_coords):
|
| 3006 |
+
_ = click_coords
|
| 3007 |
+
current_options = override.get_vqa_options(
|
| 3008 |
+
self.env,
|
| 3009 |
+
self.planner,
|
| 3010 |
+
{"obj": None},
|
| 3011 |
+
self.env_id,
|
| 3012 |
+
)
|
| 3013 |
+
current_options[option_idx]["solve"]()
|
| 3014 |
+
|
| 3015 |
+
frame_value = hold_calls[-1] if hold_calls else 0
|
| 3016 |
+
frame = np.full((24, 24, 3), frame_value, dtype=np.uint8)
|
| 3017 |
+
self.last_execution_frames = [frame.copy(), frame.copy()]
|
| 3018 |
+
self.base_frames.extend(self.last_execution_frames)
|
| 3019 |
+
return frame.copy(), f"Executing: {current_options[option_idx]['label']}", False
|
| 3020 |
+
|
| 3021 |
+
fake_session = FakeSession()
|
| 3022 |
+
monkeypatch.setattr(cb, "get_session", lambda uid: fake_session)
|
| 3023 |
+
monkeypatch.setattr(cb, "increment_execute_count", lambda uid, env_id, ep_num: 1)
|
| 3024 |
+
monkeypatch.setattr(cb, "save_video", lambda frames, suffix="": demo_video_path)
|
| 3025 |
+
monkeypatch.setattr(cb, "concatenate_frames_horizontally", lambda frames, env_id=None: list(frames))
|
| 3026 |
+
monkeypatch.setattr(cb.os.path, "exists", lambda path: True)
|
| 3027 |
+
monkeypatch.setattr(cb.os.path, "getsize", lambda path: 10)
|
| 3028 |
+
|
| 3029 |
+
with gr.Blocks(title="Native StopCube merge test") as demo:
|
| 3030 |
+
uid_state = gr.State(value="uid-stopcube-merge")
|
| 3031 |
+
phase_state = gr.State(value="action_point")
|
| 3032 |
+
post_execute_controls_state = gr.State(
|
| 3033 |
+
value={
|
| 3034 |
+
"exec_btn_interactive": True,
|
| 3035 |
+
"reference_action_interactive": True,
|
| 3036 |
+
}
|
| 3037 |
+
)
|
| 3038 |
+
post_execute_log_state = gr.State(
|
| 3039 |
+
value={
|
| 3040 |
+
"preserve_terminal_log": False,
|
| 3041 |
+
"terminal_log_value": None,
|
| 3042 |
+
}
|
| 3043 |
+
)
|
| 3044 |
+
suppress_state = gr.State(value=False)
|
| 3045 |
+
with gr.Column(visible=True, elem_id="main_interface") as main_interface:
|
| 3046 |
+
with gr.Column(visible=False, elem_id="video_phase_group") as video_phase_group:
|
| 3047 |
+
video_display = gr.Video(value=None, elem_id="demo_video", autoplay=False)
|
| 3048 |
+
watch_demo_video_btn = gr.Button(
|
| 3049 |
+
"Watch Video Input🎬",
|
| 3050 |
+
elem_id="watch_demo_video_btn",
|
| 3051 |
+
interactive=False,
|
| 3052 |
+
visible=False,
|
| 3053 |
+
)
|
| 3054 |
+
|
| 3055 |
+
with gr.Column(visible=False, elem_id="execution_video_group") as execution_video_group:
|
| 3056 |
+
execute_video_display = gr.Video(value=None, elem_id="execute_video", autoplay=True)
|
| 3057 |
+
|
| 3058 |
+
with gr.Column(visible=True, elem_id="action_phase_group") as action_phase_group:
|
| 3059 |
+
img_display = gr.Image(value=fake_obs.copy(), elem_id="live_obs")
|
| 3060 |
+
|
| 3061 |
+
with gr.Column(visible=True, elem_id="control_panel_group") as control_panel_group:
|
| 3062 |
+
options_radio = gr.Radio(
|
| 3063 |
+
choices=fake_session.available_options,
|
| 3064 |
+
value=None,
|
| 3065 |
+
elem_id="action_radio",
|
| 3066 |
+
)
|
| 3067 |
+
coords_box = gr.Textbox(config_module.UI_TEXT["coords"]["not_needed"], elem_id="coords_box")
|
| 3068 |
+
exec_btn = gr.Button("execute", interactive=True, elem_id="exec_btn")
|
| 3069 |
+
reference_action_btn = gr.Button("reference", interactive=True, elem_id="reference_action_btn")
|
| 3070 |
+
restart_episode_btn = gr.Button("restart", interactive=True, elem_id="restart_episode_btn")
|
| 3071 |
+
next_task_btn = gr.Button("next", interactive=True, elem_id="next_task_btn")
|
| 3072 |
+
task_hint_display = gr.Textbox("hint", interactive=True, elem_id="task_hint_display")
|
| 3073 |
+
|
| 3074 |
+
log_output = gr.Markdown("", elem_id="log_output")
|
| 3075 |
+
task_info_box = gr.Textbox("")
|
| 3076 |
+
progress_info_box = gr.Textbox("")
|
| 3077 |
+
|
| 3078 |
+
exec_btn.click(
|
| 3079 |
+
fn=cb.precheck_execute_inputs,
|
| 3080 |
+
inputs=[uid_state, options_radio, coords_box],
|
| 3081 |
+
outputs=[],
|
| 3082 |
+
queue=False,
|
| 3083 |
+
).then(
|
| 3084 |
+
fn=cb.switch_to_execute_phase,
|
| 3085 |
+
inputs=[uid_state],
|
| 3086 |
+
outputs=[
|
| 3087 |
+
options_radio,
|
| 3088 |
+
exec_btn,
|
| 3089 |
+
restart_episode_btn,
|
| 3090 |
+
next_task_btn,
|
| 3091 |
+
img_display,
|
| 3092 |
+
reference_action_btn,
|
| 3093 |
+
task_hint_display,
|
| 3094 |
+
],
|
| 3095 |
+
queue=False,
|
| 3096 |
+
).then(
|
| 3097 |
+
fn=cb.execute_step,
|
| 3098 |
+
inputs=[uid_state, options_radio, coords_box],
|
| 3099 |
+
outputs=[
|
| 3100 |
+
img_display,
|
| 3101 |
+
log_output,
|
| 3102 |
+
task_info_box,
|
| 3103 |
+
progress_info_box,
|
| 3104 |
+
restart_episode_btn,
|
| 3105 |
+
next_task_btn,
|
| 3106 |
+
exec_btn,
|
| 3107 |
+
execute_video_display,
|
| 3108 |
+
action_phase_group,
|
| 3109 |
+
control_panel_group,
|
| 3110 |
+
execution_video_group,
|
| 3111 |
+
options_radio,
|
| 3112 |
+
coords_box,
|
| 3113 |
+
reference_action_btn,
|
| 3114 |
+
task_hint_display,
|
| 3115 |
+
post_execute_controls_state,
|
| 3116 |
+
post_execute_log_state,
|
| 3117 |
+
phase_state,
|
| 3118 |
+
],
|
| 3119 |
+
queue=False,
|
| 3120 |
+
)
|
| 3121 |
+
options_radio.change(
|
| 3122 |
+
fn=cb.on_option_select,
|
| 3123 |
+
inputs=[uid_state, options_radio, coords_box, suppress_state, post_execute_log_state],
|
| 3124 |
+
outputs=[coords_box, img_display, log_output, suppress_state, post_execute_log_state],
|
| 3125 |
+
queue=False,
|
| 3126 |
+
)
|
| 3127 |
+
|
| 3128 |
+
execute_video_display.end(
|
| 3129 |
+
fn=cb.on_execute_video_end_transition,
|
| 3130 |
+
inputs=[uid_state, post_execute_controls_state, post_execute_log_state],
|
| 3131 |
+
outputs=[
|
| 3132 |
+
execution_video_group,
|
| 3133 |
+
action_phase_group,
|
| 3134 |
+
control_panel_group,
|
| 3135 |
+
options_radio,
|
| 3136 |
+
exec_btn,
|
| 3137 |
+
restart_episode_btn,
|
| 3138 |
+
next_task_btn,
|
| 3139 |
+
img_display,
|
| 3140 |
+
log_output,
|
| 3141 |
+
reference_action_btn,
|
| 3142 |
+
task_hint_display,
|
| 3143 |
+
phase_state,
|
| 3144 |
+
],
|
| 3145 |
+
queue=False,
|
| 3146 |
+
)
|
| 3147 |
+
execute_video_display.stop(
|
| 3148 |
+
fn=cb.on_execute_video_end_transition,
|
| 3149 |
+
inputs=[uid_state, post_execute_controls_state, post_execute_log_state],
|
| 3150 |
+
outputs=[
|
| 3151 |
+
execution_video_group,
|
| 3152 |
+
action_phase_group,
|
| 3153 |
+
control_panel_group,
|
| 3154 |
+
options_radio,
|
| 3155 |
+
exec_btn,
|
| 3156 |
+
restart_episode_btn,
|
| 3157 |
+
next_task_btn,
|
| 3158 |
+
img_display,
|
| 3159 |
+
log_output,
|
| 3160 |
+
reference_action_btn,
|
| 3161 |
+
task_hint_display,
|
| 3162 |
+
phase_state,
|
| 3163 |
+
],
|
| 3164 |
+
queue=False,
|
| 3165 |
+
)
|
| 3166 |
+
|
| 3167 |
+
port = _free_port()
|
| 3168 |
+
host = "127.0.0.1"
|
| 3169 |
+
root_url = f"http://{host}:{port}/"
|
| 3170 |
+
|
| 3171 |
+
app = FastAPI(title="native-stopcube-merge-test")
|
| 3172 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
| 3173 |
+
|
| 3174 |
+
config = uvicorn.Config(app, host=host, port=port, log_level="error")
|
| 3175 |
+
server = uvicorn.Server(config)
|
| 3176 |
+
thread = threading.Thread(target=server.run, daemon=True)
|
| 3177 |
+
thread.start()
|
| 3178 |
+
_wait_http_ready(root_url)
|
| 3179 |
+
|
| 3180 |
+
try:
|
| 3181 |
+
with sync_playwright() as p:
|
| 3182 |
+
browser = p.chromium.launch(headless=True)
|
| 3183 |
+
page = browser.new_page(viewport={"width": 1280, "height": 900})
|
| 3184 |
+
page.goto(root_url, wait_until="domcontentloaded")
|
| 3185 |
+
page.wait_for_selector("#main_interface", state="visible", timeout=20000)
|
| 3186 |
+
|
| 3187 |
+
for _ in range(2):
|
| 3188 |
+
page.locator("#action_radio input[type='radio']").nth(1).check(force=True)
|
| 3189 |
+
page.locator("#exec_btn button, button#exec_btn").first.click()
|
| 3190 |
+
page.wait_for_selector("#execute_video video", timeout=5000)
|
| 3191 |
+
page.wait_for_function(
|
| 3192 |
+
"""() => {
|
| 3193 |
+
const visible = (id) => {
|
| 3194 |
+
const el = document.getElementById(id);
|
| 3195 |
+
if (!el) return false;
|
| 3196 |
+
const st = getComputedStyle(el);
|
| 3197 |
+
return st.display !== 'none' && st.visibility !== 'hidden' && el.getClientRects().length > 0;
|
| 3198 |
+
};
|
| 3199 |
+
const videoEl = document.querySelector('#execute_video video');
|
| 3200 |
+
return (
|
| 3201 |
+
visible('execution_video_group') &&
|
| 3202 |
+
visible('execute_video') &&
|
| 3203 |
+
!visible('action_phase_group') &&
|
| 3204 |
+
!!videoEl &&
|
| 3205 |
+
videoEl.autoplay === true
|
| 3206 |
+
);
|
| 3207 |
+
}""",
|
| 3208 |
+
timeout=10000,
|
| 3209 |
+
)
|
| 3210 |
+
assert _dispatch_video_event(page, "ended", elem_id="execute_video")
|
| 3211 |
+
page.wait_for_function(
|
| 3212 |
+
"""() => {
|
| 3213 |
+
const visible = (id) => {
|
| 3214 |
+
const el = document.getElementById(id);
|
| 3215 |
+
if (!el) return false;
|
| 3216 |
+
const st = getComputedStyle(el);
|
| 3217 |
+
return st.display !== 'none' && st.visibility !== 'hidden' && el.getClientRects().length > 0;
|
| 3218 |
+
};
|
| 3219 |
+
const execBtn = document.querySelector('#exec_btn button') || document.querySelector('button#exec_btn');
|
| 3220 |
+
return (
|
| 3221 |
+
visible('action_phase_group') &&
|
| 3222 |
+
visible('control_panel_group') &&
|
| 3223 |
+
!visible('execute_video') &&
|
| 3224 |
+
!!execBtn &&
|
| 3225 |
+
execBtn.disabled === false
|
| 3226 |
+
);
|
| 3227 |
+
}""",
|
| 3228 |
+
timeout=5000,
|
| 3229 |
+
)
|
| 3230 |
+
|
| 3231 |
+
assert hold_calls == [100, 240]
|
| 3232 |
+
browser.close()
|
| 3233 |
+
finally:
|
| 3234 |
+
server.should_exit = True
|
| 3235 |
+
thread.join(timeout=10)
|
| 3236 |
+
demo.close()
|
gradio-web/vqa_options_override.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Callable, Dict, List
|
| 4 |
+
|
| 5 |
+
from robomme.robomme_env.utils import vqa_options as upstream_vqa_options
|
| 6 |
+
|
| 7 |
+
solve_button = upstream_vqa_options.solve_button
|
| 8 |
+
solve_button_ready = upstream_vqa_options.solve_button_ready
|
| 9 |
+
solve_hold_obj_absTimestep = upstream_vqa_options.solve_hold_obj_absTimestep
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _build_stopcube_static_checkpoints(final_target: int) -> List[int]:
|
| 13 |
+
checkpoints = list(range(100, final_target, 100))
|
| 14 |
+
if not checkpoints or checkpoints[-1] != final_target:
|
| 15 |
+
checkpoints.append(final_target)
|
| 16 |
+
|
| 17 |
+
if len(checkpoints) >= 2 and checkpoints[-1] - checkpoints[-2] < 50:
|
| 18 |
+
del checkpoints[-2]
|
| 19 |
+
|
| 20 |
+
return checkpoints
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _options_stopcube_override(env, planner, require_target, base) -> List[dict]:
|
| 24 |
+
_ = require_target
|
| 25 |
+
options: List[dict] = []
|
| 26 |
+
button_obj = getattr(base, "button", None)
|
| 27 |
+
|
| 28 |
+
if button_obj is not None:
|
| 29 |
+
options.append(
|
| 30 |
+
{
|
| 31 |
+
"label": "a",
|
| 32 |
+
"action": "move to the top of the button to prepare",
|
| 33 |
+
"solve": lambda button_obj=button_obj: solve_button_ready(
|
| 34 |
+
env, planner, obj=button_obj
|
| 35 |
+
),
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
steps_press = getattr(base, "steps_press", None)
|
| 40 |
+
if steps_press is not None:
|
| 41 |
+
|
| 42 |
+
def solve_with_incremental_steps():
|
| 43 |
+
steps_press_value = getattr(base, "steps_press", None)
|
| 44 |
+
if steps_press_value is None:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
interval = getattr(base, "interval", 30)
|
| 48 |
+
final_target = max(0, int(steps_press_value - interval))
|
| 49 |
+
current_step = int(getattr(env, "elapsed_steps", 0))
|
| 50 |
+
|
| 51 |
+
checkpoints_key = "_stopcube_static_checkpoints"
|
| 52 |
+
index_key = "_stopcube_static_index"
|
| 53 |
+
cached_final_target_key = "_stopcube_static_final_target"
|
| 54 |
+
last_elapsed_step_key = "_stopcube_static_last_elapsed_step"
|
| 55 |
+
|
| 56 |
+
checkpoints = getattr(base, checkpoints_key, None)
|
| 57 |
+
index = getattr(base, index_key, None)
|
| 58 |
+
cached_final_target = getattr(base, cached_final_target_key, None)
|
| 59 |
+
last_elapsed_step = getattr(base, last_elapsed_step_key, None)
|
| 60 |
+
|
| 61 |
+
needs_rebuild = (
|
| 62 |
+
not isinstance(checkpoints, list)
|
| 63 |
+
or len(checkpoints) == 0
|
| 64 |
+
or index is None
|
| 65 |
+
or cached_final_target is None
|
| 66 |
+
or int(cached_final_target) != final_target
|
| 67 |
+
or (
|
| 68 |
+
last_elapsed_step is not None
|
| 69 |
+
and current_step < int(last_elapsed_step)
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if needs_rebuild:
|
| 74 |
+
checkpoints = _build_stopcube_static_checkpoints(final_target)
|
| 75 |
+
index = 0
|
| 76 |
+
else:
|
| 77 |
+
index = int(index)
|
| 78 |
+
if index < 0:
|
| 79 |
+
index = 0
|
| 80 |
+
if index >= len(checkpoints):
|
| 81 |
+
index = len(checkpoints) - 1
|
| 82 |
+
|
| 83 |
+
target = checkpoints[index]
|
| 84 |
+
solve_hold_obj_absTimestep(env, planner, absTimestep=target)
|
| 85 |
+
|
| 86 |
+
index += 1
|
| 87 |
+
|
| 88 |
+
setattr(base, checkpoints_key, checkpoints)
|
| 89 |
+
setattr(base, index_key, index)
|
| 90 |
+
setattr(base, cached_final_target_key, final_target)
|
| 91 |
+
setattr(base, last_elapsed_step_key, current_step)
|
| 92 |
+
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
options.append(
|
| 96 |
+
{
|
| 97 |
+
"label": "b",
|
| 98 |
+
"action": "remain static",
|
| 99 |
+
"solve": solve_with_incremental_steps,
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if button_obj is not None:
|
| 104 |
+
options.append(
|
| 105 |
+
{
|
| 106 |
+
"label": "c",
|
| 107 |
+
"action": "press button to stop the cube",
|
| 108 |
+
"solve": lambda button_obj=button_obj: solve_button(
|
| 109 |
+
env, planner, obj=button_obj, without_hold=True
|
| 110 |
+
),
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return options
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
OPTION_BUILDERS: Dict[str, Callable] = dict(upstream_vqa_options.OPTION_BUILDERS)
|
| 118 |
+
OPTION_BUILDERS["StopCube"] = _options_stopcube_override
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_vqa_options(env, planner, selected_target, env_id: str) -> List[dict]:
|
| 122 |
+
"""Return Gradio-specific solve options without mutating the upstream src module."""
|
| 123 |
+
|
| 124 |
+
def _require_target():
|
| 125 |
+
obj = selected_target.get("obj")
|
| 126 |
+
if obj is None:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"No available target cube found, please click target in segmentation map first."
|
| 129 |
+
)
|
| 130 |
+
return obj
|
| 131 |
+
|
| 132 |
+
base = env.unwrapped
|
| 133 |
+
builder = OPTION_BUILDERS.get(
|
| 134 |
+
env_id, getattr(upstream_vqa_options, "_options_default")
|
| 135 |
+
)
|
| 136 |
+
return builder(env, planner, _require_target, base)
|
| 137 |
+
|