routestick patternlock text update
Browse files- gradio-web/config.py +43 -1
- gradio-web/gradio_callbacks.py +30 -11
- gradio-web/image_utils.py +6 -5
- gradio-web/main.py +14 -0
- gradio-web/test/test_reference_action_callbacks.py +31 -1
- gradio-web/test/test_ui_text_config.py +127 -2
gradio-web/config.py
CHANGED
|
@@ -37,7 +37,7 @@ DEMO_VIDEO_ENV_IDS = [
|
|
| 37 |
|
| 38 |
UI_TEXT = {
|
| 39 |
"log": {
|
| 40 |
-
"action_selection_prompt": "please select the action
|
| 41 |
"demo_video_prompt": 'press "Watch Video Input🎬" to watch a video\nNote: you can only watch the video once',
|
| 42 |
"session_error": "Session Error",
|
| 43 |
"reference_action_error": "Ground Truth Action Error: {error}",
|
|
@@ -53,6 +53,9 @@ UI_TEXT = {
|
|
| 53 |
"select_keypoint": "please click the keypoint selection image",
|
| 54 |
"select_keypoint_before_execute": "please click the keypoint selection image before execute!",
|
| 55 |
},
|
|
|
|
|
|
|
|
|
|
| 56 |
"errors": {
|
| 57 |
"load_missing_task": "Error loading task: missing current_task",
|
| 58 |
"load_invalid_task": "Error loading task: invalid task payload",
|
|
@@ -66,6 +69,45 @@ UI_TEXT = {
|
|
| 66 |
},
|
| 67 |
}
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def should_show_demo_video(env_id):
|
| 70 |
"""
|
| 71 |
判断指定的环境ID是否应该显示demonstration video
|
|
|
|
| 37 |
|
| 38 |
UI_TEXT = {
|
| 39 |
"log": {
|
| 40 |
+
"action_selection_prompt": "please select the action in the right 👈,\nsome actions also need to select keypoint",
|
| 41 |
"demo_video_prompt": 'press "Watch Video Input🎬" to watch a video\nNote: you can only watch the video once',
|
| 42 |
"session_error": "Session Error",
|
| 43 |
"reference_action_error": "Ground Truth Action Error: {error}",
|
|
|
|
| 53 |
"select_keypoint": "please click the keypoint selection image",
|
| 54 |
"select_keypoint_before_execute": "please click the keypoint selection image before execute!",
|
| 55 |
},
|
| 56 |
+
"actions": {
|
| 57 |
+
"keypoint_required_suffix": " (click mouse 🖱️ to select 🎯)",
|
| 58 |
+
},
|
| 59 |
"errors": {
|
| 60 |
"load_missing_task": "Error loading task: missing current_task",
|
| 61 |
"load_invalid_task": "Error loading task: invalid task payload",
|
|
|
|
| 69 |
},
|
| 70 |
}
|
| 71 |
|
| 72 |
+
UI_ACTION_TEXT_OVERRIDES = {
|
| 73 |
+
"PatternLock": {
|
| 74 |
+
"move forward": "move forward↓",
|
| 75 |
+
"move backward": "move backward↑",
|
| 76 |
+
"move left": "move left→",
|
| 77 |
+
"move right": "move right←",
|
| 78 |
+
"move forward-left": "move forward-left↘︎",
|
| 79 |
+
"move forward-right": "move forward-right↙︎",
|
| 80 |
+
"move backward-left": "move backward-left↗︎",
|
| 81 |
+
"move backward-right": "move backward-right↖︎",
|
| 82 |
+
},
|
| 83 |
+
"RouteStick": {
|
| 84 |
+
"move to the nearest left target by circling around the stick clockwise": "move left clockwise↘︎→↗︎ ◟→◞",
|
| 85 |
+
"move to the nearest right target by circling around the stick clockwise": "move right clockwise↖︎←↙︎ ◟←◞",
|
| 86 |
+
"move to the nearest left target by circling around the stick counterclockwise": "move left counterclockwise↗︎→↘︎ ◜→◝",
|
| 87 |
+
"move to the nearest right target by circling around the stick counterclockwise": "move right counterclockwise↙︎←↖︎ ◜←◝",
|
| 88 |
+
},
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
ROUTESTICK_OVERLAY_ACTION_TEXTS = [
|
| 92 |
+
"move to the nearest left target by circling around the stick clockwise",
|
| 93 |
+
"move to the nearest left target by circling around the stick counterclockwise",
|
| 94 |
+
"move to the nearest right target by circling around the stick clockwise",
|
| 95 |
+
"move to the nearest right target by circling around the stick counterclockwise",
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_ui_action_text(env_id, action_text):
|
| 100 |
+
"""
|
| 101 |
+
Return display-only action text overrides for a specific env/action pair.
|
| 102 |
+
Falls back to the original action text when no override is configured.
|
| 103 |
+
"""
|
| 104 |
+
if not isinstance(action_text, str):
|
| 105 |
+
return action_text
|
| 106 |
+
if not isinstance(env_id, str) or not env_id:
|
| 107 |
+
return action_text
|
| 108 |
+
env_overrides = UI_ACTION_TEXT_OVERRIDES.get(env_id, {})
|
| 109 |
+
return env_overrides.get(action_text, action_text)
|
| 110 |
+
|
| 111 |
def should_show_demo_video(env_id):
|
| 112 |
"""
|
| 113 |
判断指定的环境ID是否应该显示demonstration video
|
gradio-web/gradio_callbacks.py
CHANGED
|
@@ -38,6 +38,7 @@ from config import (
|
|
| 38 |
SESSION_TIMEOUT,
|
| 39 |
UI_TEXT,
|
| 40 |
USE_SEGMENTED_VIEW,
|
|
|
|
| 41 |
should_show_demo_video,
|
| 42 |
)
|
| 43 |
from process_session import ScrewPlanFailureError, ProcessSessionProxy
|
|
@@ -105,17 +106,34 @@ def get_videoplacebutton_goal(original_goal: str) -> str:
|
|
| 105 |
def _ui_option_label(session, opt_label: str, opt_idx: int) -> str:
|
| 106 |
"""
|
| 107 |
仅在 Gradio UI 层对选项显示文案做覆盖(不改底层 env/options 生成逻辑)。
|
| 108 |
-
|
|
|
|
| 109 |
"""
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return opt_label
|
| 120 |
|
| 121 |
|
|
@@ -508,7 +526,7 @@ def _load_status_task(uid, status):
|
|
| 508 |
if 0 <= opt_idx < len(session.raw_solve_options):
|
| 509 |
opt = session.raw_solve_options[opt_idx]
|
| 510 |
if opt.get("available"):
|
| 511 |
-
opt_label_with_hint = f"{opt_label}
|
| 512 |
else:
|
| 513 |
opt_label_with_hint = opt_label
|
| 514 |
else:
|
|
@@ -860,6 +878,7 @@ def on_reference_action(uid):
|
|
| 860 |
option_idx = reference.get("option_idx")
|
| 861 |
option_label = str(reference.get("option_label", "")).strip()
|
| 862 |
option_action = str(reference.get("option_action", "")).strip()
|
|
|
|
| 863 |
need_coords = bool(reference.get("need_coords", False))
|
| 864 |
coords_xy = reference.get("coords_xy")
|
| 865 |
|
|
|
|
| 38 |
SESSION_TIMEOUT,
|
| 39 |
UI_TEXT,
|
| 40 |
USE_SEGMENTED_VIEW,
|
| 41 |
+
get_ui_action_text,
|
| 42 |
should_show_demo_video,
|
| 43 |
)
|
| 44 |
from process_session import ScrewPlanFailureError, ProcessSessionProxy
|
|
|
|
| 106 |
def _ui_option_label(session, opt_label: str, opt_idx: int) -> str:
|
| 107 |
"""
|
| 108 |
仅在 Gradio UI 层对选项显示文案做覆盖(不改底层 env/options 生成逻辑)。
|
| 109 |
+
优先使用 raw_solve_options 中的原始 label/action 重新组装显示文本,
|
| 110 |
+
并按 env_id 做 display-only action 文案映射。
|
| 111 |
"""
|
| 112 |
+
try:
|
| 113 |
+
option_index = int(opt_idx)
|
| 114 |
+
except (TypeError, ValueError):
|
| 115 |
+
return opt_label
|
| 116 |
+
|
| 117 |
+
raw_solve_options = getattr(session, "raw_solve_options", None)
|
| 118 |
+
if not isinstance(raw_solve_options, list):
|
| 119 |
+
return opt_label
|
| 120 |
+
if not (0 <= option_index < len(raw_solve_options)):
|
| 121 |
+
return opt_label
|
| 122 |
+
|
| 123 |
+
raw_option = raw_solve_options[option_index]
|
| 124 |
+
if not isinstance(raw_option, dict):
|
| 125 |
+
return opt_label
|
| 126 |
+
|
| 127 |
+
raw_label = str(raw_option.get("label", "")).strip()
|
| 128 |
+
raw_action = str(raw_option.get("action", "")).strip()
|
| 129 |
+
mapped_action = get_ui_action_text(getattr(session, "env_id", None), raw_action)
|
| 130 |
+
|
| 131 |
+
if raw_label and mapped_action:
|
| 132 |
+
return f"{raw_label}. {mapped_action}"
|
| 133 |
+
if mapped_action:
|
| 134 |
+
return mapped_action
|
| 135 |
+
if raw_label:
|
| 136 |
+
return raw_label
|
| 137 |
return opt_label
|
| 138 |
|
| 139 |
|
|
|
|
| 526 |
if 0 <= opt_idx < len(session.raw_solve_options):
|
| 527 |
opt = session.raw_solve_options[opt_idx]
|
| 528 |
if opt.get("available"):
|
| 529 |
+
opt_label_with_hint = f"{opt_label}{_ui_text('actions', 'keypoint_required_suffix')}"
|
| 530 |
else:
|
| 531 |
opt_label_with_hint = opt_label
|
| 532 |
else:
|
|
|
|
| 878 |
option_idx = reference.get("option_idx")
|
| 879 |
option_label = str(reference.get("option_label", "")).strip()
|
| 880 |
option_action = str(reference.get("option_action", "")).strip()
|
| 881 |
+
option_action = get_ui_action_text(getattr(session, "env_id", None), option_action)
|
| 882 |
need_coords = bool(reference.get("need_coords", False))
|
| 883 |
coords_xy = reference.get("coords_xy")
|
| 884 |
|
gradio-web/image_utils.py
CHANGED
|
@@ -10,7 +10,7 @@ import math
|
|
| 10 |
from pathlib import Path
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
import cv2
|
| 13 |
-
from config import VIDEO_PLAYBACK_FPS
|
| 14 |
|
| 15 |
# DEPRECATED: 历史任务特化图像叠加配置,保留仅为兼容旧代码路径。
|
| 16 |
# 当前已统一关闭任务特化渲染。
|
|
@@ -354,6 +354,11 @@ def draw_coordinate_axes(img, position="right", rotate_180=False, env_id=None):
|
|
| 354 |
|
| 355 |
# 如果是 RouteStick 任务,绘制旋转方向示意图(左侧或右侧)
|
| 356 |
if env_id == "RouteStick" and (position == "right" or position == "left"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
# 绘制四个半圆箭头示意图(垂直排列)
|
| 358 |
# 示意图位置:在图像的左侧或右侧,从上到下垂直排列
|
| 359 |
illustration_width = 220 # 示意图区域宽度(已弃用,保留以保持兼容性)
|
|
@@ -401,7 +406,6 @@ def draw_coordinate_axes(img, position="right", rotate_180=False, env_id=None):
|
|
| 401 |
draw_semicircle(draw, (lcw_center_x , lcw_center_y+15), semicircle_radius, line_color, line_width, half="upper", start_pos="left", end_pos="right", arrow_position="end", arrow_size=arrow_size)
|
| 402 |
|
| 403 |
# 添加标签 "L CW"
|
| 404 |
-
lcw_text = "Left Clockwise"
|
| 405 |
lcw_bbox = draw.textbbox((0, 0), lcw_text, font=small_font)
|
| 406 |
lcw_text_width = lcw_bbox[2] - lcw_bbox[0]
|
| 407 |
lcw_text_height = lcw_bbox[3] - lcw_bbox[1]
|
|
@@ -418,7 +422,6 @@ def draw_coordinate_axes(img, position="right", rotate_180=False, env_id=None):
|
|
| 418 |
draw_semicircle(draw, (lccw_center_x, lccw_center_y), semicircle_radius, line_color, line_width, half="lower", start_pos="left", end_pos="right", arrow_position="end", arrow_size=arrow_size)
|
| 419 |
|
| 420 |
# 添加标签 "L CCW"
|
| 421 |
-
lccw_text = "Left Counterclockwise"
|
| 422 |
lccw_bbox = draw.textbbox((0, 0), lccw_text, font=small_font)
|
| 423 |
lccw_text_width = lccw_bbox[2] - lccw_bbox[0]
|
| 424 |
lccw_text_height = lccw_bbox[3] - lccw_bbox[1]
|
|
@@ -435,7 +438,6 @@ def draw_coordinate_axes(img, position="right", rotate_180=False, env_id=None):
|
|
| 435 |
draw_semicircle(draw, (rcw_center_x , rcw_center_y), semicircle_radius, line_color, line_width, half="lower", start_pos="right", end_pos="left", arrow_position="end", arrow_size=arrow_size)
|
| 436 |
|
| 437 |
# 添加标签 "R CW"
|
| 438 |
-
rcw_text = "Right Clockwise"
|
| 439 |
rcw_bbox = draw.textbbox((0, 0), rcw_text, font=small_font)
|
| 440 |
rcw_text_width = rcw_bbox[2] - rcw_bbox[0]
|
| 441 |
rcw_text_height = rcw_bbox[3] - rcw_bbox[1]
|
|
@@ -452,7 +454,6 @@ def draw_coordinate_axes(img, position="right", rotate_180=False, env_id=None):
|
|
| 452 |
draw_semicircle(draw, (rccw_center_x , rccw_center_y+15), semicircle_radius, line_color, line_width, half="upper",start_pos="right", end_pos="left", arrow_position="end", arrow_size=arrow_size)
|
| 453 |
|
| 454 |
# 添加标签 "R CCW"
|
| 455 |
-
rccw_text = "Right Counterclockwise"
|
| 456 |
rccw_bbox = draw.textbbox((0, 0), rccw_text, font=small_font)
|
| 457 |
rccw_text_width = rccw_bbox[2] - rccw_bbox[0]
|
| 458 |
rccw_text_height = rccw_bbox[3] - rccw_bbox[1]
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
import cv2
|
| 13 |
+
from config import VIDEO_PLAYBACK_FPS, ROUTESTICK_OVERLAY_ACTION_TEXTS, get_ui_action_text
|
| 14 |
|
| 15 |
# DEPRECATED: 历史任务特化图像叠加配置,保留仅为兼容旧代码路径。
|
| 16 |
# 当前已统一关闭任务特化渲染。
|
|
|
|
| 354 |
|
| 355 |
# 如果是 RouteStick 任务,绘制旋转方向示意图(左侧或右侧)
|
| 356 |
if env_id == "RouteStick" and (position == "right" or position == "left"):
|
| 357 |
+
lcw_text, lccw_text, rcw_text, rccw_text = [
|
| 358 |
+
get_ui_action_text("RouteStick", action_text)
|
| 359 |
+
for action_text in ROUTESTICK_OVERLAY_ACTION_TEXTS
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
# 绘制四个半圆箭头示意图(垂直排列)
|
| 363 |
# 示意图位置:在图像的左侧或右侧,从上到下垂直排列
|
| 364 |
illustration_width = 220 # 示意图区域宽度(已弃用,保留以保持兼容性)
|
|
|
|
| 406 |
draw_semicircle(draw, (lcw_center_x , lcw_center_y+15), semicircle_radius, line_color, line_width, half="upper", start_pos="left", end_pos="right", arrow_position="end", arrow_size=arrow_size)
|
| 407 |
|
| 408 |
# 添加标签 "L CW"
|
|
|
|
| 409 |
lcw_bbox = draw.textbbox((0, 0), lcw_text, font=small_font)
|
| 410 |
lcw_text_width = lcw_bbox[2] - lcw_bbox[0]
|
| 411 |
lcw_text_height = lcw_bbox[3] - lcw_bbox[1]
|
|
|
|
| 422 |
draw_semicircle(draw, (lccw_center_x, lccw_center_y), semicircle_radius, line_color, line_width, half="lower", start_pos="left", end_pos="right", arrow_position="end", arrow_size=arrow_size)
|
| 423 |
|
| 424 |
# 添加标签 "L CCW"
|
|
|
|
| 425 |
lccw_bbox = draw.textbbox((0, 0), lccw_text, font=small_font)
|
| 426 |
lccw_text_width = lccw_bbox[2] - lccw_bbox[0]
|
| 427 |
lccw_text_height = lccw_bbox[3] - lccw_bbox[1]
|
|
|
|
| 438 |
draw_semicircle(draw, (rcw_center_x , rcw_center_y), semicircle_radius, line_color, line_width, half="lower", start_pos="right", end_pos="left", arrow_position="end", arrow_size=arrow_size)
|
| 439 |
|
| 440 |
# 添加标签 "R CW"
|
|
|
|
| 441 |
rcw_bbox = draw.textbbox((0, 0), rcw_text, font=small_font)
|
| 442 |
rcw_text_width = rcw_bbox[2] - rcw_bbox[0]
|
| 443 |
rcw_text_height = rcw_bbox[3] - rcw_bbox[1]
|
|
|
|
| 454 |
draw_semicircle(draw, (rccw_center_x , rccw_center_y+15), semicircle_radius, line_color, line_width, half="upper",start_pos="right", end_pos="left", arrow_position="end", arrow_size=arrow_size)
|
| 455 |
|
| 456 |
# 添加标签 "R CCW"
|
|
|
|
| 457 |
rccw_bbox = draw.textbbox((0, 0), rccw_text, font=small_font)
|
| 458 |
rccw_text_width = rccw_bbox[2] - rccw_bbox[0]
|
| 459 |
rccw_text_height = rccw_bbox[3] - rccw_bbox[1]
|
gradio-web/main.py
CHANGED
|
@@ -6,6 +6,9 @@ import sys
|
|
| 6 |
import tempfile
|
| 7 |
from pathlib import Path
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
APP_DIR = Path(__file__).resolve().parent
|
| 10 |
PROJECT_ROOT = APP_DIR.parent
|
| 11 |
SRC_DIR = PROJECT_ROOT / "src"
|
|
@@ -13,6 +16,17 @@ VIDEOS_DIR = APP_DIR / "videos"
|
|
| 13 |
TEMP_DEMOS_DIR = PROJECT_ROOT / "temp_demos"
|
| 14 |
CWD_TEMP_DEMOS_DIR = Path.cwd() / "temp_demos"
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
if str(PROJECT_ROOT) not in sys.path:
|
| 17 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 18 |
if str(SRC_DIR) not in sys.path:
|
|
|
|
| 6 |
import tempfile
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
APP_DIR = Path(__file__).resolve().parent
|
| 13 |
PROJECT_ROOT = APP_DIR.parent
|
| 14 |
SRC_DIR = PROJECT_ROOT / "src"
|
|
|
|
| 16 |
TEMP_DEMOS_DIR = PROJECT_ROOT / "temp_demos"
|
| 17 |
CWD_TEMP_DEMOS_DIR = Path.cwd() / "temp_demos"
|
| 18 |
|
| 19 |
+
|
| 20 |
+
def configure_runtime_devices():
|
| 21 |
+
"""Restrict the app to physical GPU 1 and map rendering to the visible device."""
|
| 22 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
| 23 |
+
os.environ.setdefault("NVIDIA_VISIBLE_DEVICES", "1")
|
| 24 |
+
# After masking to physical GPU 1, libraries should use logical cuda:0.
|
| 25 |
+
os.environ["SAPIEN_RENDER_DEVICE"] = "cuda:0"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
configure_runtime_devices()
|
| 29 |
+
|
| 30 |
if str(PROJECT_ROOT) not in sys.path:
|
| 31 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 32 |
if str(SRC_DIR) not in sys.path:
|
gradio-web/test/test_reference_action_callbacks.py
CHANGED
|
@@ -4,8 +4,9 @@ from PIL import Image
|
|
| 4 |
|
| 5 |
|
| 6 |
class _FakeSession:
|
| 7 |
-
def __init__(self, reference_payload):
|
| 8 |
self._reference_payload = reference_payload
|
|
|
|
| 9 |
|
| 10 |
def get_reference_action(self):
|
| 11 |
return self._reference_payload
|
|
@@ -92,3 +93,32 @@ def test_on_option_select_keeps_valid_coords_when_option_needs_coords(monkeypatc
|
|
| 92 |
|
| 93 |
assert coords_text == "12, 34"
|
| 94 |
assert img_update.get("interactive") is True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class _FakeSession:
|
| 7 |
+
def __init__(self, reference_payload, env_id="BinFill"):
|
| 8 |
self._reference_payload = reference_payload
|
| 9 |
+
self.env_id = env_id
|
| 10 |
|
| 11 |
def get_reference_action(self):
|
| 12 |
return self._reference_payload
|
|
|
|
| 93 |
|
| 94 |
assert coords_text == "12, 34"
|
| 95 |
assert img_update.get("interactive") is True
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_on_reference_action_uses_configured_action_text_override(monkeypatch, reload_module):
|
| 99 |
+
config = reload_module("config")
|
| 100 |
+
callbacks = reload_module("gradio_callbacks")
|
| 101 |
+
|
| 102 |
+
session = _FakeSession(
|
| 103 |
+
{
|
| 104 |
+
"ok": True,
|
| 105 |
+
"option_idx": 0,
|
| 106 |
+
"option_label": "a",
|
| 107 |
+
"option_action": "move forward",
|
| 108 |
+
"need_coords": False,
|
| 109 |
+
"coords_xy": None,
|
| 110 |
+
"message": "ok",
|
| 111 |
+
},
|
| 112 |
+
env_id="PatternLock",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
|
| 116 |
+
monkeypatch.setattr(callbacks, "get_session", lambda uid: session)
|
| 117 |
+
|
| 118 |
+
_img, _option_update, coords_text, log_html = callbacks.on_reference_action("uid-1")
|
| 119 |
+
|
| 120 |
+
assert coords_text == config.UI_TEXT["coords"]["not_needed"]
|
| 121 |
+
assert log_html == config.UI_TEXT["log"]["reference_action_message"].format(
|
| 122 |
+
option_label="a",
|
| 123 |
+
option_action="move forward↑",
|
| 124 |
+
)
|
gradio-web/test/test_ui_text_config.py
CHANGED
|
@@ -4,8 +4,25 @@ import pytest
|
|
| 4 |
|
| 5 |
|
| 6 |
class _FakeOptionSession:
|
| 7 |
-
def __init__(self):
|
| 8 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def test_on_option_select_uses_configured_select_keypoint_message(monkeypatch, reload_module):
|
|
@@ -67,3 +84,111 @@ def test_missing_session_paths_use_configured_session_error(monkeypatch, reload_
|
|
| 67 |
assert log_text == "Session Error From Config"
|
| 68 |
assert map_img is None
|
| 69 |
assert map_text == "Session Error From Config"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class _FakeOptionSession:
|
| 7 |
+
def __init__(self, env_id="BinFill", raw_solve_options=None):
|
| 8 |
+
self.env_id = env_id
|
| 9 |
+
self.raw_solve_options = raw_solve_options or [{"available": True}]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class _FakeLoadSession:
|
| 13 |
+
def __init__(self, env_id, available_options, raw_solve_options):
|
| 14 |
+
self.env_id = env_id
|
| 15 |
+
self.available_options = available_options
|
| 16 |
+
self.raw_solve_options = raw_solve_options
|
| 17 |
+
self.language_goal = ""
|
| 18 |
+
self.demonstration_frames = []
|
| 19 |
+
|
| 20 |
+
def load_episode(self, env_id, episode_idx):
|
| 21 |
+
self.env_id = env_id
|
| 22 |
+
return "IMG", f"loaded {env_id} {episode_idx}"
|
| 23 |
+
|
| 24 |
+
def get_pil_image(self, use_segmented=False):
|
| 25 |
+
return "IMG"
|
| 26 |
|
| 27 |
|
| 28 |
def test_on_option_select_uses_configured_select_keypoint_message(monkeypatch, reload_module):
|
|
|
|
| 84 |
assert log_text == "Session Error From Config"
|
| 85 |
assert map_img is None
|
| 86 |
assert map_text == "Session Error From Config"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_get_ui_action_text_uses_configured_overrides_and_fallback(reload_module):
|
| 90 |
+
config = reload_module("config")
|
| 91 |
+
|
| 92 |
+
patternlock_expected = {
|
| 93 |
+
"move forward": "move forward↑",
|
| 94 |
+
"move backward": "move backward↓",
|
| 95 |
+
"move left": "move left→",
|
| 96 |
+
"move right": "move right←",
|
| 97 |
+
"move forward-left": "move forward-left↘︎",
|
| 98 |
+
"move forward-right": "move forward-right↙︎",
|
| 99 |
+
"move backward-left": "move backward-left↗︎",
|
| 100 |
+
"move backward-right": "move backward-right↖︎",
|
| 101 |
+
}
|
| 102 |
+
routestick_expected = {
|
| 103 |
+
"move to the nearest left target by circling around the stick clockwise": "move left clockwise↘︎→↗︎",
|
| 104 |
+
"move to the nearest right target by circling around the stick clockwise": "move right clockwise↖︎←↙︎",
|
| 105 |
+
"move to the nearest left target by circling around the stick counterclockwise": "move left counterclockwise↗︎→↘︎",
|
| 106 |
+
"move to the nearest right target by circling around the stick counterclockwise": "move right counterclockwise↖︎←↙︎",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
for raw_action, expected in patternlock_expected.items():
|
| 110 |
+
assert config.get_ui_action_text("PatternLock", raw_action) == expected
|
| 111 |
+
for raw_action, expected in routestick_expected.items():
|
| 112 |
+
assert config.get_ui_action_text("RouteStick", raw_action) == expected
|
| 113 |
+
assert config.get_ui_action_text("BinFill", "pick up the cube") == "pick up the cube"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_ui_option_label_uses_patternlock_configured_action_text(reload_module):
|
| 117 |
+
reload_module("config")
|
| 118 |
+
callbacks = reload_module("gradio_callbacks")
|
| 119 |
+
session = _FakeOptionSession(
|
| 120 |
+
env_id="PatternLock",
|
| 121 |
+
raw_solve_options=[{"label": "a", "action": "move forward", "available": False}],
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
assert callbacks._ui_option_label(session, "fallback", 0) == "a. move forward↑"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_ui_option_label_uses_routestick_configured_action_text(reload_module):
|
| 128 |
+
reload_module("config")
|
| 129 |
+
callbacks = reload_module("gradio_callbacks")
|
| 130 |
+
session = _FakeOptionSession(
|
| 131 |
+
env_id="RouteStick",
|
| 132 |
+
raw_solve_options=[
|
| 133 |
+
{
|
| 134 |
+
"label": "d",
|
| 135 |
+
"action": "move to the nearest right target by circling around the stick counterclockwise",
|
| 136 |
+
"available": False,
|
| 137 |
+
}
|
| 138 |
+
],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
assert callbacks._ui_option_label(session, "fallback", 0) == "d. move right counterclockwise↖︎←↙︎"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def test_load_status_task_appends_configured_keypoint_suffix_after_mapped_label(monkeypatch, reload_module):
|
| 145 |
+
config = reload_module("config")
|
| 146 |
+
callbacks = reload_module("gradio_callbacks")
|
| 147 |
+
session = _FakeLoadSession(
|
| 148 |
+
env_id="PatternLock",
|
| 149 |
+
available_options=[("a. move forward", 0)],
|
| 150 |
+
raw_solve_options=[{"label": "a", "action": "move forward", "available": [object()]}],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
monkeypatch.setattr(callbacks, "get_session", lambda uid: session)
|
| 154 |
+
monkeypatch.setattr(callbacks, "reset_play_button_clicked", lambda uid: None)
|
| 155 |
+
monkeypatch.setattr(callbacks, "reset_execute_count", lambda uid, env_id, episode_idx: None)
|
| 156 |
+
monkeypatch.setattr(callbacks, "set_task_start_time", lambda uid, env_id, episode_idx, start_time: None)
|
| 157 |
+
monkeypatch.setattr(callbacks, "set_ui_phase", lambda uid, phase: None)
|
| 158 |
+
monkeypatch.setattr(callbacks, "get_task_hint", lambda env_id: "")
|
| 159 |
+
monkeypatch.setattr(callbacks, "should_show_demo_video", lambda env_id: False)
|
| 160 |
+
|
| 161 |
+
result = callbacks._load_status_task(
|
| 162 |
+
"uid-1",
|
| 163 |
+
{"current_task": {"env_id": "PatternLock", "episode_idx": 1}, "completed_count": 3},
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
assert result[4]["choices"] == [
|
| 167 |
+
(
|
| 168 |
+
f"a. move forward↑{config.UI_TEXT['actions']['keypoint_required_suffix']}",
|
| 169 |
+
0,
|
| 170 |
+
)
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def test_draw_coordinate_axes_uses_configured_routestick_overlay_labels(monkeypatch, reload_module):
|
| 175 |
+
config = reload_module("config")
|
| 176 |
+
image_utils = reload_module("image_utils")
|
| 177 |
+
recorded_texts = []
|
| 178 |
+
original_text = image_utils.ImageDraw.ImageDraw.text
|
| 179 |
+
|
| 180 |
+
def _record_text(self, xy, text, *args, **kwargs):
|
| 181 |
+
recorded_texts.append(text)
|
| 182 |
+
return original_text(self, xy, text, *args, **kwargs)
|
| 183 |
+
|
| 184 |
+
monkeypatch.setattr(image_utils.ImageDraw.ImageDraw, "text", _record_text)
|
| 185 |
+
|
| 186 |
+
img = image_utils.Image.new("RGB", (220, 260), color=(0, 0, 0))
|
| 187 |
+
image_utils.draw_coordinate_axes(img, position="left", env_id="RouteStick")
|
| 188 |
+
|
| 189 |
+
expected_labels = [
|
| 190 |
+
config.get_ui_action_text("RouteStick", action_text)
|
| 191 |
+
for action_text in config.ROUTESTICK_OVERLAY_ACTION_TEXTS
|
| 192 |
+
]
|
| 193 |
+
for label in expected_labels:
|
| 194 |
+
assert label in recorded_texts
|