HongzeFu commited on
Commit
06c11b0
·
0 Parent(s):

HF Space: code-only (no binary assets)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +217 -0
  2. .python-version +1 -0
  3. README.md +10 -0
  4. app.py +72 -0
  5. doc/env_format.md +117 -0
  6. doc/h5_data_format.md +71 -0
  7. doc/submission/model_example.md +69 -0
  8. gradio-web/AAAuser_generator.py +147 -0
  9. gradio-web/config.py +41 -0
  10. gradio-web/gradio_callbacks.py +997 -0
  11. gradio-web/image_utils.py +716 -0
  12. gradio-web/main.py +61 -0
  13. gradio-web/note_content.py +181 -0
  14. gradio-web/oracle_logic.py +975 -0
  15. gradio-web/process_session.py +448 -0
  16. gradio-web/scripts/run_background.sh +287 -0
  17. gradio-web/scripts/后台运行说明.md +288 -0
  18. gradio-web/state_manager.py +473 -0
  19. gradio-web/test/conftest.py +39 -0
  20. gradio-web/test/test_episode98_removed_behavior.py +107 -0
  21. gradio-web/test/test_execute_stream_frames.py +59 -0
  22. gradio-web/test/test_live_obs_refresh.py +70 -0
  23. gradio-web/test/test_option_label_format.py +196 -0
  24. gradio-web/test/test_oracle_builder_integration.py +184 -0
  25. gradio-web/test/test_oracle_imports.py +18 -0
  26. gradio-web/test/test_precheck_execute_inputs.py +53 -0
  27. gradio-web/test/test_process_session_sanitize.py +39 -0
  28. gradio-web/test/test_reference_action_callbacks.py +84 -0
  29. gradio-web/test/test_reference_action_oracle.py +117 -0
  30. gradio-web/test/test_ui_native_layout_contract.py +88 -0
  31. gradio-web/test/test_ui_phase_machine_runtime_e2e.py +782 -0
  32. gradio-web/test/test_user_manager_random_flow.py +96 -0
  33. gradio-web/ui_layout.py +547 -0
  34. gradio-web/user_manager.py +178 -0
  35. gradio-web/verify_video_names.py +128 -0
  36. pyproject.toml +34 -0
  37. readme.md +135 -0
  38. requirements.txt +11 -0
  39. scripts/dataset_replay.py +268 -0
  40. scripts/dev/compare_multi_choice_readers.py +334 -0
  41. scripts/dev/dataset_replay_printType.py +254 -0
  42. scripts/dev/deprecated/dataset_replay-FK-parallel.py +335 -0
  43. scripts/dev/deprecated/dataset_replay-FK.py +264 -0
  44. scripts/dev/deprecated/dataset_replay-ee-parallel.py +214 -0
  45. scripts/dev/deprecated/dataset_replay-ee.py +163 -0
  46. scripts/dev/eval-dataset-offline-rpy.py +195 -0
  47. scripts/dev/eval_dataset_replay.py +476 -0
  48. scripts/dev/evaluate_dataset_replay-parallelv3.py +669 -0
  49. scripts/dev/evaluate_dataset_replay-parallelv4-noresolver.py +676 -0
  50. scripts/dev/generate-dataset-control-seed-readJson-advanceV3.py +878 -0
.gitignore ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # Agent
210
+ .agent/
211
+ .cursor/
212
+
213
+ # Local temp demo files
214
+ temp_demos/
215
+
216
+ # Gradio user action logs
217
+ gradio/data/user_action_logs/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RoboMME Oracle Planner
3
+ sdk: gradio
4
+ app_file: gradio/main.py
5
+ python_version: "3.11"
6
+ ---
7
+
8
+ This Space runs the RoboMME Gradio interface in single-instance session mode.
9
+
10
+ Project docs are in `readme.md`.
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Spaces entrypoint for RoboMME Gradio app."""
2
+
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ APP_DIR = Path(__file__).resolve().parent
9
+ GRADIO_WEB_DIR = APP_DIR / "gradio-web"
10
+ SRC_DIR = APP_DIR / "src"
11
+ VIDEOS_DIR = GRADIO_WEB_DIR / "videos"
12
+ TEMP_DEMOS_DIR = APP_DIR / "temp_demos"
13
+ CWD_TEMP_DEMOS_DIR = Path.cwd() / "temp_demos"
14
+
15
+ # Ensure local modules are importable when running from repository root (HF Spaces).
16
+ for import_path in (GRADIO_WEB_DIR, SRC_DIR, APP_DIR):
17
+ resolved = str(import_path.resolve())
18
+ if resolved not in sys.path:
19
+ sys.path.insert(0, resolved)
20
+
21
+ from state_manager import start_timeout_monitor
22
+ from ui_layout import create_ui_blocks
23
+
24
+
25
+ def ensure_media_dirs() -> None:
26
+ """Create temp media directories before first write."""
27
+ TEMP_DEMOS_DIR.mkdir(parents=True, exist_ok=True)
28
+ CWD_TEMP_DEMOS_DIR.mkdir(parents=True, exist_ok=True)
29
+
30
+
31
+ def build_allowed_paths() -> list[str]:
32
+ """Build Gradio file access allowlist (absolute, deduplicated)."""
33
+ candidates = [
34
+ Path.cwd(),
35
+ APP_DIR,
36
+ GRADIO_WEB_DIR,
37
+ SRC_DIR,
38
+ VIDEOS_DIR,
39
+ TEMP_DEMOS_DIR,
40
+ CWD_TEMP_DEMOS_DIR,
41
+ Path(tempfile.gettempdir()),
42
+ ]
43
+
44
+ deduped = []
45
+ seen = set()
46
+ for path in candidates:
47
+ normalized = str(path.resolve())
48
+ if normalized not in seen:
49
+ seen.add(normalized)
50
+ deduped.append(normalized)
51
+ return deduped
52
+
53
+
54
+ def main() -> None:
55
+ ensure_media_dirs()
56
+ start_timeout_monitor()
57
+
58
+ os.environ.setdefault("ROBOMME_TEMP_DEMOS_DIR", str(TEMP_DEMOS_DIR))
59
+ allowed_paths = build_allowed_paths()
60
+
61
+ demo = create_ui_blocks()
62
+ demo.queue(default_concurrency_limit=2)
63
+ demo.launch(
64
+ server_name="0.0.0.0",
65
+ server_port=int(os.getenv("PORT", "7860")),
66
+ allowed_paths=allowed_paths,
67
+
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
doc/env_format.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Input/Output
2
+
3
+ On RoboMME, a key difference from traditional Gym-like envs is that every observation value is a **list** rather than a single item. This is because some RoboMME tasks use conditioning video input, and for discrete action types (e.g. waypoint or multi_choice) we also return intermediate observations for potential use with video-based policy models.
4
+
5
+
6
+ ## Env Input Format
7
+
8
+ We support four `ACTION_SPACE` types:
9
+
10
+ - `joint_angle`: 7 joint angles + gripper open/close
11
+ - `ee_pose`: 3 position (xyz) + 3 rotation (rpy) + gripper open/close
12
+ - `waypoint`: Same format as ee_pose, but executed in discrete keyframe steps
13
+ - `multi_choice`: Command dict, e.g. `{"choice": "A", "point": [y, x]}`; the total choices can be found in `info["available_multi_choices"]`, where the `point` is the pixel location on the front image. this action is designed for Video-QA research.
14
+
15
+ Note: Gripper closed is -1, gripper open is 1.
16
+
17
+
18
+ ## Env Output Format
19
+
20
+ When calling the `step` function:
21
+
22
+ ```python
23
+ obs, reward, terminated, truncated, info = env.step(action)
24
+ ```
25
+
26
+ | Return | Description | Typical type |
27
+ |--------|-------------|--------------|
28
+ | `obs` | Observation dict | `dict[str, list]` |
29
+ | `info` | Info dict | `dict[str, Any]` |
30
+ | `reward` | Reward value (not used) | scalar tensor |
31
+ | `terminated` | Termination flag | scalar boolean tensor |
32
+ | `truncated` | Truncation flag | scalar boolean tensor |
33
+
34
+ ### `obs` dict
35
+
36
+ | Key | Meaning | Typical content |
37
+ |-----|---------|-----------------|
38
+ | `maniskill_obs` | The original raw env observation from ManiSkill | Raw observation dict |
39
+ | `front_rgb_list` | Front camera RGB List | Image frames, e.g. `(H, W, 3)` |
40
+ | `wrist_rgb_list` | Wrist camera RGB List | Image frames, e.g. `(H, W, 3)` |
41
+ | `front_depth_list` | Front camera depth List | Depth map, e.g. `(H, W, 1)` |
42
+ | `wrist_depth_list` | Wrist camera depth List | Depth map, e.g. `(H, W, 1)` |
43
+ | `eef_state_list` | End-effector state List | `[x, y, z, roll, pitch, yaw]` |
44
+ | `joint_state_list` | Robot joint state List | Joint vector, often 7-D |
45
+ | `gripper_state_list` | Robot gripper state List | 2-D |
46
+ | `front_camera_extrinsic_list` | Front camera extrinsic List | Camera extrinsic matrix |
47
+ | `wrist_camera_extrinsic_list` | Wrist camera extrinsic List | Camera extrinsic matrix |
48
+
49
+
50
+ To use only the current (latest) observation, use `obs[key][-1]`.
51
+
52
+ ### Optional field switches (`include_*`)
53
+
54
+ `BenchmarkEnvBuilder.make_env_for_episode(...)` controls optional observation/info fields through `include_*` flags.
55
+
56
+ Default behavior:
57
+ - All `include_*` flags default to `False`.
58
+ - Without extra flags, env returns RGB + state related fields only.
59
+
60
+ Mapping:
61
+
62
+ | Flag | Added key |
63
+ |------|-----------|
64
+ | `include_maniskill_obs` | `obs["maniskill_obs"]` |
65
+ | `include_front_depth` | `obs["front_depth_list"]` |
66
+ | `include_wrist_depth` | `obs["wrist_depth_list"]` |
67
+ | `include_front_camera_extrinsic` | `obs["front_camera_extrinsic_list"]` |
68
+ | `include_wrist_camera_extrinsic` | `obs["wrist_camera_extrinsic_list"]` |
69
+ | `include_available_multi_choices` | `info["available_multi_choices"]` |
70
+ | `include_front_camera_intrinsic` | `info["front_camera_intrinsic"]` |
71
+ | `include_wrist_camera_intrinsic` | `info["wrist_camera_intrinsic"]` |
72
+
73
+ Special case:
74
+ - If `action_space="multi_choice"`, front camera parameters are forced on internally:
75
+ - `front_camera_extrinsic_list`
76
+ - `front_camera_intrinsic`
77
+ Even if the corresponding `include_front_camera_*` flags are `False`.
78
+
79
+ Example:
80
+
81
+ ```python
82
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
83
+
84
+ builder = BenchmarkEnvBuilder(
85
+ env_id="VideoUnmaskSwap",
86
+ dataset="test",
87
+ action_space="joint_angle",
88
+ gui_render=False,
89
+ )
90
+
91
+ env = builder.make_env_for_episode(
92
+ episode_idx=0,
93
+ max_steps=1000,
94
+ include_maniskill_obs=False,
95
+ include_front_depth=True,
96
+ include_wrist_depth=False,
97
+ include_front_camera_extrinsic=True,
98
+ include_wrist_camera_extrinsic=False,
99
+ include_available_multi_choices=False,
100
+ include_front_camera_intrinsic=True,
101
+ include_wrist_camera_intrinsic=False,
102
+ )
103
+
104
+ obs, info = env.reset()
105
+ ```
106
+
107
+ ### `info` dict
108
+
109
+ | Key | Meaning | Typical content |
110
+ |-----|---------|-----------------|
111
+ | `task_goal` | Task goal list | `list[str]` |
112
+ | `simple_subgoal_online` | Oracle online simple subgoal | Description of the current simple subgoal |
113
+ | `grounded_subgoal_online` | Oracle online grounded subgoal | Description of the current grounded subgoal |
114
+ | `available_multi_choices` | Current available options for multi-choice action | List of e.g. `{"label: "a/b/...", "action": str, "need_parameter": bool}`, need_parameter means this action needs grounding info like `[y, x]` |
115
+ | `front_camera_intrinsic` | Front camera intrinsic | Camera intrinsic matrix |
116
+ | `wrist_camera_intrinsic` | Wrist camera intrinsic | Camera intrinsic matrix |
117
+ | `status` | Status flag | One of `success`, `fail`, `timeout`, `ongoing`, `error` |
doc/h5_data_format.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HDF5 Training Data Format
2
+
3
+ Structure inside each `record_dataset_<EnvID>.h5` file:
4
+
5
+ ```text
6
+ episode_1/
7
+ setup/
8
+ timestep_1/
9
+ obs/
10
+ action/
11
+ info/
12
+ timestep_2/
13
+ obs/
14
+ action/
15
+ info/
16
+ ...
17
+ ...
18
+ ```
19
+
20
+ Each episode contains:
21
+ - `setup/`: episode-level configuration.
22
+ - `timestep_<K>/`: per-timestep data.
23
+
24
+ ## `setup/` fields (episode configuration)
25
+
26
+ | Field | Type | Description |
27
+ |-------|------|-------------|
28
+ | `seed` | `int` | Environment seed (fixed for benchmarking) |
29
+ | `difficulty` | `str` | Difficulty level (fixed for benchmarking) |
30
+ | `task_goal` | `list[str]` | Possible language goals for the task |
31
+ | `front_camera_intrinsic` | `float32 (3, 3)` | Front camera intrinsic matrix |
32
+ | `wrist_camera_intrinsic` | `float32 (3, 3)` | Wrist camera intrinsic matrix |
33
+ | `available_multi_choices` | `str` | Available options for the multi-choice Video-QA problem |
34
+
35
+ ## `obs/` fields (observations)
36
+
37
+ | Field | Type / shape | Description |
38
+ |-------|---------------|-------------|
39
+ | `front_rgb` | `uint8 (512, 512, 3)` | Front camera RGB |
40
+ | `wrist_rgb` | `uint8 (256, 256, 3)` | Wrist camera RGB |
41
+ | `front_depth` | `int16 (512, 512, 1)` | Front camera depth (mm) |
42
+ | `wrist_depth` | `int16 (256, 256, 1)` | Wrist camera depth (mm) |
43
+ | `joint_state` | `float32 (7,)` | Joint positions (7 joints) |
44
+ | `eef_state` | `float32 (6,)` | End-effector pose `[x, y, z, roll, pitch, yaw]` |
45
+ | `gripper_state` | `float32 (2,)` | Gripper opening width in [0, 0.04] |
46
+ | `is_gripper_close` | `bool` | Whether gripper is closed |
47
+ | `front_camera_extrinsic` | `float32 (3, 4)` | Front camera extrinsic matrix |
48
+ | `wrist_camera_extrinsic` | `float32 (3, 4)` | Wrist camera extrinsic matrix |
49
+
50
+ ## `action/` fields
51
+
52
+ | Field | Type / shape | Description |
53
+ |-------|---------------|-------------|
54
+ | `joint_action` | `float32 (8,)` | Joint-space action: 7 joint angles + gripper |
55
+ | `eef_action` | `float32 (7,)` | End-effector action `[x, y, z, roll, pitch, yaw, gripper]` |
56
+ | `waypoint_action` | `float32 (7,)` | End-effector action at discrete time steps; a subtask may contain multiple waypoint actions. Used for data generation. |
57
+ | `choice_action` | `str` | JSON string for multi-choice selection with an optional grounded pixel location on the front image, e.g., `{"choice": "A", "point": [y, x]}` |
58
+
59
+ In RoboMME, a gripper action of -1 means close and 1 means open.
60
+
61
+ ## `info/` fields (metadata)
62
+
63
+ | Field | Type | Description |
64
+ |-------|------|-------------|
65
+ | `simple_subgoal` | `bytes (UTF-8)` | Simple subgoal text (built-in planner view) |
66
+ | `simple_subgoal_online` | `bytes (UTF-8)` | Simple subgoal text (online view; may advance to the next subgoal earlier than planner view) |
67
+ | `grounded_subgoal` | `bytes (UTF-8)` | Grounded subgoal text (built-in planner view) |
68
+ | `grounded_subgoal_online` | `bytes (UTF-8)` | Grounded subgoal text (online view; may advance to the next subgoal earlier than planner view) |
69
+ | `is_video_demo` | `bool` | Whether this frame is from the conditioning video shown before execution |
70
+ | `is_subgoal_boundary` | `bool` | Whether this is a keyframe (i.e., a boundary between subtasks) |
71
+ | `is_completed` | `bool` | Whether the task is finished |
doc/submission/model_example.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Your Cool Model Name
2
+
3
+ ### [Website]() | [Paper]() | [Code]()
4
+
5
+ ## Introduction
6
+ My cool model leverages a novel representation for history keyframes and maintains a memory cache to integrate with diffusion policy.
7
+
8
+ ## Results
9
+
10
+ > We ask for **at least three runs** with different model seeds to decrease the performance fluctuations.
11
+ > The benchmark seed is fixed internally.
12
+
13
+ ### Table
14
+
15
+ <table>
16
+ <tr>
17
+ <th rowspan="2">Suite</th>
18
+ <th rowspan="2">Task</th>
19
+ </tr>
20
+ <tr>
21
+ <th>Seed 7</th><th>Seed 42</th><th>Seed 0</th><th><b>Avg</b></th>
22
+ </tr>
23
+ <tr>
24
+ <td rowspan="4">Counting</td>
25
+ <td>BinFill</td><td></td><td></td><td></td><td></td>
26
+ </tr>
27
+ <tr><td>PickXtimes</td><td></td><td></td><td></td><td></td></tr>
28
+ <tr><td>SwingXtimes</td><td></td><td></td><td></td><td></td></tr>
29
+ <tr><td>StopCube</td><td></td><td></td><td></td><td></td></tr>
30
+ <tr>
31
+ <td rowspan="4">Permanence</td>
32
+ <td>VideoUnmask</td><td></td><td></td><td></td><td></td>
33
+ </tr>
34
+ <tr><td>VideoUnmaskSwap</td><td></td><td></td><td></td><td></td></tr>
35
+ <tr><td>ButtonUnmask</td><td></td><td></td><td></td><td></td></tr>
36
+ <tr><td>ButtonUnmaskSwap</td><td></td><td></td><td></td><td></td></tr>
37
+ <tr>
38
+ <td rowspan="4">Reference</td>
39
+ <td>PickHighlight</td><td></td><td></td><td></td><td></td>
40
+ </tr>
41
+ <tr><td>VideoRepick</td><td></td><td></td><td></td><td></td></tr>
42
+ <tr><td>VideoPlaceButton</td><td></td><td></td><td></td><td></td></tr>
43
+ <tr><td>VideoPlaceOrder</td><td></td><td></td><td></td><td></td></tr>
44
+ <tr>
45
+ <td rowspan="4">Imitation</td>
46
+ <td>MoveCube</td><td></td><td></td><td></td><td></td>
47
+ </tr>
48
+ <tr><td>InsertPeg</td><td></td><td></td><td></td><td></td></tr>
49
+ <tr><td>PatternLock</td><td></td><td></td><td></td><td></td></tr>
50
+ <tr><td>RouteStick</td><td></td><td></td><td></td><td></td></tr>
51
+ <tr>
52
+ <td colspan="2"><b>Overall</b></td><td></td><td></td><td></td><td></td>
53
+ </tr>
54
+ </table>
55
+
56
+
57
+ ### Training Details
58
+
59
+ Any hyperparameters you would like to share
60
+
61
+ ### Released Checkpoints
62
+
63
+ Any fine-tuned checkpoints you would like to release
64
+
65
+ > We highly encourage authors to fully release their training/eval code and checkpoints to help the community accelerate memory-augmented manipulation.
66
+
67
+ ### Citations
68
+ ```
69
+ ```
gradio-web/AAAuser_generator.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ # Deprecated runtime path:
5
+ # This script is only for offline generation experiments and is not used by
6
+ # the current Gradio runtime task assignment flow.
7
+
8
+ ENVS = [
9
+ # Counting
10
+ "BinFill",
11
+ "PickXtimes",
12
+ "SwingXtimes",
13
+ "StopCube",
14
+
15
+ # Persistence
16
+ "VideoUnmask",
17
+ "ButtonUnmask",
18
+ "VideoUnmaskSwap",
19
+ "ButtonUnmaskSwap",
20
+
21
+ # Reference
22
+ "PickHighlight",
23
+ "VideoRepick",
24
+ "VideoPlaceButton",
25
+ "VideoPlaceOrder",
26
+
27
+ # Behavior
28
+ "MoveCube",
29
+ "InsertPeg",
30
+ "PatternLock",
31
+ "RouteStick",
32
+ ]
33
+
34
+ REAL_USERS = [
35
+ "Hongyu_Zhou",
36
+ "Wanling_Cai",
37
+ "Xinyi_Wang",
38
+ "Yinpei_Dai",
39
+ "Hongze_Fu",
40
+ "Run_Peng",
41
+ "Haoran_Zhang",
42
+ "Yunqi_Zhao",
43
+ "Yue_Hu",
44
+ "Yiwei_Lyu",
45
+ "Josue_Torres-Fonseca",
46
+ "Jung-Chun_Liu",
47
+ "Jacob_Sansom",
48
+ "Long-Jing_Hsu"
49
+
50
+ ]
51
+
52
+ NUM_USERS = 20
53
+ EPISODES_PER_ENV = 50
54
+ TEST_EPISODE_IDX = 98
55
+
56
+
57
+ def generate_json(seed: int = 0):
58
+ rng = random.Random(seed)
59
+
60
+ # 1️⃣ 为每个环境生成所有任务
61
+ env_tasks = {}
62
+ for env in ENVS:
63
+ env_tasks[env] = [
64
+ {"env_id": env, "episode_idx": ep}
65
+ for ep in range(EPISODES_PER_ENV)
66
+ ]
67
+
68
+ # Generate user keys
69
+ user_keys = []
70
+ for i in range(NUM_USERS):
71
+ if i < len(REAL_USERS):
72
+ user_keys.append(REAL_USERS[i])
73
+ else:
74
+ user_keys.append(f"user{i+1}")
75
+
76
+ # 2️⃣ 初始化用户任务列表
77
+ users = {key: [] for key in user_keys}
78
+
79
+ # 3️⃣ 阶段1:保证每个用户都有全部环境至少一次
80
+ # 为每个用户从每个环境随机选择1个任务
81
+ used_tasks = {env: set() for env in ENVS} # 记录已使用的episode_idx
82
+
83
+ for user_key in user_keys:
84
+ for env in ENVS:
85
+ # 从该环境的可用任务中随机选择一个
86
+ available = [
87
+ task for task in env_tasks[env]
88
+ if task["episode_idx"] not in used_tasks[env]
89
+ ]
90
+ if available:
91
+ selected_task = rng.choice(available)
92
+ users[user_key].append(selected_task)
93
+ used_tasks[env].add(selected_task["episode_idx"])
94
+
95
+ # 4️⃣ 阶段2:均匀分配剩余任务
96
+ # 收集剩余任务(未被使用的任务)
97
+ remaining_tasks = []
98
+ for env in ENVS:
99
+ for task in env_tasks[env]:
100
+ if task["episode_idx"] not in used_tasks[env]:
101
+ remaining_tasks.append(task)
102
+
103
+ # 打乱剩余任务
104
+ rng.shuffle(remaining_tasks)
105
+
106
+ # 均匀分配给用户,保持每个环境在每个用户中的平衡
107
+ # 每个用户再分到剩余任务数/用户数的任务
108
+ remaining_per_user = len(remaining_tasks) // NUM_USERS
109
+
110
+ for i in range(NUM_USERS):
111
+ start = i * remaining_per_user
112
+ end = (i + 1) * remaining_per_user
113
+ users[user_keys[i]].extend(remaining_tasks[start:end])
114
+
115
+ # 如果有余数,分配给前几个用户(每个用户1个)
116
+ remainder = len(remaining_tasks) % NUM_USERS
117
+ if remainder > 0:
118
+ start_idx = remaining_per_user * NUM_USERS
119
+ for i in range(remainder):
120
+ users[user_keys[i]].append(remaining_tasks[start_idx + i])
121
+
122
+ # 5️⃣ test(保持你原格式)
123
+ test_template = [
124
+ {"env_id": env, "episode_idx": TEST_EPISODE_IDX}
125
+ #for env in ENVS if env == "ButtonUnmask" or env == "VideoUnmaskSwap"
126
+ for env in ENVS
127
+ ]
128
+
129
+ output = {}
130
+ for user_key in user_keys:
131
+ # 把test任务放在训练任务前面
132
+ output[user_key] = test_template + users[user_key]
133
+ #output[f"user{i}_test"] = test_template 不输出test
134
+
135
+ return output
136
+
137
+
138
+ if __name__ == "__main__":
139
+ data = generate_json(seed=42)
140
+
141
+ with open("user_tasks.json", "w", encoding="utf-8") as f:
142
+ json.dump(data, f, indent=2, ensure_ascii=False)
143
+
144
+ counts = {k: len(v) for k, v in data.items() if not k.endswith("_test")}
145
+ print("Train counts:", counts)
146
+ print("Min/Max:", min(counts.values()), max(counts.values()))
147
+ print("✅ 已生成并保存到 user_tasks.json")
gradio-web/config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 配置常量模块
3
+ """
4
+ # --- Configuration ---
5
+ VIDEO_PLAYBACK_FPS = 30.0 # Frame rate for demonstration video playback
6
+ USE_SEGMENTED_VIEW = False # Set to True to use segmented view, False to use original image
7
+
8
+ # 主界面两列宽度比例 (Keypoint Selection : Right Panel)
9
+ KEYPOINT_SELECTION_SCALE = 1
10
+ CONTROL_PANEL_SCALE = 2
11
+
12
+ # 右侧顶部并排比例 (Action Selection : System Log)
13
+ RIGHT_TOP_ACTION_SCALE = 2
14
+ RIGHT_TOP_LOG_SCALE = 1
15
+
16
+ # Session超时配置
17
+ SESSION_TIMEOUT = 300 # Session超时时间(秒),如果30秒内没有execute_step操作,将自动回收session
18
+
19
+ # 兜底执行次数配置
20
+ EXECUTE_LIMIT_OFFSET = 4 # 兜底执行次数 = non_demonstration_task_length + EXECUTE_LIMIT_OFFSET
21
+
22
+
23
+ # 应该显示demonstration videos的环境ID列表
24
+ DEMO_VIDEO_ENV_IDS = [
25
+ "VideoPlaceOrder",
26
+ "VideoUnmaskSwap",
27
+ "VideoUnmask",
28
+ "VideoRepick",
29
+ "VideoPlaceButton",
30
+ "InsertPeg",
31
+ "MoveCube",
32
+ "PatternLock",
33
+ "RouteStick"
34
+ ]
35
+
36
+ def should_show_demo_video(env_id):
37
+ """
38
+ 判断指定的环境ID是否应该显示demonstration video
39
+ 只有DEMO_VIDEO_ENV_IDS列表中的环境才显示demonstration videos
40
+ """
41
+ return env_id in DEMO_VIDEO_ENV_IDS
gradio-web/gradio_callbacks.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio回调函数模块
3
+ 响应UI事件,调用业务逻辑,返回UI更新
4
+ """
5
+ import gradio as gr
6
+ import numpy as np
7
+ import time
8
+ import threading
9
+ import queue
10
+ import os
11
+ import re
12
+ from datetime import datetime
13
+ from PIL import Image
14
+ from state_manager import (
15
+ get_session,
16
+ create_session,
17
+ set_ui_phase,
18
+ reset_ui_phase,
19
+ get_execute_count,
20
+ increment_execute_count,
21
+ reset_execute_count,
22
+ set_task_start_time,
23
+ update_session_activity,
24
+ get_session_activity,
25
+ cleanup_session,
26
+ reset_play_button_clicked,
27
+ GLOBAL_SESSIONS,
28
+ SESSION_LAST_ACTIVITY,
29
+ _state_lock,
30
+ )
31
+ from image_utils import draw_marker, save_video, concatenate_frames_horizontally
32
+ from user_manager import user_manager
33
+ from config import USE_SEGMENTED_VIEW, should_show_demo_video, SESSION_TIMEOUT, EXECUTE_LIMIT_OFFSET
34
+ from process_session import ScrewPlanFailureError, ProcessSessionProxy
35
+ from note_content import get_task_hint
36
+
37
+
38
+ # --- live_obs refresh queue state ---
39
+ # Each uid keeps its own FIFO queue and sampling cursor.
40
+ _LIVE_OBS_REFRESH = {}
41
+ _LIVE_OBS_REFRESH_LOCK = threading.Lock()
42
+
43
+
44
+ def capitalize_first_letter(text: str) -> str:
45
+ """确保字符串的第一个字母大写,其余字符保持不变"""
46
+ if not text:
47
+ return text
48
+ if len(text) == 1:
49
+ return text.upper()
50
+ return text[0].upper() + text[1:]
51
+
52
+
53
+ def get_videoplacebutton_goal(original_goal: str) -> str:
54
+ """
55
+ 为 VideoPlaceButton 任务构造新的任务目标
56
+ 匹配 "cube on the target" 并替换为新的目标格式
57
+ """
58
+ if not original_goal:
59
+ return ""
60
+
61
+ original_lower = original_goal.lower()
62
+
63
+ # 匹配 "cube on the target" 并替换
64
+ if "cube on the target" in original_lower:
65
+ # 使用正则表达式进行不区分大小写的替换
66
+ pattern = re.compile(re.escape("cube on the target"), re.IGNORECASE)
67
+ new_goal = pattern.sub("cube on the target that it was previously placed on", original_goal)
68
+ return capitalize_first_letter(new_goal)
69
+ else:
70
+ # 如果无法匹配,保持原始任务目标不变
71
+ return capitalize_first_letter(original_goal)
72
+
73
+
74
+ def _ui_option_label(session, opt_label: str, opt_idx: int) -> str:
75
+ """
76
+ 仅在 Gradio UI 层对选项显示文案做覆盖(不改底层 env/options 生成逻辑)。
77
+ 目前只对 RouteStick 任务把 4 个长句 label 显示为短 label。
78
+ """
79
+ env_id = getattr(session, "env_id", None)
80
+ if env_id == "RouteStick":
81
+ routestick_map = {
82
+ 0: "move left clockwise",
83
+ 1: "move right clockwise",
84
+ 2: "move left counterclockwise",
85
+ 3: "move right counterclockwise",
86
+ }
87
+ return routestick_map.get(int(opt_idx), opt_label)
88
+ return opt_label
89
+
90
+
91
+ def format_log_markdown(log_message):
92
+ """
93
+ 将日志消息标准化为纯文本,供 Textbox 展示。
94
+
95
+ Args:
96
+ log_message: 纯文本日志消息(可以是多行)
97
+
98
+ Returns:
99
+ str: 清洗后的纯文本日志字符串
100
+ """
101
+ if log_message is None:
102
+ return ""
103
+ return str(log_message).replace("\r\n", "\n").replace("\r", "\n")
104
+
105
+
106
+ def show_task_hint(uid, current_hint=""):
107
+ """
108
+ 按需加载任务提示内容(仅在用户点击"Task Hint"按钮时调用)
109
+ On-demand loading of task hint based on current session's env_id.
110
+ 支持切换显示/隐藏:如果当前提示为空则显示,如果不为空则隐藏。
111
+
112
+ 【修改说明】
113
+ 此函数用于实现任务提示的延迟加载和切换显示功能。用户点击"Task Hint"按钮时:
114
+ - 如果当前提示内容为空,则从当前session中读取env_id并加载对应的提示内容
115
+ - 如果当前提示内容不为空,则清空提示内容(隐藏)
116
+
117
+ Args:
118
+ uid: 用户会话的唯一标识符,用于获取当前session对象
119
+ current_hint: 当前提示内容的文本,用于判断是否显示/隐藏
120
+
121
+ Returns:
122
+ str: 根据当前环境ID返回的任务提示内容(Markdown格式),
123
+ 如果当前提示不为空则返回空字符串(隐藏),
124
+ 如果session不存在或env_id未加载则返回空字符串或错误提示
125
+ """
126
+ # 如果当前提示内容不为空,则切换为隐藏(返回空字符串)
127
+ if current_hint and current_hint.strip():
128
+ return ""
129
+
130
+ # 从全局状态管理器中获取当前用户的session对象
131
+ session = get_session(uid)
132
+ if not session:
133
+ # 如果session不存在,返回空字符串(前端不会显示任何内容)
134
+ return ""
135
+
136
+ # 从session对象中获取当前加载的环境ID(env_id)
137
+ # 使用getattr安全获取属性,如果不存在则返回None
138
+ env_id = getattr(session, 'env_id', None)
139
+ if not env_id:
140
+ # 如果环境ID未加载,返回提示信息
141
+ return "No environment loaded."
142
+
143
+ # 根据环境ID调���get_task_hint函数获取对应的任务提示内容
144
+ # 该函数会根据不同的env_id返回不同的提示文本(如PickXtimes、VideoPlaceOrder等)
145
+ return get_task_hint(env_id)
146
+
147
+
148
+ def show_loading_info():
149
+ """
150
+ 显示加载环境的全屏遮罩层提示信息
151
+
152
+ 功能说明:
153
+ - 此函数在用户点击登录/加载任务等按钮时被调用
154
+ - 返回包含全屏遮罩层的 HTML 字符串,用于显示加载提示
155
+ - 遮罩层会覆盖整个页面,防止用户在加载过程中进行其他操作
156
+ - 加载完成后,回调函数会返回空字符串 "" 来清空 loading_overlay 组件,从而隐藏遮罩层
157
+
158
+ 工作流程:
159
+ 1. 用户点击按钮(如 Login、Next Task 等)
160
+ 2. 按钮的 click 事件首先调用此函数,显示遮罩层
161
+ 3. 然后通过 .then() 链式调用实际的加载函数(如 login_and_load_task)
162
+ 4. 加载函数执行完成后,返回 gr.update(visible=False) 隐藏遮罩层
163
+
164
+ Returns:
165
+ gr.update: 显示 loading overlay group
166
+ """
167
+ return gr.update(visible=True)
168
+
169
+
170
+ def on_video_end(uid):
171
+ """
172
+ Called when the demonstration video finishes playing.
173
+ Updates the system log to prompt for action selection.
174
+ """
175
+ return format_log_markdown("please select the action below 👇🏻,\nsome actions also need to select keypoint")
176
+
177
+
178
+ def switch_to_execute_phase(uid):
179
+ """Disable controls and keypoint clicking during execute playback."""
180
+ if uid:
181
+ session = get_session(uid)
182
+ base_count = len(getattr(session, "base_frames", []) or []) if session else 0
183
+ with _LIVE_OBS_REFRESH_LOCK:
184
+ _LIVE_OBS_REFRESH[uid] = {
185
+ "frame_queue": queue.Queue(),
186
+ "last_base_count": base_count,
187
+ "take_next": True, # downsample x2 by enqueueing every other frame
188
+ }
189
+ return (
190
+ gr.update(interactive=False), # options_radio
191
+ gr.update(interactive=False), # exec_btn
192
+ gr.update(interactive=False), # restart_episode_btn
193
+ gr.update(interactive=False), # next_task_btn
194
+ gr.update(interactive=False), # img_display
195
+ gr.update(interactive=False), # reference_action_btn
196
+ )
197
+
198
+
199
+ def switch_to_action_phase(uid=None):
200
+ """Switch display to action phase and restore control panel interactions."""
201
+ if uid:
202
+ with _LIVE_OBS_REFRESH_LOCK:
203
+ _LIVE_OBS_REFRESH.pop(uid, None)
204
+ return (
205
+ gr.update(interactive=True), # options_radio
206
+ gr.update(), # exec_btn (keep execute_step result)
207
+ gr.update(), # restart_episode_btn (keep execute_step result)
208
+ gr.update(), # next_task_btn (keep execute_step result)
209
+ gr.update(interactive=True), # img_display
210
+ gr.update(interactive=True), # reference_action_btn
211
+ )
212
+
213
+
214
+ def _get_live_obs_refresh_state(uid, base_count=0):
215
+ with _LIVE_OBS_REFRESH_LOCK:
216
+ if uid not in _LIVE_OBS_REFRESH:
217
+ _LIVE_OBS_REFRESH[uid] = {
218
+ "frame_queue": queue.Queue(),
219
+ "last_base_count": int(base_count),
220
+ "take_next": True, # downsample x2 by enqueueing every other frame
221
+ }
222
+ return _LIVE_OBS_REFRESH[uid]
223
+
224
+
225
+ def _enqueue_live_obs_frames(uid, base_frames):
226
+ """
227
+ Push newly appended base_frames into per-uid FIFO queue with x2 downsampling.
228
+ """
229
+ if not uid:
230
+ return 0
231
+ frames = base_frames or []
232
+ state = _get_live_obs_refresh_state(uid, base_count=len(frames))
233
+ frame_queue = state["frame_queue"]
234
+ current_count = len(frames)
235
+ last_count = int(state.get("last_base_count", 0))
236
+
237
+ # Session/task reset: history shrank.
238
+ if current_count < last_count:
239
+ with _LIVE_OBS_REFRESH_LOCK:
240
+ state["frame_queue"] = queue.Queue()
241
+ state["last_base_count"] = current_count
242
+ state["take_next"] = True
243
+ return 0
244
+
245
+ if current_count <= last_count:
246
+ return frame_queue.qsize()
247
+
248
+ new_frames = frames[last_count:current_count]
249
+ take_next = bool(state.get("take_next", True))
250
+ for frame in new_frames:
251
+ if take_next and frame is not None:
252
+ frame_queue.put(frame)
253
+ take_next = not take_next
254
+
255
+ with _LIVE_OBS_REFRESH_LOCK:
256
+ state["last_base_count"] = current_count
257
+ state["take_next"] = take_next
258
+ return frame_queue.qsize()
259
+
260
+
261
+ def _wait_for_live_obs_queue_drain(uid, max_wait_sec=None, empty_grace_sec=0.2, poll_sec=0.05):
262
+ """
263
+ Wait for timer-driven live_obs refresh to consume queued frames before phase switch.
264
+ """
265
+ if not uid:
266
+ return
267
+ with _LIVE_OBS_REFRESH_LOCK:
268
+ state0 = _LIVE_OBS_REFRESH.get(uid)
269
+ queue0 = state0.get("frame_queue") if state0 else None
270
+ initial_qsize = int(queue0.qsize()) if queue0 is not None else 0
271
+ if max_wait_sec is None:
272
+ # 0.1s tick playback + small buffer, capped to keep UI responsive.
273
+ max_wait_sec = min(30.0, max(2.0, initial_qsize * 0.12 + 1.0))
274
+
275
+ start = time.time()
276
+ empty_since = None
277
+ while True:
278
+ if (time.time() - start) >= max_wait_sec:
279
+ break
280
+ with _LIVE_OBS_REFRESH_LOCK:
281
+ state = _LIVE_OBS_REFRESH.get(uid)
282
+ frame_queue = state.get("frame_queue") if state else None
283
+ if frame_queue is None:
284
+ break
285
+ if frame_queue.qsize() > 0:
286
+ empty_since = None
287
+ else:
288
+ if empty_since is None:
289
+ empty_since = time.time()
290
+ elif (time.time() - empty_since) >= empty_grace_sec:
291
+ break
292
+ time.sleep(poll_sec)
293
+
294
+
295
+ def _prepare_refresh_frame(frame):
296
+ """Normalize cached frame to an RGB uint8 PIL image for gr.Image."""
297
+ if frame is None:
298
+ return None
299
+ frame_arr = np.asarray(frame)
300
+ if frame_arr.dtype != np.uint8:
301
+ max_val = float(np.max(frame_arr)) if frame_arr.size else 0.0
302
+ if max_val <= 1.0:
303
+ frame_arr = (frame_arr * 255.0).clip(0, 255).astype(np.uint8)
304
+ else:
305
+ frame_arr = frame_arr.clip(0, 255).astype(np.uint8)
306
+ if frame_arr.ndim == 2:
307
+ frame_arr = np.stack([frame_arr] * 3, axis=-1)
308
+ elif frame_arr.ndim == 3 and frame_arr.shape[2] == 4:
309
+ frame_arr = frame_arr[:, :, :3]
310
+ return Image.fromarray(frame_arr)
311
+
312
+
313
+ def refresh_live_obs(uid, ui_phase):
314
+ """
315
+ Poll latest cached frame during execute phase.
316
+ Updates live_obs every 0.1s via gr.Timer.
317
+ """
318
+ if ui_phase != "execution_playback":
319
+ return gr.update()
320
+ session = get_session(uid)
321
+ if not session:
322
+ return gr.update()
323
+
324
+ base_frames = getattr(session, "base_frames", None) or []
325
+ if not base_frames:
326
+ return gr.update()
327
+
328
+ _enqueue_live_obs_frames(uid, base_frames)
329
+ state = _get_live_obs_refresh_state(uid, base_count=len(base_frames))
330
+ frame_queue = state["frame_queue"]
331
+
332
+ if frame_queue.empty():
333
+ return gr.update()
334
+
335
+ latest = frame_queue.get()
336
+ env_id = getattr(session, "env_id", None)
337
+ stitched = concatenate_frames_horizontally([latest], env_id=env_id)
338
+ if stitched:
339
+ latest = stitched[-1]
340
+
341
+ img = _prepare_refresh_frame(latest)
342
+ if img is None:
343
+ return gr.update()
344
+ return gr.update(value=img, interactive=False)
345
+
346
+
347
+ def on_video_end_transition(uid):
348
+ """Called when demo video finishes. Transition from video to action phase."""
349
+ return (
350
+ gr.update(visible=False), # video_phase_group
351
+ gr.update(visible=True), # action_phase_group
352
+ gr.update(visible=True), # control_panel_group
353
+ format_log_markdown("please select the action below 👇🏻,\nsome actions also need to select keypoint")
354
+ )
355
+
356
+
357
+ def _task_load_failed_response(uid, message):
358
+ return (
359
+ uid,
360
+ gr.update(visible=True), # main_interface
361
+ gr.update(value=None, interactive=False), # img_display
362
+ format_log_markdown(message), # log_output
363
+ gr.update(choices=[], value=None), # options_radio
364
+ "", # goal_box
365
+ "No need for coordinates", # coords_box
366
+ gr.update(value=None, visible=False), # video_display
367
+ "", # task_info_box
368
+ "", # progress_info_box
369
+ gr.update(interactive=False), # restart_episode_btn
370
+ gr.update(interactive=False), # next_task_btn
371
+ gr.update(interactive=False), # exec_btn
372
+ gr.update(visible=False), # video_phase_group
373
+ gr.update(visible=False), # action_phase_group
374
+ gr.update(visible=False), # control_panel_group
375
+ gr.update(value=""), # task_hint_display
376
+ gr.update(visible=False), # loading_overlay
377
+ gr.update(interactive=False), # reference_action_btn
378
+ )
379
+
380
+
381
+ def _load_status_task(uid, status):
382
+ """Load status.current_task to session and build the standard UI update tuple."""
383
+ current_task = status.get("current_task") if isinstance(status, dict) else None
384
+ if not current_task:
385
+ return _task_load_failed_response(uid, "Error loading task: missing current_task")
386
+
387
+ env_id = current_task.get("env_id")
388
+ ep_num = current_task.get("episode_idx")
389
+ if env_id is None or ep_num is None:
390
+ return _task_load_failed_response(uid, "Error loading task: invalid task payload")
391
+
392
+ try:
393
+ completed_count = int(status.get("completed_count", 0))
394
+ except (TypeError, ValueError):
395
+ completed_count = 0
396
+ progress_text = f"Completed: {completed_count}"
397
+
398
+ session = get_session(uid)
399
+ if session is None:
400
+ print(f"Session {uid} not found, creating new session")
401
+ session = ProcessSessionProxy()
402
+ with _state_lock:
403
+ GLOBAL_SESSIONS[uid] = session
404
+ SESSION_LAST_ACTIVITY[uid] = time.time()
405
+ print(f"New session created for {uid}")
406
+
407
+ print(f"Loading {env_id} Ep {ep_num} for {uid}")
408
+
409
+ with _LIVE_OBS_REFRESH_LOCK:
410
+ _LIVE_OBS_REFRESH.pop(uid, None)
411
+ reset_play_button_clicked(uid)
412
+ reset_execute_count(uid, env_id, int(ep_num))
413
+
414
+ img, load_msg = session.load_episode(env_id, int(ep_num))
415
+ actual_env_id = getattr(session, "env_id", None) or env_id
416
+
417
+ if img is not None:
418
+ start_time = datetime.now().isoformat()
419
+ set_task_start_time(uid, env_id, int(ep_num), start_time)
420
+
421
+ if img is None:
422
+ set_ui_phase(uid, "executing_task")
423
+ return (
424
+ uid,
425
+ gr.update(visible=True), # main_interface
426
+ gr.update(value=None, interactive=False), # img_display
427
+ format_log_markdown(f"Error: {load_msg}"), # log_output
428
+ gr.update(choices=[], value=None), # options_radio
429
+ "", # goal_box
430
+ "No need for coordinates", # coords_box
431
+ gr.update(value=None, visible=False), # video_display
432
+ f"{actual_env_id} (Episode {ep_num})", # task_info_box
433
+ progress_text, # progress_info_box
434
+ gr.update(interactive=True), # restart_episode_btn
435
+ gr.update(interactive=True), # next_task_btn
436
+ gr.update(interactive=False), # exec_btn
437
+ gr.update(visible=False), # video_phase_group
438
+ gr.update(visible=True), # action_phase_group
439
+ gr.update(visible=True), # control_panel_group
440
+ gr.update(value=get_task_hint(env_id) if env_id else ""), # task_hint_display
441
+ gr.update(visible=False), # loading_overlay
442
+ gr.update(interactive=False), # reference_action_btn
443
+ )
444
+
445
+ if session.env_id == "VideoPlaceButton" and session.language_goal:
446
+ goal_text = get_videoplacebutton_goal(session.language_goal)
447
+ else:
448
+ goal_text = capitalize_first_letter(session.language_goal) if session.language_goal else ""
449
+
450
+ options = session.available_options
451
+ radio_choices = []
452
+ for opt_label, opt_idx in options:
453
+ opt_label = _ui_option_label(session, opt_label, opt_idx)
454
+ if 0 <= opt_idx < len(session.raw_solve_options):
455
+ opt = session.raw_solve_options[opt_idx]
456
+ if opt.get("available"):
457
+ opt_label_with_hint = f"{opt_label} (click mouse 🖱️ to select 🎯)"
458
+ else:
459
+ opt_label_with_hint = opt_label
460
+ else:
461
+ opt_label_with_hint = opt_label
462
+ radio_choices.append((opt_label_with_hint, opt_idx))
463
+
464
+ demo_video_path = None
465
+ has_demo_video = False
466
+ should_show = should_show_demo_video(actual_env_id) if actual_env_id else False
467
+ initial_log_msg = format_log_markdown("please select the action below 👇🏻,\nsome actions also need to select keypoint")
468
+
469
+ if should_show:
470
+ has_demo_video = True
471
+ initial_log_msg = format_log_markdown('press "Watch Video Input🎬" to watch a video\nNote: you can only watch the video once')
472
+ if session.demonstration_frames:
473
+ try:
474
+ demo_video_path = save_video(session.demonstration_frames, "demo")
475
+ if demo_video_path:
476
+ file_exists = os.path.exists(demo_video_path)
477
+ file_size = os.path.getsize(demo_video_path) if file_exists else 0
478
+ if not (file_exists and file_size > 0):
479
+ demo_video_path = None
480
+ except Exception:
481
+ demo_video_path = None
482
+
483
+ img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
484
+
485
+ if has_demo_video:
486
+ set_ui_phase(uid, "executing_task")
487
+
488
+ return (
489
+ uid,
490
+ gr.update(visible=True), # main_interface
491
+ gr.update(value=img, interactive=False), # img_display
492
+ initial_log_msg, # log_output
493
+ gr.update(choices=radio_choices, value=None), # options_radio
494
+ goal_text, # goal_box
495
+ "No need for coordinates", # coords_box
496
+ gr.update(value=demo_video_path, visible=True), # video_display
497
+ f"{actual_env_id} (Episode {ep_num})", # task_info_box
498
+ progress_text, # progress_info_box
499
+ gr.update(interactive=True), # restart_episode_btn
500
+ gr.update(interactive=True), # next_task_btn
501
+ gr.update(interactive=True), # exec_btn
502
+ gr.update(visible=True), # video_phase_group
503
+ gr.update(visible=False), # action_phase_group
504
+ gr.update(visible=False), # control_panel_group
505
+ gr.update(value=get_task_hint(actual_env_id)), # task_hint_display
506
+ gr.update(visible=False), # loading_overlay
507
+ gr.update(interactive=True), # reference_action_btn
508
+ )
509
+
510
+ set_ui_phase(uid, "executing_task")
511
+
512
+ return (
513
+ uid,
514
+ gr.update(visible=True), # main_interface
515
+ gr.update(value=img, interactive=False), # img_display
516
+ initial_log_msg, # log_output
517
+ gr.update(choices=radio_choices, value=None), # options_radio
518
+ goal_text, # goal_box
519
+ "No need for coordinates", # coords_box
520
+ gr.update(value=None, visible=False), # video_display (no video)
521
+ f"{actual_env_id} (Episode {ep_num})", # task_info_box
522
+ progress_text, # progress_info_box
523
+ gr.update(interactive=True), # restart_episode_btn
524
+ gr.update(interactive=True), # next_task_btn
525
+ gr.update(interactive=True), # exec_btn
526
+ gr.update(visible=False), # video_phase_group
527
+ gr.update(visible=True), # action_phase_group
528
+ gr.update(visible=True), # control_panel_group
529
+ gr.update(value=get_task_hint(actual_env_id)), # task_hint_display
530
+ gr.update(visible=False), # loading_overlay
531
+ gr.update(interactive=True), # reference_action_btn
532
+ )
533
+
534
+
535
+ def init_session_and_load_task(uid):
536
+ """Initialize the Gradio session and load the current task."""
537
+ if not uid:
538
+ uid = create_session()
539
+
540
+ success, msg, status = user_manager.init_session(uid)
541
+
542
+ if uid:
543
+ update_session_activity(uid)
544
+
545
+ if not success:
546
+ return _task_load_failed_response(uid, msg)
547
+ return _load_status_task(uid, status)
548
+
549
+
550
+ def load_next_task_wrapper(uid):
551
+ """Move to a random episode within the same env and reload task."""
552
+
553
+ if not uid:
554
+ uid = create_session()
555
+
556
+ if uid:
557
+ update_session_activity(uid)
558
+
559
+ status = user_manager.next_episode_same_env(uid)
560
+ if not status:
561
+ return _task_load_failed_response(uid, "Failed to load next task")
562
+ return _load_status_task(uid, status)
563
+
564
+
565
+ def restart_episode_wrapper(uid):
566
+ """Reload the current env + episode."""
567
+ if not uid:
568
+ uid = create_session()
569
+
570
+ if uid:
571
+ update_session_activity(uid)
572
+
573
+ status = user_manager.get_session_status(uid)
574
+ current_task = status.get("current_task") if isinstance(status, dict) else None
575
+ if not current_task:
576
+ return _task_load_failed_response(uid, "Failed to restart episode: missing current task")
577
+
578
+ env_id = current_task.get("env_id")
579
+ ep_num = current_task.get("episode_idx")
580
+ if env_id is None or ep_num is None:
581
+ return _task_load_failed_response(uid, "Failed to restart episode: invalid task payload")
582
+
583
+ return _load_status_task(uid, status)
584
+
585
+
586
+ def switch_env_wrapper(uid, selected_env):
587
+ """Switch env from Current Task dropdown and randomly assign an episode."""
588
+ if not uid:
589
+ uid = create_session()
590
+
591
+ if uid:
592
+ update_session_activity(uid)
593
+
594
+ if selected_env:
595
+ status = user_manager.switch_env_and_random_episode(uid, selected_env)
596
+ else:
597
+ status = user_manager.get_session_status(uid)
598
+
599
+ if not status:
600
+ return _task_load_failed_response(uid, f"Failed to switch environment to '{selected_env}'")
601
+
602
+ return _load_status_task(uid, status)
603
+
604
+
605
+ def on_map_click(uid, option_value, evt: gr.SelectData):
606
+ """
607
+ 处理图片点击事件
608
+ """
609
+ # 更新session活动时间(点击图片操作)
610
+ if uid:
611
+ update_session_activity(uid)
612
+
613
+ session = get_session(uid)
614
+ if not session:
615
+ return None, "Session Error"
616
+
617
+ # Check if current option actually needs coordinates
618
+ needs_coords = False
619
+ if option_value is not None:
620
+ # Parse option index similar to on_option_select
621
+ option_idx = None
622
+ if isinstance(option_value, tuple):
623
+ _, option_idx = option_value
624
+ else:
625
+ option_idx = option_value
626
+
627
+ if option_idx is not None and 0 <= option_idx < len(session.raw_solve_options):
628
+ opt = session.raw_solve_options[option_idx]
629
+ if opt.get("available"):
630
+ needs_coords = True
631
+
632
+ if not needs_coords:
633
+ # Return current state without changes (or reset to default message if needed, but it should already be there)
634
+ # We return the clean image and the "No need" message to enforce state
635
+ base_img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
636
+ return base_img, "No need for coordinates"
637
+
638
+ x, y = evt.index[0], evt.index[1]
639
+
640
+ # Get clean image from session
641
+ base_img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
642
+
643
+ # Draw marker
644
+ marked_img = draw_marker(base_img, x, y)
645
+
646
+ coords_str = f"{x}, {y}"
647
+
648
+ return marked_img, coords_str
649
+
650
+
651
+ def _is_valid_coords_text(coords_text: str) -> bool:
652
+ if not isinstance(coords_text, str):
653
+ return False
654
+ text = coords_text.strip()
655
+ if text in {"", "please click the keypoint selection image", "No need for coordinates"}:
656
+ return False
657
+ if "," not in text:
658
+ return False
659
+ try:
660
+ x_raw, y_raw = text.split(",", 1)
661
+ int(x_raw.strip())
662
+ int(y_raw.strip())
663
+ except Exception:
664
+ return False
665
+ return True
666
+
667
+
668
+ def on_option_select(uid, option_value, coords_str=None):
669
+ """
670
+ 处理选项选择事件
671
+ """
672
+ default_msg = "No need for coordinates"
673
+
674
+ if option_value is None:
675
+ return default_msg, gr.update(interactive=False)
676
+
677
+ # 更新session活动时间(选择选项操作)
678
+ if uid:
679
+ update_session_activity(uid)
680
+
681
+ session = get_session(uid)
682
+ if not session:
683
+ return default_msg, gr.update(interactive=False)
684
+
685
+ # option_value 是 (label, idx) 元组或直接是 idx
686
+ if isinstance(option_value, tuple):
687
+ _, option_idx = option_value
688
+ else:
689
+ option_idx = option_value
690
+
691
+ # Determine coords message
692
+ if 0 <= option_idx < len(session.raw_solve_options):
693
+ opt = session.raw_solve_options[option_idx]
694
+ if opt.get("available"):
695
+ if _is_valid_coords_text(coords_str):
696
+ return coords_str, gr.update(interactive=True)
697
+ return "please click the keypoint selection image", gr.update(interactive=True)
698
+
699
+ return default_msg, gr.update(interactive=False)
700
+
701
+
702
+ def on_reference_action(uid):
703
+ """
704
+ 自动获取并回填当前步参考 action + 像素坐标(不执行)。
705
+ """
706
+ if uid:
707
+ update_session_activity(uid)
708
+
709
+ session = get_session(uid)
710
+ if not session:
711
+ return (
712
+ None,
713
+ gr.update(),
714
+ "No need for coordinates",
715
+ format_log_markdown("Session Error"),
716
+ )
717
+
718
+ current_img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
719
+
720
+ try:
721
+ reference = session.get_reference_action()
722
+ except Exception as exc:
723
+ return (
724
+ current_img,
725
+ gr.update(),
726
+ gr.update(),
727
+ format_log_markdown(f"Ground Truth Action Error: {exc}"),
728
+ )
729
+
730
+ if not isinstance(reference, dict) or not reference.get("ok", False):
731
+ message = "Failed to resolve ground truth action."
732
+ if isinstance(reference, dict) and reference.get("message"):
733
+ message = str(reference.get("message"))
734
+ return (
735
+ current_img,
736
+ gr.update(),
737
+ gr.update(),
738
+ format_log_markdown(f"Ground Truth Action: {message}"),
739
+ )
740
+
741
+ option_idx = reference.get("option_idx")
742
+ option_label = str(reference.get("option_label", "")).strip()
743
+ option_action = str(reference.get("option_action", "")).strip()
744
+ need_coords = bool(reference.get("need_coords", False))
745
+ coords_xy = reference.get("coords_xy")
746
+
747
+ updated_img = current_img
748
+ coords_text = "No need for coordinates"
749
+ log_text = f"Ground Truth Action: {option_label}. {option_action}".strip()
750
+
751
+ if need_coords and isinstance(coords_xy, (list, tuple)) and len(coords_xy) >= 2:
752
+ x = int(coords_xy[0])
753
+ y = int(coords_xy[1])
754
+ updated_img = draw_marker(current_img, x, y)
755
+ coords_text = f"{x}, {y}"
756
+ log_text = f"Ground Truth Action: {option_label}. {option_action} | coords: {coords_text}"
757
+
758
+ return (
759
+ updated_img,
760
+ gr.update(value=option_idx),
761
+ coords_text,
762
+ format_log_markdown(log_text),
763
+ )
764
+
765
+
766
+ def init_app(request: gr.Request):
767
+ """
768
+ 处理初始页面加载,直接初始化会话并加载首个任务。
769
+
770
+ Args:
771
+ request: Gradio Request 对象,包含查询参数 / Gradio Request object containing query parameters
772
+
773
+ Returns:
774
+ 初始化后的UI状态
775
+ """
776
+ _ = request # Query params are intentionally ignored in session-based mode.
777
+ uid = create_session()
778
+ return init_session_and_load_task(uid)
779
+
780
+
781
+ def precheck_execute_inputs(uid, option_idx, coords_str):
782
+ """
783
+ Native precheck for execute action.
784
+ Replaces frontend JS interception by validating inputs server-side before phase switch.
785
+ """
786
+ if uid:
787
+ update_session_activity(uid)
788
+
789
+ session = get_session(uid)
790
+ if not session:
791
+ raise gr.Error("Session Error")
792
+
793
+ parsed_option_idx = option_idx
794
+ if isinstance(option_idx, tuple):
795
+ _, parsed_option_idx = option_idx
796
+
797
+ if parsed_option_idx is None:
798
+ raise gr.Error("Error: No action selected")
799
+
800
+ needs_coords = False
801
+ if (
802
+ isinstance(parsed_option_idx, int)
803
+ and 0 <= parsed_option_idx < len(session.raw_solve_options)
804
+ ):
805
+ opt = session.raw_solve_options[parsed_option_idx]
806
+ needs_coords = bool(opt.get("available"))
807
+
808
+ if needs_coords and not _is_valid_coords_text(coords_str):
809
+ raise gr.Error("please click the keypoint selection image before execute!")
810
+
811
+
812
+ def execute_step(uid, option_idx, coords_str):
813
+ # 检查session是否超时(在更新活动时间之前检查)
814
+ last_activity = get_session_activity(uid)
815
+ if last_activity is not None:
816
+ elapsed = time.time() - last_activity
817
+ if elapsed > SESSION_TIMEOUT:
818
+ raise gr.Error(f"Session已超时:超过 {SESSION_TIMEOUT} 秒未活动。请刷新页面重新登录。")
819
+
820
+ # 更新session的最后活动时间
821
+ update_session_activity(uid)
822
+
823
+ session = get_session(uid)
824
+ if not session:
825
+ return None, format_log_markdown("Session Error"), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=False)
826
+
827
+ # 检查 execute 次数限制(在执行前检查,如果达到限制则模拟失败状态)
828
+ execute_limit_reached = False
829
+ if uid and session.env_id is not None and session.episode_idx is not None:
830
+ # 从 session 读取 non_demonstration_task_length,如果存在则加上配置的偏移量作为限制,否则不设置限制
831
+ max_execute = None
832
+ if hasattr(session, 'non_demonstration_task_length') and session.non_demonstration_task_length is not None:
833
+ max_execute = session.non_demonstration_task_length + EXECUTE_LIMIT_OFFSET
834
+
835
+ if max_execute is not None:
836
+ current_count = get_execute_count(uid, session.env_id, session.episode_idx)
837
+ if current_count >= max_execute:
838
+ execute_limit_reached = True
839
+
840
+ # Ensure at least one cached frame exists for timer-based refresh.
841
+ if not session.base_frames:
842
+ session.update_observation(use_segmentation=USE_SEGMENTED_VIEW)
843
+
844
+ if option_idx is None:
845
+ return session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW), format_log_markdown("Error: No action selected"), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
846
+
847
+ # 检查当前选项是否需要坐标
848
+ needs_coords = False
849
+ if option_idx is not None and 0 <= option_idx < len(session.raw_solve_options):
850
+ opt = session.raw_solve_options[option_idx]
851
+ if opt.get("available"):
852
+ needs_coords = True
853
+
854
+ # 如果选项需要坐标,检查是否已经点击了图片
855
+ if needs_coords:
856
+ # 检查 coords_str 是否是有效的坐标(不是提示信息)
857
+ is_valid_coords = False
858
+ if coords_str and "," in coords_str:
859
+ try:
860
+ parts = coords_str.split(",")
861
+ x = int(parts[0].strip())
862
+ y = int(parts[1].strip())
863
+ # 如果成功解析为数字,且不是提示信息,则认为是有效坐标
864
+ if coords_str.strip() not in ["please click the keypoint selection image", "No need for coordinates"]:
865
+ is_valid_coords = True
866
+ except:
867
+ pass
868
+
869
+ # 如果需要坐标但没有有效坐标,返回错误提示
870
+ if not is_valid_coords:
871
+ current_img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
872
+ error_msg = "please click the keypoint selection image before execute!"
873
+ return current_img, format_log_markdown(error_msg), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
874
+
875
+ # Parse coords
876
+ click_coords = None
877
+ if coords_str and "," in coords_str:
878
+ try:
879
+ parts = coords_str.split(",")
880
+ click_coords = (int(parts[0].strip()), int(parts[1].strip()))
881
+ except:
882
+ pass
883
+
884
+ # Execute
885
+ # 如果达到 execute 次数限制,模拟失败状态(使用和任务失败一样的机制)
886
+ if execute_limit_reached:
887
+ # 获取选项标签用于状态消息
888
+ option_label = None
889
+ if session.available_options:
890
+ for label, idx in session.available_options:
891
+ if idx == option_idx:
892
+ option_label = _ui_option_label(session, label, idx)
893
+ break
894
+
895
+ # 模拟失败状态,使用和 oracle_logic.py 中任务失败一样的格式
896
+ status = f"Executing: {option_label or 'Action'}"
897
+ status += " | FAILED" # 和任务失败一样的格式
898
+ done = True # 设置为完成,触发任务完成流程
899
+
900
+ # 获取当前图片
901
+ img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
902
+
903
+ # 增加 execute 计数(因为这也算一次尝试)
904
+ if uid and session.env_id is not None and session.episode_idx is not None:
905
+ new_count = increment_execute_count(uid, session.env_id, session.episode_idx)
906
+ print(f"Execute limit reached for {uid}:{session.env_id}:{session.episode_idx} (count: {new_count})")
907
+ else:
908
+ # 正常执行
909
+ # 异常处理:所有异常(ScrewPlanFailure 和其他执行错误)都会显示弹窗通知
910
+ print(f"Executing step: Opt {option_idx}, Coords {click_coords}")
911
+ try:
912
+ img, status, done = session.execute_action(option_idx, click_coords)
913
+ except ScrewPlanFailureError as e:
914
+ # 捕获 screw plan 失败异常,显示弹窗通知
915
+ error_message = str(e)
916
+ gr.Info(f"Robot cannot reach position, Refresh the page and try again.")
917
+ # 返回当前状态,在状态消息中显示错误信息
918
+ current_img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
919
+ status = f"Screw plan failed: {error_message}"
920
+ done = False
921
+ # 继续正常返回流程
922
+ img = current_img
923
+ except RuntimeError as e:
924
+ # 捕获所有其他执行错误,显示弹窗通知
925
+ error_message = str(e)
926
+ gr.Info(f"Cannot find suitable target, Refresh the page and try again.")
927
+ # 返回当前状态,在状态消息中显示错误信息
928
+ current_img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
929
+ status = f"Error: {error_message}"
930
+ done = False
931
+ # 继续正常返回流程
932
+ img = current_img
933
+
934
+ # 增加 execute 计数(无论成功或失败都计数,因为用户已经执行了一次操作)
935
+ if uid and session.env_id is not None and session.episode_idx is not None:
936
+ new_count = increment_execute_count(uid, session.env_id, session.episode_idx)
937
+ print(f"Execute count for {uid}:{session.env_id}:{session.episode_idx} = {new_count}")
938
+
939
+ # Execute frames are produced in batch when execute_action returns from worker process.
940
+ # Enqueue them now, then wait briefly for the 0.1s timer to drain FIFO playback.
941
+ _enqueue_live_obs_frames(uid, getattr(session, "base_frames", None))
942
+ _wait_for_live_obs_queue_drain(uid)
943
+
944
+ # 注意:执行阶段画面由 live_obs 的 0.1s 轮询刷新。
945
+
946
+ progress_update = gr.update() # 默认不更新 progress
947
+ task_update = gr.update()
948
+
949
+ if done:
950
+ # 确定最终状态用于日志记录
951
+ final_log_status = "failed"
952
+ if "SUCCESS" in status:
953
+ final_log_status = "success"
954
+
955
+ # Episode完成时,格式化System Log的状态消息
956
+ # 使用固定模板,所有行长度一致(32个字符),无空行
957
+ if final_log_status == "success":
958
+ status = "********************************\n**** episode success ****\n********************************\n ---please press change episode---- "
959
+ else:
960
+ status = "********************************\n**** episode failed ****\n********************************\n ---please press change episode---- "
961
+
962
+ # 更新累计完成计数,不再推进固定任务索引
963
+ if uid:
964
+ seed = getattr(session, 'seed', None)
965
+ user_status = user_manager.complete_current_task(
966
+ uid,
967
+ env_id=session.env_id,
968
+ episode_idx=session.episode_idx,
969
+ status=final_log_status,
970
+ difficulty=session.difficulty if hasattr(session, 'difficulty') and session.difficulty is not None else None,
971
+ language_goal=session.language_goal,
972
+ seed=seed
973
+ )
974
+ if user_status:
975
+ completed_count = user_status.get("completed_count", 0)
976
+ task_update = f"{session.env_id} (Episode {session.episode_idx})"
977
+ progress_update = f"Completed: {completed_count}"
978
+
979
+ # 根据视图模式重新获取图片
980
+ img = session.get_pil_image(use_segmented=USE_SEGMENTED_VIEW)
981
+
982
+ restart_episode_update = gr.update(interactive=True)
983
+ next_task_update = gr.update(interactive=True)
984
+ exec_btn_update = gr.update(interactive=False) if done else gr.update(interactive=True)
985
+
986
+ # 格式化日志消息为 HTML 格式(支持颜色显示)
987
+ formatted_status = format_log_markdown(status)
988
+
989
+ return (
990
+ img,
991
+ formatted_status,
992
+ task_update,
993
+ progress_update,
994
+ restart_episode_update,
995
+ next_task_update,
996
+ exec_btn_update,
997
+ )
gradio-web/image_utils.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 图像处理工具模块
3
+ 无状态的图像处理函数
4
+ """
5
+ import numpy as np
6
+ import tempfile
7
+ import os
8
+ import traceback
9
+ import math
10
+ from pathlib import Path
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ import cv2
13
+ from config import VIDEO_PLAYBACK_FPS
14
+
15
+ # DEPRECATED: 历史任务特化图像叠加配置,保留仅为兼容旧代码路径。
16
+ # 当前已统一关闭任务特化渲染。
17
+ DEPRECATED_COORDINATE_AXES_ENVS = ["PatternLock", "RouteStick", "InsertPeg", "SwingXtimes"]
18
+ ENABLE_DEPRECATED_COORDINATE_AXES_OVERLAY = False
19
+
20
+
21
+ def _video_output_dirs():
22
+ """视频输出目录候选(按优先级)。"""
23
+ current_dir = Path(__file__).resolve().parent
24
+ project_root = current_dir.parent
25
+ env_dir = os.environ.get("ROBOMME_TEMP_DEMOS_DIR")
26
+
27
+ candidates = [
28
+ Path(env_dir).expanduser() if env_dir else None,
29
+ project_root / "temp_demos",
30
+ current_dir / "temp_demos",
31
+ Path.cwd() / "temp_demos",
32
+ Path(tempfile.gettempdir()) / "robomme_temp_demos",
33
+ ]
34
+
35
+ result = []
36
+ seen = set()
37
+ for path in candidates:
38
+ if path is None:
39
+ continue
40
+ resolved = path.resolve()
41
+ key = str(resolved)
42
+ if key in seen:
43
+ continue
44
+ seen.add(key)
45
+ result.append(resolved)
46
+ return result
47
+
48
+
49
+ def _write_with_opencv(path, frames):
50
+ """imageio 不可用时使用 OpenCV 写视频。"""
51
+ if not frames:
52
+ return False
53
+
54
+ h, w = frames[0].shape[:2]
55
+ writer = cv2.VideoWriter(
56
+ path,
57
+ cv2.VideoWriter_fourcc(*"mp4v"),
58
+ VIDEO_PLAYBACK_FPS,
59
+ (w, h),
60
+ )
61
+ if not writer.isOpened():
62
+ return False
63
+
64
+ try:
65
+ for frame in frames:
66
+ if frame.shape[:2] != (h, w):
67
+ frame = cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR)
68
+ writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
69
+ return True
70
+ finally:
71
+ writer.release()
72
+
73
+
74
+ def save_video(frames, suffix=""):
75
+ """
76
+ 视频保存函数 - 使用imageio生成视频
77
+
78
+ 优化点:
79
+ 1. 使用imageio.mimwrite,不依赖FFmpeg编码器
80
+ 2. 直接处理RGB帧,无需颜色空间转换
81
+ 3. 自动处理编码,简单可靠
82
+ """
83
+ if not frames or len(frames) == 0:
84
+ return None
85
+
86
+ try:
87
+ # 准备帧:确保是 uint8 RGB
88
+ processed_frames = []
89
+ for f in frames:
90
+ if not isinstance(f, np.ndarray):
91
+ f = np.array(f)
92
+ if f.dtype != np.uint8:
93
+ if np.max(f) <= 1.0:
94
+ f = (f * 255).astype(np.uint8)
95
+ else:
96
+ f = f.clip(0, 255).astype(np.uint8)
97
+ if len(f.shape) == 2:
98
+ f = np.stack([f] * 3, axis=-1)
99
+ elif len(f.shape) == 3 and f.shape[2] == 4:
100
+ f = f[:, :, :3]
101
+ processed_frames.append(f)
102
+
103
+ imageio = None
104
+ try:
105
+ import imageio as _imageio
106
+ imageio = _imageio
107
+ except Exception:
108
+ imageio = None
109
+
110
+ for temp_dir in _video_output_dirs():
111
+ try:
112
+ temp_dir.mkdir(parents=True, exist_ok=True)
113
+ fd, path = tempfile.mkstemp(suffix=f"_{suffix}.mp4", dir=str(temp_dir))
114
+ os.close(fd)
115
+
116
+ if imageio is not None:
117
+ imageio.mimwrite(
118
+ path,
119
+ processed_frames,
120
+ fps=VIDEO_PLAYBACK_FPS,
121
+ quality=8,
122
+ macro_block_size=None,
123
+ )
124
+ else:
125
+ ok = _write_with_opencv(path, processed_frames)
126
+ if not ok:
127
+ raise RuntimeError("OpenCV video writer failed")
128
+
129
+ if os.path.exists(path) and os.path.getsize(path) > 0:
130
+ return path
131
+
132
+ raise RuntimeError(f"generated empty video: {path}")
133
+ except Exception as e:
134
+ print(f"save_video failed in {temp_dir}: {e}")
135
+ traceback.print_exc()
136
+ try:
137
+ if "path" in locals() and path and os.path.exists(path):
138
+ os.remove(path)
139
+ except Exception:
140
+ pass
141
+
142
+ print("Error in save_video: all video output directories failed")
143
+ return None
144
+ except Exception as e:
145
+ print(f"Error in save_video: {e}")
146
+ traceback.print_exc()
147
+ return None
148
+
149
+
150
+ def concatenate_frames_horizontally(frames1, frames2=None, env_id=None):
151
+ """
152
+ 处理 base frames 序列,添加标注和坐标系(已移除 wrist camera)
153
+
154
+ Args:
155
+ frames1: base frames 视频帧列表
156
+ frames2: 已弃用,保留以保持向后兼容,但不会被使用
157
+ env_id: 环境ID,用于决定是否显示坐标系(可选)
158
+
159
+ Returns:
160
+ 处理后的帧列表
161
+ """
162
+ # DEPRECATED: 任务特化图像叠加(坐标系/RouteStick示意图)已关闭。
163
+ # 保留机制与绘图函数,便于后续按需恢复。
164
+ show_coordinate_axes = (
165
+ ENABLE_DEPRECATED_COORDINATE_AXES_OVERLAY
166
+ and (env_id in DEPRECATED_COORDINATE_AXES_ENVS if env_id else False)
167
+ )
168
+ if not frames1:
169
+ return []
170
+
171
+ concatenated_frames = []
172
+
173
+ for i in range(len(frames1)):
174
+ # 获取当前帧
175
+ frame1 = frames1[i] if i < len(frames1) else frames1[-1]
176
+
177
+ # 转换为numpy数组并确保格式正确
178
+ if frame1 is not None:
179
+ if not isinstance(frame1, np.ndarray):
180
+ frame1 = np.array(frame1)
181
+ if frame1.dtype != np.uint8:
182
+ if np.max(frame1) <= 1.0:
183
+ frame1 = (frame1 * 255).astype(np.uint8)
184
+ else:
185
+ frame1 = frame1.clip(0, 255).astype(np.uint8)
186
+ if len(frame1.shape) == 2:
187
+ frame1 = np.stack([frame1] * 3, axis=-1)
188
+ else:
189
+ continue
190
+
191
+ # 获取帧的宽度和高度
192
+ actual_h, actual_w1 = frame1.shape[:2]
193
+
194
+ # 确定左侧和右侧边框宽度
195
+ left_border_width = 0
196
+ right_border_width = 0
197
+ if show_coordinate_axes:
198
+ if env_id == "RouteStick":
199
+ left_border_width = 200 # RouteStick 任务的左侧边框宽度(用于四个半圆示意图)
200
+ right_border_width = 0 # RouteStick 任务不再显示右侧边框
201
+ else:
202
+ left_border_width = 150 # 其他任务的左侧边框宽度
203
+
204
+ if show_coordinate_axes:
205
+ # 添加左侧黑色边框
206
+ left_border = np.zeros((actual_h, left_border_width, 3), dtype=np.uint8)
207
+
208
+ # 拼接(包含左侧边框)
209
+ concatenated_frame = np.concatenate([left_border, frame1], axis=1)
210
+
211
+ # 转换为PIL图像以便在黑色边框区域绘制
212
+ concatenated_pil = Image.fromarray(concatenated_frame)
213
+
214
+ # 在左侧黑色边框绘制内容
215
+ left_border_pil = Image.new('RGB', (left_border_width, actual_h), (0, 0, 0))
216
+ if env_id == "RouteStick":
217
+ # RouteStick 任务:在左侧绘制四个半圆示意图(不绘制坐标系)
218
+ left_border_pil = draw_coordinate_axes(left_border_pil, position="left", rotate_180=False, env_id=env_id)
219
+ else:
220
+ # 其他任务:绘制坐标系(旋转180度)
221
+ left_border_pil = draw_coordinate_axes(left_border_pil, position="left", rotate_180=True, env_id=env_id)
222
+
223
+ # 将内容绘制到拼接后的图像上
224
+ concatenated_pil.paste(left_border_pil, (0, 0))
225
+
226
+ # 转换回numpy数组
227
+ concatenated_frame = np.array(concatenated_pil)
228
+ else:
229
+ # 不显示坐标系,直接使用原帧
230
+ concatenated_frame = frame1
231
+
232
+ concatenated_frames.append(concatenated_frame)
233
+
234
+ return concatenated_frames
235
+
236
+
237
+ def draw_semicircle(draw, center, radius, color, width=2, half="lower", start_pos="left", end_pos="right", arrow_position="end", arrow_size=6):
238
+ """
239
+ DEPRECATED: 仅供旧版 RouteStick 旋转示意图绘制使用(当前默认不再调用)。
240
+
241
+ 绘制半圆封装函数
242
+
243
+ Args:
244
+ draw: PIL ImageDraw object
245
+ center: (x, y) 圆心坐标
246
+ radius: 半径
247
+ color: 颜色
248
+ width: 线宽
249
+ half: "upper" (上半圆) or "lower" (下半圆)
250
+ start_pos: "left" or "right" (起始位置)
251
+ end_pos: "left" or "right" (结束位置)
252
+ arrow_position: "start" (箭头在起始位置) or "end" (箭头在结束位置) or None
253
+ arrow_size: 箭头大小
254
+ """
255
+ cx, cy = center
256
+
257
+ # 确定角度范围
258
+ # 在图像坐标系中(y向下):
259
+ # lower: 0-180度 (y > cy)
260
+ # upper: 180-360度 (y < cy)
261
+
262
+ angle_map = {
263
+ "lower": {"right": 0, "left": 180},
264
+ "upper": {"right": 360, "left": 180}
265
+ }
266
+
267
+ start_angle = angle_map[half].get(start_pos, 0)
268
+ end_angle = angle_map[half].get(end_pos, 180)
269
+
270
+ # 确定步长
271
+ step = 5
272
+ if start_angle > end_angle:
273
+ step = -5
274
+
275
+ points = []
276
+ # 生成点
277
+ # 注意range不包含end,所以要根据step方向加减1
278
+ for a in range(start_angle, end_angle + (1 if step > 0 else -1), step):
279
+ rad = math.radians(a)
280
+ x = cx + radius * math.cos(rad)
281
+ y = cy + radius * math.sin(rad)
282
+ points.append((x, y))
283
+
284
+ if len(points) < 2:
285
+ return
286
+
287
+ # 绘制圆弧
288
+ draw.line(points, fill=color, width=width)
289
+
290
+ # 绘制箭头
291
+ if arrow_position:
292
+ if arrow_position == "start":
293
+ # 箭头在起点,方向指向路径方向
294
+ arrow_center = points[0] # 箭头中心点位于半圆端点
295
+ next_pt = points[1]
296
+ dx = next_pt[0] - arrow_center[0]
297
+ dy = next_pt[1] - arrow_center[1]
298
+ else: # end
299
+ # 箭头在终点,方向指向路径方向
300
+ arrow_center = points[-1] # 箭头中心点位于半圆端点
301
+ prev_pt = points[-2]
302
+ dx = arrow_center[0] - prev_pt[0]
303
+ dy = arrow_center[1] - prev_pt[1]
304
+
305
+ angle = math.atan2(dy, dx)
306
+
307
+ # 箭头参数
308
+ arrow_len = arrow_size * 1.5
309
+ arrow_wing = arrow_size
310
+
311
+ # 箭头中心点在半圆端点,箭头沿路径方向延伸
312
+ # 计算箭头尖端(沿方向向前)
313
+ tip_x = arrow_center[0] + arrow_len * math.cos(angle)
314
+ tip_y = arrow_center[1] + arrow_len * math.sin(angle)
315
+ tip_pt = (tip_x, tip_y)
316
+
317
+ # 计算箭头尾部中心(沿方向向后)
318
+ bx = arrow_center[0] - arrow_len * math.cos(angle)
319
+ by = arrow_center[1] - arrow_len * math.sin(angle)
320
+
321
+ # 计算两翼(从尾部中心向两侧展开)
322
+ w1x = bx + arrow_wing * math.cos(angle + math.pi/2)
323
+ w1y = by + arrow_wing * math.sin(angle + math.pi/2)
324
+
325
+ w2x = bx + arrow_wing * math.cos(angle - math.pi/2)
326
+ w2y = by + arrow_wing * math.sin(angle - math.pi/2)
327
+
328
+ draw.polygon([tip_pt, (w1x, w1y), (w2x, w2y)], fill=color)
329
+
330
+
331
+ def draw_coordinate_axes(img, position="right", rotate_180=False, env_id=None):
332
+ """
333
+ DEPRECATED: 历史任务特化图像叠加函数(当前默认不再调用)。
334
+
335
+ 在图片外的黑色区域绘制坐标系,标注 forward/backward/left/right
336
+
337
+ Args:
338
+ img: PIL Image 或 numpy array
339
+ position: "left" 或 "right",指定在左侧还是右侧绘制
340
+ rotate_180: 如果为 True,将坐标系顺时针旋转180度(用于 base camera)
341
+ env_id: 环境ID,用于决定是否绘制特殊示意图(如 RouteStick 的旋转方向)
342
+
343
+ Returns:
344
+ PIL Image with coordinate axes drawn
345
+ """
346
+ if isinstance(img, np.ndarray):
347
+ img = Image.fromarray(img)
348
+
349
+ img = img.copy()
350
+ draw = ImageDraw.Draw(img)
351
+
352
+ # 获取图片尺寸
353
+ width, height = img.size
354
+
355
+ # 如果是 RouteStick 任务,绘制旋转方向示意图(左侧或右侧)
356
+ if env_id == "RouteStick" and (position == "right" or position == "left"):
357
+ # 绘制四个半圆箭头示意图(垂直排列)
358
+ # 示意图位置:在图像的左侧或右侧,从上到下垂直排列
359
+ illustration_width = 220 # 示意图区域宽度(已弃用,保留以保持兼容性)
360
+
361
+ # 尝试加载字体
362
+ try:
363
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
364
+ except:
365
+ try:
366
+ small_font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 12)
367
+ except:
368
+ small_font = ImageFont.load_default()
369
+
370
+ line_color = (255, 255, 255) # 白色
371
+ semicircle_radius = 15 # 半圆半径
372
+ arrow_size = 3 # 箭头大小
373
+ vertical_spacing = 5 # 垂直间距(每个半圆之间的间距)
374
+ line_width = 2 # 线宽
375
+
376
+ # 计算垂直布局
377
+ # 每个半圆需要的高度:半圆直径 + 标签高度 + 间距
378
+ item_height = semicircle_radius * 2 + 20 +10 # 半圆直径 + 标签空间
379
+ total_height = 4 * item_height + 3 * vertical_spacing # 4个半圆 + 3个间距
380
+
381
+ # 布局中心位置
382
+ # 左侧或右侧黑色区域的中心,将布局中心放在区域中心
383
+ layout_center_x = width // 2
384
+ start_y = (height - total_height) // 2 -20 # 从顶部开始的起始位置
385
+
386
+ # 计算四个半圆的中心位置(从上到下)
387
+ # 1. Left Clockwise (最上)
388
+ lcw_center_x = layout_center_x
389
+ lcw_center_y = start_y + item_height // 2
390
+ # 2. Left Counterclockwise
391
+ lccw_center_x = layout_center_x
392
+ lccw_center_y = lcw_center_y + item_height + vertical_spacing
393
+ # 3. Right Clockwise
394
+ rcw_center_x = layout_center_x
395
+ rcw_center_y = lccw_center_y + item_height + vertical_spacing
396
+ # 4. Right Counterclockwise (最下)
397
+ rccw_center_x = layout_center_x
398
+ rccw_center_y = rcw_center_y + item_height + vertical_spacing
399
+
400
+ # 1. 绘制 left clockwise(最上):左半圆,右→左(顺时针),箭头在左端朝上
401
+ draw_semicircle(draw, (lcw_center_x , lcw_center_y+15), semicircle_radius, line_color, line_width, half="upper", start_pos="left", end_pos="right", arrow_position="end", arrow_size=arrow_size)
402
+
403
+ # 添加标签 "L CW"
404
+ lcw_text = "Left Clockwise"
405
+ lcw_bbox = draw.textbbox((0, 0), lcw_text, font=small_font)
406
+ lcw_text_width = lcw_bbox[2] - lcw_bbox[0]
407
+ lcw_text_height = lcw_bbox[3] - lcw_bbox[1]
408
+ lcw_text_x = lcw_center_x - lcw_text_width // 2
409
+ lcw_text_y = lcw_center_y + semicircle_radius + 5
410
+ draw.rectangle(
411
+ [(lcw_text_x - 2, lcw_text_y - 2),
412
+ (lcw_text_x + lcw_text_width + 2, lcw_text_y + lcw_text_height + 2)],
413
+ fill=(0, 0, 0)
414
+ )
415
+ draw.text((lcw_text_x, lcw_text_y), lcw_text, fill=line_color, font=small_font)
416
+
417
+ # 2. 绘制 left counterclockwise(第二个):左半圆,左→右(逆时针),箭头在右端朝下
418
+ draw_semicircle(draw, (lccw_center_x, lccw_center_y), semicircle_radius, line_color, line_width, half="lower", start_pos="left", end_pos="right", arrow_position="end", arrow_size=arrow_size)
419
+
420
+ # 添加标签 "L CCW"
421
+ lccw_text = "Left Counterclockwise"
422
+ lccw_bbox = draw.textbbox((0, 0), lccw_text, font=small_font)
423
+ lccw_text_width = lccw_bbox[2] - lccw_bbox[0]
424
+ lccw_text_height = lccw_bbox[3] - lccw_bbox[1]
425
+ lccw_text_x = lccw_center_x - lccw_text_width // 2
426
+ lccw_text_y = lccw_center_y + semicircle_radius + 5
427
+ draw.rectangle(
428
+ [(lccw_text_x - 2, lccw_text_y - 2),
429
+ (lccw_text_x + lccw_text_width + 2, lccw_text_y + lccw_text_height + 2)],
430
+ fill=(0, 0, 0)
431
+ )
432
+ draw.text((lccw_text_x, lccw_text_y), lccw_text, fill=line_color, font=small_font)
433
+
434
+ # 3. 绘制 right clockwise(第三个):右半圆,左→右(顺时针),箭头在右端朝上
435
+ draw_semicircle(draw, (rcw_center_x , rcw_center_y), semicircle_radius, line_color, line_width, half="lower", start_pos="right", end_pos="left", arrow_position="end", arrow_size=arrow_size)
436
+
437
+ # 添加标签 "R CW"
438
+ rcw_text = "Right Clockwise"
439
+ rcw_bbox = draw.textbbox((0, 0), rcw_text, font=small_font)
440
+ rcw_text_width = rcw_bbox[2] - rcw_bbox[0]
441
+ rcw_text_height = rcw_bbox[3] - rcw_bbox[1]
442
+ rcw_text_x = rcw_center_x - rcw_text_width // 2
443
+ rcw_text_y = rcw_center_y + semicircle_radius + 5
444
+ draw.rectangle(
445
+ [(rcw_text_x - 2, rcw_text_y - 2),
446
+ (rcw_text_x + rcw_text_width + 2, rcw_text_y + rcw_text_height + 2)],
447
+ fill=(0, 0, 0)
448
+ )
449
+ draw.text((rcw_text_x, rcw_text_y), rcw_text, fill=line_color, font=small_font)
450
+
451
+ # 4. 绘制 right counterclockwise(最下):右半圆,右→左(逆时针),箭头在左端朝下
452
+ draw_semicircle(draw, (rccw_center_x , rccw_center_y+15), semicircle_radius, line_color, line_width, half="upper",start_pos="right", end_pos="left", arrow_position="end", arrow_size=arrow_size)
453
+
454
+ # 添加标签 "R CCW"
455
+ rccw_text = "Right Counterclockwise"
456
+ rccw_bbox = draw.textbbox((0, 0), rccw_text, font=small_font)
457
+ rccw_text_width = rccw_bbox[2] - rccw_bbox[0]
458
+ rccw_text_height = rccw_bbox[3] - rccw_bbox[1]
459
+ rccw_text_x = rccw_center_x - rccw_text_width // 2
460
+ rccw_text_y = rccw_center_y + semicircle_radius + 5
461
+ draw.rectangle(
462
+ [(rccw_text_x - 2, rccw_text_y - 2),
463
+ (rccw_text_x + rccw_text_width + 2, rccw_text_y + rccw_text_height + 2)],
464
+ fill=(0, 0, 0)
465
+ )
466
+ draw.text((rccw_text_x, rccw_text_y), rccw_text, fill=line_color, font=small_font)
467
+
468
+ # RouteStick 任务只绘制旋转示意图,不绘制坐标系,直接返回
469
+ return img
470
+
471
+ # 坐标系位置(在黑色边框内)
472
+ axis_size = 60 # 坐标系大小
473
+
474
+ # 坐标轴中心位于边框宽度的中心
475
+ origin_x = width // 2 - axis_size // 2
476
+ origin_y = height // 2 - axis_size // 2
477
+
478
+ # 尝试加载字体
479
+ try:
480
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
481
+ except:
482
+ try:
483
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 14)
484
+ except:
485
+ font = ImageFont.load_default()
486
+
487
+ # 绘制坐标轴(十字形)
488
+ axis_length = axis_size - 20
489
+ center_x = origin_x + axis_size // 2
490
+ center_y = origin_y + axis_size // 2
491
+
492
+ # 绘制坐标轴线条(白色,带半透明效果)
493
+ line_color = (255, 255, 255) # 白色
494
+ line_width = 2
495
+
496
+ # 根据是否旋转180度,调整方向
497
+ if rotate_180:
498
+ # 旋转180度:forward变成backward,left变成right
499
+ # 水平轴(left-right,但方向相反)
500
+ draw.line(
501
+ [(center_x - axis_length // 2, center_y),
502
+ (center_x + axis_length // 2, center_y)],
503
+ fill=line_color, width=line_width
504
+ )
505
+
506
+ # 垂直轴(forward-backward,但方向相反)
507
+ draw.line(
508
+ [(center_x, center_y - axis_length // 2),
509
+ (center_x, center_y + axis_length // 2)],
510
+ fill=line_color, width=line_width
511
+ )
512
+
513
+ # 绘制箭头(旋转180度后的方向)
514
+ arrow_size = 5
515
+ # Forward 箭头(现在在下方��原来是上方)
516
+ draw.polygon(
517
+ [(center_x, center_y + axis_length // 2),
518
+ (center_x - arrow_size, center_y + axis_length // 2 - arrow_size),
519
+ (center_x + arrow_size, center_y + axis_length // 2 - arrow_size)],
520
+ fill=line_color
521
+ )
522
+ # Backward 箭头(现在在上方,原来是下方)
523
+ draw.polygon(
524
+ [(center_x, center_y - axis_length // 2),
525
+ (center_x - arrow_size, center_y - axis_length // 2 + arrow_size),
526
+ (center_x + arrow_size, center_y - axis_length // 2 + arrow_size)],
527
+ fill=line_color
528
+ )
529
+ # Right 箭头(现在在左侧,原来是右侧)
530
+ draw.polygon(
531
+ [(center_x - axis_length // 2, center_y),
532
+ (center_x - axis_length // 2 + arrow_size, center_y - arrow_size),
533
+ (center_x - axis_length // 2 + arrow_size, center_y + arrow_size)],
534
+ fill=line_color
535
+ )
536
+ # Left 箭头(现在在右侧,原来是左侧)
537
+ draw.polygon(
538
+ [(center_x + axis_length // 2, center_y),
539
+ (center_x + axis_length // 2 - arrow_size, center_y - arrow_size),
540
+ (center_x + axis_length // 2 - arrow_size, center_y + arrow_size)],
541
+ fill=line_color
542
+ )
543
+
544
+ # 添加文字标签(旋转180度后的位置)
545
+ text_color = (255, 255, 255) # 白色文字
546
+
547
+ # Forward (现在在下方)
548
+ forward_text = "forward"
549
+ forward_bbox = draw.textbbox((0, 0), forward_text, font=font)
550
+ forward_width = forward_bbox[2] - forward_bbox[0]
551
+ forward_x = center_x - forward_width // 2
552
+ forward_y = center_y + axis_length // 2 + 5
553
+ draw.rectangle(
554
+ [(forward_x - 2, forward_y - 2),
555
+ (forward_x + forward_width + 2, forward_y + (forward_bbox[3] - forward_bbox[1]) + 2)],
556
+ fill=(0, 0, 0)
557
+ )
558
+ draw.text((forward_x, forward_y), forward_text, fill=text_color, font=font)
559
+
560
+ # Backward (现在在上方)
561
+ backward_text = "backward"
562
+ backward_bbox = draw.textbbox((0, 0), backward_text, font=font)
563
+ backward_width = backward_bbox[2] - backward_bbox[0]
564
+ backward_x = center_x - backward_width // 2
565
+ backward_y = center_y - axis_length // 2 - 20
566
+ draw.rectangle(
567
+ [(backward_x - 2, backward_y - 2),
568
+ (backward_x + backward_width + 2, backward_y + (backward_bbox[3] - backward_bbox[1]) + 2)],
569
+ fill=(0, 0, 0)
570
+ )
571
+ draw.text((backward_x, backward_y), backward_text, fill=text_color, font=font)
572
+
573
+ # Right (现在在左侧)
574
+ right_text = "right"
575
+ right_bbox = draw.textbbox((0, 0), right_text, font=font)
576
+ right_width = right_bbox[2] - right_bbox[0]
577
+ right_x = center_x - axis_length // 2 - right_width - 5
578
+ right_y = center_y - (right_bbox[3] - right_bbox[1]) // 2
579
+ draw.rectangle(
580
+ [(right_x - 2, right_y - 2),
581
+ (right_x + right_width + 2, right_y + (right_bbox[3] - right_bbox[1]) + 2)],
582
+ fill=(0, 0, 0)
583
+ )
584
+ draw.text((right_x, right_y), right_text, fill=text_color, font=font)
585
+
586
+ # Left (现在在右侧)
587
+ left_text = "left"
588
+ left_bbox = draw.textbbox((0, 0), left_text, font=font)
589
+ left_width = left_bbox[2] - left_bbox[0]
590
+ left_x = center_x + axis_length // 2 + 5
591
+ left_y = center_y - (left_bbox[3] - left_bbox[1]) // 2
592
+ draw.rectangle(
593
+ [(left_x - 2, left_y - 2),
594
+ (left_x + left_width + 2, left_y + (left_bbox[3] - left_bbox[1]) + 2)],
595
+ fill=(0, 0, 0)
596
+ )
597
+ draw.text((left_x, left_y), left_text, fill=text_color, font=font)
598
+ else:
599
+ # 正常方向(不旋转)
600
+ # 水平轴(left-right)
601
+ draw.line(
602
+ [(center_x - axis_length // 2, center_y),
603
+ (center_x + axis_length // 2, center_y)],
604
+ fill=line_color, width=line_width
605
+ )
606
+
607
+ # 垂直轴(forward-backward)
608
+ draw.line(
609
+ [(center_x, center_y - axis_length // 2),
610
+ (center_x, center_y + axis_length // 2)],
611
+ fill=line_color, width=line_width
612
+ )
613
+
614
+ # 绘制箭头(在轴的两端)
615
+ arrow_size = 5
616
+ # Forward (上) 箭头
617
+ draw.polygon(
618
+ [(center_x, center_y - axis_length // 2),
619
+ (center_x - arrow_size, center_y - axis_length // 2 + arrow_size),
620
+ (center_x + arrow_size, center_y - axis_length // 2 + arrow_size)],
621
+ fill=line_color
622
+ )
623
+ # Backward (下) 箭头
624
+ draw.polygon(
625
+ [(center_x, center_y + axis_length // 2),
626
+ (center_x - arrow_size, center_y + axis_length // 2 - arrow_size),
627
+ (center_x + arrow_size, center_y + axis_length // 2 - arrow_size)],
628
+ fill=line_color
629
+ )
630
+ # Right (右) 箭头
631
+ draw.polygon(
632
+ [(center_x + axis_length // 2, center_y),
633
+ (center_x + axis_length // 2 - arrow_size, center_y - arrow_size),
634
+ (center_x + axis_length // 2 - arrow_size, center_y + arrow_size)],
635
+ fill=line_color
636
+ )
637
+ # Left (左) 箭头
638
+ draw.polygon(
639
+ [(center_x - axis_length // 2, center_y),
640
+ (center_x - axis_length // 2 + arrow_size, center_y - arrow_size),
641
+ (center_x - axis_length // 2 + arrow_size, center_y + arrow_size)],
642
+ fill=line_color
643
+ )
644
+
645
+ # 添加文字标签
646
+ text_color = (255, 255, 255) # 白色文字
647
+
648
+ # Forward (上)
649
+ forward_text = "forward"
650
+ forward_bbox = draw.textbbox((0, 0), forward_text, font=font)
651
+ forward_width = forward_bbox[2] - forward_bbox[0]
652
+ forward_x = center_x - forward_width // 2
653
+ forward_y = center_y - axis_length // 2 - 20
654
+ draw.rectangle(
655
+ [(forward_x - 2, forward_y - 2),
656
+ (forward_x + forward_width + 2, forward_y + (forward_bbox[3] - forward_bbox[1]) + 2)],
657
+ fill=(0, 0, 0)
658
+ )
659
+ draw.text((forward_x, forward_y), forward_text, fill=text_color, font=font)
660
+
661
+ # Backward (下)
662
+ backward_text = "backward"
663
+ backward_bbox = draw.textbbox((0, 0), backward_text, font=font)
664
+ backward_width = backward_bbox[2] - backward_bbox[0]
665
+ backward_x = center_x - backward_width // 2
666
+ backward_y = center_y + axis_length // 2 + 5
667
+ draw.rectangle(
668
+ [(backward_x - 2, backward_y - 2),
669
+ (backward_x + backward_width + 2, backward_y + (backward_bbox[3] - backward_bbox[1]) + 2)],
670
+ fill=(0, 0, 0)
671
+ )
672
+ draw.text((backward_x, backward_y), backward_text, fill=text_color, font=font)
673
+
674
+ # Right (右)
675
+ right_text = "right"
676
+ right_bbox = draw.textbbox((0, 0), right_text, font=font)
677
+ right_width = right_bbox[2] - right_bbox[0]
678
+ right_x = center_x + axis_length // 2 + 5
679
+ right_y = center_y - (right_bbox[3] - right_bbox[1]) // 2
680
+ draw.rectangle(
681
+ [(right_x - 2, right_y - 2),
682
+ (right_x + right_width + 2, right_y + (right_bbox[3] - right_bbox[1]) + 2)],
683
+ fill=(0, 0, 0)
684
+ )
685
+ draw.text((right_x, right_y), right_text, fill=text_color, font=font)
686
+
687
+ # Left (左)
688
+ left_text = "left"
689
+ left_bbox = draw.textbbox((0, 0), left_text, font=font)
690
+ left_width = left_bbox[2] - left_bbox[0]
691
+ left_x = center_x - axis_length // 2 - left_width - 5
692
+ left_y = center_y - (left_bbox[3] - left_bbox[1]) // 2
693
+ draw.rectangle(
694
+ [(left_x - 2, left_y - 2),
695
+ (left_x + left_width + 2, left_y + (left_bbox[3] - left_bbox[1]) + 2)],
696
+ fill=(0, 0, 0)
697
+ )
698
+ draw.text((left_x, left_y), left_text, fill=text_color, font=font)
699
+
700
+ return img
701
+
702
+
703
+ def draw_marker(img, x, y):
704
+ """Draws a red circle and cross at (x, y)."""
705
+ if isinstance(img, np.ndarray):
706
+ img = Image.fromarray(img)
707
+
708
+ img = img.copy()
709
+ draw = ImageDraw.Draw(img)
710
+ r = 5
711
+ # Circle
712
+ draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2)
713
+ # Cross
714
+ draw.line((x-r, y, x+r, y), fill="red", width=2)
715
+ draw.line((x, y-r, x, y+r), fill="red", width=2)
716
+ return img
gradio-web/main.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main entry for Gradio app (single-instance mode for Hugging Face Spaces)."""
2
+
3
+ import os
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ from ui_layout import create_ui_blocks
8
+ from state_manager import start_timeout_monitor
9
+
10
+ APP_DIR = Path(__file__).resolve().parent
11
+ PROJECT_ROOT = APP_DIR.parent
12
+ VIDEOS_DIR = APP_DIR / "videos"
13
+ TEMP_DEMOS_DIR = PROJECT_ROOT / "temp_demos"
14
+ CWD_TEMP_DEMOS_DIR = Path.cwd() / "temp_demos"
15
+
16
+
17
+ def ensure_media_dirs():
18
+ """Ensure media temp directories exist before first write."""
19
+ TEMP_DEMOS_DIR.mkdir(parents=True, exist_ok=True)
20
+ CWD_TEMP_DEMOS_DIR.mkdir(parents=True, exist_ok=True)
21
+
22
+
23
+ def build_allowed_paths():
24
+ """Build Gradio file access allowlist (absolute, deduplicated)."""
25
+ candidates = [
26
+ Path.cwd(),
27
+ PROJECT_ROOT,
28
+ APP_DIR,
29
+ VIDEOS_DIR,
30
+ TEMP_DEMOS_DIR,
31
+ CWD_TEMP_DEMOS_DIR,
32
+ Path(tempfile.gettempdir()),
33
+ ]
34
+ deduped = []
35
+ seen = set()
36
+ for path in candidates:
37
+ normalized = str(path.resolve())
38
+ if normalized not in seen:
39
+ seen.add(normalized)
40
+ deduped.append(normalized)
41
+ return deduped
42
+
43
+
44
+ def main():
45
+ ensure_media_dirs()
46
+ start_timeout_monitor()
47
+
48
+ os.environ.setdefault("ROBOMME_TEMP_DEMOS_DIR", str(TEMP_DEMOS_DIR))
49
+ allowed_paths = build_allowed_paths()
50
+
51
+ demo = create_ui_blocks()
52
+ demo.queue(default_concurrency_limit=2)
53
+ demo.launch(
54
+ server_name="0.0.0.0",
55
+ server_port=int(os.getenv("PORT", "7860")),
56
+ allowed_paths=allowed_paths,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()
gradio-web/note_content.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Note content management module
3
+ Manages Coordinate Information and Task Hint content
4
+ """
5
+ def get_coordinate_information():
6
+ """
7
+ Get coordinate information content (Note 1)
8
+
9
+ Returns:
10
+ str: Coordinate information in Markdown format
11
+ """
12
+ return """
13
+ The coordinate system differs based on the camera perspective.
14
+
15
+ In the base camera view, the lateral axis is inverted relative to the robot: the right side of the camera frame corresponds to the robot's left side, and vice versa.
16
+
17
+ Conversely, the wrist camera view is fully aligned with the robot's motion frame. Directional movements are consistent, meaning 'right' in the camera view corresponds to the robot's right, and 'forward' implies forward movement
18
+ """
19
+
20
+
21
+ def get_task_hint(env_id):
22
+ """
23
+ Get task hint content based on environment ID (Note 2)
24
+
25
+ Args:
26
+ env_id (str): Environment ID, e.g., "VideoPlaceOrder", "PickXtimes", etc.
27
+
28
+ Returns:
29
+ str: Task hint in Markdown format
30
+ """
31
+ # Return different hints based on env_id
32
+ # Order follows solve_3.5_parallel_multi_loop_v4.py DEFAULT_ENVS list
33
+ hints = {
34
+ "PickXtimes": """Suppose the task goal is to pick up red cubes for two times, a typical action sequence could be:
35
+ 1. pick up the cube (use mouse click to select the cube with the correct color)
36
+ 2. place the cube onto the target.
37
+ 3. pick up the cube (use mouse click to select the cube with the correct color)
38
+ 4. place the cube onto the target.
39
+ 5. press the button to stop.
40
+ """,
41
+
42
+ "StopCube": """Suppose the task goal is to stop the cube on the target for three times, a typical action sequence could be:
43
+ 1. move to the top of the button to prepare
44
+ 2. remain static (it will execute for a fixed time duration, you need to count the times the cube has passed the target)
45
+ 3. remain static
46
+ 4. remain static
47
+ 5. remain static (Suppose you feel the cube is about to reach the target for the expected number of times, you should press the button to stop the cube directly)
48
+ 6. press the button to stop.
49
+ """,
50
+
51
+ "SwingXtimes": """Suppose the task goal is to swing the back and forth for two times, a typical action sequence could be:
52
+ 1. pick up the cube (use mouse click to select the cube with the correct color)
53
+ 2. move to the top of the target (use mouse click to select the right-side target)
54
+ 3. move to the top of the target (use mouse click to select the left-side target)
55
+ 4. move to the top of the target (use mouse click to select the right-side target)
56
+ 5. move to the top of the target (use mouse click to select the left-side target)
57
+ 6. put the cube onto the table
58
+ 7. press the button to stop.
59
+ """,
60
+
61
+ "BinFill": """Suppose the task goal is to pick two red cubes in the bin, a typical action sequence could be:
62
+ 1. pick up the cube (use mouse click to select the cube with the correct color)
63
+ 2. put it into the bin.
64
+ 3. pick up the cube (use mouse click to select the cube with the correct color)
65
+ 4. put it into the bin.
66
+ 5. press the button to stop.
67
+ """,
68
+
69
+ "VideoUnmaskSwap": """Watch the video carefully. Cubes will be hidden by containers, and you need to memorize the color of the cube inside each one.
70
+ You need to track the containers since they swap positions!
71
+ A typical action sequence could be:
72
+ 1. pick up the container (use mouse click to select the container)
73
+ 2. drop the container down.
74
+
75
+ pick up another container if the task goal is to find two containers.
76
+ """,
77
+
78
+ "VideoUnmask": """Watch the video carefully. Cubes will be hidden by containers, and you need to memorize the color of the cube inside each one.
79
+ A typical action sequence could be:
80
+ 1. pick up the container (use mouse click to select the container)
81
+ 2. drop the container down.
82
+
83
+ pick up another container if the task goal is to find two containers.
84
+ """,
85
+
86
+ "ButtonUnmaskSwap":
87
+ """Press the buttons sequentially. While pressing the buttons, the cubes will be hidden inside the containers, and you need to memorize the color of the cube inside each one.
88
+ You need to track the containers since they swap positions!
89
+ A typical action sequence could be:
90
+ 1. press the first button.
91
+ 2. press the second button.
92
+ 3. pick up the container (use mouse click to select the container)
93
+ 4. drop the container down.
94
+
95
+ pick up another container if the task goal is to find two containers.
96
+ """,
97
+
98
+ "ButtonUnmask":"""Press the buttons sequentially. While pressing the buttons, the cubes will be hidden inside the containers, and you need to memorize the color of the cube inside each one.
99
+ A typical action sequence could be:
100
+ 1. press the button.
101
+ 2. pick up the container (use mouse click to select the container)
102
+ 3. drop the container down.
103
+
104
+ pick up another container if the task goal is to find two containers.
105
+ """,
106
+
107
+ "VideoRepick": """Remember the cube that has been picked up before, and then pick it up again. The cubes might be swapped positions.
108
+ A typical action sequence could be:
109
+ 1. pick up the cube (use mouse click to select the correct cube with the correct color)
110
+ 2. put the cube down on the table.
111
+ (repeat 1 and 2 for the expected number of times)
112
+ 3. press the button to stop.
113
+ """,
114
+
115
+ "VideoPlaceButton":
116
+ """The video shows a robot placing a cube on different targets and pressing the button in a sequence. The targets may change positions.
117
+ A typical action sequence could be:
118
+ 1. pick up the cube (use mouse click to select the correct cube with the correct color)
119
+ 2. put the cube down on the target (use mouse click to select the target)
120
+ """
121
+ ,
122
+
123
+ "VideoPlaceOrder": """The video shows a robot placing a cube on different targets and pressing the button in a sequence. The targets may change positions.
124
+ A typical action sequence could be:
125
+ 1. pick up the cube (use mouse click to select the correct cube with the correct color)
126
+ 2. put the cube down on the target (use mouse click to select the target)
127
+ """,
128
+
129
+ "PickHighlight": """While the robot is pressing the button, some cubes will be highlighted with white discs on the table. Remember them.
130
+ A typical action sequence could be:
131
+ 1. press the button.
132
+ 2. pick up the cube (use mouse click to select the correct cube with the correct color)
133
+ 3. put the cube down on the table.
134
+ (Repeat 2 and 3 for with the rest of highlighted cubes)
135
+ """,
136
+
137
+ "InsertPeg": """The video shows a robot picking up and inserting a peg into a hole.
138
+ The peg consists of two parts with different colors; you need to pick up the correct part of the peg and insert it into the hole from the correct side.
139
+ A typical action sequence could be:
140
+ 1. pick up the peg (use mouse click to select the correct peg and the correct part of the peg)
141
+ 2. insert the peg into the hole on the left side
142
+ """,
143
+
144
+ "MoveCube": """The video shows a robot moving a cube to a target using different methods.
145
+ The robot might (1) pick up and place the cube, (2) push it with the gripper, or (3) hook it using a peg.
146
+ Remember the way the robot moves the cube and choose the correct action to execute.
147
+ """,
148
+
149
+ "PatternLock": """The video shows a robot tracing a pattern with a stick.
150
+ Remember the movements and reproduce them by choosing correct actions.
151
+ The correct directions (e.g., left, right, forward, backward) are as given near the base camera view.
152
+ """,
153
+
154
+ "RouteStick": """The video shows a robot navigating from one target to another by circling around a stick.
155
+ The movement can be clockwise or counter-clockwise, and the stick may be on the left or right side.
156
+ Remember the sequence of actions and choose the correct action to execute.
157
+ The correct directions (e.g., left, right, forward, backward) are as given near the base camera view.
158
+ """,
159
+
160
+ }
161
+
162
+ # Normalize env_id to handle case-insensitive matching
163
+ # First try direct lookup
164
+ if env_id in hints:
165
+ return hints[env_id]
166
+
167
+ # Create a mapping from lowercase to standard format for case-insensitive lookup
168
+ # This handles cases where env_id might be passed as lowercase (e.g., "pickxtimes", "binfill")
169
+ env_id_lower_to_standard = {
170
+ key.lower(): key for key in hints.keys()
171
+ }
172
+
173
+ # Try case-insensitive lookup
174
+ if env_id:
175
+ env_id_lower = env_id.lower()
176
+ if env_id_lower in env_id_lower_to_standard:
177
+ standard_key = env_id_lower_to_standard[env_id_lower]
178
+ return hints[standard_key]
179
+
180
+ # Return default hint if not found
181
+ return """///"""
gradio-web/oracle_logic.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import gymnasium as gym
5
+ import cv2
6
+ import colorsys
7
+ import torch
8
+ from pathlib import Path
9
+ from PIL import Image
10
+
11
+ # --- Setup Paths ---
12
+ # Ensure we can import local project packages from parent directory
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ parent_dir = os.path.dirname(current_dir)
15
+ if parent_dir not in sys.path:
16
+ sys.path.insert(0, parent_dir)
17
+
18
+ # --- NLP Imports ---
19
+ try:
20
+ from sentence_transformers import SentenceTransformer, util as st_util
21
+ print("Loading NLP Model (all-MiniLM-L6-v2)...")
22
+ _NLP_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
23
+ print("NLP Model loaded.")
24
+ except ImportError:
25
+ print("Warning: sentence-transformers not found. NLP matching will fail.")
26
+ _NLP_MODEL = None
27
+ except Exception as e:
28
+ print(f"Error loading NLP model: {e}")
29
+ _NLP_MODEL = None
30
+
31
+ # --- Project Imports ---
32
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
33
+ from robomme.robomme_env import * # noqa: F401,F403; ensure gym envs are registered
34
+ from robomme.robomme_env.utils.vqa_options import get_vqa_options
35
+ from robomme.robomme_env.utils.oracle_action_matcher import (
36
+ find_exact_label_option_index,
37
+ map_action_text_to_option_label,
38
+ )
39
+ from robomme.robomme_env.utils.choice_action_mapping import (
40
+ extract_actor_position_xyz,
41
+ project_world_to_pixel,
42
+ select_target_with_position,
43
+ )
44
+ from mani_skill.examples.motionplanning.panda.motionplanner import PandaArmMotionPlanningSolver
45
+ from mani_skill.examples.motionplanning.panda.motionplanner_stick import PandaStickMotionPlanningSolver
46
+
47
+ # --- FailAware Planner Imports ---
48
+ try:
49
+ from robomme.robomme_env.utils.planner_fail_safe import (
50
+ FailAwarePandaArmMotionPlanningSolver,
51
+ FailAwarePandaStickMotionPlanningSolver,
52
+ ScrewPlanFailure,
53
+ )
54
+ except ImportError as e:
55
+ print(f"Warning: Failed to import robomme fail-aware planners: {e}")
56
+ # Fallback to regular planners
57
+ FailAwarePandaArmMotionPlanningSolver = PandaArmMotionPlanningSolver
58
+ FailAwarePandaStickMotionPlanningSolver = PandaStickMotionPlanningSolver
59
+ ScrewPlanFailure = RuntimeError
60
+
61
+ # --- Constants ---
62
+ ROBOMME_METADATA_ROOT_ENV = "ROBOMME_METADATA_ROOT"
63
+ # For backward compatibility with process_session constructor naming.
64
+ # Semantics: optional override root for metadata json files.
65
+ DEFAULT_DATASET_ROOT = os.environ.get(ROBOMME_METADATA_ROOT_ENV)
66
+
67
+ # --- Helper Functions from Script ---
68
+
69
+ def _generate_color_map(n=10000, s_min=0.70, s_max=0.95, v_min=0.78, v_max=0.95):
70
+ phi = 0.6180339887498948
71
+ color_map = {}
72
+ for i in range(1, n + 1):
73
+ h = (i * phi) % 1.0
74
+ s = s_min + (s_max - s_min) * ((i % 7) / 6)
75
+ v = v_min + (v_max - v_min) * (((i * 3) % 5) / 4)
76
+ r, g, b = colorsys.hsv_to_rgb(h, s, v)
77
+ color_map[i] = [int(round(r * 255)), int(round(g * 255)), int(round(b * 255))]
78
+ return color_map
79
+
80
+ def _sync_table_color(env, color_map):
81
+ seg_id_map = getattr(env.unwrapped, "segmentation_id_map", None)
82
+ if not isinstance(seg_id_map, dict):
83
+ return
84
+ for obj_id, obj in seg_id_map.items():
85
+ if getattr(obj, "name", None) == "table-workspace":
86
+ color_map[obj_id] = [0, 0, 0]
87
+
88
+ def _tensor_to_bool(value):
89
+ if value is None:
90
+ return False
91
+ if isinstance(value, torch.Tensor):
92
+ return bool(value.detach().cpu().bool().item())
93
+ if isinstance(value, np.ndarray):
94
+ return bool(np.any(value))
95
+ return bool(value)
96
+
97
+ def _prepare_frame(frame):
98
+ frame = np.asarray(frame)
99
+ if frame.dtype != np.uint8:
100
+ max_val = float(np.max(frame)) if frame.size else 0.0
101
+ if max_val <= 1.0:
102
+ frame = (frame * 255.0).clip(0, 255).astype(np.uint8)
103
+ else:
104
+ frame = frame.clip(0, 255).astype(np.uint8)
105
+ if frame.ndim == 2:
106
+ frame = np.stack([frame] * 3, axis=-1)
107
+ return frame
108
+
109
+ def _prepare_segmentation_visual(segmentation, color_map, target_hw):
110
+ if segmentation is None:
111
+ return None, None
112
+
113
+ seg = segmentation
114
+ if hasattr(seg, "cpu"):
115
+ seg = seg.cpu().numpy()
116
+ seg = np.asarray(seg)
117
+ if seg.ndim > 2:
118
+ seg = seg[0]
119
+ seg_2d = seg.squeeze().astype(np.int64)
120
+
121
+ h, w = seg_2d.shape[:2]
122
+ seg_rgb = np.zeros((h, w, 3), dtype=np.uint8)
123
+ unique_ids = np.unique(seg_2d)
124
+ for seg_id in unique_ids:
125
+ if seg_id <= 0:
126
+ continue
127
+ color = color_map.get(int(seg_id))
128
+ if color is None:
129
+ continue
130
+ seg_rgb[seg_2d == seg_id] = color
131
+ seg_bgr = cv2.cvtColor(seg_rgb, cv2.COLOR_RGB2BGR)
132
+
133
+ target_h, target_w = target_hw
134
+ if seg_bgr.shape[:2] != (target_h, target_w):
135
+ seg_bgr = cv2.resize(seg_bgr, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
136
+
137
+ return seg_bgr, seg_2d
138
+
139
+ def _fetch_segmentation(env):
140
+ obs = env.unwrapped.get_obs(unflattened=True)
141
+ return obs["sensor_data"]["base_camera"]["segmentation"]
142
+
143
+ def _build_solve_options(env, planner, selected_target, env_id):
144
+ return get_vqa_options(env, planner, selected_target, env_id)
145
+
146
+ def _extract_first_text(value, default="Unknown Goal"):
147
+ if isinstance(value, str):
148
+ text = value.strip()
149
+ return text or default
150
+ if isinstance(value, (list, tuple)):
151
+ for item in value:
152
+ if item is None:
153
+ continue
154
+ text = str(item).strip()
155
+ if text:
156
+ return text
157
+ return default
158
+
159
+ def _ensure_list(value):
160
+ if value is None:
161
+ return []
162
+ if isinstance(value, list):
163
+ return value
164
+ if isinstance(value, tuple):
165
+ return list(value)
166
+ return []
167
+
168
+ def _to_frame_list(frames_like):
169
+ if frames_like is None:
170
+ return []
171
+ if isinstance(frames_like, list):
172
+ return frames_like
173
+ if isinstance(frames_like, tuple):
174
+ return list(frames_like)
175
+ if isinstance(frames_like, torch.Tensor):
176
+ arr = frames_like.detach().cpu().numpy()
177
+ if arr.ndim == 3:
178
+ return [arr]
179
+ if arr.ndim == 4:
180
+ return [x for x in arr]
181
+ return []
182
+ if isinstance(frames_like, np.ndarray):
183
+ if frames_like.ndim == 3:
184
+ return [frames_like]
185
+ if frames_like.ndim == 4:
186
+ return [x for x in frames_like]
187
+ return []
188
+ return []
189
+
190
+ def _iter_env_chain(env, max_depth=16):
191
+ current = env
192
+ seen = set()
193
+ for _ in range(max_depth):
194
+ if current is None:
195
+ return
196
+ env_id = id(current)
197
+ if env_id in seen:
198
+ return
199
+ seen.add(env_id)
200
+ yield current
201
+ current = getattr(current, "env", None)
202
+
203
+ def _extract_obs_front_frames(env):
204
+ """
205
+ Strict path: only use wrapper-produced obs batch front_rgb_list.
206
+ Returns (front_list, obs_ref_id) or (None, None) if unavailable.
207
+ """
208
+ for wrapped in _iter_env_chain(env):
209
+ for attr_name in ("_last_obs", "last_obs"):
210
+ obs_candidate = getattr(wrapped, attr_name, None)
211
+ if not isinstance(obs_candidate, dict):
212
+ continue
213
+ if "front_rgb_list" not in obs_candidate:
214
+ continue
215
+ front_list = _to_frame_list(obs_candidate.get("front_rgb_list"))
216
+ return front_list, id(obs_candidate)
217
+ return None, None
218
+
219
+ def _collect_front_frames_from_step_output(step_output):
220
+ """
221
+ Extract front camera frames from a single env.step(...) output.
222
+ Supports both classic step tuple and dense batch tuple.
223
+ """
224
+ if not (isinstance(step_output, tuple) and len(step_output) == 5):
225
+ return []
226
+ obs = step_output[0]
227
+ if not isinstance(obs, dict):
228
+ return []
229
+ return _to_frame_list(obs.get("front_rgb_list"))
230
+
231
+
232
+ def _collect_choice_segment_candidates(item, out):
233
+ if isinstance(item, (list, tuple)):
234
+ for child in item:
235
+ _collect_choice_segment_candidates(child, out)
236
+ return
237
+ if isinstance(item, dict):
238
+ for child in item.values():
239
+ _collect_choice_segment_candidates(child, out)
240
+ return
241
+ if item is not None:
242
+ out.append(item)
243
+
244
+
245
+ def _extract_choice_segment_position_xyz(current_segment):
246
+ candidates = []
247
+ _collect_choice_segment_candidates(current_segment, candidates)
248
+ for candidate in candidates:
249
+ pos = extract_actor_position_xyz(candidate)
250
+ if pos is not None:
251
+ return pos.astype(np.float64)
252
+ return None
253
+
254
+
255
+ def _find_actor_segmentation_id(segmentation_id_map, actor):
256
+ if not isinstance(segmentation_id_map, dict):
257
+ return None
258
+ for seg_id, obj in segmentation_id_map.items():
259
+ if obj is actor:
260
+ try:
261
+ return int(seg_id)
262
+ except Exception:
263
+ continue
264
+ return None
265
+
266
+
267
+ def _compute_segmentation_centroid_xy(segmentation, seg_id):
268
+ if segmentation is None:
269
+ return None
270
+ try:
271
+ seg_arr = np.asarray(segmentation)
272
+ except Exception:
273
+ return None
274
+ if seg_arr.ndim > 2:
275
+ seg_arr = np.squeeze(seg_arr)
276
+ if seg_arr.ndim != 2:
277
+ return None
278
+ mask = seg_arr == int(seg_id)
279
+ if not np.any(mask):
280
+ return None
281
+ ys, xs = np.nonzero(mask)
282
+ x = int(np.rint(xs.mean()))
283
+ y = int(np.rint(ys.mean()))
284
+ return [x, y]
285
+
286
+ def _extract_demonstration_payload(demonstration_data):
287
+ """
288
+ Compatible with both legacy dict payloads and current DemonstrationWrapper tuple batch:
289
+ - dict style: {"language goal": "...", "frames": [...]}
290
+ - tuple/list style: (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch)
291
+ """
292
+ default_goal = "Unknown Goal"
293
+ default_frames = []
294
+
295
+ if isinstance(demonstration_data, dict):
296
+ goal_candidate = (
297
+ demonstration_data.get("language goal")
298
+ or demonstration_data.get("language_goal")
299
+ or demonstration_data.get("task_goal")
300
+ )
301
+ frames_candidate = demonstration_data.get("frames")
302
+ if frames_candidate is None:
303
+ frames_candidate = demonstration_data.get("front_rgb_list")
304
+ return _extract_first_text(goal_candidate, default_goal), _ensure_list(frames_candidate)
305
+
306
+ if isinstance(demonstration_data, (tuple, list)):
307
+ obs_batch = demonstration_data[0] if len(demonstration_data) >= 1 else None
308
+ info_batch = demonstration_data[4] if len(demonstration_data) >= 5 else None
309
+ if info_batch is None and len(demonstration_data) >= 2 and isinstance(demonstration_data[1], dict):
310
+ # Fallback for (obs, info) shaped payloads
311
+ info_batch = demonstration_data[1]
312
+
313
+ frames_candidate = None
314
+ if isinstance(obs_batch, dict):
315
+ frames_candidate = obs_batch.get("front_rgb_list")
316
+
317
+ goal_candidate = None
318
+ if isinstance(info_batch, dict):
319
+ goal_candidate = info_batch.get("task_goal")
320
+ if goal_candidate is None:
321
+ goal_candidate = info_batch.get("language goal")
322
+ if goal_candidate is None:
323
+ goal_candidate = info_batch.get("language_goal")
324
+
325
+ return _extract_first_text(goal_candidate, default_goal), _ensure_list(frames_candidate)
326
+
327
+ return default_goal, default_frames
328
+
329
+ def _find_best_semantic_match(user_query, options):
330
+ if _NLP_MODEL is None:
331
+ return -1, 0.0
332
+
333
+ if not options:
334
+ return -1, 0.0
335
+
336
+ labels = [opt.get("label", "") for opt in options]
337
+ query_text = str(user_query or "").strip()
338
+
339
+ try:
340
+ query_embedding = _NLP_MODEL.encode(query_text, convert_to_tensor=True)
341
+ corpus_embeddings = _NLP_MODEL.encode(labels, convert_to_tensor=True)
342
+ cos_scores = st_util.cos_sim(query_embedding, corpus_embeddings)[0]
343
+ best_idx = torch.argmax(cos_scores).item()
344
+ best_score = cos_scores[best_idx].item()
345
+ except Exception as exc:
346
+ print(f" [NLP] Semantic match failed ({exc}); defaulting to option 1.")
347
+ return 0, 0.0
348
+
349
+ return best_idx, best_score
350
+
351
+ # --- Core Logic Wrapper ---
352
+
353
+ class OracleSession:
354
+ def __init__(self, dataset_root=DEFAULT_DATASET_ROOT, gui_render=False):
355
+ """
356
+ gui_render: If True, uses 'human' render mode (pops up window).
357
+ For Gradio, we usually want False (rgb_array).
358
+ """
359
+ self.dataset_root = Path(dataset_root) if dataset_root else None
360
+ self.gui_render = gui_render # Usually False for web app
361
+ self.render_mode = "human" if gui_render else "rgb_array"
362
+
363
+ self.env = None
364
+ self.planner = None
365
+ self.color_map = None
366
+ self.env_id = None
367
+ self.episode_idx = None
368
+ self.language_goal = ""
369
+ self.difficulty = None
370
+ self.seed = None
371
+ self.history = [] # Logs interaction steps
372
+
373
+ # State caches
374
+ self.seg_vis = None
375
+ self.seg_raw = None
376
+ self.base_frames = []
377
+ self.wrist_frames = []
378
+ self.demonstration_frames = []
379
+ self.available_options = []
380
+ self.raw_solve_options = []
381
+ # Track frame indices for incremental video updates
382
+ self.last_base_frame_idx = 0
383
+ self.last_wrist_frame_idx = 0
384
+ self.non_demonstration_task_length = None # 从 DemonstrationWrapper 读取
385
+ # Track latest obs-batch object and consumed indices to avoid duplicate appends.
386
+ self._last_obs_ref_id = None
387
+ self._last_obs_front_consumed = 0
388
+
389
+ def _resolve_metadata_override_root(self):
390
+ if self.dataset_root:
391
+ return self.dataset_root
392
+ env_root = os.environ.get(ROBOMME_METADATA_ROOT_ENV)
393
+ if env_root:
394
+ return Path(env_root)
395
+ return None
396
+
397
+ def load_episode(self, env_id, episode_idx):
398
+ """Initialize environment for a specific episode."""
399
+ if self.env:
400
+ self.env.close()
401
+
402
+ try:
403
+ metadata_override_root = self._resolve_metadata_override_root()
404
+ builder = BenchmarkEnvBuilder(
405
+ env_id=env_id,
406
+ dataset="train",
407
+ # Gradio uses local oracle solve() directly (not env.step(command_dict)),
408
+ # so we must keep a low-level stepping wrapper chain.
409
+ # "multi_choice" inserts OraclePlannerDemonstrationWrapper, which expects
410
+ # dict commands and may swallow planner low-level action arrays.
411
+ action_space="joint_angle",
412
+ gui_render=self.gui_render,
413
+ #gui_render=True,
414
+ override_metadata_path=metadata_override_root,
415
+ max_steps=3000,
416
+ )
417
+
418
+ episode_num = builder.get_episode_num()
419
+ if episode_num <= 0:
420
+ if metadata_override_root:
421
+ expected = metadata_override_root / f"record_dataset_{env_id}_metadata.json"
422
+ return None, f"Dataset metadata not found or empty: {expected}"
423
+ return None, f"Dataset metadata not found or empty for env '{env_id}' in split 'test'"
424
+
425
+ if episode_idx < 0 or episode_idx >= episode_num:
426
+ return None, f"Episode index out of range for {env_id}: {episode_idx} (valid 0-{episode_num - 1})"
427
+
428
+ seed, difficulty = builder.resolve_episode(episode_idx)
429
+ self.env = builder.make_env_for_episode(episode_idx)
430
+ self.env.reset()
431
+ self.env_id = env_id
432
+ self.episode_idx = episode_idx
433
+ self.difficulty = difficulty
434
+ self.seed = seed
435
+
436
+ # Demonstration data
437
+ demonstration_data = getattr(self.env, "demonstration_data", None)
438
+ self.language_goal, self.demonstration_frames = _extract_demonstration_payload(demonstration_data)
439
+
440
+ # Setup Color Map
441
+ self.color_map = _generate_color_map()
442
+ _sync_table_color(self.env, self.color_map)
443
+
444
+ # Initialize Planner (using FailAware versions)
445
+ if env_id in ("PatternLock", "RouteStick"):
446
+ self.planner = FailAwarePandaStickMotionPlanningSolver(
447
+ self.env, debug=False, vis=self.gui_render,
448
+ base_pose=self.env.unwrapped.agent.robot.pose,
449
+ visualize_target_grasp_pose=False, print_env_info=False,
450
+ joint_vel_limits=0.3,
451
+ )
452
+ else:
453
+ self.planner = FailAwarePandaArmMotionPlanningSolver(
454
+ self.env, debug=False, vis=self.gui_render,
455
+ base_pose=self.env.unwrapped.agent.robot.pose,
456
+ visualize_target_grasp_pose=False, print_env_info=False,
457
+ )
458
+
459
+ self.env.unwrapped.evaluate() # Initial eval check
460
+
461
+ # 从 DemonstrationWrapper 读取 non_demonstration_task_length(如果存在)
462
+ self.non_demonstration_task_length = getattr(self.env, 'non_demonstration_task_length', None)
463
+
464
+ # Reset logs
465
+ self.history = []
466
+
467
+ # Reset frame indices
468
+ self.last_base_frame_idx = 0
469
+ self.last_wrist_frame_idx = 0
470
+ self.base_frames = []
471
+ self.wrist_frames = []
472
+ self._last_obs_ref_id = None
473
+ self._last_obs_front_consumed = 0
474
+
475
+ # Initial Observation
476
+ return self.update_observation()
477
+
478
+ except Exception as e:
479
+ import traceback
480
+ traceback.print_exc()
481
+ return None, f"Error initializing episode: {e}"
482
+
483
+ def update_observation(self, use_segmentation=True):
484
+ """Captures current state, updates segmentation, and generates options."""
485
+ if not self.env:
486
+ return None, "Environment not initialized"
487
+
488
+ # 1. Capture Frames (strict path: only front_rgb_list from wrapper obs batch)
489
+ front_frames, obs_ref_id = _extract_obs_front_frames(self.env)
490
+ self.wrist_frames = []
491
+ if front_frames is not None:
492
+ front_frames = front_frames or []
493
+ if obs_ref_id != self._last_obs_ref_id:
494
+ self._last_obs_ref_id = obs_ref_id
495
+ self._last_obs_front_consumed = 0
496
+ new_front = front_frames[self._last_obs_front_consumed:]
497
+ self._last_obs_front_consumed = len(front_frames)
498
+ if new_front:
499
+ self.base_frames.extend(_prepare_frame(frame) for frame in new_front if frame is not None)
500
+ else:
501
+ self.base_frames = []
502
+ self._last_obs_ref_id = None
503
+ self._last_obs_front_consumed = 0
504
+
505
+ seg_data = _fetch_segmentation(self.env)
506
+
507
+ # 2. Determine Resolution
508
+ seg_hw = (255, 255) # Default
509
+ if self.base_frames and len(self.base_frames) > 0:
510
+ seg_hw = self.base_frames[-1].shape[:2]
511
+ elif seg_data is not None:
512
+ # Try to guess from seg data
513
+ try:
514
+ temp = seg_data
515
+ if hasattr(temp, "cpu"): temp = temp.cpu().numpy()
516
+ temp = np.asarray(temp)
517
+ if temp.ndim > 2: temp = temp[0]
518
+ seg_hw = temp.shape[:2]
519
+ except: pass
520
+
521
+ # 3. Process Segmentation/Image
522
+ if use_segmentation:
523
+ self.seg_vis, self.seg_raw = _prepare_segmentation_visual(seg_data, self.color_map, seg_hw)
524
+ else:
525
+ # If not using segmentation view, use RGB but scale to match seg logic
526
+ seg_vis_from_seg, self.seg_raw = (
527
+ _prepare_segmentation_visual(seg_data, self.color_map, seg_hw)
528
+ if seg_data is not None
529
+ else (None, None)
530
+ )
531
+ if self.base_frames:
532
+ vis_frame = _prepare_frame(self.base_frames[-1])
533
+ vis_frame = cv2.cvtColor(vis_frame, cv2.COLOR_RGB2BGR) # Keep consistent BGR internally
534
+ if vis_frame.shape[:2] != seg_hw:
535
+ vis_frame = cv2.resize(vis_frame, (seg_hw[1], seg_hw[0]), interpolation=cv2.INTER_LINEAR)
536
+ self.seg_vis = vis_frame
537
+ elif seg_vis_from_seg is not None:
538
+ # 没有 RGB 原始帧时,回退到 segmentation 可视化,避免首屏空白。
539
+ self.seg_vis = seg_vis_from_seg
540
+ else:
541
+ self.seg_vis = np.zeros((seg_hw[0], seg_hw[1], 3), dtype=np.uint8)
542
+
543
+ # 4. Generate Options
544
+ dummy_target = {"obj": None, "name": None, "seg_id": None, "click_point": None, "centroid_point": None}
545
+ self.raw_solve_options = _build_solve_options(self.env, self.planner, dummy_target, self.env_id)
546
+
547
+ # Format for UI
548
+ self.available_options = []
549
+ for i, opt in enumerate(self.raw_solve_options):
550
+ opt_label = str(opt.get("label", f"Option {i + 1}")).strip()
551
+ opt_action = str(opt.get("action", "")).strip()
552
+ if opt_label and opt_action:
553
+ ui_label = f"{opt_label}. {opt_action}"
554
+ else:
555
+ ui_label = opt_label or opt_action or f"Option {i + 1}"
556
+ self.available_options.append((ui_label, i)) # Tuple for Gradio Radio/Dropdown
557
+
558
+ return self.get_pil_image(), "Ready"
559
+
560
+ def get_pil_image(self, use_segmented=True):
561
+ """
562
+ 获取PIL图像
563
+
564
+ Args:
565
+ use_segmented: 如果为True,返回分割视图(seg_vis);如果为False,返回原图(base_frames)
566
+ """
567
+ if use_segmented:
568
+ # 返回分割视图
569
+ if self.seg_vis is None:
570
+ return Image.new('RGB', (255, 255), color='gray')
571
+ # Convert BGR (OpenCV) to RGB (PIL)
572
+ rgb = cv2.cvtColor(self.seg_vis, cv2.COLOR_BGR2RGB)
573
+ return Image.fromarray(rgb)
574
+ else:
575
+ # 返回原图
576
+ if not self.base_frames or len(self.base_frames) == 0:
577
+ return Image.new('RGB', (255, 255), color='gray')
578
+ # 获取最后一帧
579
+ frame = self.base_frames[-1]
580
+ # 准备帧(确保格式正确)
581
+ frame = _prepare_frame(frame)
582
+ # frame 已经是 RGB 格式,直接转换为 PIL Image
583
+ return Image.fromarray(frame)
584
+
585
+ def close(self):
586
+ if self.env:
587
+ self.env.close()
588
+
589
+ def _get_front_camera_projection_params(self):
590
+ if not self.env:
591
+ return None, None, None
592
+
593
+ intrinsic = None
594
+ extrinsic = None
595
+ image_shape = None
596
+
597
+ try:
598
+ obs = self.env.unwrapped.get_obs(unflattened=True)
599
+ except Exception:
600
+ obs = None
601
+
602
+ if isinstance(obs, dict):
603
+ try:
604
+ cam_param = obs.get("sensor_param", {}).get("base_camera", {})
605
+ intrinsic = np.asarray(cam_param.get("intrinsic_cv")).reshape(-1)[:9].reshape(3, 3)
606
+ extrinsic = np.asarray(cam_param.get("extrinsic_cv")).reshape(-1)[:12].reshape(3, 4)
607
+ except Exception:
608
+ intrinsic = None
609
+ extrinsic = None
610
+
611
+ try:
612
+ rgb = obs.get("sensor_data", {}).get("base_camera", {}).get("rgb")
613
+ if rgb is not None and hasattr(rgb, "cpu"):
614
+ rgb = rgb.cpu().numpy()
615
+ rgb = np.asarray(rgb)
616
+ if rgb.ndim == 4:
617
+ image_shape = (int(rgb.shape[1]), int(rgb.shape[2]))
618
+ elif rgb.ndim == 3:
619
+ image_shape = (int(rgb.shape[0]), int(rgb.shape[1]))
620
+ except Exception:
621
+ image_shape = None
622
+
623
+ if image_shape is None and self.seg_raw is not None:
624
+ try:
625
+ seg = np.asarray(self.seg_raw)
626
+ image_shape = (int(seg.shape[0]), int(seg.shape[1]))
627
+ except Exception:
628
+ image_shape = None
629
+
630
+ if image_shape is None and self.base_frames:
631
+ frame = np.asarray(self.base_frames[-1])
632
+ image_shape = (int(frame.shape[0]), int(frame.shape[1]))
633
+
634
+ return intrinsic, extrinsic, image_shape
635
+
636
+ def get_reference_action(self):
637
+ if not self.env:
638
+ return {
639
+ "ok": False,
640
+ "option_idx": None,
641
+ "option_label": "",
642
+ "option_action": "",
643
+ "need_coords": False,
644
+ "coords_xy": None,
645
+ "message": "Environment not initialized.",
646
+ }
647
+
648
+ target_action_text = getattr(self.env.unwrapped, "current_choice_label", "")
649
+ if not isinstance(target_action_text, str) or not target_action_text.strip():
650
+ return {
651
+ "ok": False,
652
+ "option_idx": None,
653
+ "option_label": "",
654
+ "option_action": "",
655
+ "need_coords": False,
656
+ "coords_xy": None,
657
+ "message": "Current step has no ground truth action text.",
658
+ }
659
+
660
+ selected_target = {
661
+ "obj": None,
662
+ "name": None,
663
+ "seg_id": None,
664
+ "click_point": None,
665
+ "centroid_point": None,
666
+ }
667
+ try:
668
+ current_options = _build_solve_options(self.env, self.planner, selected_target, self.env_id)
669
+ except Exception as exc:
670
+ return {
671
+ "ok": False,
672
+ "option_idx": None,
673
+ "option_label": "",
674
+ "option_action": "",
675
+ "need_coords": False,
676
+ "coords_xy": None,
677
+ "message": f"Failed to build options: {exc}",
678
+ }
679
+
680
+ if not current_options:
681
+ return {
682
+ "ok": False,
683
+ "option_idx": None,
684
+ "option_label": "",
685
+ "option_action": "",
686
+ "need_coords": False,
687
+ "coords_xy": None,
688
+ "message": "No available options for current step.",
689
+ }
690
+
691
+ matched_label = map_action_text_to_option_label(target_action_text, current_options)
692
+ if matched_label is None:
693
+ return {
694
+ "ok": False,
695
+ "option_idx": None,
696
+ "option_label": "",
697
+ "option_action": "",
698
+ "need_coords": False,
699
+ "coords_xy": None,
700
+ "message": f"Cannot map ground truth action '{target_action_text}' to option label.",
701
+ }
702
+
703
+ option_idx = find_exact_label_option_index(matched_label, current_options)
704
+ if option_idx < 0:
705
+ return {
706
+ "ok": False,
707
+ "option_idx": None,
708
+ "option_label": "",
709
+ "option_action": "",
710
+ "need_coords": False,
711
+ "coords_xy": None,
712
+ "message": f"Mapped label '{matched_label}' not found in current options.",
713
+ }
714
+
715
+ option = current_options[option_idx]
716
+ option_label = str(option.get("label", "")).strip()
717
+ option_action = str(option.get("action", "")).strip()
718
+ need_coords = bool(option.get("available"))
719
+
720
+ if not need_coords:
721
+ return {
722
+ "ok": True,
723
+ "option_idx": int(option_idx),
724
+ "option_label": option_label,
725
+ "option_action": option_action,
726
+ "need_coords": False,
727
+ "coords_xy": None,
728
+ "message": "Ground truth action resolved.",
729
+ }
730
+
731
+ reference_position = _extract_choice_segment_position_xyz(
732
+ getattr(self.env.unwrapped, "current_segment", None)
733
+ )
734
+ if reference_position is None:
735
+ return {
736
+ "ok": False,
737
+ "option_idx": int(option_idx),
738
+ "option_label": option_label,
739
+ "option_action": option_action,
740
+ "need_coords": True,
741
+ "coords_xy": None,
742
+ "message": "Cannot resolve reference target position from current segment.",
743
+ }
744
+
745
+ best_candidate = select_target_with_position(option.get("available"), reference_position)
746
+ if best_candidate is None or best_candidate.get("obj") is None:
747
+ return {
748
+ "ok": False,
749
+ "option_idx": int(option_idx),
750
+ "option_label": option_label,
751
+ "option_action": option_action,
752
+ "need_coords": True,
753
+ "coords_xy": None,
754
+ "message": "Cannot match reference target to available candidates.",
755
+ }
756
+
757
+ actor = best_candidate.get("obj")
758
+ segmentation_id_map = getattr(self.env.unwrapped, "segmentation_id_map", {}) or {}
759
+ seg_id = _find_actor_segmentation_id(segmentation_id_map, actor)
760
+ coords_xy = None
761
+ if seg_id is not None:
762
+ coords_xy = _compute_segmentation_centroid_xy(self.seg_raw, seg_id)
763
+
764
+ if coords_xy is None:
765
+ world_xyz = best_candidate.get("position")
766
+ if world_xyz is None:
767
+ world_xyz = extract_actor_position_xyz(actor)
768
+ intrinsic, extrinsic, image_shape = self._get_front_camera_projection_params()
769
+ if world_xyz is not None and intrinsic is not None and extrinsic is not None and image_shape is not None:
770
+ coords_xy = project_world_to_pixel(
771
+ world_xyz=world_xyz,
772
+ intrinsic_cv=intrinsic,
773
+ extrinsic_cv=extrinsic,
774
+ image_shape=image_shape,
775
+ )
776
+
777
+ if coords_xy is None:
778
+ return {
779
+ "ok": False,
780
+ "option_idx": int(option_idx),
781
+ "option_label": option_label,
782
+ "option_action": option_action,
783
+ "need_coords": True,
784
+ "coords_xy": None,
785
+ "message": "Failed to compute pixel coordinates for reference target.",
786
+ }
787
+
788
+ coords_xy = [int(coords_xy[0]), int(coords_xy[1])]
789
+ return {
790
+ "ok": True,
791
+ "option_idx": int(option_idx),
792
+ "option_label": option_label,
793
+ "option_action": option_action,
794
+ "need_coords": True,
795
+ "coords_xy": coords_xy,
796
+ "message": f"Ground truth action resolved at ({coords_xy[0]}, {coords_xy[1]}).",
797
+ }
798
+
799
+ def execute_action(self, action_idx, click_coords):
800
+
801
+ # 用户点击EXECUTE
802
+ # ↓
803
+ # execute_step() 调用 session.execute_action()
804
+ # ↓
805
+ # execute_action() 执行 solve()
806
+ # ↓ (在solve()执行过程中,step()可能检测到失败)
807
+ # ↓
808
+ # evaluate(solve_complete_eval=True) 被调用
809
+ # ↓
810
+ # BinFill.evaluate() 检查失败状态
811
+ # - 保存 previous_failure
812
+ # - 调用 sequential_task_check
813
+ # - 如果 previous_failure=True 或 task_failed=True,设置 failureflag=True
814
+ # ↓
815
+ # oracle_logic.py 获取 evaluation 结果
816
+ # - 如果 is_fail=False,额外检查 failureflag 和 current_task_failure
817
+ # - 设置 done = is_success or is_fail
818
+ # ↓
819
+ # execute_step() 检查 done
820
+ # - 如果 done=True,调用 complete_current_task()
821
+ # ↓
822
+ # complete_current_task() 更新任务索引
823
+ # - current_idx: 0 -> 1 (episode: 0 -> 1)
824
+
825
+
826
+ """
827
+ The real step logic.
828
+ """
829
+ if not self.env: return None, "No Env", False
830
+
831
+ # 1. Re-create options with a persistent target dict that we can modify
832
+ target_ref = {"obj": None, "name": None, "seg_id": None, "click_point": None, "centroid_point": None}
833
+ current_options = _build_solve_options(self.env, self.planner, target_ref, self.env_id)
834
+
835
+ if action_idx < 0 or action_idx >= len(current_options):
836
+ return self.get_pil_image(), "Invalid Action Index", False
837
+
838
+ chosen_opt = current_options[action_idx]
839
+
840
+ # 2. Resolve Target (Click -> Object)
841
+ if click_coords:
842
+ # Reuse logic from step() above, applying to target_ref
843
+ cx, cy = click_coords
844
+ h, w = self.seg_raw.shape[:2]
845
+ cx = max(0, min(cx, w-1))
846
+ cy = max(0, min(cy, h-1))
847
+
848
+ seg_id_map = getattr(self.env.unwrapped, "segmentation_id_map", {}) or {}
849
+
850
+ candidates = []
851
+ def _collect(item):
852
+ if isinstance(item, (list, tuple)):
853
+ for x in item: _collect(x)
854
+ elif isinstance(item, dict):
855
+ for x in item.values(): _collect(x)
856
+ else:
857
+ if item: candidates.append(item)
858
+
859
+ avail = chosen_opt.get("available")
860
+ if avail:
861
+ _collect(avail)
862
+ best_cand = None
863
+ min_dist = float('inf')
864
+ for actor in candidates:
865
+ target_ids = [sid for sid, obj in seg_id_map.items() if obj is actor]
866
+ for tid in target_ids:
867
+ tid = int(tid)
868
+ mask = (self.seg_raw == tid)
869
+ if np.any(mask):
870
+ ys, xs = np.nonzero(mask)
871
+ center_x, center_y = xs.mean(), ys.mean()
872
+ dist = (center_x - cx)**2 + (center_y - cy)**2
873
+ if dist < min_dist:
874
+ min_dist = dist
875
+ best_cand = {
876
+ "obj": actor,
877
+ "name": getattr(actor, "name", f"id_{tid}"),
878
+ "seg_id": tid,
879
+ "click_point": (int(cx), int(cy)),
880
+ "centroid_point": (int(center_x), int(center_y))
881
+ }
882
+ if best_cand:
883
+ target_ref.update(best_cand)
884
+ else:
885
+ target_ref["click_point"] = (int(cx), int(cy))
886
+ else:
887
+ target_ref["click_point"] = (int(cx), int(cy))
888
+
889
+ # 3. Execute Solve
890
+ # 异常处理流程:
891
+ # 任何异常发生 (ScrewPlanFailure 或其他异常)
892
+ # ↓
893
+ # oracle_logic.py: 重新抛出异常
894
+ # ↓
895
+ # process_session.py: 捕获并传递到主进程
896
+ # ↓
897
+ # gradio_callbacks.py: 捕获并显示弹窗 (gr.Info)
898
+ status_msg = f"Executing: {chosen_opt.get('label')}"
899
+ before_elapsed_steps = getattr(self.env.unwrapped, "elapsed_steps", None)
900
+ # Collect intermediate front-camera frames during solve() so livestream
901
+ # can show the full execution process instead of only the final frame.
902
+ original_step = self.env.step
903
+ captured_front_frames = []
904
+ stream_frame_callback = getattr(self, "stream_frame_callback", None)
905
+ self._execute_streamed_frame_count = 0
906
+
907
+ def _step_with_capture(action):
908
+ step_output = original_step(action)
909
+ step_front_frames = _collect_front_frames_from_step_output(step_output)
910
+ if step_front_frames:
911
+ prepared_frames = [
912
+ _prepare_frame(frame) for frame in step_front_frames if frame is not None
913
+ ]
914
+ if prepared_frames:
915
+ captured_front_frames.extend(prepared_frames)
916
+ if callable(stream_frame_callback):
917
+ try:
918
+ stream_frame_callback(prepared_frames)
919
+ self._execute_streamed_frame_count += len(prepared_frames)
920
+ except Exception:
921
+ # Keep solve path robust even if streaming callback fails.
922
+ pass
923
+ return step_output
924
+
925
+ self.env.step = _step_with_capture
926
+ try:
927
+ chosen_opt.get("solve")()
928
+ except ScrewPlanFailure as e:
929
+ # Re-raise ScrewPlanFailure so it can be handled in process_session and displayed as popup
930
+ print(f"Screw Plan Failure")
931
+ raise
932
+ except Exception as e:
933
+ # Re-raise all other exceptions so they can be displayed as popup too
934
+ print(f"Execution Error")
935
+ raise
936
+ finally:
937
+ self.env.step = original_step
938
+
939
+ if captured_front_frames:
940
+ self.base_frames.extend(captured_front_frames)
941
+ print(f"[execute_action] captured_front_frames={len(captured_front_frames)}")
942
+ after_elapsed_steps = getattr(self.env.unwrapped, "elapsed_steps", None)
943
+ print(
944
+ "[execute_action] elapsed_steps: "
945
+ f"{before_elapsed_steps} -> {after_elapsed_steps}"
946
+ )
947
+
948
+ # 4. Evaluate
949
+ self.env.unwrapped.evaluate()
950
+ evaluation = self.env.unwrapped.evaluate(solve_complete_eval=True)
951
+
952
+ is_success = _tensor_to_bool(evaluation.get("success", False))
953
+ is_fail = _tensor_to_bool(evaluation.get("fail", False))
954
+
955
+ # 如果evaluate()没有检测到失败,但环境已经设置了failureflag,则使用failureflag
956
+ # 这是因为失败可能在solve()执行过程中的step()里被检测到,但evaluate()可能还没有反映
957
+ failureflag = getattr(self.env.unwrapped, "failureflag", None)
958
+ current_task_failure = getattr(self.env.unwrapped, "current_task_failure", False)
959
+
960
+ if not is_fail:
961
+ if failureflag is not None:
962
+ failureflag_bool = _tensor_to_bool(failureflag)
963
+ if failureflag_bool:
964
+ is_fail = True
965
+ elif current_task_failure:
966
+ is_fail = True
967
+
968
+ if is_success: status_msg += " | SUCCESS"
969
+ if is_fail: status_msg += " | FAILED"
970
+
971
+ # 5. Update State for next step
972
+ img, _ = self.update_observation()
973
+
974
+ done = is_success or is_fail
975
+ return img, status_msg, done
gradio-web/process_session.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 多进程会话管理模块
3
+
4
+ 本模块实现了多进程架构,将每个用户的 OracleSession 运行在独立的工作进程中。
5
+ 这样可以确保重计算任务不会阻塞主进程,多个用户可以并发使用系统。
6
+
7
+ 架构说明:
8
+ 1. ProcessSessionProxy: 主进程中的代理类,提供与 OracleSession 相同的接口
9
+ 2. session_worker_loop: 工作进程中的循环函数,运行实际的 OracleSession
10
+ 3. 进程间通信:通过 multiprocessing.Queue 进行命令和结果的传递
11
+ 4. 视频帧同步:工作进程产生的新帧通过 stream_queue 推送到主进程,由后台线程同步到代理的本地缓存
12
+ """
13
+ import multiprocessing
14
+ import queue
15
+ import threading
16
+ import time
17
+ import traceback
18
+ import numpy as np
19
+ import sys
20
+ import os
21
+
22
+ # 添加父目录到路径(逻辑复制自 oracle_logic.py)
23
+ current_dir = os.path.dirname(os.path.abspath(__file__))
24
+ parent_dir = os.path.dirname(current_dir)
25
+ if parent_dir not in sys.path:
26
+ sys.path.insert(0, parent_dir)
27
+
28
+ from oracle_logic import OracleSession, DEFAULT_DATASET_ROOT
29
+
30
+ # Import ScrewPlanFailure for exception handling
31
+ try:
32
+ from robomme.robomme_env.utils.planner_fail_safe import ScrewPlanFailure
33
+ except ImportError:
34
+ # Fallback if import fails
35
+ ScrewPlanFailure = RuntimeError
36
+
37
+ # Custom exception for screw plan failures (to be caught in gradio_callbacks)
38
+ class ScrewPlanFailureError(RuntimeError):
39
+ """Exception raised when screw plan fails, to be caught and displayed via gr.Info popup"""
40
+ pass
41
+
42
+ # 定义命令常量
43
+ CMD_LOAD_EPISODE = "load_episode"
44
+ CMD_UPDATE_OBSERVATION = "update_observation"
45
+ CMD_GET_PIL_IMAGE = "get_pil_image"
46
+ CMD_EXECUTE_ACTION = "execute_action"
47
+ CMD_GET_REFERENCE_ACTION = "get_reference_action"
48
+ CMD_CLOSE = "close"
49
+
50
+ def _sanitize_options(options):
51
+ """
52
+ 清理选项数据,移除不可序列化的项(如 'solve' 函数)
53
+
54
+ 在跨进程通信时,需要确保所有数据都可以被 pickle 序列化。
55
+ raw_solve_options 中包含的 'solve' 函数无法序列化,需要移除。
56
+ 'available' 字段可能是复杂对象,需要转换为简单的布尔值。
57
+
58
+ Args:
59
+ options: 原始选项列表
60
+
61
+ Returns:
62
+ list: 清理后的选项列表
63
+ """
64
+ clean_opts = []
65
+ if not options:
66
+ return clean_opts
67
+ for opt in options:
68
+ clean_opt = opt.copy()
69
+ if "solve" in clean_opt:
70
+ del clean_opt["solve"]
71
+ if "available" in clean_opt:
72
+ # Only keep truthiness for UI logic
73
+ clean_opt["available"] = bool(clean_opt["available"])
74
+ clean_opts.append(clean_opt)
75
+ return clean_opts
76
+
77
+ def session_worker_loop(cmd_queue, result_queue, stream_queue, dataset_root, gui_render):
78
+ """
79
+ 工作进程主循环
80
+
81
+ 此函数在工作进程中运行,负责:
82
+ 1. 初始化 OracleSession 实例
83
+ 2. 监听来自主进程的命令(通过 cmd_queue)
84
+ 3. 执行命令并返回结果(通过 result_queue)
85
+ 4. 监控视频帧变化,将新帧推送到流队列(通过 stream_queue)
86
+ 5. 处理异常和清理资源
87
+
88
+ Args:
89
+ cmd_queue: 命令队列,主进程发送命令到此队列
90
+ result_queue: 结果队列,工作进程返回命令执行结果到此队列
91
+ stream_queue: 流队列,工作进程推送新视频帧到此队列
92
+ dataset_root: 数据集根目录路径
93
+ gui_render: 是否使用GUI渲染模式
94
+ """
95
+ session = None
96
+ try:
97
+ session = OracleSession(dataset_root=dataset_root, gui_render=gui_render)
98
+ session.stream_frame_callback = lambda frames: stream_queue.put({"base": frames, "wrist": []})
99
+
100
+ while True:
101
+ try:
102
+ # Check for commands
103
+ cmd_data = cmd_queue.get(timeout=0.1)
104
+ except queue.Empty:
105
+ continue
106
+
107
+ cmd = cmd_data["cmd"]
108
+ args = cmd_data.get("args", [])
109
+ kwargs = cmd_data.get("kwargs", {})
110
+
111
+ if cmd == CMD_CLOSE:
112
+ if session:
113
+ session.close()
114
+ break
115
+
116
+ elif cmd == CMD_LOAD_EPISODE:
117
+ # 加载环境episode
118
+ res = session.load_episode(*args, **kwargs)
119
+
120
+ # 更新帧索引跟踪(用于增量同步)
121
+ session.last_base_frame_idx = len(session.base_frames)
122
+ session.last_wrist_frame_idx = len(session.wrist_frames)
123
+
124
+ # 获取演示状态(从 DemonstrationWrapper 获取)
125
+ is_demonstration = False
126
+ if session.env:
127
+ is_demonstration = getattr(session.env, 'current_task_demonstration', False)
128
+
129
+ # 构建状态更新(完整同步,因为这是加载操作)
130
+ state_update = {
131
+ "env_id": session.env_id,
132
+ "episode_idx": session.episode_idx,
133
+ "language_goal": session.language_goal,
134
+ "difficulty": session.difficulty,
135
+ "seed": session.seed,
136
+ "demonstration_frames": session.demonstration_frames,
137
+ "base_frames": session.base_frames, # 加载时完整同步
138
+ "wrist_frames": session.wrist_frames, # 加载时完整同步
139
+ "available_options": session.available_options,
140
+ "raw_solve_options": _sanitize_options(session.raw_solve_options),
141
+ "seg_vis": session.seg_vis,
142
+ "is_demonstration": is_demonstration,
143
+ "non_demonstration_task_length": session.non_demonstration_task_length # 同步非demonstration任务长度
144
+ }
145
+ result_queue.put({"status": "success", "result": res, "state": state_update})
146
+
147
+ elif cmd == CMD_EXECUTE_ACTION:
148
+ # 执行动作(重计算任务)
149
+ try:
150
+ res = session.execute_action(*args, **kwargs)
151
+ except ScrewPlanFailure as e:
152
+ # 捕获 ScrewPlanFailure 并作为特殊状态传递到主进程,用于显示弹窗
153
+ result_queue.put({"status": "screw_plan_failure", "message": str(e)})
154
+ continue
155
+ except Exception as e:
156
+ # 捕获所有其他异常并传递到主进程,用于显示弹窗
157
+ result_queue.put({"status": "execution_error", "message": str(e)})
158
+ continue
159
+
160
+ # 增量帧同步:只发送新增的帧
161
+ new_base = session.base_frames[session.last_base_frame_idx:]
162
+ new_wrist = session.wrist_frames[session.last_wrist_frame_idx:]
163
+ streamed_count = int(getattr(session, "_execute_streamed_frame_count", 0) or 0)
164
+ # Frames already pushed by stream_frame_callback during solve() should not be sent twice.
165
+ if streamed_count > 0 and new_base:
166
+ if streamed_count >= len(new_base):
167
+ new_base = []
168
+ else:
169
+ new_base = new_base[streamed_count:]
170
+
171
+ # 更新帧索引
172
+ session.last_base_frame_idx = len(session.base_frames)
173
+ session.last_wrist_frame_idx = len(session.wrist_frames)
174
+
175
+ # 如果有新帧,推送到流队列
176
+ if new_base or new_wrist:
177
+ stream_queue.put({"base": new_base, "wrist": new_wrist})
178
+
179
+ # 获取演示状态(从 DemonstrationWrapper 获取)
180
+ is_demonstration = False
181
+ if session.env:
182
+ is_demonstration = getattr(session.env, 'current_task_demonstration', False)
183
+
184
+ # 构建状态更新(只更新选项和分割视图,帧通过流队列同步)
185
+ state_update = {
186
+ "available_options": session.available_options,
187
+ "raw_solve_options": _sanitize_options(session.raw_solve_options),
188
+ "seg_vis": session.seg_vis,
189
+ "is_demonstration": is_demonstration
190
+ }
191
+ result_queue.put({"status": "success", "result": res, "state": state_update})
192
+
193
+ elif cmd == CMD_GET_PIL_IMAGE:
194
+ res = session.get_pil_image(*args, **kwargs)
195
+ result_queue.put({"status": "success", "result": res})
196
+
197
+ elif cmd == CMD_UPDATE_OBSERVATION:
198
+ # 更新观察(获取当前环境状态)
199
+ res = session.update_observation(*args, **kwargs)
200
+
201
+ # 增量帧同步
202
+ new_base = session.base_frames[session.last_base_frame_idx:]
203
+ new_wrist = session.wrist_frames[session.last_wrist_frame_idx:]
204
+
205
+ # 更新帧索引
206
+ session.last_base_frame_idx = len(session.base_frames)
207
+ session.last_wrist_frame_idx = len(session.wrist_frames)
208
+
209
+ # 如果有新帧,推送到流队列
210
+ if new_base or new_wrist:
211
+ stream_queue.put({"base": new_base, "wrist": new_wrist})
212
+
213
+ # 获取演示状态(从 DemonstrationWrapper 获取)
214
+ is_demonstration = False
215
+ if session.env:
216
+ is_demonstration = getattr(session.env, 'current_task_demonstration', False)
217
+
218
+ # 构建状态更新
219
+ state_update = {
220
+ "available_options": session.available_options,
221
+ "raw_solve_options": _sanitize_options(session.raw_solve_options),
222
+ "seg_vis": session.seg_vis,
223
+ "is_demonstration": is_demonstration
224
+ }
225
+ result_queue.put({"status": "success", "result": res, "state": state_update})
226
+
227
+ elif cmd == CMD_GET_REFERENCE_ACTION:
228
+ res = session.get_reference_action(*args, **kwargs)
229
+ result_queue.put({"status": "success", "result": res})
230
+
231
+ else:
232
+ result_queue.put({"status": "error", "message": f"Unknown command: {cmd}"})
233
+
234
+ except Exception as e:
235
+ traceback.print_exc()
236
+ result_queue.put({"status": "fatal", "message": str(e)})
237
+
238
+
239
+ class ProcessSessionProxy:
240
+ """
241
+ 进程会话代理类
242
+
243
+ 此类在主进程中运行,提供与 OracleSession 相同的接口。
244
+ 所有方法调用都会被转发到工作进程中的实际 OracleSession 实例。
245
+
246
+ 主要功能:
247
+ 1. 启动和管理工作进程
248
+ 2. 通过队列与工作进程通信
249
+ 3. 维护本地状态缓存(从工作进程同步)
250
+ 4. 后台线程实时同步视频帧
251
+ """
252
+
253
+ def __init__(self, dataset_root=DEFAULT_DATASET_ROOT, gui_render=False):
254
+ """
255
+ 初始化代理对象
256
+
257
+ Args:
258
+ dataset_root: 数据集根目录路径
259
+ gui_render: 是否使用GUI渲染模式
260
+ """
261
+ # 使用 spawn 上下文以获得更清晰的进程隔离
262
+ ctx = multiprocessing.get_context("spawn")
263
+
264
+ # 创建进程间通信队列
265
+ self.cmd_queue = ctx.Queue() # 命令队列:主进程 -> 工作进程
266
+ self.result_queue = ctx.Queue() # 结果队列:工作进程 -> 主进程
267
+ self.stream_queue = ctx.Queue() # 流队列:工作进程 -> 主进程(视频帧)
268
+
269
+ # 启动工作进程
270
+ self.process = ctx.Process(
271
+ target=session_worker_loop,
272
+ args=(self.cmd_queue, self.result_queue, self.stream_queue, dataset_root, gui_render),
273
+ daemon=True
274
+ )
275
+ self.process.start()
276
+
277
+ # 本地状态缓存(从工作进程同步)
278
+ self.env_id = None
279
+ self.episode_idx = None
280
+ self.language_goal = ""
281
+ self.difficulty = None
282
+ self.seed = None
283
+ self.demonstration_frames = []
284
+ self.base_frames = [] # 由后台同步线程持续更新
285
+ self.wrist_frames = [] # 由后台同步线程持续更新
286
+ self.available_options = []
287
+ self.raw_solve_options = []
288
+ self.seg_vis = None
289
+ self.is_demonstration = False # 演示模式标志
290
+ self.non_demonstration_task_length = None # 从工作进程同步的非demonstration任务长度
291
+
292
+ # 帧同步线程:从流队列接收新帧并更新本地缓存
293
+ self.stop_sync = False
294
+ self.sync_thread = threading.Thread(target=self._sync_loop, daemon=True)
295
+ self.sync_thread.start()
296
+
297
+ def _sync_loop(self):
298
+ """
299
+ 后台线程循环:从流队列消费视频帧并更新本地缓存
300
+
301
+ 此线程持续运行,实时接收工作进程推送的新视频帧,
302
+ 并将其追加到本地的 base_frames 和 wrist_frames 列表中。
303
+ UI 刷新逻辑会直接从代理的本地缓存读取帧数据。
304
+ """
305
+ while not self.stop_sync:
306
+ try:
307
+ # Use a short timeout to check stop_sync frequently
308
+ frames = self.stream_queue.get(timeout=0.1)
309
+ new_base = frames.get("base", [])
310
+ new_wrist = frames.get("wrist", [])
311
+
312
+ # Append to local lists
313
+ if new_base:
314
+ self.base_frames.extend(new_base)
315
+ if new_wrist:
316
+ self.wrist_frames.extend(new_wrist)
317
+ except queue.Empty:
318
+ continue
319
+ except Exception:
320
+ break
321
+
322
+ def _send_cmd(self, cmd, *args, **kwargs):
323
+ """
324
+ 发送命令到工作进程并等待结果
325
+
326
+ Args:
327
+ cmd: 命令名称
328
+ *args: 位置参数
329
+ **kwargs: 关键字参数
330
+
331
+ Returns:
332
+ 命令执行结果
333
+
334
+ Raises:
335
+ RuntimeError: 工作进程返回错误或致命错误
336
+ TimeoutError: 工作进程超时(600秒)
337
+ """
338
+ # 发送命令到工作进程
339
+ self.cmd_queue.put({"cmd": cmd, "args": args, "kwargs": kwargs})
340
+ try:
341
+ # 等待结果(重任务如加载/执行可能需要较长时间,设置600秒超时)
342
+ res = self.result_queue.get(timeout=600)
343
+
344
+ # 检查错误状态并转换为异常,以便在 gradio_callbacks 中捕获并显示弹窗
345
+ if res.get("status") == "screw_plan_failure":
346
+ raise ScrewPlanFailureError(f"screw plan failed: {res.get('message', 'unknown error')}")
347
+ if res.get("status") == "execution_error":
348
+ raise RuntimeError(f"Execution error: {res.get('message', 'unknown error')}")
349
+ if res.get("status") == "fatal":
350
+ raise RuntimeError(f"工作进程致命错误: {res.get('message')}")
351
+ if res.get("status") == "error":
352
+ raise RuntimeError(f"命令执行错误: {res.get('message')}")
353
+
354
+ # 更新本地状态缓存(如果工作进程返回了状态更新)
355
+ if "state" in res:
356
+ state = res["state"]
357
+ for k, v in state.items():
358
+ if k in ["base_frames", "wrist_frames"]:
359
+ # 对于帧数据:只有在显式发送时才替换(如加载时)
360
+ # 否则由同步循环处理增量更新
361
+ if v is not None:
362
+ setattr(self, k, v)
363
+ else:
364
+ # 其他状态直接更新
365
+ setattr(self, k, v)
366
+
367
+ return res.get("result")
368
+ except queue.Empty:
369
+ raise TimeoutError("工作进程超时")
370
+
371
+ def load_episode(self, env_id, episode_idx):
372
+ """
373
+ 加载环境episode(在工作进程中执行)
374
+
375
+ Args:
376
+ env_id: 环境ID
377
+ episode_idx: episode索引
378
+
379
+ Returns:
380
+ tuple: (PIL.Image, str) 图像和状态消息
381
+ """
382
+ return self._send_cmd(CMD_LOAD_EPISODE, env_id, episode_idx)
383
+
384
+ def execute_action(self, action_idx, click_coords):
385
+ """
386
+ 执行动作(在工作进程中执行,重计算任务)
387
+
388
+ Args:
389
+ action_idx: 动作索引
390
+ click_coords: 点击坐标 (x, y) 或 None
391
+
392
+ Returns:
393
+ tuple: (PIL.Image, str, bool) 图像、状态消息、是否完成
394
+ """
395
+ return self._send_cmd(CMD_EXECUTE_ACTION, action_idx, click_coords)
396
+
397
+ def get_pil_image(self, use_segmented=True):
398
+ """
399
+ 获取PIL图像(在工作进程中执行)
400
+
401
+ Args:
402
+ use_segmented: 是否使用分割视图
403
+
404
+ Returns:
405
+ PIL.Image: 图像对象
406
+ """
407
+ return self._send_cmd(CMD_GET_PIL_IMAGE, use_segmented=use_segmented)
408
+
409
+ def update_observation(self, use_segmentation=True):
410
+ """
411
+ 更新观察(在工作进程中执行)
412
+
413
+ Args:
414
+ use_segmentation: 是否使用分割视图
415
+
416
+ Returns:
417
+ tuple: (PIL.Image, str) 图像和状态消息
418
+ """
419
+ return self._send_cmd(CMD_UPDATE_OBSERVATION, use_segmentation=use_segmentation)
420
+
421
+ def get_reference_action(self):
422
+ """
423
+ 获取当前步参考动作与坐标(在工作进程中执行)
424
+
425
+ Returns:
426
+ dict: 参考动作结果
427
+ """
428
+ return self._send_cmd(CMD_GET_REFERENCE_ACTION)
429
+
430
+ def close(self):
431
+ """
432
+ 关闭代理并清理资源
433
+
434
+ 此方法会:
435
+ 1. 停止帧同步线程
436
+ 2. 发送关闭命令到工作进程
437
+ 3. 等待工作进程优雅退出(最多1秒)
438
+ 4. 如果进程仍在运行,强制终止
439
+ """
440
+ self.stop_sync = True
441
+ try:
442
+ self.cmd_queue.put({"cmd": CMD_CLOSE})
443
+ except:
444
+ pass
445
+ # 等待工作进程优雅退出
446
+ self.process.join(timeout=1)
447
+ if self.process.is_alive():
448
+ self.process.terminate()
gradio-web/scripts/run_background.sh ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 后台运行脚本 - 统一管理 HistoryBench Gradio 服务器
3
+ # 使用方法: bash run_background.sh [start|stop|status|restart|logs]
4
+
5
+ # 获取脚本所在目录,然后定位到 gradio 目录
6
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
7
+ GRADIO_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
8
+ cd "$GRADIO_DIR"
9
+
10
+ # Micromamba 环境路径
11
+ MICROMAMBA_ENV="/data/hongzefu/maniskillenv1120"
12
+
13
+ # 日志目录(放在 gradio 目录中)
14
+ LOG_DIR="$GRADIO_DIR/logs"
15
+ PID_FILE="$GRADIO_DIR/server.pid"
16
+ # 合并日志文件(包含所有输出:标准输出 + 错误输出)
17
+ LOG_FILE="$LOG_DIR/server.log"
18
+
19
+ # 创建日志目录
20
+ mkdir -p "$LOG_DIR"
21
+
22
+ # 函数:启动服务器
23
+ start_server() {
24
+ # 检查是否已经在运行
25
+ if [ -f "$PID_FILE" ]; then
26
+ OLD_PID=$(cat "$PID_FILE")
27
+ if ps -p "$OLD_PID" > /dev/null 2>&1; then
28
+ echo "⚠️ 服务器已经在运行中 (PID: $OLD_PID)"
29
+ echo " 如需重启,请使用: bash $0 restart"
30
+ return 1
31
+ else
32
+ echo "清理旧的 PID 文件..."
33
+ rm -f "$PID_FILE"
34
+ fi
35
+ fi
36
+
37
+ # 检查 micromamba 环境是否存在
38
+ if [ ! -d "$MICROMAMBA_ENV" ]; then
39
+ echo "❌ 错误: Micromamba 环境不存在: $MICROMAMBA_ENV"
40
+ return 1
41
+ fi
42
+
43
+ # 检查 Python 可执行文件
44
+ PYTHON_EXE="$MICROMAMBA_ENV/bin/python"
45
+ if [ ! -f "$PYTHON_EXE" ]; then
46
+ echo "❌ 错误: Python 可执行文件不存在: $PYTHON_EXE"
47
+ return 1
48
+ fi
49
+
50
+ # 启动服务器
51
+ echo "🚀 正在后台启动服务器..."
52
+ echo " Micromamba 环境: $MICROMAMBA_ENV"
53
+ echo " Python 可执行文件: $PYTHON_EXE"
54
+ echo " 工作目录: $GRADIO_DIR"
55
+ echo " 日志文件: $LOG_FILE"
56
+ echo ""
57
+
58
+ # 使用环境中的 Python 直接运行服务器
59
+ # 使用 nohup 在后台运行,并将所有输出重定向到日志文件
60
+ # 设置环境变量以确保使用环境中的包和正确的输出行为
61
+ # 使用 unbuffered 模式 (-u) 和 PYTHONUNBUFFERED=1 确保输出立即写入,不缓冲
62
+ # 使用 stdbuf -oL -eL 确保行缓冲输出(如果可用)
63
+ # 将标准输出和错误输出合并到一个文件 (2>&1),这样所有日志都会完整显示
64
+ # 使用 >> 追加模式,确保日志不会覆盖
65
+
66
+ # 检查是否可以使用 stdbuf(Linux 系统通常有)
67
+ if command -v stdbuf >/dev/null 2>&1; then
68
+ # 使用 stdbuf 确保行缓冲输出,所有 print 和日志都会实时写入
69
+ nohup env PATH="$MICROMAMBA_ENV/bin:$PATH" \
70
+ PYTHONUNBUFFERED=1 \
71
+ PYTHONIOENCODING=utf-8 \
72
+ stdbuf -oL -eL "$PYTHON_EXE" -u "$GRADIO_DIR/main.py" >> "$LOG_FILE" 2>&1 &
73
+ else
74
+ # 如果没有 stdbuf,使用 Python 的 unbuffered 模式
75
+ # 仍然设置所有必要的环境变量确保输出实时写入
76
+ nohup env PATH="$MICROMAMBA_ENV/bin:$PATH" \
77
+ PYTHONUNBUFFERED=1 \
78
+ PYTHONIOENCODING=utf-8 \
79
+ "$PYTHON_EXE" -u "$GRADIO_DIR/main.py" >> "$LOG_FILE" 2>&1 &
80
+ fi
81
+
82
+ # 保存进程ID
83
+ SERVER_PID=$!
84
+ echo $SERVER_PID > "$PID_FILE"
85
+
86
+ # 等待一下,检查进程是否成功启动
87
+ sleep 3
88
+
89
+ if ps -p "$SERVER_PID" > /dev/null 2>&1; then
90
+ echo "✅ 服务器已成功在后台启动!"
91
+ echo " PID: $SERVER_PID"
92
+ echo " Micromamba 环境: $MICROMAMBA_ENV"
93
+ echo " 日志文件: $LOG_FILE"
94
+ echo ""
95
+ echo "📋 常用命令:"
96
+ echo " 查看实时日志: bash $0 logs"
97
+ echo " 查看状态: bash $0 status"
98
+ echo " 停止服务器: bash $0 stop"
99
+ echo ""
100
+ echo "💡 提示:"
101
+ echo " - 所有输出都保存在 $LOG_FILE(包括所有 print、uvicorn 日志等)"
102
+ echo " - 日志实时写入,与前台运行完全一致"
103
+ echo " - 即使关闭SSH连接,服务器也会继续运行"
104
+ echo " - 使用 PYTHONUNBUFFERED=1 和 stdbuf 确保日志实时写入"
105
+ echo ""
106
+ echo "🌐 服务器启动后,请查看日志文件获取访问地址:"
107
+ echo " bash $0 logs"
108
+ return 0
109
+ else
110
+ echo "❌ 服务器启动失败!"
111
+ echo " 请查看完整日志: $LOG_FILE"
112
+ rm -f "$PID_FILE"
113
+ return 1
114
+ fi
115
+ }
116
+
117
+ # 函数:停止服务器
118
+ stop_server() {
119
+ # 检查PID文件是否存在
120
+ if [ ! -f "$PID_FILE" ]; then
121
+ echo "⚠️ 未找到 PID 文件,服务器可能未运行"
122
+ return 1
123
+ fi
124
+
125
+ # 读取PID
126
+ SERVER_PID=$(cat "$PID_FILE")
127
+
128
+ # 检查进程是否存在
129
+ if ! ps -p "$SERVER_PID" > /dev/null 2>&1; then
130
+ echo "⚠️ 进程 $SERVER_PID 不存在,可能已经停止"
131
+ rm -f "$PID_FILE"
132
+ return 1
133
+ fi
134
+
135
+ # 停止进程
136
+ echo "🛑 正在停止服务器 (PID: $SERVER_PID)..."
137
+ kill "$SERVER_PID"
138
+
139
+ # 等待进程结束(最多等待10秒)
140
+ for i in {1..10}; do
141
+ if ! ps -p "$SERVER_PID" > /dev/null 2>&1; then
142
+ echo "✅ 服务器已成功停止"
143
+ rm -f "$PID_FILE"
144
+ return 0
145
+ fi
146
+ sleep 1
147
+ done
148
+
149
+ # 如果还在运行,强制杀死
150
+ if ps -p "$SERVER_PID" > /dev/null 2>&1; then
151
+ echo "⚠️ 进程未响应,强制终止..."
152
+ kill -9 "$SERVER_PID"
153
+ sleep 1
154
+ if ! ps -p "$SERVER_PID" > /dev/null 2>&1; then
155
+ echo "✅ 服务器已强制停止"
156
+ rm -f "$PID_FILE"
157
+ return 0
158
+ else
159
+ echo "❌ 无法停止服务器,请手动检查"
160
+ return 1
161
+ fi
162
+ fi
163
+ }
164
+
165
+ # 函数:查看服务器状态
166
+ status_server() {
167
+ echo "📊 服务器状态信息"
168
+ echo "=========================================="
169
+
170
+ # 检查PID文件
171
+ if [ ! -f "$PID_FILE" ]; then
172
+ echo "❌ 服务器未运行 (未找到 PID 文件)"
173
+ return 1
174
+ fi
175
+
176
+ SERVER_PID=$(cat "$PID_FILE")
177
+
178
+ # 检查进程是否存在
179
+ if ps -p "$SERVER_PID" > /dev/null 2>&1; then
180
+ echo "✅ 服务器正在运行"
181
+ echo " PID: $SERVER_PID"
182
+ echo ""
183
+
184
+ # 显示进程信息
185
+ echo "📋 进程信息:"
186
+ ps -p "$SERVER_PID" -o pid,ppid,user,%cpu,%mem,etime,cmd
187
+ echo ""
188
+
189
+ # 显示日志文件信息
190
+ if [ -f "$LOG_FILE" ]; then
191
+ LOG_SIZE=$(du -h "$LOG_FILE" | cut -f1)
192
+ LOG_LINES=$(wc -l < "$LOG_FILE" 2>/dev/null || echo "0")
193
+ echo "📄 日志文件信息:"
194
+ echo " 文件: $LOG_FILE"
195
+ echo " 大小: $LOG_SIZE"
196
+ echo " 行数: $LOG_LINES"
197
+ echo " 最后修改: $(stat -c %y "$LOG_FILE" 2>/dev/null || stat -f %Sm "$LOG_FILE" 2>/dev/null || echo "未知")"
198
+ fi
199
+
200
+ # 显示最后几行日志
201
+ if [ -f "$LOG_FILE" ]; then
202
+ echo ""
203
+ echo "📝 最近的日志输出 (最后10行):"
204
+ echo "----------------------------------------"
205
+ tail -n 10 "$LOG_FILE"
206
+ fi
207
+ return 0
208
+ else
209
+ echo "❌ 服务器未运行 (进程 $SERVER_PID 不存在)"
210
+ echo " 清理 PID 文件..."
211
+ rm -f "$PID_FILE"
212
+ return 1
213
+ fi
214
+ }
215
+
216
+ # 函数:重启服务器
217
+ restart_server() {
218
+ echo "🔄 正在重启服务器..."
219
+ stop_server
220
+ sleep 2
221
+ start_server
222
+ }
223
+
224
+ # 函数:查看日志
225
+ view_logs() {
226
+ if [ ! -f "$LOG_FILE" ]; then
227
+ echo "⚠️ 日志文件不存在: $LOG_FILE"
228
+ return 1
229
+ fi
230
+
231
+ echo "📝 查看服务器日志 (按 Ctrl+C 退出)"
232
+ echo "=========================================="
233
+ tail -f "$LOG_FILE"
234
+ }
235
+
236
+ # 函数:显示帮助信息
237
+ show_help() {
238
+ echo "HistoryBench 服务器管理脚本"
239
+ echo ""
240
+ echo "使用方法:"
241
+ echo " bash $0 [命令]"
242
+ echo ""
243
+ echo "可用命令:"
244
+ echo " start - 启动服务器(后台运行)"
245
+ echo " stop - 停止服务器"
246
+ echo " status - 查看服务器状态"
247
+ echo " restart - 重启服务器"
248
+ echo " logs - 查看实时日志(按 Ctrl+C 退出)"
249
+ echo " help - 显示此帮助信息"
250
+ echo ""
251
+ echo "示例:"
252
+ echo " bash $0 start # 启动服务器"
253
+ echo " bash $0 status # 查看状态"
254
+ echo " bash $0 logs # 查看日志"
255
+ echo " bash $0 stop # 停止服务器"
256
+ echo ""
257
+ }
258
+
259
+ # 主逻辑:根据命令行参数执行相应操作
260
+ case "${1:-help}" in
261
+ start)
262
+ start_server
263
+ ;;
264
+ stop)
265
+ stop_server
266
+ ;;
267
+ status)
268
+ status_server
269
+ ;;
270
+ restart)
271
+ restart_server
272
+ ;;
273
+ logs)
274
+ view_logs
275
+ ;;
276
+ help|--help|-h)
277
+ show_help
278
+ ;;
279
+ *)
280
+ echo "❌ 未知命令: $1"
281
+ echo ""
282
+ show_help
283
+ exit 1
284
+ ;;
285
+ esac
286
+
287
+ exit $?
gradio-web/scripts/后台运行说明.md ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HistoryBench 后台运行说明
2
+
3
+ 本文档说明如何在后台运行 HistoryBench Gradio 服务器。
4
+
5
+ ## 命令概览
6
+
7
+ ```bash
8
+ # 启动服务器
9
+ bash scripts/run_background.sh start
10
+
11
+ # 查看状态
12
+ bash scripts/run_background.sh status
13
+
14
+ # 查看日志
15
+ bash scripts/run_background.sh logs
16
+
17
+ # 停止服务器
18
+ bash scripts/run_background.sh stop
19
+
20
+ # 重启服务器
21
+ bash scripts/run_background.sh restart
22
+
23
+ # 查看帮助
24
+ bash scripts/run_background.sh help
25
+ ```
26
+
27
+ ## 快速开始
28
+
29
+ ### 启动服务器(后台运行)
30
+
31
+ ```bash
32
+ cd /data/hongzefu/historybench-v5.6.7/gradio
33
+ bash scripts/run_background.sh start
34
+ ```
35
+
36
+ ### 查看服务器状态
37
+
38
+ ```bash
39
+ bash scripts/run_background.sh status
40
+ ```
41
+
42
+ ### 查看实时日志
43
+
44
+ ```bash
45
+ # 方法1: 使用脚本命令(推荐)
46
+ bash scripts/run_background.sh logs
47
+
48
+ # 方法2: 直接使用 tail 命令
49
+ tail -f logs/server.log
50
+
51
+ # 方法3: 查看最后100行日志
52
+ tail -n 100 logs/server.log
53
+ ```
54
+
55
+ ### 停止服务器
56
+
57
+ ```bash
58
+ bash scripts/run_background.sh stop
59
+ ```
60
+
61
+ ### 重启服务器
62
+
63
+ ```bash
64
+ bash scripts/run_background.sh restart
65
+ ```
66
+
67
+ ### 查看帮助信息
68
+
69
+ ```bash
70
+ bash scripts/run_background.sh help
71
+ ```
72
+
73
+ ## 脚本功能说明
74
+
75
+ ### run_background.sh
76
+
77
+ 统一的后台运行管理脚本,支持以下命令:
78
+
79
+ #### 可用命令
80
+
81
+ - **`start`** - 启动服务器(后台运行)
82
+ - 自动检测服务器是否已在运行
83
+ - 使用指定的 Micromamba 环境运行
84
+ - 将所有输出(标准输出和错误输出)保存到日志文件
85
+ - 使用 `nohup` 确保即使关闭 SSH 连接也能继续运行
86
+ - 自动保存进程 ID 到 `server.pid` 文件
87
+ - 启动后自动验证进程是否成功运行
88
+
89
+ - **`stop`** - 停止服务器
90
+ - 优雅地停止服务器进程
91
+ - 如果进程无响应,会自动强制终止
92
+ - 自动清理 PID 文件
93
+
94
+ - **`status`** - 查看服务器状态
95
+ - 显示进程信息(PID、CPU、内存、运行时间等)
96
+ - 显示日志文件信息(大小、行数、最后修改时间)
97
+ - 显示最近的日志输出(最后10行)
98
+
99
+ - **`restart`** - 重启服务器
100
+ - 先停止服务器,然后重新启动
101
+
102
+ - **`logs`** - 查看实时日志
103
+ - 实时显示日志输出(类似 `tail -f`)
104
+ - 按 Ctrl+C 退出
105
+
106
+ - **`help`** - 显示帮助信息
107
+
108
+ ### 配置信息
109
+
110
+ 脚本使用以下默认配置:
111
+
112
+ - **Micromamba 环境**: `/data/hongzefu/maniskillenv1114`
113
+ - **工作目录**: `/data/hongzefu/historybench-v5.6.7/gradio`
114
+ - **日志目录**: `logs/`
115
+ - **日志文件**: `logs/server.log`
116
+ - **PID 文件**: `server.pid`
117
+
118
+ ### 修改配置
119
+
120
+ 如果需要修改配置,请编辑 `run_background.sh` 脚本中的以下变量:
121
+
122
+ ```bash
123
+ # Micromamba 环境路径
124
+ MICROMAMBA_ENV="/data/hongzefu/maniskillenv1114"
125
+
126
+ # 日志目录
127
+ LOG_DIR="$GRADIO_DIR/logs"
128
+
129
+ # PID 文件
130
+ PID_FILE="$GRADIO_DIR/server.pid"
131
+ ```
132
+
133
+ ## 使用场景
134
+
135
+ ### 场景1: 首次启动服务器
136
+
137
+ ```bash
138
+ cd /data/hongzefu/historybench-v5.6.7/gradio
139
+ bash scripts/run_background.sh start
140
+ ```
141
+
142
+ 启动后,查看日志获取服务器访问地址:
143
+
144
+ ```bash
145
+ bash scripts/run_background.sh logs
146
+ ```
147
+
148
+ ### 场景2: 检查服务器是否运行
149
+
150
+ ```bash
151
+ bash scripts/run_background.sh status
152
+ ```
153
+
154
+ 或者手动检查:
155
+
156
+ ```bash
157
+ # 检查 PID 文件
158
+ cat server.pid
159
+
160
+ # 检查进程
161
+ ps aux | grep main.py
162
+ ```
163
+
164
+ ### 场景3: 重启服务器
165
+
166
+ ```bash
167
+ bash scripts/run_background.sh restart
168
+ ```
169
+
170
+ ### 场景4: 查看错误日志
171
+
172
+ ```bash
173
+ # 查看最后50行日志
174
+ tail -n 50 logs/server.log
175
+
176
+ # 搜索错误信息
177
+ grep -i error logs/server.log
178
+
179
+ # 搜索警告信息
180
+ grep -i warning logs/server.log
181
+
182
+ # 或使用脚本查看实时日志
183
+ bash scripts/run_background.sh logs
184
+ ```
185
+
186
+ ## 注意事项
187
+
188
+ 1. **脚本位置**: 脚本位于 `scripts/run_background.sh`,可以从任何位置运行,使用绝对路径或相对路径
189
+
190
+ 2. **日志文件位置**: 日志文件保存在 `gradio/logs/server.log`,包含所有标准输出和错误输出
191
+
192
+ 3. **进程持久化**: 使用 `nohup` 确保即使关闭 SSH 连接,服务器也会继续运行
193
+
194
+ 4. **日志实时写入**: 使用 `PYTHONUNBUFFERED=1` 和 `-u` 参数确保日志实时写入,方便调试
195
+
196
+ 5. **端口冲突**: 如果端口被占用,服务器会自动查找下一个可用端口(从 7860 开始)
197
+
198
+ 6. **环境变量**: 脚本会自动设置 `PATH` 环境变量,确保使用 Micromamba 环境中的 Python 和依赖包
199
+
200
+ ## 故障排查
201
+
202
+ ### 问题1: 服务器启动失败
203
+
204
+ **检查步骤**:
205
+
206
+ 1. 查看日志文件:
207
+ ```bash
208
+ tail -n 100 logs/server.log
209
+ ```
210
+
211
+ 2. 检查 Micromamba 环境是否存在:
212
+ ```bash
213
+ ls -la /data/hongzefu/maniskillenv1114
214
+ ```
215
+
216
+ 3. 检查 Python 可执行文件:
217
+ ```bash
218
+ /data/hongzefu/maniskillenv1114/bin/python --version
219
+ ```
220
+
221
+ ### 问题2: 端口已被占用
222
+
223
+ 服务器会自动查找可用端口,但如果你想手动指定端口,可以修改 `main.py` 中的 `find_free_port()` 函数。
224
+
225
+ ### 问题3: 无法��问服务器
226
+
227
+ 1. 检查服务器是否正在运行:
228
+ ```bash
229
+ bash scripts/run_background.sh status
230
+ ```
231
+
232
+ 2. 查看日志获取正确的访问地址:
233
+ ```bash
234
+ tail -n 50 logs/server.log | grep -E "(http://|SERVER STARTING)"
235
+ ```
236
+
237
+ 3. 检查防火墙设置(如果需要从外部访问)
238
+
239
+ ### 问题4: 进程意外退出
240
+
241
+ 1. 查看日志文件中的错误信息:
242
+ ```bash
243
+ grep -i error logs/server.log | tail -n 20
244
+ ```
245
+
246
+ 2. 检查系统资源(内存、磁盘空间等)
247
+
248
+ 3. 检查依赖包是否完整安装
249
+
250
+ ## 相关脚本
251
+
252
+ - `run_background.sh` - 统一管理脚本
253
+ - 支持 `start`、`stop`、`status`、`restart`、`logs`、`help` 命令
254
+ - 功能完整,使用方便
255
+
256
+ ## 技术细节
257
+
258
+ ### 后台运行机制
259
+
260
+ 脚本使用以下技术实现后台运行:
261
+
262
+ 1. **nohup**: 防止进程在终端关闭时被终止
263
+ 2. **重定向输出**: `>> "$LOG_FILE" 2>&1` 将所有输出保存到日志文件
264
+ 3. **PID 文件**: 保存进程 ID,方便后续管理
265
+ 4. **环境变量**: 设置 `PATH` 和 `PYTHONUNBUFFERED` 确保正确运行
266
+
267
+ ### 日志管理
268
+
269
+ - **所有输出都会被捕获**:包括所有 `print()` 语句、uvicorn 的访问日志、错误日志等
270
+ - 日志文件位置:`logs/server.log`
271
+ - 实时写入机制:
272
+ - 使用 `PYTHONUNBUFFERED=1` 和 `-u` 参数确保 Python 输出不缓冲
273
+ - 使用 `stdbuf -oL -eL` 确保行缓冲输出(如果系统支持)
274
+ - 所有标准输出和错误输出都重定向到日志文件(`2>&1`)
275
+ - 日志内容:
276
+ - 所有 `print()` 输出
277
+ - uvicorn 服务器日志(启动信息、请求日志等)
278
+ - FastAPI 应用日志
279
+ - Gradio 界面日志
280
+ - 错误和异常信息
281
+ - 日志文件会持续增长,建议定期清理或使用日志轮转工具
282
+
283
+ ## 联系与支持
284
+
285
+ 如有问题,请查看:
286
+ - 日志文件: `logs/server.log`
287
+ - 项目文档: 项目根目录的 README 文件
288
+
gradio-web/state_manager.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 状态管理模块
3
+ 管理所有全局状态和Session生命周期
4
+
5
+ 本模块负责:
6
+ 1. 创建和管理 ProcessSessionProxy 实例(每个用户一个)
7
+ 2. 存储任务索引等UI状态
8
+ 3. 提供线程安全的访问接口
9
+ 4. 清理会话资源(当用户重复登录时,自动清理旧会话的进程和状态)
10
+
11
+ 注意:GLOBAL_SESSIONS 中存储的是 ProcessSessionProxy 对象,而不是 OracleSession。
12
+ 实际的 OracleSession 运行在独立的工作进程中,通过代理对象进行通信。
13
+ 当同一用户第二次登录时,系统会自动清理旧会话的所有资源(进程、RAM、VRAM、状态数据等)。
14
+ """
15
+ import uuid
16
+ import threading
17
+ import traceback
18
+ import time
19
+ from process_session import ProcessSessionProxy
20
+
21
+ # --- 全局会话存储 ---
22
+ # 存储所有用户的 ProcessSessionProxy 实例
23
+ # 每个用户登录时会创建一个代理,代理会启动一个独立的工作进程运行 OracleSession
24
+ GLOBAL_SESSIONS = {}
25
+
26
+ # --- 任务索引存储(用于进度显示) ---
27
+ # 存储每个session的任务索引和总任务数,用于直接读取Progress
28
+ TASK_INDEX_MAP = {} # {uid: {"task_index": int, "total_tasks": int}}
29
+
30
+ # --- UI阶段存储 ---
31
+ # 存储每个session的UI阶段:"watching_demo" 或 "executing_task"
32
+ UI_PHASE_MAP = {} # {uid: "watching_demo" | "executing_task"}
33
+
34
+ # --- Execute 次数跟踪 ---
35
+ # 跟踪每个会话每个任务的 execute 次数
36
+ # 键格式: "{uid}:{env_id}:{episode_idx}"
37
+ EXECUTE_COUNTS = {} # {task_key: count}
38
+
39
+ # --- 任务开始时间跟踪 ---
40
+ # 跟踪每个任务的开始时间
41
+ # 键格式: "{uid}:{env_id}:{episode_idx}"
42
+ # 值: ISO 格式的时间字符串
43
+ TASK_START_TIMES = {} # {task_key: "2025-12-28T14:01:25.372278"}
44
+
45
+ # --- Session活动时间跟踪 ---
46
+ # 跟踪每个session的最后活动时间(用于超时检测)
47
+ SESSION_LAST_ACTIVITY = {} # {uid: timestamp} - timestamp是time.time()返回的浮点数
48
+ SESSION_TIMEOUT_WARNED = {} # {uid: bool} - 跟踪已警告的session,避免重复警告
49
+
50
+ # --- 播放按钮状态跟踪 ---
51
+ # 跟踪每个session的播放按钮是否已被点击(用于execute按钮条件控制)
52
+ PLAY_BUTTON_CLICKED = {} # {uid: bool} - 跟踪播放按钮是否已被点击
53
+
54
+ # 线程锁,用于保护全局状态的访问
55
+ _state_lock = threading.Lock()
56
+
57
+
58
+ def get_session(uid):
59
+ """
60
+ 获取指定uid的session(ProcessSessionProxy实例)
61
+
62
+ Args:
63
+ uid: 会话ID
64
+
65
+ Returns:
66
+ ProcessSessionProxy: 代理对象,提供与 OracleSession 相同的接口
67
+ """
68
+ with _state_lock:
69
+ return GLOBAL_SESSIONS.get(uid)
70
+
71
+
72
+ def create_session():
73
+ """
74
+ 创建新的session并返回uid
75
+
76
+ 此函数会:
77
+ 1. 生成一个唯一的会话ID(UUID)
78
+ 2. 创建一个 ProcessSessionProxy 实例
79
+ 3. ProcessSessionProxy 会自动启动一个独立的工作进程运行 OracleSession
80
+ 4. 将代理对象存储到 GLOBAL_SESSIONS 中
81
+ 5. 初始化最后活动时间为当前时间
82
+
83
+ Returns:
84
+ str: 新创建的会话ID
85
+ """
86
+ uid = str(uuid.uuid4())
87
+ session = ProcessSessionProxy()
88
+ with _state_lock:
89
+ GLOBAL_SESSIONS[uid] = session
90
+ SESSION_LAST_ACTIVITY[uid] = time.time()
91
+ return uid
92
+
93
+
94
+ def get_task_index(uid):
95
+ """获取任务索引信息"""
96
+ with _state_lock:
97
+ return TASK_INDEX_MAP.get(uid)
98
+
99
+
100
+ def set_task_index(uid, task_index, total_tasks):
101
+ """设置任务索引信息"""
102
+ with _state_lock:
103
+ TASK_INDEX_MAP[uid] = {
104
+ "task_index": task_index,
105
+ "total_tasks": total_tasks
106
+ }
107
+
108
+
109
+ def get_ui_phase(uid):
110
+ """获取UI阶段"""
111
+ with _state_lock:
112
+ return UI_PHASE_MAP.get(uid, "watching_demo") # 默认为观看示范阶段
113
+
114
+
115
+ def set_ui_phase(uid, phase):
116
+ """设置UI阶段
117
+
118
+ Args:
119
+ uid: session ID
120
+ phase: "watching_demo" 或 "executing_task"
121
+ """
122
+ with _state_lock:
123
+ UI_PHASE_MAP[uid] = phase
124
+
125
+
126
+ def reset_ui_phase(uid):
127
+ """重置UI阶段为初始阶段(watching_demo)"""
128
+ with _state_lock:
129
+ UI_PHASE_MAP[uid] = "watching_demo"
130
+
131
+
132
+ def set_play_button_clicked(uid, clicked=True):
133
+ """
134
+ 设置播放按钮是否已被点击
135
+
136
+ Args:
137
+ uid: 会话ID
138
+ clicked: 是否已被点击(默认 True)
139
+ """
140
+ with _state_lock:
141
+ PLAY_BUTTON_CLICKED[uid] = clicked
142
+
143
+
144
+ def get_play_button_clicked(uid):
145
+ """
146
+ 获取播放按钮是否已被点击
147
+
148
+ Args:
149
+ uid: 会话ID
150
+
151
+ Returns:
152
+ bool: 如果已被点击返回 True,否则返回 False
153
+ """
154
+ with _state_lock:
155
+ return PLAY_BUTTON_CLICKED.get(uid, False)
156
+
157
+
158
+ def reset_play_button_clicked(uid):
159
+ """
160
+ 重置播放按钮点击状态
161
+
162
+ Args:
163
+ uid: 会话ID
164
+ """
165
+ with _state_lock:
166
+ if uid in PLAY_BUTTON_CLICKED:
167
+ del PLAY_BUTTON_CLICKED[uid]
168
+
169
+
170
+ def _get_task_key(uid, env_id, episode_idx):
171
+ """��成任务键(用于跟踪 execute 次数)"""
172
+ return f"{uid}:{env_id}:{episode_idx}"
173
+
174
+
175
+ def get_execute_count(uid, env_id, episode_idx):
176
+ """
177
+ 获取指定任务的 execute 次数
178
+
179
+ Args:
180
+ uid: 会话ID
181
+ env_id: 环境ID
182
+ episode_idx: Episode索引
183
+
184
+ Returns:
185
+ int: execute 次数,如果任务不存在则返回 0
186
+ """
187
+ with _state_lock:
188
+ task_key = _get_task_key(uid, env_id, episode_idx)
189
+ return EXECUTE_COUNTS.get(task_key, 0)
190
+
191
+
192
+ def increment_execute_count(uid, env_id, episode_idx):
193
+ """
194
+ 增加指定任务的 execute 次数
195
+
196
+ Args:
197
+ uid: 会话ID
198
+ env_id: 环境ID
199
+ episode_idx: Episode索引
200
+
201
+ Returns:
202
+ int: 增加后的 execute 次数
203
+ """
204
+ with _state_lock:
205
+ task_key = _get_task_key(uid, env_id, episode_idx)
206
+ current_count = EXECUTE_COUNTS.get(task_key, 0)
207
+ EXECUTE_COUNTS[task_key] = current_count + 1
208
+ return EXECUTE_COUNTS[task_key]
209
+
210
+
211
+ def reset_execute_count(uid, env_id, episode_idx):
212
+ """
213
+ 重置指定任务的 execute 次数为 0
214
+
215
+ Args:
216
+ uid: 会话ID
217
+ env_id: 环境ID
218
+ episode_idx: Episode索引
219
+ """
220
+ with _state_lock:
221
+ task_key = _get_task_key(uid, env_id, episode_idx)
222
+ EXECUTE_COUNTS[task_key] = 0
223
+
224
+
225
+ def get_task_start_time(uid, env_id, episode_idx):
226
+ """
227
+ 获取指定任务的开始时间
228
+
229
+ Args:
230
+ uid: 会话ID
231
+ env_id: 环境ID
232
+ episode_idx: Episode索引
233
+
234
+ Returns:
235
+ str: ISO 格式的时间字符串,如果任务不存在则返回 None
236
+ """
237
+ with _state_lock:
238
+ task_key = _get_task_key(uid, env_id, episode_idx)
239
+ return TASK_START_TIMES.get(task_key)
240
+
241
+
242
+ def set_task_start_time(uid, env_id, episode_idx, start_time):
243
+ """
244
+ 设置指定任务的开始时间
245
+
246
+ Args:
247
+ uid: 会话ID
248
+ env_id: 环境ID
249
+ episode_idx: Episode索引
250
+ start_time: ISO 格式的时间字符串
251
+ """
252
+ with _state_lock:
253
+ task_key = _get_task_key(uid, env_id, episode_idx)
254
+ TASK_START_TIMES[task_key] = start_time
255
+
256
+
257
+ def clear_task_start_time(uid, env_id, episode_idx):
258
+ """
259
+ 清除指定任务的开始时间记录
260
+
261
+ Args:
262
+ uid: 会话ID
263
+ env_id: 环境ID
264
+ episode_idx: Episode索引
265
+ """
266
+ with _state_lock:
267
+ task_key = _get_task_key(uid, env_id, episode_idx)
268
+ if task_key in TASK_START_TIMES:
269
+ del TASK_START_TIMES[task_key]
270
+
271
+
272
+ def cleanup_session(uid):
273
+ """
274
+ 清理指定会话的所有资源
275
+
276
+ 此函数会清理与指定 uid 相关的所有资源:
277
+ 1. 关闭 ProcessSessionProxy(会终止工作进程,释放 RAM/VRAM)
278
+ 2. 从 GLOBAL_SESSIONS 中移除
279
+ 3. 清理所有相关的状态数据(任务索引、UI阶段)
280
+
281
+ Args:
282
+ uid: 要清理的会话ID
283
+ """
284
+ if not uid:
285
+ return
286
+
287
+ with _state_lock:
288
+ # 1. 关闭 ProcessSessionProxy(终止工作进程)
289
+ session = GLOBAL_SESSIONS.get(uid)
290
+ if session:
291
+ try:
292
+ print(f"Cleaning up session {uid}: closing ProcessSessionProxy...")
293
+ session.close()
294
+ print(f"Session {uid}: ProcessSessionProxy closed successfully")
295
+ except Exception as e:
296
+ print(f"Error closing ProcessSessionProxy for {uid}: {e}")
297
+ traceback.print_exc()
298
+
299
+ # 2. 从 GLOBAL_SESSIONS 中移除
300
+ if uid in GLOBAL_SESSIONS:
301
+ del GLOBAL_SESSIONS[uid]
302
+ print(f"Session {uid}: removed from GLOBAL_SESSIONS")
303
+
304
+ # 3. 清理任务索引
305
+ if uid in TASK_INDEX_MAP:
306
+ del TASK_INDEX_MAP[uid]
307
+ print(f"Session {uid}: task index cleaned up")
308
+
309
+ # 4. 清理UI阶段
310
+ if uid in UI_PHASE_MAP:
311
+ del UI_PHASE_MAP[uid]
312
+
313
+ # 清理播放按钮状态
314
+ if uid in PLAY_BUTTON_CLICKED:
315
+ del PLAY_BUTTON_CLICKED[uid]
316
+ print(f"Session {uid}: UI phase cleaned up")
317
+
318
+ # 5. 清理活动时间跟踪
319
+ if uid in SESSION_LAST_ACTIVITY:
320
+ del SESSION_LAST_ACTIVITY[uid]
321
+ print(f"Session {uid}: last activity time cleaned up")
322
+
323
+ # 6. 清理超时警告标志
324
+ if uid in SESSION_TIMEOUT_WARNED:
325
+ del SESSION_TIMEOUT_WARNED[uid]
326
+ print(f"Session {uid}: timeout warning flag cleaned up")
327
+
328
+ # 注意:不清理 EXECUTE_COUNTS,因为它是按任务跟踪的,不是按 session 跟踪的
329
+ # 如果需要清理,应该在任务切换时调用 reset_execute_count
330
+
331
+ print(f"Session {uid}: all resources cleaned up successfully")
332
+
333
+
334
+ def update_session_activity(uid):
335
+ """
336
+ 更新指定session的最后活动时间为当前时间
337
+
338
+ Args:
339
+ uid: 会话ID
340
+ """
341
+ with _state_lock:
342
+ SESSION_LAST_ACTIVITY[uid] = time.time()
343
+ # 如果之前被警告过,清除警告标志
344
+ if uid in SESSION_TIMEOUT_WARNED:
345
+ del SESSION_TIMEOUT_WARNED[uid]
346
+
347
+
348
+ def get_session_activity(uid):
349
+ """
350
+ 获取指定session的最后活动时间
351
+
352
+ Args:
353
+ uid: 会话ID
354
+
355
+ Returns:
356
+ float: 最后活动时间戳(time.time()),如果session不存在则返回None
357
+ """
358
+ with _state_lock:
359
+ return SESSION_LAST_ACTIVITY.get(uid)
360
+
361
+
362
+ def check_and_cleanup_timeout_sessions():
363
+ """
364
+ 检查所有session,清理超时的session
365
+
366
+ 此函数会:
367
+ 1. 检查所有活跃session的最后活动时间
368
+ 2. 如果超过SESSION_TIMEOUT秒且未警告,设置警告标志并记录日志
369
+ 3. 如果已警告且超过警告时间(再等5秒),调用cleanup_session清理资源
370
+ """
371
+ from config import SESSION_TIMEOUT
372
+
373
+ current_time = time.time()
374
+ timeout_sessions = []
375
+ warned_sessions_to_cleanup = []
376
+
377
+ with _state_lock:
378
+ # 获取所有活跃的session uid
379
+ active_uids = list(GLOBAL_SESSIONS.keys())
380
+
381
+ # 在锁外检查,避免长时间持有锁
382
+ for uid in active_uids:
383
+ with _state_lock:
384
+ last_activity = SESSION_LAST_ACTIVITY.get(uid)
385
+ is_warned = SESSION_TIMEOUT_WARNED.get(uid, False)
386
+
387
+ if last_activity is None:
388
+ # 如果session没有活动记录,跳过(可能是刚创建的)
389
+ continue
390
+
391
+ elapsed = current_time - last_activity
392
+
393
+ if elapsed > SESSION_TIMEOUT:
394
+ if not is_warned:
395
+ # 首次超时,设置警告标志
396
+ with _state_lock:
397
+ SESSION_TIMEOUT_WARNED[uid] = True
398
+ timeout_sessions.append(uid)
399
+ print(f"Session {uid}: 超时警告 - 已超过 {SESSION_TIMEOUT} 秒未活动")
400
+ elif elapsed > SESSION_TIMEOUT + 5:
401
+ # 已警告且再等5秒仍未活动,标记为需要清理
402
+ warned_sessions_to_cleanup.append(uid)
403
+
404
+ # 清理超时的session
405
+ for uid in warned_sessions_to_cleanup:
406
+ print(f"Session {uid}: 超时清理 - 已超过 {SESSION_TIMEOUT + 5} 秒未活动,开始清理资源")
407
+ cleanup_session(uid)
408
+ # cleanup_session内部会清理SESSION_LAST_ACTIVITY和SESSION_TIMEOUT_WARNED
409
+
410
+
411
+ # 后台监控线程相关变量
412
+ _timeout_monitor_thread = None
413
+ _timeout_monitor_running = False
414
+ _timeout_monitor_lock = threading.Lock()
415
+
416
+
417
+ def _timeout_monitor_loop():
418
+ """
419
+ 后台监控线程的主循环
420
+ 每5秒检查一次所有session的超时状态
421
+ """
422
+ global _timeout_monitor_running
423
+ while _timeout_monitor_running:
424
+ try:
425
+ check_and_cleanup_timeout_sessions()
426
+ except Exception as e:
427
+ print(f"Error in timeout monitor loop: {e}")
428
+ traceback.print_exc()
429
+
430
+ # 每5秒检查一次
431
+ for _ in range(50): # 5秒 = 50 * 0.1秒
432
+ if not _timeout_monitor_running:
433
+ break
434
+ time.sleep(0.1)
435
+
436
+
437
+ def start_timeout_monitor():
438
+ """
439
+ 启动后台监控线程
440
+ 在应用启动时调用此函数
441
+ """
442
+ global _timeout_monitor_thread, _timeout_monitor_running
443
+
444
+ with _timeout_monitor_lock:
445
+ if _timeout_monitor_running:
446
+ print("Timeout monitor is already running")
447
+ return
448
+
449
+ _timeout_monitor_running = True
450
+ _timeout_monitor_thread = threading.Thread(
451
+ target=_timeout_monitor_loop,
452
+ daemon=True,
453
+ name="SessionTimeoutMonitor"
454
+ )
455
+ _timeout_monitor_thread.start()
456
+ print("Session timeout monitor started")
457
+
458
+
459
+ def stop_timeout_monitor():
460
+ """
461
+ 停止后台监控线程
462
+ 在应用关闭时调用此函数
463
+ """
464
+ global _timeout_monitor_thread, _timeout_monitor_running
465
+
466
+ with _timeout_monitor_lock:
467
+ if not _timeout_monitor_running:
468
+ return
469
+
470
+ _timeout_monitor_running = False
471
+ if _timeout_monitor_thread:
472
+ _timeout_monitor_thread.join(timeout=2.0)
473
+ print("Session timeout monitor stopped")
gradio-web/test/conftest.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+
10
+ def _find_repo_root(start_file: str | Path) -> Path:
11
+ path = Path(start_file).resolve()
12
+ cur = path if path.is_dir() else path.parent
13
+ for candidate in (cur, *cur.parents):
14
+ if (candidate / "pyproject.toml").exists():
15
+ return candidate
16
+ raise FileNotFoundError(f"Could not find repo root from {path}")
17
+
18
+
19
+ REPO_ROOT = _find_repo_root(__file__)
20
+ SRC_ROOT = REPO_ROOT / "src"
21
+ GRADIO_ROOT = REPO_ROOT / "gradio"
22
+
23
+ for p in (str(REPO_ROOT), str(SRC_ROOT), str(GRADIO_ROOT)):
24
+ if p not in sys.path:
25
+ sys.path.insert(0, p)
26
+
27
+
28
+ @pytest.fixture(scope="session")
29
+ def repo_root() -> Path:
30
+ return REPO_ROOT
31
+
32
+
33
+ @pytest.fixture
34
+ def reload_module():
35
+ def _reload(name: str):
36
+ module = importlib.import_module(name)
37
+ return importlib.reload(module)
38
+
39
+ return _reload
gradio-web/test/test_episode98_removed_behavior.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+
5
+
6
+ def test_load_next_task_wrapper_treats_episode98_as_normal(monkeypatch, reload_module):
7
+ callbacks = reload_module("gradio_callbacks")
8
+
9
+ expected = ("SENTINEL",)
10
+
11
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
12
+ monkeypatch.setattr(
13
+ callbacks.user_manager,
14
+ "next_episode_same_env",
15
+ lambda uid: {"is_done_all": False, "current_task": {"env_id": "BinFill", "episode_idx": 98}},
16
+ )
17
+ monkeypatch.setattr(callbacks, "_load_status_task", lambda uid, status: expected)
18
+
19
+ result = callbacks.load_next_task_wrapper("uid1")
20
+
21
+ assert result == expected
22
+
23
+
24
+ def test_restart_episode_wrapper_reloads_same_episode(monkeypatch, reload_module):
25
+ callbacks = reload_module("gradio_callbacks")
26
+
27
+ load_calls = []
28
+ expected = ("RESTARTED",)
29
+
30
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
31
+ monkeypatch.setattr(
32
+ callbacks.user_manager,
33
+ "get_session_status",
34
+ lambda uid: {"is_done_all": False, "current_task": {"env_id": "BinFill", "episode_idx": 98}},
35
+ )
36
+
37
+ def _fake_load_status_task(uid, status):
38
+ load_calls.append((uid, status))
39
+ return expected
40
+
41
+ monkeypatch.setattr(callbacks, "_load_status_task", _fake_load_status_task)
42
+
43
+ result = callbacks.restart_episode_wrapper("uid1")
44
+
45
+ assert len(load_calls) == 1
46
+ assert load_calls[0][1]["current_task"] == {"env_id": "BinFill", "episode_idx": 98}
47
+ assert result == expected
48
+
49
+
50
+ def test_restart_episode_wrapper_missing_status_returns_login_failed(monkeypatch, reload_module):
51
+ callbacks = reload_module("gradio_callbacks")
52
+
53
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
54
+ monkeypatch.setattr(callbacks.user_manager, "get_session_status", lambda uid: None)
55
+
56
+ result = callbacks.restart_episode_wrapper("uid1")
57
+
58
+ assert "Failed to restart episode" in result[3]
59
+
60
+
61
+ def test_execute_step_failed_episode98_still_advances(monkeypatch, reload_module):
62
+ callbacks = reload_module("gradio_callbacks")
63
+
64
+ class _FakeSession:
65
+ def __init__(self):
66
+ self.env_id = "BinFill"
67
+ self.episode_idx = 98
68
+ self.base_frames = []
69
+ self.raw_solve_options = [{"available": False}]
70
+ self.available_options = [("run", 0)]
71
+ self.difficulty = "hard"
72
+ self.language_goal = "goal"
73
+ self.seed = 123
74
+ self.non_demonstration_task_length = None
75
+
76
+ def update_observation(self, use_segmentation=False):
77
+ return None
78
+
79
+ def get_pil_image(self, use_segmented=False):
80
+ return "IMG"
81
+
82
+ def execute_action(self, option_idx, click_coords):
83
+ return "IMG", "FAILED", True
84
+
85
+ fake_session = _FakeSession()
86
+ complete_calls = []
87
+
88
+ monkeypatch.setattr(callbacks, "get_session_activity", lambda uid: time.time())
89
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
90
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: fake_session)
91
+ monkeypatch.setattr(callbacks, "increment_execute_count", lambda uid, env_id, episode_idx: 1)
92
+
93
+ def _fake_complete_current_task(*args, **kwargs):
94
+ payload = dict(kwargs)
95
+ if args:
96
+ payload["uid"] = args[0]
97
+ complete_calls.append(payload)
98
+ return {"is_done_all": False, "current_task": {"env_id": "MoveCube", "episode_idx": 7}}
99
+
100
+ monkeypatch.setattr(callbacks.user_manager, "complete_current_task", _fake_complete_current_task)
101
+
102
+ result = callbacks.execute_step("uid1", 0, "No need for coordinates")
103
+
104
+ assert len(complete_calls) == 1
105
+ assert complete_calls[0]["episode_idx"] == 98
106
+ assert complete_calls[0]["status"] == "failed"
107
+ assert result[2] == "BinFill (Episode 98)"
gradio-web/test/test_execute_stream_frames.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ class _FakeUnwrapped:
7
+ def __init__(self):
8
+ self.segmentation_id_map = {}
9
+ self.elapsed_steps = 0
10
+
11
+ def evaluate(self, solve_complete_eval=False):
12
+ return {"success": False, "fail": False}
13
+
14
+
15
+ class _FakeEnv:
16
+ def __init__(self):
17
+ self.unwrapped = _FakeUnwrapped()
18
+ self._step_idx = 0
19
+ self._last_obs = None
20
+
21
+ def step(self, action):
22
+ self._step_idx += 1
23
+ self.unwrapped.elapsed_steps = self._step_idx
24
+ frame = np.full((8, 8, 3), self._step_idx, dtype=np.uint8)
25
+ obs = {"front_rgb_list": frame}
26
+ self._last_obs = obs
27
+ return obs, 0.0, False, False, {}
28
+
29
+
30
+ def test_execute_action_captures_intermediate_front_frames(monkeypatch, reload_module):
31
+ oracle_logic = reload_module("oracle_logic")
32
+
33
+ monkeypatch.setattr(
34
+ oracle_logic,
35
+ "_fetch_segmentation",
36
+ lambda env: np.zeros((1, 8, 8), dtype=np.int64),
37
+ )
38
+ monkeypatch.setattr(
39
+ oracle_logic,
40
+ "_build_solve_options",
41
+ lambda env, planner, selected_target, env_id: [
42
+ {"label": "a", "action": "run", "solve": lambda: [env.step(None) for _ in range(3)]}
43
+ ],
44
+ )
45
+
46
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
47
+ session.env = _FakeEnv()
48
+ session.planner = object()
49
+ session.env_id = "BinFill"
50
+ session.color_map = {}
51
+
52
+ _img, status, done = session.execute_action(0, None)
53
+
54
+ # Captured during solve(): 1,2,3. update_observation may append the last frame again.
55
+ pixel_trace = [int(frame[0, 0, 0]) for frame in session.base_frames]
56
+ assert pixel_trace[:3] == [1, 2, 3]
57
+ assert len(pixel_trace) >= 3
58
+ assert status.startswith("Executing: a")
59
+ assert done is False
gradio-web/test/test_live_obs_refresh.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+
7
+ class _FakeSession:
8
+ def __init__(self, frames, env_id="BinFill"):
9
+ self.base_frames = frames
10
+ self.env_id = env_id
11
+
12
+
13
+ def test_refresh_live_obs_skips_when_not_execution_phase(monkeypatch, reload_module):
14
+ callbacks = reload_module("gradio_callbacks")
15
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: _FakeSession([]))
16
+
17
+ update = callbacks.refresh_live_obs("uid-1", "action_keypoint")
18
+
19
+ assert update.get("__type__") == "update"
20
+ assert "value" not in update
21
+
22
+
23
+ def test_refresh_live_obs_updates_image_from_latest_frame(monkeypatch, reload_module):
24
+ callbacks = reload_module("gradio_callbacks")
25
+ frame0 = np.zeros((8, 8, 3), dtype=np.uint8)
26
+ frame1 = np.full((8, 8, 3), 11, dtype=np.uint8)
27
+ frame2 = np.full((8, 8, 3), 22, dtype=np.uint8)
28
+ frame3 = np.full((8, 8, 3), 33, dtype=np.uint8)
29
+ frame4 = np.full((8, 8, 3), 44, dtype=np.uint8)
30
+ session = _FakeSession([frame0])
31
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: session)
32
+
33
+ # Reset queue state at execute start (cursor anchored at current base_frames length).
34
+ callbacks.switch_to_execute_phase("uid-2")
35
+ session.base_frames.extend([frame1, frame2, frame3, frame4])
36
+
37
+ # Downsample x2 + FIFO => first frame1, then frame3.
38
+ update1 = callbacks.refresh_live_obs("uid-2", "execution_playback")
39
+ update2 = callbacks.refresh_live_obs("uid-2", "execution_playback")
40
+ update3 = callbacks.refresh_live_obs("uid-2", "execution_playback")
41
+
42
+ assert update1.get("__type__") == "update"
43
+ assert update1.get("interactive") is False
44
+ assert isinstance(update1.get("value"), Image.Image)
45
+ assert update1["value"].getpixel((0, 0)) == (11, 11, 11)
46
+
47
+ assert update2.get("__type__") == "update"
48
+ assert update2.get("interactive") is False
49
+ assert isinstance(update2.get("value"), Image.Image)
50
+ assert update2["value"].getpixel((0, 0)) == (33, 33, 33)
51
+
52
+ # Queue drained, so no further value update.
53
+ assert update3.get("__type__") == "update"
54
+ assert "value" not in update3
55
+
56
+
57
+ def test_switch_phase_keeps_live_obs_visible_and_toggles_interactive(reload_module):
58
+ callbacks = reload_module("gradio_callbacks")
59
+
60
+ to_exec = callbacks.switch_to_execute_phase("uid-3")
61
+ assert len(to_exec) == 6
62
+ assert to_exec[0].get("interactive") is False
63
+ assert to_exec[4].get("interactive") is False
64
+ assert to_exec[5].get("interactive") is False
65
+
66
+ to_action = callbacks.switch_to_action_phase()
67
+ assert len(to_action) == 6
68
+ assert to_action[0].get("interactive") is True
69
+ assert to_action[4].get("interactive") is True
70
+ assert to_action[5].get("interactive") is True
gradio-web/test/test_option_label_format.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ class _FakeUnwrapped:
7
+ def __init__(self):
8
+ self.segmentation_id_map = {}
9
+
10
+
11
+ class _FakeEnv:
12
+ def __init__(self):
13
+ self.unwrapped = _FakeUnwrapped()
14
+ self.frames = [np.zeros((8, 8, 3), dtype=np.uint8)]
15
+ self.wrist_frames = []
16
+
17
+
18
+ class _FakeObsWrapperEnv:
19
+ def __init__(self, front_rgb_list, wrist_rgb_list):
20
+ self.unwrapped = _FakeUnwrapped()
21
+ self._last_obs = {
22
+ "front_rgb_list": front_rgb_list,
23
+ "wrist_rgb_list": wrist_rgb_list,
24
+ }
25
+
26
+
27
+
28
+ def test_available_options_use_label_plus_action(monkeypatch, reload_module):
29
+ oracle_logic = reload_module("oracle_logic")
30
+
31
+ monkeypatch.setattr(
32
+ oracle_logic,
33
+ "_fetch_segmentation",
34
+ lambda env: np.zeros((1, 8, 8), dtype=np.int64),
35
+ )
36
+ monkeypatch.setattr(
37
+ oracle_logic,
38
+ "_build_solve_options",
39
+ lambda env, planner, selected_target, env_id: [
40
+ {"label": "a", "action": "pick up the cube", "available": [1]},
41
+ {"label": "b", "action": "put it down", "available": []},
42
+ ],
43
+ )
44
+
45
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
46
+ session.env = _FakeEnv()
47
+ session.planner = object()
48
+ session.env_id = "BinFill"
49
+ session.color_map = {}
50
+
51
+ _img, msg = session.update_observation()
52
+
53
+ assert msg == "Ready"
54
+ assert session.available_options == [
55
+ ("a. pick up the cube", 0),
56
+ ("b. put it down", 1),
57
+ ]
58
+ assert session.raw_solve_options[0]["label"] == "a"
59
+
60
+
61
+ def test_update_observation_no_seg_vis_base_fallback(monkeypatch, reload_module):
62
+ oracle_logic = reload_module("oracle_logic")
63
+
64
+ seg_vis = np.zeros((6, 6, 3), dtype=np.uint8)
65
+ seg_vis[:, :, 0] = 10 # B
66
+ seg_vis[:, :, 1] = 20 # G
67
+ seg_vis[:, :, 2] = 30 # R
68
+
69
+ monkeypatch.setattr(
70
+ oracle_logic,
71
+ "_fetch_segmentation",
72
+ lambda env: np.zeros((1, 6, 6), dtype=np.int64),
73
+ )
74
+ monkeypatch.setattr(
75
+ oracle_logic,
76
+ "_prepare_segmentation_visual",
77
+ lambda seg, color_map, hw: (seg_vis, np.zeros((6, 6), dtype=np.int64)),
78
+ )
79
+ monkeypatch.setattr(
80
+ oracle_logic,
81
+ "_build_solve_options",
82
+ lambda env, planner, selected_target, env_id: [],
83
+ )
84
+
85
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
86
+ session.env = type(
87
+ "_NoFrameEnv",
88
+ (),
89
+ {"unwrapped": _FakeUnwrapped(), "frames": [], "wrist_frames": []},
90
+ )()
91
+ session.planner = object()
92
+ session.env_id = "BinFill"
93
+ session.color_map = {}
94
+
95
+ _img, msg = session.update_observation(use_segmentation=False)
96
+
97
+ assert msg == "Ready"
98
+ assert len(session.base_frames) == 0
99
+
100
+ pil_img = session.get_pil_image(use_segmented=False)
101
+ assert pil_img.size == (255, 255)
102
+
103
+
104
+ def test_update_observation_uses_only_front_rgb_list(monkeypatch, reload_module):
105
+ oracle_logic = reload_module("oracle_logic")
106
+
107
+ monkeypatch.setattr(
108
+ oracle_logic,
109
+ "_fetch_segmentation",
110
+ lambda env: np.zeros((1, 8, 8), dtype=np.int64),
111
+ )
112
+ monkeypatch.setattr(
113
+ oracle_logic,
114
+ "_build_solve_options",
115
+ lambda env, planner, selected_target, env_id: [],
116
+ )
117
+
118
+ f1 = np.full((8, 8, 3), 11, dtype=np.uint8)
119
+ f2 = np.full((8, 8, 3), 22, dtype=np.uint8)
120
+
121
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
122
+ session.env = _FakeObsWrapperEnv(front_rgb_list=[f1, f2], wrist_rgb_list=[])
123
+ session.planner = object()
124
+ session.env_id = "BinFill"
125
+ session.color_map = {}
126
+
127
+ _img, msg = session.update_observation(use_segmentation=False)
128
+
129
+ assert msg == "Ready"
130
+ assert len(session.base_frames) == 2
131
+ assert len(session.wrist_frames) == 0
132
+ assert session.base_frames[-1][0, 0, 0] == 22
133
+
134
+
135
+ def test_update_observation_does_not_duplicate_same_last_obs(monkeypatch, reload_module):
136
+ oracle_logic = reload_module("oracle_logic")
137
+
138
+ monkeypatch.setattr(
139
+ oracle_logic,
140
+ "_fetch_segmentation",
141
+ lambda env: np.zeros((1, 8, 8), dtype=np.int64),
142
+ )
143
+ monkeypatch.setattr(
144
+ oracle_logic,
145
+ "_build_solve_options",
146
+ lambda env, planner, selected_target, env_id: [],
147
+ )
148
+
149
+ f1 = np.full((8, 8, 3), 10, dtype=np.uint8)
150
+ f2 = np.full((8, 8, 3), 20, dtype=np.uint8)
151
+ env = _FakeObsWrapperEnv(front_rgb_list=[f1, f2], wrist_rgb_list=[])
152
+
153
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
154
+ session.env = env
155
+ session.planner = object()
156
+ session.env_id = "BinFill"
157
+ session.color_map = {}
158
+
159
+ session.update_observation(use_segmentation=False)
160
+ session.update_observation(use_segmentation=False)
161
+ assert len(session.base_frames) == 2
162
+
163
+ f3 = np.full((8, 8, 3), 30, dtype=np.uint8)
164
+ env._last_obs = {"front_rgb_list": [f3], "wrist_rgb_list": []}
165
+ session.update_observation(use_segmentation=False)
166
+ assert len(session.base_frames) == 3
167
+ assert session.base_frames[-1][0, 0, 0] == 30
168
+
169
+
170
+ def test_update_observation_does_not_fallback_to_env_frames(monkeypatch, reload_module):
171
+ oracle_logic = reload_module("oracle_logic")
172
+
173
+ monkeypatch.setattr(
174
+ oracle_logic,
175
+ "_fetch_segmentation",
176
+ lambda env: np.zeros((1, 8, 8), dtype=np.int64),
177
+ )
178
+ monkeypatch.setattr(
179
+ oracle_logic,
180
+ "_build_solve_options",
181
+ lambda env, planner, selected_target, env_id: [],
182
+ )
183
+
184
+ env = _FakeEnv()
185
+ env.frames = [np.full((8, 8, 3), 99, dtype=np.uint8)]
186
+
187
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
188
+ session.env = env
189
+ session.planner = object()
190
+ session.env_id = "BinFill"
191
+ session.color_map = {}
192
+
193
+ _img, msg = session.update_observation(use_segmentation=False)
194
+
195
+ assert msg == "Ready"
196
+ assert session.base_frames == []
gradio-web/test/test_oracle_builder_integration.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ class _DummyPlanner:
7
+ def __init__(self, *args, **kwargs):
8
+ self.args = args
9
+ self.kwargs = kwargs
10
+
11
+
12
+ class _FakeRobot:
13
+ def __init__(self):
14
+ self.pose = object()
15
+
16
+
17
+ class _FakeAgent:
18
+ def __init__(self):
19
+ self.robot = _FakeRobot()
20
+
21
+
22
+ class _FakeUnwrapped:
23
+ def __init__(self):
24
+ self.agent = _FakeAgent()
25
+ self.segmentation_id_map = {}
26
+
27
+ def evaluate(self, solve_complete_eval=False):
28
+ return {"success": False, "fail": False}
29
+
30
+
31
+ class _FakeEnv:
32
+ def __init__(self):
33
+ self.unwrapped = _FakeUnwrapped()
34
+ self.demonstration_data = {"language goal": "test goal", "frames": ["f1", "f2"]}
35
+ self.non_demonstration_task_length = 7
36
+ self.frames = []
37
+ self.wrist_frames = []
38
+ self.closed = False
39
+
40
+ def reset(self):
41
+ return None
42
+
43
+ def close(self):
44
+ self.closed = True
45
+
46
+
47
+ class _FakeEnvTupleDemo(_FakeEnv):
48
+ def __init__(self):
49
+ super().__init__()
50
+ self.demonstration_data = (
51
+ {"front_rgb_list": ["tuple_f1", "tuple_f2"]},
52
+ None,
53
+ None,
54
+ None,
55
+ {"task_goal": ["tuple goal", "backup goal"]},
56
+ )
57
+
58
+
59
+ class _BuilderSuccess:
60
+ last_init_kwargs = None
61
+
62
+ def __init__(self, **kwargs):
63
+ type(self).last_init_kwargs = kwargs
64
+
65
+ def get_episode_num(self):
66
+ return 3
67
+
68
+ def resolve_episode(self, episode_idx):
69
+ return 123, "hard"
70
+
71
+ def make_env_for_episode(self, episode_idx):
72
+ return _FakeEnv()
73
+
74
+
75
+ class _BuilderTupleDemo(_BuilderSuccess):
76
+ def make_env_for_episode(self, episode_idx):
77
+ return _FakeEnvTupleDemo()
78
+
79
+
80
+ class _BuilderNoMetadata:
81
+ def __init__(self, **kwargs):
82
+ self.kwargs = kwargs
83
+
84
+ def get_episode_num(self):
85
+ return 0
86
+
87
+
88
+ class _BuilderRaiseOnMake:
89
+ def __init__(self, **kwargs):
90
+ self.kwargs = kwargs
91
+
92
+ def get_episode_num(self):
93
+ return 1
94
+
95
+ def resolve_episode(self, episode_idx):
96
+ return None, None
97
+
98
+ def make_env_for_episode(self, episode_idx):
99
+ raise RuntimeError("boom")
100
+
101
+
102
+ def test_load_episode_uses_benchmark_builder(monkeypatch, reload_module):
103
+ oracle_logic = reload_module("oracle_logic")
104
+
105
+ monkeypatch.setenv("ROBOMME_METADATA_ROOT", "/tmp/meta-root")
106
+ monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderSuccess)
107
+ monkeypatch.setattr(oracle_logic, "FailAwarePandaArmMotionPlanningSolver", _DummyPlanner)
108
+ monkeypatch.setattr(oracle_logic, "FailAwarePandaStickMotionPlanningSolver", _DummyPlanner)
109
+ monkeypatch.setattr(oracle_logic.OracleSession, "update_observation", lambda self: ("IMG", "Ready"))
110
+
111
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
112
+ img, msg = session.load_episode("BinFill", 1)
113
+
114
+ assert img == "IMG"
115
+ assert msg == "Ready"
116
+ assert session.env_id == "BinFill"
117
+ assert session.episode_idx == 1
118
+ assert session.seed == 123
119
+ assert session.difficulty == "hard"
120
+ assert session.language_goal == "test goal"
121
+ assert session.demonstration_frames == ["f1", "f2"]
122
+
123
+ init_kwargs = _BuilderSuccess.last_init_kwargs
124
+ assert init_kwargs["dataset"] == "train"
125
+ assert init_kwargs["action_space"] == "joint_angle"
126
+ assert init_kwargs["gui_render"] is False
127
+ assert init_kwargs["max_steps"] == 3000
128
+ assert init_kwargs["override_metadata_path"] == Path("/tmp/meta-root")
129
+
130
+
131
+ def test_load_episode_metadata_missing_returns_stable_error(monkeypatch, reload_module):
132
+ oracle_logic = reload_module("oracle_logic")
133
+
134
+ monkeypatch.setenv("ROBOMME_METADATA_ROOT", "/tmp/custom-metadata")
135
+ monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderNoMetadata)
136
+
137
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
138
+ img, msg = session.load_episode("RouteStick", 0)
139
+
140
+ assert img is None
141
+ assert "Dataset metadata not found or empty" in msg
142
+ assert "record_dataset_RouteStick_metadata.json" in msg
143
+
144
+
145
+ def test_load_episode_out_of_range_returns_stable_error(monkeypatch, reload_module):
146
+ oracle_logic = reload_module("oracle_logic")
147
+
148
+ monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderSuccess)
149
+
150
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
151
+ img, msg = session.load_episode("BinFill", 99)
152
+
153
+ assert img is None
154
+ assert "Episode index out of range" in msg
155
+ assert "valid 0-2" in msg
156
+
157
+
158
+ def test_load_episode_init_failure_is_caught(monkeypatch, reload_module):
159
+ oracle_logic = reload_module("oracle_logic")
160
+
161
+ monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderRaiseOnMake)
162
+
163
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
164
+ img, msg = session.load_episode("BinFill", 0)
165
+
166
+ assert img is None
167
+ assert msg.startswith("Error initializing episode:")
168
+
169
+
170
+ def test_load_episode_supports_tuple_demonstration_data(monkeypatch, reload_module):
171
+ oracle_logic = reload_module("oracle_logic")
172
+
173
+ monkeypatch.setattr(oracle_logic, "BenchmarkEnvBuilder", _BuilderTupleDemo)
174
+ monkeypatch.setattr(oracle_logic, "FailAwarePandaArmMotionPlanningSolver", _DummyPlanner)
175
+ monkeypatch.setattr(oracle_logic, "FailAwarePandaStickMotionPlanningSolver", _DummyPlanner)
176
+ monkeypatch.setattr(oracle_logic.OracleSession, "update_observation", lambda self: ("IMG", "Ready"))
177
+
178
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
179
+ img, msg = session.load_episode("BinFill", 0)
180
+
181
+ assert img == "IMG"
182
+ assert msg == "Ready"
183
+ assert session.language_goal == "tuple goal"
184
+ assert session.demonstration_frames == ["tuple_f1", "tuple_f2"]
gradio-web/test/test_oracle_imports.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ def test_oracle_logic_imports_without_historybench(reload_module):
7
+ oracle_logic = reload_module("oracle_logic")
8
+ assert oracle_logic is not None
9
+
10
+ module_path = Path(oracle_logic.__file__).resolve()
11
+ source = module_path.read_text(encoding="utf-8")
12
+ assert "historybench" not in source
13
+
14
+
15
+ def test_oracle_logic_exports_builder_and_vqa(reload_module):
16
+ oracle_logic = reload_module("oracle_logic")
17
+ assert hasattr(oracle_logic, "BenchmarkEnvBuilder")
18
+ assert hasattr(oracle_logic, "get_vqa_options")
gradio-web/test/test_precheck_execute_inputs.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pytest
4
+
5
+
6
+ class _FakeSession:
7
+ def __init__(self, available=True):
8
+ self.raw_solve_options = [{"available": available}]
9
+
10
+
11
+ def test_precheck_execute_inputs_requires_action(monkeypatch, reload_module):
12
+ callbacks = reload_module("gradio_callbacks")
13
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: _FakeSession(available=False))
14
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
15
+
16
+ with pytest.raises(Exception) as excinfo:
17
+ callbacks.precheck_execute_inputs("uid-1", None, "No need for coordinates")
18
+
19
+ assert "No action selected" in str(excinfo.value)
20
+
21
+
22
+ def test_precheck_execute_inputs_requires_coords_when_option_needs_it(monkeypatch, reload_module):
23
+ callbacks = reload_module("gradio_callbacks")
24
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: _FakeSession(available=True))
25
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
26
+
27
+ with pytest.raises(Exception) as excinfo:
28
+ callbacks.precheck_execute_inputs(
29
+ "uid-1", 0, "please click the keypoint selection image"
30
+ )
31
+
32
+ assert "before execute" in str(excinfo.value)
33
+
34
+
35
+ def test_precheck_execute_inputs_accepts_valid_coords(monkeypatch, reload_module):
36
+ callbacks = reload_module("gradio_callbacks")
37
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: _FakeSession(available=True))
38
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
39
+
40
+ result = callbacks.precheck_execute_inputs("uid-1", 0, "11, 22")
41
+
42
+ assert result is None
43
+
44
+
45
+ def test_precheck_execute_inputs_session_error(monkeypatch, reload_module):
46
+ callbacks = reload_module("gradio_callbacks")
47
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: None)
48
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
49
+
50
+ with pytest.raises(Exception) as excinfo:
51
+ callbacks.precheck_execute_inputs("uid-missing", 0, "1, 2")
52
+
53
+ assert "Session Error" in str(excinfo.value)
gradio-web/test/test_process_session_sanitize.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ def test_sanitize_options_removes_solve_and_boolifies_available(reload_module):
5
+ process_session = reload_module("process_session")
6
+
7
+ raw = [
8
+ {
9
+ "label": "a",
10
+ "action": "pick",
11
+ "available": ["obj1"],
12
+ "solve": lambda: None,
13
+ "extra": 123,
14
+ },
15
+ {
16
+ "label": "b",
17
+ "action": "place",
18
+ "available": [],
19
+ "solve": lambda: None,
20
+ },
21
+ ]
22
+
23
+ cleaned = process_session._sanitize_options(raw)
24
+
25
+ assert len(cleaned) == 2
26
+ assert "solve" not in cleaned[0]
27
+ assert "solve" not in cleaned[1]
28
+ assert cleaned[0]["available"] is True
29
+ assert cleaned[1]["available"] is False
30
+ assert cleaned[0]["label"] == "a"
31
+ assert cleaned[0]["action"] == "pick"
32
+ assert cleaned[0]["extra"] == 123
33
+
34
+
35
+ def test_sanitize_options_handles_empty_input(reload_module):
36
+ process_session = reload_module("process_session")
37
+
38
+ assert process_session._sanitize_options(None) == []
39
+ assert process_session._sanitize_options([]) == []
gradio-web/test/test_reference_action_callbacks.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from PIL import Image
3
+
4
+
5
+ class _FakeSession:
6
+ def __init__(self, reference_payload):
7
+ self._reference_payload = reference_payload
8
+
9
+ def get_reference_action(self):
10
+ return self._reference_payload
11
+
12
+ def get_pil_image(self, use_segmented=True):
13
+ return Image.new("RGB", (24, 24), color=(0, 0, 0))
14
+
15
+
16
+ class _FakeOptionSession:
17
+ def __init__(self):
18
+ self.raw_solve_options = [{"available": [object()]}]
19
+ self.available_options = [("pick", 0)]
20
+
21
+
22
+ def test_on_reference_action_success_updates_option_and_coords(monkeypatch, reload_module):
23
+ callbacks = reload_module("gradio_callbacks")
24
+
25
+ session = _FakeSession(
26
+ {
27
+ "ok": True,
28
+ "option_idx": 2,
29
+ "option_label": "c",
30
+ "option_action": "press the button",
31
+ "need_coords": True,
32
+ "coords_xy": [5, 6],
33
+ "message": "ok",
34
+ }
35
+ )
36
+
37
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
38
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: session)
39
+
40
+ img, option_update, coords_text, log_html = callbacks.on_reference_action("uid-1")
41
+
42
+ assert isinstance(img, Image.Image)
43
+ assert img.getpixel((5, 6)) != (0, 0, 0)
44
+ assert option_update.get("value") == 2
45
+ assert coords_text == "5, 6"
46
+ assert "Ground Truth Action" in log_html
47
+
48
+
49
+ def test_on_reference_action_session_missing(monkeypatch, reload_module):
50
+ callbacks = reload_module("gradio_callbacks")
51
+
52
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
53
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: None)
54
+
55
+ img, option_update, coords_text, log_html = callbacks.on_reference_action("uid-missing")
56
+
57
+ assert img is None
58
+ assert option_update.get("__type__") == "update"
59
+ assert coords_text == "No need for coordinates"
60
+ assert "Session Error" in log_html
61
+
62
+
63
+ def test_on_reference_action_error_message_from_reference(monkeypatch, reload_module):
64
+ callbacks = reload_module("gradio_callbacks")
65
+
66
+ session = _FakeSession({"ok": False, "message": "bad ref"})
67
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
68
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: session)
69
+
70
+ _img, _opt, _coords, log_html = callbacks.on_reference_action("uid-1")
71
+ assert "bad ref" in log_html
72
+
73
+
74
+ def test_on_option_select_keeps_valid_coords_when_option_needs_coords(monkeypatch, reload_module):
75
+ callbacks = reload_module("gradio_callbacks")
76
+
77
+ session = _FakeOptionSession()
78
+ monkeypatch.setattr(callbacks, "update_session_activity", lambda uid: None)
79
+ monkeypatch.setattr(callbacks, "get_session", lambda uid: session)
80
+
81
+ coords_text, img_update = callbacks.on_option_select("uid-1", 0, "12, 34")
82
+
83
+ assert coords_text == "12, 34"
84
+ assert img_update.get("interactive") is True
gradio-web/test/test_reference_action_oracle.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ class _FakePose:
7
+ def __init__(self, p):
8
+ self.p = np.asarray(p, dtype=np.float64)
9
+
10
+
11
+ class _FakeActor:
12
+ def __init__(self, name: str, p):
13
+ self.name = name
14
+ self.pose = _FakePose(p)
15
+
16
+
17
+ class _FakeUnwrapped:
18
+ def __init__(self, choice_label: str, current_segment=None, seg_map=None):
19
+ self.current_choice_label = choice_label
20
+ self.current_segment = current_segment
21
+ self.segmentation_id_map = seg_map or {}
22
+
23
+ def get_obs(self, unflattened=True):
24
+ raise RuntimeError("not needed for centroid path")
25
+
26
+
27
+ class _FakeEnv:
28
+ def __init__(self, unwrapped):
29
+ self.unwrapped = unwrapped
30
+
31
+
32
+ def test_get_reference_action_maps_choice_and_returns_centroid_coords(monkeypatch, reload_module):
33
+ oracle_logic = reload_module("oracle_logic")
34
+
35
+ actor = _FakeActor("cube", [0.1, 0.2, 0.3])
36
+ unwrapped = _FakeUnwrapped(
37
+ choice_label="pick up the cube",
38
+ current_segment=actor,
39
+ seg_map={7: actor},
40
+ )
41
+ env = _FakeEnv(unwrapped)
42
+
43
+ monkeypatch.setattr(
44
+ oracle_logic,
45
+ "_build_solve_options",
46
+ lambda env, planner, selected_target, env_id: [
47
+ {"label": "a", "action": "pick up the cube", "available": [actor]}
48
+ ],
49
+ )
50
+
51
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
52
+ session.env = env
53
+ session.planner = object()
54
+ session.env_id = "BinFill"
55
+ session.seg_raw = np.zeros((10, 10), dtype=np.int64)
56
+ session.seg_raw[2:5, 6:9] = 7
57
+
58
+ result = session.get_reference_action()
59
+
60
+ assert result["ok"] is True
61
+ assert result["option_idx"] == 0
62
+ assert result["option_label"] == "a"
63
+ assert result["need_coords"] is True
64
+ assert result["coords_xy"] == [7, 3]
65
+
66
+
67
+ def test_get_reference_action_for_non_parameter_option(monkeypatch, reload_module):
68
+ oracle_logic = reload_module("oracle_logic")
69
+
70
+ unwrapped = _FakeUnwrapped(choice_label="press the button")
71
+ env = _FakeEnv(unwrapped)
72
+
73
+ monkeypatch.setattr(
74
+ oracle_logic,
75
+ "_build_solve_options",
76
+ lambda env, planner, selected_target, env_id: [
77
+ {"label": "c", "action": "press the button"}
78
+ ],
79
+ )
80
+
81
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
82
+ session.env = env
83
+ session.planner = object()
84
+ session.env_id = "ButtonUnmask"
85
+
86
+ result = session.get_reference_action()
87
+
88
+ assert result["ok"] is True
89
+ assert result["option_idx"] == 0
90
+ assert result["need_coords"] is False
91
+ assert result["coords_xy"] is None
92
+
93
+
94
+ def test_get_reference_action_when_choice_text_cannot_match(monkeypatch, reload_module):
95
+ oracle_logic = reload_module("oracle_logic")
96
+
97
+ unwrapped = _FakeUnwrapped(choice_label="unknown action")
98
+ env = _FakeEnv(unwrapped)
99
+
100
+ monkeypatch.setattr(
101
+ oracle_logic,
102
+ "_build_solve_options",
103
+ lambda env, planner, selected_target, env_id: [
104
+ {"label": "a", "action": "pick up the cube", "available": []}
105
+ ],
106
+ )
107
+
108
+ session = oracle_logic.OracleSession(dataset_root=None, gui_render=False)
109
+ session.env = env
110
+ session.planner = object()
111
+ session.env_id = "BinFill"
112
+
113
+ result = session.get_reference_action()
114
+
115
+ assert result["ok"] is False
116
+ assert result["option_idx"] is None
117
+ assert "Cannot map ground truth action" in result["message"]
gradio-web/test/test_ui_native_layout_contract.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ def test_native_ui_has_no_legacy_runtime_js_or_card_shell_tokens(reload_module):
5
+ ui_layout = reload_module("ui_layout")
6
+
7
+ assert ui_layout.SYNC_JS.strip() == ""
8
+
9
+ css = ui_layout.CSS
10
+ assert ".native-card" in css
11
+
12
+ forbidden_tokens = [
13
+ "card-shell-hit",
14
+ "card-shell-button",
15
+ "floating-card",
16
+ "applyCardShellOnce",
17
+ "media_card_anchor",
18
+ "action_selection_card_anchor",
19
+ "next_task_btn_card_anchor",
20
+ "MutationObserver",
21
+ ]
22
+ for token in forbidden_tokens:
23
+ assert token not in css
24
+
25
+
26
+ def test_native_ui_config_contains_phase_machine_and_precheck_chain(reload_module):
27
+ ui_layout = reload_module("ui_layout")
28
+ demo = ui_layout.create_ui_blocks()
29
+
30
+ try:
31
+ cfg = demo.get_config_file()
32
+
33
+ elem_ids = {
34
+ comp.get("props", {}).get("elem_id")
35
+ for comp in cfg.get("components", [])
36
+ if comp.get("props", {}).get("elem_id")
37
+ }
38
+
39
+ required_ids = {
40
+ "header_task",
41
+ "loading_overlay_group",
42
+ "main_layout_row",
43
+ "media_card",
44
+ "log_card",
45
+ "right_top_row",
46
+ "right_action_col",
47
+ "right_log_col",
48
+ "control_panel_group",
49
+ "video_phase_group",
50
+ "action_phase_group",
51
+ "demo_video",
52
+ "live_obs",
53
+ "action_radio",
54
+ "coords_box",
55
+ "exec_btn",
56
+ "reference_action_btn",
57
+ "restart_episode_btn",
58
+ "next_task_btn",
59
+ }
60
+ missing = required_ids - elem_ids
61
+ assert not missing, f"missing required elem_ids: {sorted(missing)}"
62
+
63
+ values = [
64
+ comp.get("props", {}).get("value")
65
+ for comp in cfg.get("components", [])
66
+ if "value" in comp.get("props", {})
67
+ ]
68
+ assert all("_anchor" not in str(v) for v in values)
69
+ assert any(
70
+ "Logging in and setting up environment... Please wait." in str(v)
71
+ for v in values
72
+ )
73
+ assert all("Loading environment, please wait..." not in str(v) for v in values)
74
+
75
+ log_output_comp = next(
76
+ comp
77
+ for comp in cfg.get("components", [])
78
+ if comp.get("props", {}).get("elem_id") == "log_output"
79
+ )
80
+ assert log_output_comp.get("props", {}).get("max_lines") is None
81
+
82
+ api_names = [dep.get("api_name") for dep in cfg.get("dependencies", [])]
83
+ assert "precheck_execute_inputs" in api_names
84
+ assert "switch_to_execute_phase" in api_names
85
+ assert "execute_step" in api_names
86
+ assert "switch_to_action_phase" in api_names
87
+ finally:
88
+ demo.close()
gradio-web/test/test_ui_phase_machine_runtime_e2e.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import importlib
5
+ import socket
6
+ import threading
7
+ import time
8
+ from urllib.error import URLError
9
+ from urllib.request import urlopen
10
+
11
+ import numpy as np
12
+ import pytest
13
+ from PIL import Image
14
+
15
+
16
+ gr = pytest.importorskip("gradio")
17
+ pytest.importorskip("fastapi")
18
+ pytest.importorskip("uvicorn")
19
+ pytest.importorskip("playwright.sync_api")
20
+
21
+ import uvicorn
22
+ from fastapi import FastAPI
23
+ from playwright.sync_api import sync_playwright
24
+
25
+
26
+ def _free_port() -> int:
27
+ with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
28
+ sock.bind(("127.0.0.1", 0))
29
+ return int(sock.getsockname()[1])
30
+
31
+
32
+ def _wait_http_ready(url: str, timeout_s: float = 20.0) -> None:
33
+ end = time.time() + timeout_s
34
+ while time.time() < end:
35
+ try:
36
+ with urlopen(url, timeout=1.0) as resp: # noqa: S310 - local test URL only
37
+ if int(getattr(resp, "status", 200)) < 500:
38
+ return
39
+ except URLError:
40
+ time.sleep(0.2)
41
+ except Exception:
42
+ time.sleep(0.2)
43
+ raise RuntimeError(f"Server did not become ready: {url}")
44
+
45
+
46
+ def _read_header_task_value(page) -> str | None:
47
+ return page.evaluate(
48
+ """() => {
49
+ const root = document.getElementById('header_task');
50
+ if (!root) return null;
51
+ const input = root.querySelector('input');
52
+ if (input && typeof input.value === 'string') {
53
+ const value = input.value.trim();
54
+ return value || null;
55
+ }
56
+ const selected = root.querySelector('.single-select');
57
+ if (!selected) return null;
58
+ const text = (selected.textContent || '').trim();
59
+ return text || null;
60
+ }"""
61
+ )
62
+
63
+
64
+ @pytest.fixture
65
+ def phase_machine_ui_url():
66
+ state = {"precheck_calls": 0}
67
+ demo_video_url = "https://interactive-examples.mdn.mozilla.net/media/cc0-videos/flower.mp4"
68
+
69
+ with gr.Blocks(title="Native phase machine test") as demo:
70
+ phase_state = gr.State("init")
71
+
72
+ with gr.Column(visible=True, elem_id="login_group") as login_group:
73
+ login_btn = gr.Button("Login", elem_id="login_btn")
74
+
75
+ with gr.Column(visible=False, elem_id="main_interface") as main_interface:
76
+ with gr.Column(visible=False, elem_id="video_phase_group") as video_phase_group:
77
+ video_display = gr.Video(value=None, elem_id="demo_video", autoplay=True)
78
+
79
+ with gr.Column(visible=False, elem_id="action_phase_group") as action_phase_group:
80
+ img_display = gr.Image(value=np.zeros((24, 24, 3), dtype=np.uint8), elem_id="live_obs")
81
+
82
+ with gr.Column(visible=False, elem_id="control_panel_group") as control_panel_group:
83
+ options_radio = gr.Radio(choices=[("pick", 0)], value=0, elem_id="action_radio")
84
+ coords_box = gr.Textbox(value="please click the keypoint selection image", elem_id="coords_box")
85
+ with gr.Column(visible=False, elem_id="action_buttons_row") as action_buttons_row:
86
+ exec_btn = gr.Button("EXECUTE", elem_id="exec_btn")
87
+ next_task_btn = gr.Button("Next Task", elem_id="next_task_btn")
88
+
89
+ log_output = gr.Markdown("", elem_id="log_output")
90
+
91
+ def login_fn():
92
+ return (
93
+ gr.update(visible=False),
94
+ gr.update(visible=True),
95
+ gr.update(visible=True),
96
+ gr.update(value=demo_video_url, visible=True),
97
+ gr.update(visible=False),
98
+ gr.update(visible=False),
99
+ gr.update(visible=False),
100
+ gr.update(value="please click the keypoint selection image"),
101
+ "demo_video",
102
+ )
103
+
104
+ def on_video_end_fn():
105
+ return (
106
+ gr.update(visible=False),
107
+ gr.update(visible=True),
108
+ gr.update(visible=True),
109
+ gr.update(visible=True),
110
+ "action_keypoint",
111
+ )
112
+
113
+ def precheck_fn(_option_idx, _coords):
114
+ state["precheck_calls"] += 1
115
+ if state["precheck_calls"] == 1:
116
+ raise gr.Error("please click the keypoint selection image before execute!")
117
+
118
+ def to_execute_fn():
119
+ return (
120
+ gr.update(interactive=False),
121
+ gr.update(interactive=False),
122
+ gr.update(interactive=False),
123
+ gr.update(interactive=False),
124
+ "execution_playback",
125
+ )
126
+
127
+ def execute_fn():
128
+ time.sleep(0.8)
129
+ return (
130
+ "executed",
131
+ gr.update(interactive=True),
132
+ gr.update(interactive=True),
133
+ )
134
+
135
+ def to_action_fn():
136
+ return (
137
+ gr.update(interactive=True),
138
+ gr.update(interactive=True),
139
+ gr.update(interactive=True),
140
+ gr.update(interactive=True),
141
+ "action_keypoint",
142
+ )
143
+
144
+ login_btn.click(
145
+ fn=login_fn,
146
+ outputs=[
147
+ login_group,
148
+ main_interface,
149
+ video_phase_group,
150
+ video_display,
151
+ action_phase_group,
152
+ control_panel_group,
153
+ action_buttons_row,
154
+ coords_box,
155
+ phase_state,
156
+ ],
157
+ queue=False,
158
+ )
159
+
160
+ video_display.end(
161
+ fn=on_video_end_fn,
162
+ outputs=[video_phase_group, action_phase_group, control_panel_group, action_buttons_row, phase_state],
163
+ queue=False,
164
+ )
165
+
166
+ exec_btn.click(
167
+ fn=precheck_fn,
168
+ inputs=[options_radio, coords_box],
169
+ outputs=[],
170
+ queue=False,
171
+ ).then(
172
+ fn=to_execute_fn,
173
+ outputs=[
174
+ options_radio,
175
+ exec_btn,
176
+ next_task_btn,
177
+ img_display,
178
+ phase_state,
179
+ ],
180
+ queue=False,
181
+ ).then(
182
+ fn=execute_fn,
183
+ outputs=[log_output, next_task_btn, exec_btn],
184
+ queue=False,
185
+ ).then(
186
+ fn=to_action_fn,
187
+ outputs=[options_radio, exec_btn, next_task_btn, img_display, phase_state],
188
+ queue=False,
189
+ )
190
+
191
+ port = _free_port()
192
+ host = "127.0.0.1"
193
+ root_url = f"http://{host}:{port}/"
194
+
195
+ app = FastAPI(title="native-phase-machine-test")
196
+ app = gr.mount_gradio_app(app, demo, path="/")
197
+
198
+ config = uvicorn.Config(app, host=host, port=port, log_level="error")
199
+ server = uvicorn.Server(config)
200
+ thread = threading.Thread(target=server.run, daemon=True)
201
+ thread.start()
202
+ _wait_http_ready(root_url)
203
+
204
+ try:
205
+ yield root_url, state
206
+ finally:
207
+ server.should_exit = True
208
+ thread.join(timeout=10)
209
+ demo.close()
210
+
211
+
212
+ def test_phase_machine_runtime_flow_and_execute_precheck(phase_machine_ui_url):
213
+ root_url, state = phase_machine_ui_url
214
+
215
+ with sync_playwright() as p:
216
+ browser = p.chromium.launch(headless=True)
217
+ page = browser.new_page(viewport={"width": 1280, "height": 900})
218
+ page.goto(root_url, wait_until="domcontentloaded")
219
+
220
+ page.wait_for_timeout(2500)
221
+ page.wait_for_selector("#login_btn", timeout=20000)
222
+ page.click("#login_btn")
223
+
224
+ page.wait_for_function(
225
+ """() => {
226
+ const el = document.getElementById('demo_video');
227
+ return !!el && getComputedStyle(el).display !== 'none';
228
+ }"""
229
+ )
230
+
231
+ phase_after_login = page.evaluate(
232
+ """() => {
233
+ const visible = (id) => {
234
+ const el = document.getElementById(id);
235
+ if (!el) return false;
236
+ const st = getComputedStyle(el);
237
+ return st.display !== 'none' && st.visibility !== 'hidden' && el.getClientRects().length > 0;
238
+ };
239
+ return {
240
+ video: visible('demo_video'),
241
+ action: visible('live_obs'),
242
+ control: visible('action_radio'),
243
+ };
244
+ }"""
245
+ )
246
+ assert phase_after_login == {
247
+ "video": True,
248
+ "action": False,
249
+ "control": False,
250
+ }
251
+
252
+ page.wait_for_selector("#demo_video video", timeout=5000)
253
+ did_dispatch_end = page.evaluate(
254
+ """() => {
255
+ const videoEl = document.querySelector('#demo_video video');
256
+ if (!videoEl) return false;
257
+ videoEl.dispatchEvent(new Event('ended', { bubbles: true }));
258
+ return true;
259
+ }"""
260
+ )
261
+ assert did_dispatch_end
262
+
263
+ page.wait_for_function(
264
+ """() => {
265
+ const action = document.getElementById('live_obs');
266
+ const control = document.getElementById('action_radio');
267
+ if (!action || !control) return false;
268
+ return getComputedStyle(action).display !== 'none' && getComputedStyle(control).display !== 'none';
269
+ }"""
270
+ )
271
+
272
+ did_click_exec = page.evaluate(
273
+ """() => {
274
+ const btn = document.getElementById('exec_btn');
275
+ if (!btn) return false;
276
+ btn.click();
277
+ return true;
278
+ }"""
279
+ )
280
+ assert did_click_exec
281
+ page.wait_for_timeout(300)
282
+
283
+ phase_after_failed_precheck = page.evaluate(
284
+ """() => {
285
+ const visible = (id) => {
286
+ const el = document.getElementById(id);
287
+ if (!el) return false;
288
+ return getComputedStyle(el).display !== 'none';
289
+ };
290
+ return {
291
+ action: visible('live_obs'),
292
+ };
293
+ }"""
294
+ )
295
+ assert phase_after_failed_precheck == {"action": True}
296
+
297
+ did_click_exec = page.evaluate(
298
+ """() => {
299
+ const btn = document.getElementById('exec_btn');
300
+ if (!btn) return false;
301
+ btn.click();
302
+ return true;
303
+ }"""
304
+ )
305
+ assert did_click_exec
306
+
307
+ page.wait_for_function(
308
+ """() => {
309
+ const resolveButton = (id) => {
310
+ return document.querySelector(`#${id} button`) || document.querySelector(`button#${id}`);
311
+ };
312
+ const execBtn = resolveButton('exec_btn');
313
+ const nextBtn = resolveButton('next_task_btn');
314
+ return !!execBtn && !!nextBtn && execBtn.disabled === true && nextBtn.disabled === true;
315
+ }"""
316
+ )
317
+
318
+ interactive_snapshot = page.evaluate(
319
+ """() => {
320
+ const resolveButton = (id) => {
321
+ return document.querySelector(`#${id} button`) || document.querySelector(`button#${id}`);
322
+ };
323
+ const execBtn = resolveButton('exec_btn');
324
+ const nextBtn = resolveButton('next_task_btn');
325
+ return {
326
+ execDisabled: execBtn ? execBtn.disabled : null,
327
+ nextDisabled: nextBtn ? nextBtn.disabled : null,
328
+ };
329
+ }"""
330
+ )
331
+ assert interactive_snapshot["execDisabled"] is True
332
+ assert interactive_snapshot["nextDisabled"] is True
333
+
334
+ page.wait_for_function(
335
+ """() => {
336
+ const execBtn = document.querySelector('button#exec_btn') || document.querySelector('#exec_btn button');
337
+ const action = document.getElementById('live_obs');
338
+ if (!execBtn || !action) return false;
339
+ return execBtn.disabled === false && getComputedStyle(action).display !== 'none';
340
+ }""",
341
+ timeout=6000,
342
+ )
343
+
344
+ final_interactive_snapshot = page.evaluate(
345
+ """() => {
346
+ const resolveButton = (id) => {
347
+ return document.querySelector(`#${id} button`) || document.querySelector(`button#${id}`);
348
+ };
349
+ const execBtn = resolveButton('exec_btn');
350
+ const nextBtn = resolveButton('next_task_btn');
351
+ return {
352
+ execDisabled: execBtn ? execBtn.disabled : null,
353
+ nextDisabled: nextBtn ? nextBtn.disabled : null,
354
+ };
355
+ }"""
356
+ )
357
+ assert final_interactive_snapshot["execDisabled"] is False
358
+ assert final_interactive_snapshot["nextDisabled"] is False
359
+
360
+ browser.close()
361
+
362
+ assert state["precheck_calls"] >= 2
363
+
364
+
365
+ def test_unified_loading_overlay_init_flow(monkeypatch):
366
+ ui_layout = importlib.reload(importlib.import_module("ui_layout"))
367
+
368
+ canonical_copy = "Logging in and setting up environment... Please wait."
369
+ legacy_copy = "Loading environment, please wait..."
370
+ fake_obs = np.zeros((24, 24, 3), dtype=np.uint8)
371
+ fake_obs_img = Image.fromarray(fake_obs)
372
+ calls = {"init": 0}
373
+
374
+ def fake_show_loading_info():
375
+ return gr.update(visible=True)
376
+
377
+ def fake_init_app(_request=None):
378
+ calls["init"] += 1
379
+ time.sleep(0.8)
380
+ return (
381
+ "uid-init",
382
+ gr.update(visible=True), # main_interface
383
+ gr.update(value=fake_obs_img, interactive=False), # img_display
384
+ "ready", # log_output
385
+ gr.update(choices=[("pick", 0)], value=None), # options_radio
386
+ "goal", # goal_box
387
+ "No need for coordinates", # coords_box
388
+ gr.update(value=None, visible=False), # video_display
389
+ "PickXtimes (Episode 1)", # task_info_box
390
+ "Completed: 0", # progress_info_box
391
+ gr.update(interactive=True), # restart_episode_btn
392
+ gr.update(interactive=True), # next_task_btn
393
+ gr.update(interactive=True), # exec_btn
394
+ gr.update(visible=False), # video_phase_group
395
+ gr.update(visible=True), # action_phase_group
396
+ gr.update(visible=True), # control_panel_group
397
+ gr.update(value="hint"), # task_hint_display
398
+ gr.update(visible=False), # loading_overlay
399
+ gr.update(interactive=True), # reference_action_btn
400
+ )
401
+
402
+ monkeypatch.setattr(ui_layout, "show_loading_info", fake_show_loading_info)
403
+ monkeypatch.setattr(ui_layout, "init_app", fake_init_app)
404
+
405
+ demo = ui_layout.create_ui_blocks()
406
+
407
+ port = _free_port()
408
+ host = "127.0.0.1"
409
+ root_url = f"http://{host}:{port}/"
410
+
411
+ app = FastAPI(title="native-unified-loading-overlay-test")
412
+ app = gr.mount_gradio_app(app, demo, path="/")
413
+
414
+ config = uvicorn.Config(app, host=host, port=port, log_level="error")
415
+ server = uvicorn.Server(config)
416
+ thread = threading.Thread(target=server.run, daemon=True)
417
+ thread.start()
418
+ _wait_http_ready(root_url)
419
+
420
+ try:
421
+ with sync_playwright() as p:
422
+ browser = p.chromium.launch(headless=True)
423
+ page = browser.new_page(viewport={"width": 1280, "height": 900})
424
+ page.goto(root_url, wait_until="domcontentloaded")
425
+
426
+ page.wait_for_selector("#loading_overlay_group", state="visible", timeout=2500)
427
+
428
+ overlay_text = page.evaluate(
429
+ """() => {
430
+ const el = document.getElementById('loading_overlay_group');
431
+ return el ? (el.textContent || '') : '';
432
+ }"""
433
+ )
434
+ assert canonical_copy in overlay_text
435
+ assert legacy_copy not in page.content()
436
+
437
+ page.wait_for_selector("#loading_overlay_group", state="hidden", timeout=15000)
438
+ page.wait_for_selector("#main_interface_root", state="visible", timeout=15000)
439
+ page.wait_for_function(
440
+ """() => {
441
+ const root = document.getElementById('header_task');
442
+ const input = root ? root.querySelector('input') : null;
443
+ return !!input && input.value.trim() === 'PickXtimes';
444
+ }""",
445
+ timeout=5000,
446
+ )
447
+ assert _read_header_task_value(page) == "PickXtimes"
448
+
449
+ browser.close()
450
+ finally:
451
+ server.should_exit = True
452
+ thread.join(timeout=10)
453
+ demo.close()
454
+
455
+ assert calls["init"] >= 1
456
+
457
+
458
+ def test_header_task_shows_env_after_init(monkeypatch):
459
+ ui_layout = importlib.reload(importlib.import_module("ui_layout"))
460
+
461
+ fake_obs = np.zeros((24, 24, 3), dtype=np.uint8)
462
+ fake_obs_img = Image.fromarray(fake_obs)
463
+
464
+ def fake_init_app(request=None):
465
+ _ = request
466
+ return (
467
+ "uid-auto",
468
+ gr.update(visible=True), # main_interface
469
+ gr.update(value=fake_obs_img, interactive=False), # img_display
470
+ "ready", # log_output
471
+ gr.update(choices=[("pick", 0)], value=None), # options_radio
472
+ "goal", # goal_box
473
+ "No need for coordinates", # coords_box
474
+ gr.update(value=None, visible=False), # video_display
475
+ "PickXtimes (Episode 1)", # task_info_box
476
+ "Completed: 0", # progress_info_box
477
+ gr.update(interactive=True), # restart_episode_btn
478
+ gr.update(interactive=True), # next_task_btn
479
+ gr.update(interactive=True), # exec_btn
480
+ gr.update(visible=False), # video_phase_group
481
+ gr.update(visible=True), # action_phase_group
482
+ gr.update(visible=True), # control_panel_group
483
+ gr.update(value="hint"), # task_hint_display
484
+ gr.update(visible=False), # loading_overlay
485
+ gr.update(interactive=True), # reference_action_btn
486
+ )
487
+
488
+ monkeypatch.setattr(ui_layout, "init_app", fake_init_app)
489
+
490
+ demo = ui_layout.create_ui_blocks()
491
+
492
+ port = _free_port()
493
+ host = "127.0.0.1"
494
+ root_url = f"http://{host}:{port}/"
495
+
496
+ app = FastAPI(title="header-task-url-auto-login-test")
497
+ app = gr.mount_gradio_app(app, demo, path="/")
498
+
499
+ config = uvicorn.Config(app, host=host, port=port, log_level="error")
500
+ server = uvicorn.Server(config)
501
+ thread = threading.Thread(target=server.run, daemon=True)
502
+ thread.start()
503
+ _wait_http_ready(root_url)
504
+
505
+ try:
506
+ with sync_playwright() as p:
507
+ browser = p.chromium.launch(headless=True)
508
+ page = browser.new_page(viewport={"width": 1280, "height": 900})
509
+ page.goto(f"{root_url}?user=user1", wait_until="domcontentloaded")
510
+ page.wait_for_selector("#main_interface_root", state="visible", timeout=15000)
511
+ page.wait_for_function(
512
+ """() => {
513
+ const root = document.getElementById('header_task');
514
+ const input = root ? root.querySelector('input') : null;
515
+ return !!input && input.value.trim() === 'PickXtimes';
516
+ }""",
517
+ timeout=5000,
518
+ )
519
+ assert _read_header_task_value(page) == "PickXtimes"
520
+ browser.close()
521
+ finally:
522
+ server.should_exit = True
523
+ thread.join(timeout=10)
524
+ demo.close()
525
+
526
+
527
+ @pytest.mark.parametrize(
528
+ "task_info_text,expected_header_value",
529
+ [
530
+ ("pickxtimes (Episode 1)", "PickXtimes"),
531
+ ("EnvFromSessionOnly (Episode 1)", "EnvFromSessionOnly"),
532
+ ],
533
+ )
534
+ def test_header_task_env_normalization_and_fallback(monkeypatch, task_info_text, expected_header_value):
535
+ ui_layout = importlib.reload(importlib.import_module("ui_layout"))
536
+
537
+ fake_obs = np.zeros((24, 24, 3), dtype=np.uint8)
538
+ fake_obs_img = Image.fromarray(fake_obs)
539
+
540
+ def fake_init_app(_request=None):
541
+ return (
542
+ "uid-auto",
543
+ gr.update(visible=True), # main_interface
544
+ gr.update(value=fake_obs_img, interactive=False), # img_display
545
+ "ready", # log_output
546
+ gr.update(choices=[("pick", 0)], value=None), # options_radio
547
+ "goal", # goal_box
548
+ "No need for coordinates", # coords_box
549
+ gr.update(value=None, visible=False), # video_display
550
+ task_info_text, # task_info_box
551
+ "Completed: 0", # progress_info_box
552
+ gr.update(interactive=True), # restart_episode_btn
553
+ gr.update(interactive=True), # next_task_btn
554
+ gr.update(interactive=True), # exec_btn
555
+ gr.update(visible=False), # video_phase_group
556
+ gr.update(visible=True), # action_phase_group
557
+ gr.update(visible=True), # control_panel_group
558
+ gr.update(value="hint"), # task_hint_display
559
+ gr.update(visible=False), # loading_overlay
560
+ gr.update(interactive=True), # reference_action_btn
561
+ )
562
+
563
+ monkeypatch.setattr(ui_layout, "init_app", fake_init_app)
564
+
565
+ demo = ui_layout.create_ui_blocks()
566
+
567
+ port = _free_port()
568
+ host = "127.0.0.1"
569
+ root_url = f"http://{host}:{port}/"
570
+
571
+ app = FastAPI(title="header-task-normalization-fallback-test")
572
+ app = gr.mount_gradio_app(app, demo, path="/")
573
+
574
+ config = uvicorn.Config(app, host=host, port=port, log_level="error")
575
+ server = uvicorn.Server(config)
576
+ thread = threading.Thread(target=server.run, daemon=True)
577
+ thread.start()
578
+ _wait_http_ready(root_url)
579
+
580
+ try:
581
+ with sync_playwright() as p:
582
+ browser = p.chromium.launch(headless=True)
583
+ page = browser.new_page(viewport={"width": 1280, "height": 900})
584
+ page.goto(root_url, wait_until="domcontentloaded")
585
+ page.wait_for_selector("#main_interface_root", state="visible", timeout=15000)
586
+ page.wait_for_function(
587
+ """(expectedValue) => {
588
+ const root = document.getElementById('header_task');
589
+ const input = root ? root.querySelector('input') : null;
590
+ return !!input && input.value.trim() === expectedValue;
591
+ }""",
592
+ arg=expected_header_value,
593
+ timeout=5000,
594
+ )
595
+ assert _read_header_task_value(page) == expected_header_value
596
+ browser.close()
597
+ finally:
598
+ server.should_exit = True
599
+ thread.join(timeout=10)
600
+ demo.close()
601
+
602
+
603
+ def test_phase_machine_runtime_local_video_path_end_transition():
604
+ import gradio_callbacks as cb
605
+
606
+ demo_video_path = gr.get_video("world.mp4")
607
+ fake_obs = np.zeros((24, 24, 3), dtype=np.uint8)
608
+
609
+ class FakeSession:
610
+ def __init__(self):
611
+ self.env_id = "VideoUnmask"
612
+ self.language_goal = "place cube on target"
613
+ self.available_options = [("pick", 0)]
614
+ self.raw_solve_options = [{"available": False}]
615
+ self.demonstration_frames = [fake_obs.copy() for _ in range(4)]
616
+
617
+ def load_episode(self, env_id, episode_idx):
618
+ self.env_id = env_id
619
+ return fake_obs.copy(), f"loaded {env_id}:{episode_idx}"
620
+
621
+ def get_pil_image(self, use_segmented=False):
622
+ _ = use_segmented
623
+ return fake_obs.copy()
624
+
625
+ originals = {
626
+ "get_session": cb.get_session,
627
+ "reset_play_button_clicked": cb.reset_play_button_clicked,
628
+ "reset_execute_count": cb.reset_execute_count,
629
+ "set_task_start_time": cb.set_task_start_time,
630
+ "set_ui_phase": cb.set_ui_phase,
631
+ "save_video": cb.save_video,
632
+ }
633
+
634
+ fake_session = FakeSession()
635
+
636
+ cb.get_session = lambda uid: fake_session
637
+ cb.reset_play_button_clicked = lambda uid: None
638
+ cb.reset_execute_count = lambda uid, env_id, ep_num: None
639
+ cb.set_task_start_time = lambda uid, env_id, ep_num, start_time: None
640
+ cb.set_ui_phase = lambda uid, phase: None
641
+ cb.save_video = lambda frames, suffix="": demo_video_path
642
+
643
+ try:
644
+ with gr.Blocks(title="Native phase machine local video test") as demo:
645
+ uid_state = gr.State(value="uid-local-video")
646
+ with gr.Column(visible=False, elem_id="main_interface") as main_interface:
647
+ with gr.Column(visible=False, elem_id="video_phase_group") as video_phase_group:
648
+ video_display = gr.Video(value=None, elem_id="demo_video", autoplay=False)
649
+
650
+ with gr.Column(visible=True, elem_id="action_phase_group") as action_phase_group:
651
+ img_display = gr.Image(value=fake_obs.copy(), elem_id="live_obs")
652
+
653
+ with gr.Column(visible=True, elem_id="control_panel_group") as control_panel_group:
654
+ options_radio = gr.Radio(choices=[("pick", 0)], value=None, elem_id="action_radio")
655
+
656
+ log_output = gr.Markdown("", elem_id="log_output")
657
+ goal_box = gr.Textbox("")
658
+ coords_box = gr.Textbox("No need for coordinates")
659
+ task_info_box = gr.Textbox("")
660
+ progress_info_box = gr.Textbox("")
661
+ task_hint_display = gr.Textbox("")
662
+ with gr.Column(visible=False) as loading_overlay:
663
+ gr.Markdown("Loading...")
664
+
665
+ restart_episode_btn = gr.Button("restart", interactive=False)
666
+ next_task_btn = gr.Button("next", interactive=False)
667
+ exec_btn = gr.Button("execute", interactive=False)
668
+ reference_action_btn = gr.Button("reference", interactive=False)
669
+
670
+ def load_fn():
671
+ status = {
672
+ "current_task": {"env_id": "VideoUnmask", "episode_idx": 1},
673
+ "completed_count": 0,
674
+ }
675
+ return cb._load_status_task("uid-local-video", status)
676
+
677
+ demo.load(
678
+ fn=load_fn,
679
+ outputs=[
680
+ uid_state,
681
+ main_interface,
682
+ img_display,
683
+ log_output,
684
+ options_radio,
685
+ goal_box,
686
+ coords_box,
687
+ video_display,
688
+ task_info_box,
689
+ progress_info_box,
690
+ restart_episode_btn,
691
+ next_task_btn,
692
+ exec_btn,
693
+ video_phase_group,
694
+ action_phase_group,
695
+ control_panel_group,
696
+ task_hint_display,
697
+ loading_overlay,
698
+ reference_action_btn,
699
+ ],
700
+ queue=False,
701
+ )
702
+
703
+ video_display.end(
704
+ fn=cb.on_video_end_transition,
705
+ inputs=[uid_state],
706
+ outputs=[video_phase_group, action_phase_group, control_panel_group, log_output],
707
+ queue=False,
708
+ )
709
+
710
+ port = _free_port()
711
+ host = "127.0.0.1"
712
+ root_url = f"http://{host}:{port}/"
713
+
714
+ app = FastAPI(title="native-phase-machine-local-video-test")
715
+ app = gr.mount_gradio_app(app, demo, path="/")
716
+
717
+ config = uvicorn.Config(app, host=host, port=port, log_level="error")
718
+ server = uvicorn.Server(config)
719
+ thread = threading.Thread(target=server.run, daemon=True)
720
+ thread.start()
721
+ _wait_http_ready(root_url)
722
+
723
+ try:
724
+ with sync_playwright() as p:
725
+ browser = p.chromium.launch(headless=True)
726
+ page = browser.new_page(viewport={"width": 1280, "height": 900})
727
+ page.goto(root_url, wait_until="domcontentloaded")
728
+ page.wait_for_selector("#main_interface", state="visible", timeout=20000)
729
+
730
+ page.wait_for_selector("#demo_video video", timeout=5000)
731
+ phase_after_login = page.evaluate(
732
+ """() => {
733
+ const visible = (id) => {
734
+ const el = document.getElementById(id);
735
+ if (!el) return false;
736
+ const st = getComputedStyle(el);
737
+ return st.display !== 'none' && st.visibility !== 'hidden' && el.getClientRects().length > 0;
738
+ };
739
+ return {
740
+ video: visible('demo_video'),
741
+ action: visible('live_obs'),
742
+ control: visible('action_radio'),
743
+ };
744
+ }"""
745
+ )
746
+ assert phase_after_login == {
747
+ "video": True,
748
+ "action": False,
749
+ "control": False,
750
+ }
751
+
752
+ did_dispatch_end = page.evaluate(
753
+ """() => {
754
+ const videoEl = document.querySelector('#demo_video video');
755
+ if (!videoEl) return false;
756
+ videoEl.dispatchEvent(new Event('ended', { bubbles: true }));
757
+ return true;
758
+ }"""
759
+ )
760
+ assert did_dispatch_end
761
+
762
+ page.wait_for_function(
763
+ """() => {
764
+ const visible = (id) => {
765
+ const el = document.getElementById(id);
766
+ if (!el) return false;
767
+ const st = getComputedStyle(el);
768
+ return st.display !== 'none' && st.visibility !== 'hidden' && el.getClientRects().length > 0;
769
+ };
770
+ return visible('live_obs') && visible('action_radio') && !visible('demo_video');
771
+ }""",
772
+ timeout=2000,
773
+ )
774
+
775
+ browser.close()
776
+ finally:
777
+ server.should_exit = True
778
+ thread.join(timeout=10)
779
+ demo.close()
780
+ finally:
781
+ for name, value in originals.items():
782
+ setattr(cb, name, value)
gradio-web/test/test_user_manager_random_flow.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+
5
+
6
+ def _write_metadata(root, env_id: str, episodes: list[int]) -> None:
7
+ root.mkdir(parents=True, exist_ok=True)
8
+ payload = {
9
+ "env_id": env_id,
10
+ "records": [
11
+ {"task": env_id, "episode": ep, "seed": 1000 + ep, "difficulty": "easy"}
12
+ for ep in episodes
13
+ ],
14
+ }
15
+ (root / f"record_dataset_{env_id}_metadata.json").write_text(
16
+ json.dumps(payload), encoding="utf-8"
17
+ )
18
+
19
+
20
+ def test_fixed_users_login_and_random_task_pool(monkeypatch, reload_module, tmp_path):
21
+ metadata_root = tmp_path / "metadata"
22
+ _write_metadata(metadata_root, "EnvA", [0, 1, 2])
23
+ _write_metadata(metadata_root, "EnvB", [10, 11])
24
+ monkeypatch.setenv("ROBOMME_METADATA_ROOT", str(metadata_root))
25
+
26
+ user_manager_mod = reload_module("user_manager")
27
+ monkeypatch.setattr(user_manager_mod.random, "choice", lambda seq: seq[0])
28
+ manager = user_manager_mod.UserManager()
29
+
30
+ success, _msg, status = manager.init_session("uid1")
31
+ assert success
32
+ assert status["current_task"]["env_id"] in {"EnvA", "EnvB"}
33
+ assert status["current_task"]["episode_idx"] in {0, 1, 2, 10, 11}
34
+ assert status["is_done_all"] is False
35
+
36
+
37
+ def test_switch_env_and_next_episode_stays_in_same_env(monkeypatch, reload_module, tmp_path):
38
+ metadata_root = tmp_path / "metadata"
39
+ _write_metadata(metadata_root, "EnvA", [0, 1, 2])
40
+ _write_metadata(metadata_root, "EnvB", [10, 11])
41
+ monkeypatch.setenv("ROBOMME_METADATA_ROOT", str(metadata_root))
42
+
43
+ user_manager_mod = reload_module("user_manager")
44
+ monkeypatch.setattr(user_manager_mod.random, "choice", lambda seq: seq[-1])
45
+ manager = user_manager_mod.UserManager()
46
+
47
+ success, _msg, _status = manager.init_session("uid2")
48
+ assert success
49
+
50
+ switched = manager.switch_env_and_random_episode("uid2", "EnvA")
51
+ assert switched is not None
52
+ assert switched["current_task"]["env_id"] == "EnvA"
53
+ assert switched["current_task"]["episode_idx"] in {0, 1, 2}
54
+
55
+ nxt = manager.next_episode_same_env("uid2")
56
+ assert nxt is not None
57
+ assert nxt["current_task"]["env_id"] == "EnvA"
58
+ assert nxt["current_task"]["episode_idx"] in {0, 1, 2}
59
+
60
+
61
+ def test_complete_current_task_increments_completed_count(monkeypatch, reload_module, tmp_path):
62
+ metadata_root = tmp_path / "metadata"
63
+ _write_metadata(metadata_root, "EnvA", [0, 1])
64
+ monkeypatch.setenv("ROBOMME_METADATA_ROOT", str(metadata_root))
65
+
66
+ user_manager_mod = reload_module("user_manager")
67
+ monkeypatch.setattr(user_manager_mod.random, "choice", lambda seq: seq[0])
68
+ manager = user_manager_mod.UserManager()
69
+
70
+ success, _msg, status = manager.init_session("uid3")
71
+ assert success
72
+ assert status["completed_count"] == 0
73
+
74
+ updated = manager.complete_current_task(
75
+ "uid3",
76
+ env_id=status["current_task"]["env_id"],
77
+ episode_idx=status["current_task"]["episode_idx"],
78
+ status="success",
79
+ )
80
+ assert updated is not None
81
+ assert updated["completed_count"] == 1
82
+ assert updated["is_done_all"] is False
83
+
84
+
85
+ def test_init_session_fails_when_metadata_root_missing(monkeypatch, reload_module, tmp_path):
86
+ missing_root = tmp_path / "missing-metadata-root"
87
+ monkeypatch.setenv("ROBOMME_METADATA_ROOT", str(missing_root))
88
+
89
+ user_manager_mod = reload_module("user_manager")
90
+ manager = user_manager_mod.UserManager()
91
+
92
+ success, msg, status = manager.init_session("uid-missing")
93
+
94
+ assert success is False
95
+ assert "No available environments" in msg
96
+ assert status is None
gradio-web/ui_layout.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Native Gradio UI layout.
3
+ Sequential media phases: Demo Video -> Action+Keypoint.
4
+ Two-column layout: Keypoint Selection | Right Panel.
5
+ """
6
+
7
+ import ast
8
+
9
+ import gradio as gr
10
+
11
+ from config import (
12
+ CONTROL_PANEL_SCALE,
13
+ KEYPOINT_SELECTION_SCALE,
14
+ RIGHT_TOP_ACTION_SCALE,
15
+ RIGHT_TOP_LOG_SCALE,
16
+ )
17
+ from gradio_callbacks import (
18
+ execute_step,
19
+ init_app,
20
+ load_next_task_wrapper,
21
+ on_map_click,
22
+ on_option_select,
23
+ on_reference_action,
24
+ on_video_end_transition,
25
+ precheck_execute_inputs,
26
+ refresh_live_obs,
27
+ restart_episode_wrapper,
28
+ show_loading_info,
29
+ switch_env_wrapper,
30
+ switch_to_action_phase,
31
+ switch_to_execute_phase,
32
+ )
33
+ from user_manager import user_manager
34
+
35
+
36
+ PHASE_INIT = "init"
37
+ PHASE_DEMO_VIDEO = "demo_video"
38
+ PHASE_ACTION_KEYPOINT = "action_keypoint"
39
+ PHASE_EXECUTION_PLAYBACK = "execution_playback"
40
+
41
+
42
+ # Deprecated: no runtime JS logic in native Gradio mode.
43
+ SYNC_JS = ""
44
+
45
+
46
+ CSS = f"""
47
+ .native-card {{
48
+ }}
49
+
50
+ #loading_overlay_group {{
51
+ position: fixed !important;
52
+ inset: 0 !important;
53
+ z-index: 9999 !important;
54
+ background: rgba(255, 255, 255, 0.92) !important;
55
+ text-align: center !important;
56
+ }}
57
+
58
+ #loading_overlay_group > div {{
59
+ min-height: 100%;
60
+ display: flex;
61
+ align-items: center;
62
+ justify-content: center;
63
+ }}
64
+
65
+ #loading_overlay_group h3 {{
66
+ margin: 0 !important;
67
+ }}
68
+
69
+ button#reference_action_btn:not(:disabled),
70
+ #reference_action_btn:not(:disabled),
71
+ #reference_action_btn button:not(:disabled) {{
72
+ background: #1f8b4c !important;
73
+ border-color: #1f8b4c !important;
74
+ color: #ffffff !important;
75
+ }}
76
+
77
+ button#reference_action_btn:not(:disabled):hover,
78
+ #reference_action_btn:not(:disabled):hover,
79
+ #reference_action_btn button:not(:disabled):hover {{
80
+ background: #19713d !important;
81
+ border-color: #19713d !important;
82
+ }}
83
+ """
84
+
85
+
86
+ def extract_first_goal(goal_text):
87
+ """Extract first goal from goal text that may be a list representation."""
88
+ if not goal_text:
89
+ return ""
90
+ text = goal_text.strip()
91
+ if text.startswith("[") and text.endswith("]"):
92
+ try:
93
+ goals = ast.literal_eval(text)
94
+ if isinstance(goals, list) and goals:
95
+ return str(goals[0]).strip()
96
+ except Exception:
97
+ pass
98
+ return text.split("\n")[0].strip()
99
+
100
+
101
+ def _phase_from_updates(main_interface_update, video_phase_update):
102
+ if isinstance(main_interface_update, dict) and main_interface_update.get("visible") is False:
103
+ return PHASE_INIT
104
+ if isinstance(video_phase_update, dict) and video_phase_update.get("visible") is True:
105
+ return PHASE_DEMO_VIDEO
106
+ return PHASE_ACTION_KEYPOINT
107
+
108
+
109
+ def _with_phase_from_load(load_result):
110
+ phase = _phase_from_updates(load_result[1], load_result[13])
111
+ return (*load_result, phase)
112
+
113
+
114
+ def create_ui_blocks():
115
+ """构建 Gradio Blocks,并完成页面阶段状态(phase)的联动绑定。"""
116
+
117
+ def render_header_task(task_text):
118
+ clean_task = str(task_text or "").strip()
119
+ if not clean_task:
120
+ return None
121
+ if clean_task.lower().startswith("current task:"):
122
+ clean_task = clean_task.split(":", 1)[1].strip()
123
+ marker = " (Episode "
124
+ if marker in clean_task:
125
+ clean_task = clean_task.split(marker, 1)[0].strip()
126
+ return " ".join(clean_task.splitlines()).strip() or None
127
+
128
+ def render_header_goal(goal_text):
129
+ first_goal = extract_first_goal(goal_text or "")
130
+ return first_goal if first_goal else "—"
131
+
132
+ with gr.Blocks(title="Oracle Planner Interface") as demo:
133
+ demo.theme = gr.themes.Soft()
134
+ demo.css = CSS
135
+
136
+ gr.Markdown("## RoboMME Human Evaluation", elem_id="header_title")
137
+ with gr.Row():
138
+ with gr.Column(scale=1):
139
+ header_task_box = gr.Dropdown(
140
+ choices=list(user_manager.env_choices),
141
+ value=render_header_task(""),
142
+ label="Current Task",
143
+ show_label=True,
144
+ interactive=True,
145
+ elem_id="header_task",
146
+ )
147
+ with gr.Column(scale=2):
148
+ header_goal_box = gr.Textbox(
149
+ value=render_header_goal(""),
150
+ label="Goal",
151
+ show_label=True,
152
+ interactive=False,
153
+ lines=1,
154
+ elem_id="header_goal",
155
+ )
156
+
157
+ with gr.Column(visible=True, elem_id="loading_overlay_group") as loading_overlay:
158
+ gr.Markdown("### Logging in and setting up environment... Please wait.")
159
+
160
+ uid_state = gr.State(value=None)
161
+ ui_phase_state = gr.State(value=PHASE_INIT)
162
+ live_obs_timer = gr.Timer(value=0.1, active=True)
163
+
164
+ task_info_box = gr.Textbox(visible=False, elem_id="task_info_box")
165
+ progress_info_box = gr.Textbox(visible=False)
166
+ goal_box = gr.Textbox(visible=False)
167
+
168
+ with gr.Column(visible=False, elem_id="main_interface_root") as main_interface:
169
+ with gr.Row(elem_id="main_layout_row"):
170
+ with gr.Column(scale=KEYPOINT_SELECTION_SCALE):
171
+ with gr.Column(elem_classes=["native-card"], elem_id="media_card"):
172
+ with gr.Column(visible=False, elem_id="video_phase_group") as video_phase_group:
173
+ video_display = gr.Video(
174
+ label="Demonstration Video",
175
+ interactive=False,
176
+ elem_id="demo_video",
177
+ autoplay=True,
178
+ show_label=True,
179
+ visible=True,
180
+ )
181
+
182
+ with gr.Column(visible=False, elem_id="action_phase_group") as action_phase_group:
183
+ img_display = gr.Image(
184
+ label="Keypoint Selection",
185
+ interactive=False,
186
+ type="pil",
187
+ elem_id="live_obs",
188
+ show_label=True,
189
+ buttons=[],
190
+ sources=[],
191
+ )
192
+
193
+ with gr.Column(scale=CONTROL_PANEL_SCALE):
194
+ with gr.Column(visible=False, elem_id="control_panel_group") as control_panel_group:
195
+ with gr.Row(elem_id="right_top_row", equal_height=False):
196
+ with gr.Column(scale=RIGHT_TOP_ACTION_SCALE, elem_id="right_action_col"):
197
+ with gr.Column(elem_classes=["native-card"], elem_id="action_selection_card"):
198
+ options_radio = gr.Radio(
199
+ choices=[],
200
+ label=" Action Selection",
201
+ type="value",
202
+ show_label=True,
203
+ elem_id="action_radio",
204
+ )
205
+ coords_box = gr.Textbox(
206
+ label="Coords",
207
+ value="",
208
+ interactive=False,
209
+ show_label=False,
210
+ visible=False,
211
+ elem_id="coords_box",
212
+ )
213
+
214
+ with gr.Column(scale=RIGHT_TOP_LOG_SCALE, elem_id="right_log_col"):
215
+ with gr.Column(elem_classes=["native-card"], elem_id="log_card"):
216
+ log_output = gr.Textbox(
217
+ value="",
218
+ lines=4,
219
+ max_lines=None,
220
+ show_label=True,
221
+ interactive=False,
222
+ elem_id="log_output",
223
+ label="System Log",
224
+ )
225
+
226
+ with gr.Row(elem_id="action_buttons_row"):
227
+ with gr.Column(elem_classes=["native-card", "native-button-card"], elem_id="exec_btn_card"):
228
+ exec_btn = gr.Button("EXECUTE", variant="stop", size="lg", elem_id="exec_btn")
229
+
230
+ with gr.Column(
231
+ elem_classes=["native-card", "native-button-card"],
232
+ elem_id="reference_btn_card",
233
+ ):
234
+ reference_action_btn = gr.Button(
235
+ "Ground Truth Action",
236
+ variant="secondary",
237
+ interactive=False,
238
+ elem_id="reference_action_btn",
239
+ )
240
+
241
+ with gr.Column(
242
+ elem_classes=["native-card", "native-button-card"],
243
+ elem_id="restart_episode_btn_card",
244
+ ):
245
+ restart_episode_btn = gr.Button(
246
+ "restart episode",
247
+ variant="secondary",
248
+ interactive=False,
249
+ elem_id="restart_episode_btn",
250
+ )
251
+
252
+ with gr.Column(
253
+ elem_classes=["native-card", "native-button-card"],
254
+ elem_id="next_task_btn_card",
255
+ ):
256
+ next_task_btn = gr.Button(
257
+ "change episode",
258
+ variant="primary",
259
+ interactive=False,
260
+ elem_id="next_task_btn",
261
+ )
262
+
263
+ with gr.Column(visible=True, elem_classes=["native-card"], elem_id="task_hint_card"):
264
+ task_hint_display = gr.Textbox(
265
+ value="",
266
+ lines=8,
267
+ max_lines=16,
268
+ show_label=True,
269
+ label="Task Hint",
270
+ interactive=True,
271
+ elem_id="task_hint_display",
272
+ )
273
+
274
+ def _normalize_env_choice(env_value, choices):
275
+ if env_value is None:
276
+ return None
277
+ env_text = str(env_value).strip()
278
+ if not env_text:
279
+ return None
280
+ lower_map = {}
281
+ for choice in choices:
282
+ choice_text = str(choice).strip()
283
+ if choice_text:
284
+ lower_map.setdefault(choice_text.lower(), choice_text)
285
+ return lower_map.get(env_text.lower(), env_text)
286
+
287
+ def _build_header_task_update(task_text, fallback_env=None):
288
+ base_choices = list(user_manager.env_choices)
289
+ parsed_env = render_header_task(task_text)
290
+ selected_env = _normalize_env_choice(parsed_env, base_choices)
291
+ if selected_env is None:
292
+ selected_env = _normalize_env_choice(fallback_env, base_choices)
293
+
294
+ choices = list(base_choices)
295
+ if selected_env and selected_env not in choices:
296
+ choices.append(selected_env)
297
+ return gr.update(choices=choices, value=selected_env)
298
+
299
+ def sync_header_from_task(task_text, goal_text):
300
+ return _build_header_task_update(task_text), render_header_goal(goal_text)
301
+
302
+ def sync_header_from_goal(goal_text, task_text, current_header_task):
303
+ return _build_header_task_update(task_text, fallback_env=current_header_task), render_header_goal(goal_text)
304
+
305
+ def init_app_with_phase(request: gr.Request):
306
+ return _with_phase_from_load(init_app(request))
307
+
308
+ def load_next_task_with_phase(uid):
309
+ return _with_phase_from_load(load_next_task_wrapper(uid))
310
+
311
+ def restart_episode_with_phase(uid):
312
+ return _with_phase_from_load(restart_episode_wrapper(uid))
313
+
314
+ def switch_env_with_phase(uid, selected_env):
315
+ return _with_phase_from_load(switch_env_wrapper(uid, selected_env))
316
+
317
+ task_info_box.change(
318
+ fn=sync_header_from_task,
319
+ inputs=[task_info_box, goal_box],
320
+ outputs=[header_task_box, header_goal_box],
321
+ )
322
+ goal_box.change(
323
+ fn=sync_header_from_goal,
324
+ inputs=[goal_box, task_info_box, header_task_box],
325
+ outputs=[header_task_box, header_goal_box],
326
+ )
327
+
328
+ header_task_box.input(fn=show_loading_info, outputs=[loading_overlay]).then(
329
+ fn=switch_env_with_phase,
330
+ inputs=[uid_state, header_task_box],
331
+ outputs=[
332
+ uid_state,
333
+ main_interface,
334
+ img_display,
335
+ log_output,
336
+ options_radio,
337
+ goal_box,
338
+ coords_box,
339
+ video_display,
340
+ task_info_box,
341
+ progress_info_box,
342
+ restart_episode_btn,
343
+ next_task_btn,
344
+ exec_btn,
345
+ video_phase_group,
346
+ action_phase_group,
347
+ control_panel_group,
348
+ task_hint_display,
349
+ loading_overlay,
350
+ reference_action_btn,
351
+ ui_phase_state,
352
+ ],
353
+ ).then(
354
+ fn=sync_header_from_task,
355
+ inputs=[task_info_box, goal_box],
356
+ outputs=[header_task_box, header_goal_box],
357
+ )
358
+
359
+ next_task_btn.click(fn=show_loading_info, outputs=[loading_overlay]).then(
360
+ fn=load_next_task_with_phase,
361
+ inputs=[uid_state],
362
+ outputs=[
363
+ uid_state,
364
+ main_interface,
365
+ img_display,
366
+ log_output,
367
+ options_radio,
368
+ goal_box,
369
+ coords_box,
370
+ video_display,
371
+ task_info_box,
372
+ progress_info_box,
373
+ restart_episode_btn,
374
+ next_task_btn,
375
+ exec_btn,
376
+ video_phase_group,
377
+ action_phase_group,
378
+ control_panel_group,
379
+ task_hint_display,
380
+ loading_overlay,
381
+ reference_action_btn,
382
+ ui_phase_state,
383
+ ],
384
+ ).then(
385
+ fn=sync_header_from_task,
386
+ inputs=[task_info_box, goal_box],
387
+ outputs=[header_task_box, header_goal_box],
388
+ )
389
+
390
+ restart_episode_btn.click(fn=show_loading_info, outputs=[loading_overlay]).then(
391
+ fn=restart_episode_with_phase,
392
+ inputs=[uid_state],
393
+ outputs=[
394
+ uid_state,
395
+ main_interface,
396
+ img_display,
397
+ log_output,
398
+ options_radio,
399
+ goal_box,
400
+ coords_box,
401
+ video_display,
402
+ task_info_box,
403
+ progress_info_box,
404
+ restart_episode_btn,
405
+ next_task_btn,
406
+ exec_btn,
407
+ video_phase_group,
408
+ action_phase_group,
409
+ control_panel_group,
410
+ task_hint_display,
411
+ loading_overlay,
412
+ reference_action_btn,
413
+ ui_phase_state,
414
+ ],
415
+ ).then(
416
+ fn=sync_header_from_task,
417
+ inputs=[task_info_box, goal_box],
418
+ outputs=[header_task_box, header_goal_box],
419
+ )
420
+
421
+ video_display.end(
422
+ fn=on_video_end_transition,
423
+ inputs=[uid_state],
424
+ outputs=[video_phase_group, action_phase_group, control_panel_group, log_output],
425
+ queue=False,
426
+ show_progress="hidden",
427
+ ).then(
428
+ fn=lambda: PHASE_ACTION_KEYPOINT,
429
+ outputs=[ui_phase_state],
430
+ queue=False,
431
+ show_progress="hidden",
432
+ )
433
+ video_display.stop(
434
+ fn=on_video_end_transition,
435
+ inputs=[uid_state],
436
+ outputs=[video_phase_group, action_phase_group, control_panel_group, log_output],
437
+ queue=False,
438
+ show_progress="hidden",
439
+ ).then(
440
+ fn=lambda: PHASE_ACTION_KEYPOINT,
441
+ outputs=[ui_phase_state],
442
+ queue=False,
443
+ show_progress="hidden",
444
+ )
445
+
446
+ img_display.select(
447
+ fn=on_map_click,
448
+ inputs=[uid_state, options_radio],
449
+ outputs=[img_display, coords_box],
450
+ )
451
+
452
+ options_radio.change(
453
+ fn=on_option_select,
454
+ inputs=[uid_state, options_radio, coords_box],
455
+ outputs=[coords_box, img_display],
456
+ )
457
+
458
+ reference_action_btn.click(
459
+ fn=on_reference_action,
460
+ inputs=[uid_state],
461
+ outputs=[img_display, options_radio, coords_box, log_output],
462
+ )
463
+
464
+ exec_btn.click(
465
+ fn=precheck_execute_inputs,
466
+ inputs=[uid_state, options_radio, coords_box],
467
+ outputs=[],
468
+ show_progress="hidden",
469
+ ).then(
470
+ fn=switch_to_execute_phase,
471
+ inputs=[uid_state],
472
+ outputs=[
473
+ options_radio,
474
+ exec_btn,
475
+ restart_episode_btn,
476
+ next_task_btn,
477
+ img_display,
478
+ reference_action_btn,
479
+ ],
480
+ show_progress="hidden",
481
+ ).then(
482
+ fn=lambda: PHASE_EXECUTION_PLAYBACK,
483
+ outputs=[ui_phase_state],
484
+ show_progress="hidden",
485
+ ).then(
486
+ fn=execute_step,
487
+ inputs=[uid_state, options_radio, coords_box],
488
+ outputs=[img_display, log_output, task_info_box, progress_info_box, restart_episode_btn, next_task_btn, exec_btn],
489
+ show_progress="hidden",
490
+ ).then(
491
+ fn=switch_to_action_phase,
492
+ inputs=[uid_state],
493
+ outputs=[
494
+ options_radio,
495
+ exec_btn,
496
+ restart_episode_btn,
497
+ next_task_btn,
498
+ img_display,
499
+ reference_action_btn,
500
+ ],
501
+ show_progress="hidden",
502
+ ).then(
503
+ fn=lambda: PHASE_ACTION_KEYPOINT,
504
+ outputs=[ui_phase_state],
505
+ show_progress="hidden",
506
+ )
507
+
508
+ live_obs_timer.tick(
509
+ fn=refresh_live_obs,
510
+ inputs=[uid_state, ui_phase_state],
511
+ outputs=[img_display],
512
+ queue=False,
513
+ show_progress="hidden",
514
+ )
515
+
516
+ demo.load(
517
+ fn=init_app_with_phase,
518
+ inputs=[],
519
+ outputs=[
520
+ uid_state,
521
+ main_interface,
522
+ img_display,
523
+ log_output,
524
+ options_radio,
525
+ goal_box,
526
+ coords_box,
527
+ video_display,
528
+ task_info_box,
529
+ progress_info_box,
530
+ restart_episode_btn,
531
+ next_task_btn,
532
+ exec_btn,
533
+ video_phase_group,
534
+ action_phase_group,
535
+ control_panel_group,
536
+ task_hint_display,
537
+ loading_overlay,
538
+ reference_action_btn,
539
+ ui_phase_state,
540
+ ],
541
+ ).then(
542
+ fn=sync_header_from_task,
543
+ inputs=[task_info_box, goal_box],
544
+ outputs=[header_task_box, header_goal_box],
545
+ )
546
+
547
+ return demo
gradio-web/user_manager.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import threading
5
+ from pathlib import Path
6
+
7
+ from state_manager import clear_task_start_time, get_task_start_time
8
+
9
+
10
+ METADATA_FILE_GLOB = "record_dataset_*_metadata.json"
11
+
12
+
13
+ class UserManager:
14
+ def __init__(self):
15
+ self.base_dir = Path(__file__).resolve().parent
16
+ self.lock = threading.Lock()
17
+
18
+ self.env_to_episodes = self._load_env_episode_pool()
19
+ self.env_choices = sorted(self.env_to_episodes.keys())
20
+
21
+ # Session-local progress only (no disk persistence)
22
+ self.session_progress = {}
23
+
24
+ def _resolve_metadata_root(self) -> Path:
25
+ env_root = os.environ.get("ROBOMME_METADATA_ROOT")
26
+ if env_root:
27
+ return Path(env_root)
28
+ return self.base_dir.parent / "src" / "robomme" / "env_metadata" / "train"
29
+
30
+ def _load_env_episode_pool(self):
31
+ env_to_episode_set = {}
32
+ metadata_root = self._resolve_metadata_root()
33
+ if not metadata_root.exists():
34
+ print(f"Warning: metadata root not found: {metadata_root}")
35
+ return {}
36
+
37
+ for metadata_path in sorted(metadata_root.glob(METADATA_FILE_GLOB)):
38
+ try:
39
+ payload = json.loads(metadata_path.read_text(encoding="utf-8"))
40
+ except Exception as exc:
41
+ print(f"Warning: failed to read metadata file {metadata_path}: {exc}")
42
+ continue
43
+
44
+ fallback_env = str(payload.get("env_id") or "").strip()
45
+ for record in payload.get("records", []):
46
+ env_id = str(record.get("task") or fallback_env or "").strip()
47
+ episode = record.get("episode")
48
+ if not env_id or episode is None:
49
+ continue
50
+ try:
51
+ episode_idx = int(episode)
52
+ except (TypeError, ValueError):
53
+ continue
54
+ env_to_episode_set.setdefault(env_id, set()).add(episode_idx)
55
+
56
+ env_to_episodes = {
57
+ env_id: sorted(episodes)
58
+ for env_id, episodes in env_to_episode_set.items()
59
+ if episodes
60
+ }
61
+ print(f"Loaded random env pool: {len(env_to_episodes)} envs from metadata root {metadata_root}")
62
+ return env_to_episodes
63
+
64
+ def _ensure_session_entry(self, uid):
65
+ if uid not in self.session_progress:
66
+ self.session_progress[uid] = {
67
+ "completed_count": 0,
68
+ "current_env_id": None,
69
+ "current_episode_idx": None,
70
+ }
71
+
72
+ def _set_current_random_task(self, uid, preferred_env=None):
73
+ if not self.env_choices:
74
+ return False
75
+ self._ensure_session_entry(uid)
76
+
77
+ env_id = preferred_env if preferred_env in self.env_to_episodes else random.choice(self.env_choices)
78
+ episodes = self.env_to_episodes.get(env_id, [])
79
+ if not episodes:
80
+ return False
81
+
82
+ episode_idx = int(random.choice(episodes))
83
+ self.session_progress[uid]["current_env_id"] = env_id
84
+ self.session_progress[uid]["current_episode_idx"] = episode_idx
85
+ return True
86
+
87
+ def init_session(self, uid):
88
+ if not uid:
89
+ return False, "Session uid cannot be empty", None
90
+ if not self.env_choices:
91
+ return False, "No available environments found in metadata.", None
92
+
93
+ with self.lock:
94
+ self._ensure_session_entry(uid)
95
+ progress = self.session_progress[uid]
96
+ if progress.get("current_env_id") is None or progress.get("current_episode_idx") is None:
97
+ if not self._set_current_random_task(uid):
98
+ return False, "Failed to assign random task from metadata.", None
99
+
100
+ return True, "Session initialized", self.get_session_status(uid)
101
+
102
+ def get_session_status(self, uid):
103
+ if not uid:
104
+ return None
105
+
106
+ with self.lock:
107
+ self._ensure_session_entry(uid)
108
+ progress = self.session_progress[uid]
109
+ if (
110
+ (progress.get("current_env_id") is None or progress.get("current_episode_idx") is None)
111
+ and self.env_choices
112
+ ):
113
+ self._set_current_random_task(uid)
114
+ progress = self.session_progress[uid]
115
+
116
+ current_task = None
117
+ if progress.get("current_env_id") is not None and progress.get("current_episode_idx") is not None:
118
+ current_task = {
119
+ "env_id": progress["current_env_id"],
120
+ "episode_idx": int(progress["current_episode_idx"]),
121
+ }
122
+
123
+ completed_count = int(progress.get("completed_count", 0))
124
+
125
+ return {
126
+ "uid": uid,
127
+ "total_tasks": len(self.env_choices), # compatibility only
128
+ "current_index": completed_count, # compatibility only
129
+ "completed_count": completed_count,
130
+ "current_task": current_task,
131
+ "is_done_all": False,
132
+ "tasks": [], # compatibility only
133
+ "env_choices": list(self.env_choices),
134
+ }
135
+
136
+ def complete_current_task(self, uid, env_id=None, episode_idx=None, **_kwargs):
137
+ if not uid:
138
+ return None
139
+
140
+ with self.lock:
141
+ self._ensure_session_entry(uid)
142
+ self.session_progress[uid]["completed_count"] = int(self.session_progress[uid]["completed_count"]) + 1
143
+
144
+ if env_id is not None and episode_idx is not None:
145
+ _ = get_task_start_time(uid, env_id, episode_idx)
146
+ clear_task_start_time(uid, env_id, episode_idx)
147
+
148
+ return self.get_session_status(uid)
149
+
150
+ def switch_env_and_random_episode(self, uid, env_id):
151
+ if not uid or env_id not in self.env_to_episodes:
152
+ return None
153
+
154
+ with self.lock:
155
+ self._ensure_session_entry(uid)
156
+ if not self._set_current_random_task(uid, preferred_env=env_id):
157
+ return None
158
+
159
+ return self.get_session_status(uid)
160
+
161
+ def next_episode_same_env(self, uid):
162
+ if not uid:
163
+ return None
164
+
165
+ with self.lock:
166
+ self._ensure_session_entry(uid)
167
+ current_env = self.session_progress[uid].get("current_env_id")
168
+ if current_env not in self.env_to_episodes:
169
+ if not self._set_current_random_task(uid):
170
+ return None
171
+ else:
172
+ if not self._set_current_random_task(uid, preferred_env=current_env):
173
+ return None
174
+
175
+ return self.get_session_status(uid)
176
+
177
+
178
+ user_manager = UserManager()
gradio-web/verify_video_names.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 验证和修复视频文件名,确保与 env_id 正确对应
4
+
5
+ 注意:该脚本是离线校验工具,不参与当前 Gradio 运行时任务分配逻辑。
6
+ """
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ def get_all_env_ids():
12
+ """从 user_tasks_og.json 获取所有唯一的 env_id"""
13
+ tasks_file = 'user_tasks_og.json'
14
+ with open(tasks_file, 'r') as f:
15
+ data = json.load(f)
16
+
17
+ env_ids = set()
18
+ for tasks in data.values():
19
+ for task in tasks:
20
+ env_ids.add(task['env_id'])
21
+ return sorted(env_ids)
22
+
23
+ def verify_video_files(videos_dir='videos'):
24
+ """验证视频文件名是否正确对应"""
25
+ env_ids = get_all_env_ids()
26
+ videos_path = Path(videos_dir)
27
+
28
+ if not videos_path.exists():
29
+ print(f"错误: 目录 {videos_dir} 不存在")
30
+ return
31
+
32
+ # 获取现有文件
33
+ existing_files = {f.name.lower(): f for f in videos_path.glob('*.mp4')}
34
+
35
+ print("=" * 80)
36
+ print("视频文件名验证结果")
37
+ print("=" * 80)
38
+ print(f"{'Env ID':<25} {'期望文件名':<35} {'状态':<10}")
39
+ print("-" * 80)
40
+
41
+ correct_files = []
42
+ missing_files = []
43
+ incorrect_files = []
44
+
45
+ for env_id in env_ids:
46
+ expected_filename = env_id.lower() + '.mp4'
47
+ expected_lower = expected_filename.lower()
48
+
49
+ if expected_lower in existing_files:
50
+ actual_file = existing_files[expected_lower]
51
+ if actual_file.name == expected_filename:
52
+ status = "✓ 正确"
53
+ correct_files.append((env_id, expected_filename))
54
+ else:
55
+ status = f"⚠ 大小写不匹配: {actual_file.name}"
56
+ incorrect_files.append((env_id, expected_filename, actual_file.name))
57
+ else:
58
+ status = "✗ 缺失"
59
+ missing_files.append((env_id, expected_filename))
60
+
61
+ print(f"{env_id:<25} {expected_filename:<35} {status:<10}")
62
+
63
+ print("=" * 80)
64
+ print(f"\n总结:")
65
+ print(f" ✓ 正确匹配: {len(correct_files)} 个")
66
+ print(f" ✗ 缺失文件: {len(missing_files)} 个")
67
+ print(f" ⚠ 需要修复: {len(incorrect_files)} 个")
68
+
69
+ if incorrect_files:
70
+ print(f"\n需要重命名的文件:")
71
+ for env_id, expected, actual in incorrect_files:
72
+ print(f" {actual} -> {expected}")
73
+
74
+ if missing_files:
75
+ print(f"\n缺失的视频文件 (这些 env_id 没有对应的视频):")
76
+ for env_id, expected in missing_files:
77
+ print(f" {env_id} -> {expected}")
78
+
79
+ return correct_files, missing_files, incorrect_files
80
+
81
+ def fix_incorrect_names(videos_dir='videos', dry_run=True):
82
+ """修复不正确的文件名"""
83
+ env_ids = get_all_env_ids()
84
+ videos_path = Path(videos_dir)
85
+
86
+ if not videos_path.exists():
87
+ print(f"错误: 目录 {videos_dir} 不存在")
88
+ return
89
+
90
+ existing_files = {f.name.lower(): f for f in videos_path.glob('*.mp4')}
91
+
92
+ fixed = []
93
+ for env_id in env_ids:
94
+ expected_filename = env_id.lower() + '.mp4'
95
+ expected_lower = expected_filename.lower()
96
+
97
+ if expected_lower in existing_files:
98
+ actual_file = existing_files[expected_lower]
99
+ if actual_file.name != expected_filename:
100
+ # 文件名大小写不匹配,需要重命名
101
+ new_path = actual_file.parent / expected_filename
102
+ if dry_run:
103
+ print(f"[DRY RUN] 将重命名: {actual_file.name} -> {expected_filename}")
104
+ else:
105
+ try:
106
+ actual_file.rename(new_path)
107
+ print(f"✓ 已重命名: {actual_file.name} -> {expected_filename}")
108
+ fixed.append((actual_file.name, expected_filename))
109
+ except Exception as e:
110
+ print(f"✗ 重命名失败 {actual_file.name}: {e}")
111
+ fixed.append((actual_file.name, expected_filename))
112
+
113
+ if not fixed:
114
+ print("没有需要修复的文件名")
115
+ elif dry_run:
116
+ print(f"\n[DRY RUN 模式] 共 {len(fixed)} 个文件需要重命名")
117
+ print("运行时不加 --dry-run 参数以执行实际重命名")
118
+
119
+ return fixed
120
+
121
+ if __name__ == '__main__':
122
+ import sys
123
+
124
+ if '--fix' in sys.argv:
125
+ dry_run = '--dry-run' in sys.argv or '--fix' not in sys.argv
126
+ fix_incorrect_names(dry_run=dry_run)
127
+ else:
128
+ verify_video_files()
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "robomme"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "readme.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "mani-skill",
9
+ "opencv-python>=4.11.0.86",
10
+ "setuptools==80.9.0",
11
+ "torch==2.9.1",
12
+ "torchvision==0.24.1",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ dev = ["opencv-python", "pytest"]
17
+
18
+ [tool.uv.sources]
19
+ mani-skill = { git = "https://github.com/YinpeiDai/ManiSkill.git", rev = "07be6fbc66350ddca200abfb0a11b692f078f7fd" }
20
+
21
+ [build-system]
22
+ requires = ["hatchling"]
23
+ build-backend = "hatchling.build"
24
+
25
+ [tool.hatch.build.targets.wheel]
26
+ packages = ["src/robomme"]
27
+
28
+ [tool.pytest.ini_options]
29
+ markers = [
30
+ "slow: slow-running tests",
31
+ "gpu: tests requiring GPU/display/headless rendering stack",
32
+ "dataset: tests that generate/use temporary datasets",
33
+ "lightweight: tests that do not require generated dataset",
34
+ ]
readme.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RoboMME: A Robotic Benchmark for Memory-Augmented Manipulation
2
+
3
+ ![Robomme bench](assets/robomme_bench.jpg)
4
+
5
+ ## 📢 Announcements
6
+
7
+ [03/2026] We are thrilled to release RoboMME, the first large-scale robotic benchmark dedicated to memory-augmented manipulation! Spanning 4 cognitively motivated task suites with 16 carefully designed tasks, RoboMME pushes robots to remember, reason, and act.
8
+
9
+ ## 📦 Installation
10
+
11
+ After cloning the repo, install [uv](https://docs.astral.sh/uv/getting-started/installation/), then run:
12
+
13
+ ```bash
14
+ uv sync
15
+ uv pip install -e .
16
+ ```
17
+
18
+ ## 🚀 Quick Start
19
+
20
+ Start an environment with a specified setup:
21
+
22
+ ```bash
23
+ uv run scripts/run_example.py
24
+ ```
25
+
26
+ This generates a rollout video in the `sample_run_videos` directory.
27
+
28
+ We provide four action types: `joint_action`, `ee_pose`, `waypoint`, and `multi_choice`, e.g., predict continuous actions with `joint_action` or `ee_pose`, discrete waypoint actions with `waypoint`, or use `multi_choice` for VideoQA-style problems.
29
+
30
+ ## 📁 Benchmark
31
+
32
+ ### 🤖 Tasks
33
+
34
+ We have four task suites, each with 4 tasks:
35
+
36
+ | Suite | Focus | Task ID |
37
+ | ---------- | ----------------- | --------------------------------------------------------------------- |
38
+ | Counting | Temporal memory | BinFill, PickXtimes, SwingXtimes, StopCube |
39
+ | Permanence | Spatial memory | VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap |
40
+ | Reference | Object memory | PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder |
41
+ | Imitation | Procedural memory | MoveCube, InsertPeg, PatternLock, RouteStick |
42
+
43
+ All tasks are defined in `src/robomme/robomme_env`. A detailed description can be found in our paper appendix.
44
+
45
+ ### 📥 Training Data
46
+
47
+ Training data can be downloaded [here](https://huggingface.co/datasets/Yinpei/robomme_data). There are 1,600 demonstrations in total (100 per task). The HDF5 format is described in [doc/h5_data_format.md](doc/h5_data_format.md).
48
+
49
+ After downloading, replay the dataset for a sanity check:
50
+
51
+ ```bash
52
+ uv run scripts/dataset_replay.py --h5-data-dir <your_downloaded_data_dir>
53
+ ```
54
+
55
+ ### 📊 Evaluation
56
+
57
+ To evaluate on the test set, set the `dataset` argument of `BenchmarkEnvBuilder`:
58
+
59
+ ```python
60
+ task_id = "PickXtimes"
61
+ episode_idx = 0
62
+ env_builder = BenchmarkEnvBuilder(
63
+ env_id=task_id,
64
+ dataset="test",
65
+ ...
66
+ )
67
+
68
+ env = env_builder.make_env_for_episode(episode_idx)
69
+ obs, info = env.reset() # initial step
70
+ ...
71
+ obs, _, terminated, truncated, info = env.step(action) # each step
72
+ ```
73
+ The train split has 100 episodes. The val/test splits each have 50 episodes. All seeds are fixed for benchmarking.
74
+
75
+ The environment input/output format is described in [doc/env_format.md](doc/env_format.md).
76
+
77
+ > Currently, environment spawning is set up only for imitation learning. We are working on extending it to support more general parallel environments for reinforcement learning in the future.
78
+
79
+ ### 🔧 Data Generation
80
+
81
+ You can also re-generate your own HDF5 data via parallel processing using
82
+ @hongze
83
+ ```bash
84
+ uv run scripts/dev/xxxx
85
+ ```
86
+
87
+
88
+ ## 🧠 Model Training
89
+
90
+ ### 🌟 MME-VLA-Suite
91
+
92
+ The [MME Policy Learning](https://github.com/RoboMME/robomme_policy_learning) repo provides MME-VLA model training and evaluation used in our paper. It contains a family of memory-augmented VLA models built on [pi05](https://github.com/Physical-Intelligence/openpi) backbone and our implementation of [MemER](https://jen-pan.github.io/memer/).
93
+
94
+ ### 📚 Prior Methods
95
+
96
+ **MemER**: The [MME Policy Learning](https://github.com/RoboMME/robomme_policy_learning) repo also provides our implementation of the [MemER](https://jen-pan.github.io/memer/), using the same GroundSG policy model as in MME-VLA.
97
+
98
+ **SAM2Act+**: The [RoboMME_SAM2Act](https://github.com/RoboMME/SAM2Act) repo provides our implementation adapted from the [SAM2Act](https://github.com/sam2act/sam2act) repo.
99
+
100
+ **MemoryVLA**: The [RoboMME_MemoryVLA](https://github.com/RoboMME/MemoryVLA) repo provides our implementation adapted from the [MemoryVLA](https://github.com/shihao1895/MemoryVLA) repo.
101
+
102
+ **Diffusion Policy**: The [RoboMME_DP](https://github.com/RoboMME/DP) repo provides our implementation adapted from the [diffusion_policy](https://github.com/real-stanford/diffusion_policy) repo.
103
+
104
+
105
+
106
+ ## 🏆 Submit Your Models
107
+ Want to add your model? Download the [dataset](https://huggingface.co/datasets/Yinpei/robomme_data) from Hugging Face, run evaluation using our [eval scripts](scripts/evaluation.py), then submit a PR with your results by adding `<your_model>.md` to the `doc/submission/` [directory](https://github.com/RoboMME/robomme_benchmark/tree/main/doc/submission). We will review it and update our leaderboard.
108
+
109
+
110
+ ## 🔧 Troubleshooting
111
+
112
+ **Q1: RuntimeError: Create window failed: Renderer does not support display.**
113
+
114
+ A1: Use a physical display or set up a virtual display for GUI rendering (e.g. install a VNC server and set the `DISPLAY` variable correctly).
115
+
116
+ **Q2: Failure related to Vulkan installation.**
117
+
118
+ A2: We recommend reinstalling the NVIDIA driver and Vulkan packages. We use NVIDIA driver 570.211.01 and Vulkan 1.3.275. If it still does not work, switch to CPU rendering:
119
+
120
+ ```python
121
+ os.environ['SAPIEN_RENDER_DEVICE'] = 'cpu'
122
+ os.environ['MUJOCO_GL'] = 'osmesa'
123
+ ```
124
+
125
+
126
+ ## 🙏 Acknowledgements
127
+
128
+ This work was supported in part by NSF SES-2128623, NSF CAREER #2337870, NSF NRI #2220876, NSF NAIRR250085. We would also like to thank the wonderful [OpenPi](https://github.com/Physical-Intelligence/openpi/tree/main) codebase from Physical-Intelligence.
129
+
130
+
131
+ ## 📄 Citation
132
+
133
+ ```
134
+ ...
135
+ ```
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ Pillow
4
+ opencv-python>=4.11.0.86
5
+ gymnasium
6
+ h5py
7
+ imageio
8
+ setuptools==80.9.0
9
+ torch==2.9.1
10
+ torchvision==0.24.1
11
+ mani-skill @ git+https://github.com/YinpeiDai/ManiSkill.git@07be6fbc66350ddca200abfb0a11b692f078f7fd
scripts/dataset_replay.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Replay episodes from HDF5 datasets and save rollout videos.
3
+ Loads recorded actions from record_dataset_<Task>.h5, steps the environment
4
+ """
5
+
6
+ import os
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Literal, Union
12
+
13
+ import cv2
14
+ import h5py
15
+ import imageio
16
+ import numpy as np
17
+ import torch
18
+
19
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
20
+
21
+ GUI_RENDER = False
22
+ REPLAY_VIDEO_DIR = "replay_videos"
23
+ VIDEO_FPS = 30
24
+ VIDEO_BORDER_COLOR = (255, 0, 0)
25
+ VIDEO_BORDER_THICKNESS = 10
26
+
27
+ TaskID = Literal[
28
+ # "BinFill",
29
+ # "PickXtimes",
30
+ # "SwingXtimes",
31
+ # "StopCube",
32
+ # "VideoUnmask",
33
+ "VideoUnmaskSwap",
34
+ # "ButtonUnmask",
35
+ # "ButtonUnmaskSwap",
36
+ # "PickHighlight",
37
+ # "VideoRepick",
38
+ # "VideoPlaceButton",
39
+ # "VideoPlaceOrder",
40
+ # "MoveCube",
41
+ # "InsertPeg",
42
+ # "PatternLock",
43
+ # "RouteStick",
44
+ ]
45
+
46
+
47
+ ActionSpaceType = Literal["joint_angle", "ee_pose", "waypoint", "multi_choice"]
48
+
49
+ def _to_numpy(t) -> np.ndarray:
50
+ return t.cpu().numpy() if isinstance(t, torch.Tensor) else np.asarray(t)
51
+
52
+
53
+ def _frame_from_obs(
54
+ front: np.ndarray | torch.Tensor,
55
+ wrist: np.ndarray | torch.Tensor,
56
+ is_video_demo: bool = False,
57
+ ) -> np.ndarray:
58
+ frame = np.hstack([_to_numpy(front), _to_numpy(wrist)]).astype(np.uint8)
59
+ if is_video_demo:
60
+ h, w = frame.shape[:2]
61
+ cv2.rectangle(frame, (0, 0), (w, h),
62
+ VIDEO_BORDER_COLOR, VIDEO_BORDER_THICKNESS)
63
+ return frame
64
+
65
+
66
+ def _extract_frames(obs: dict, is_video_demo_fn=None) -> list[np.ndarray]:
67
+ n = len(obs["front_rgb_list"])
68
+ return [
69
+ _frame_from_obs(
70
+ obs["front_rgb_list"][i],
71
+ obs["wrist_rgb_list"][i],
72
+ is_video_demo=(is_video_demo_fn(i) if is_video_demo_fn else False),
73
+ )
74
+ for i in range(n)
75
+ ]
76
+
77
+
78
+ def _is_video_demo(ts: h5py.Group) -> bool:
79
+ info = ts.get("info")
80
+ if info is None or "is_video_demo" not in info:
81
+ return False
82
+ return bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0])
83
+
84
+
85
+ def _is_subgoal_boundary(ts: h5py.Group) -> bool:
86
+ info = ts.get("info")
87
+ if info is None or "is_subgoal_boundary" not in info:
88
+ return False
89
+ return bool(np.reshape(np.asarray(info["is_subgoal_boundary"][()]), -1)[0])
90
+
91
+
92
+ def _decode_h5_str(raw) -> str:
93
+ """Uniformly decode bytes / numpy bytes / str from HDF5 to str."""
94
+ if isinstance(raw, np.ndarray):
95
+ raw = raw.flatten()[0]
96
+ if isinstance(raw, (bytes, np.bytes_)):
97
+ raw = raw.decode("utf-8")
98
+ return raw
99
+
100
+
101
+ def _build_action_sequence(
102
+ episode_data: h5py.Group, action_space_type: str
103
+ ) -> list[Union[np.ndarray, Dict[str, Any]]]:
104
+ """
105
+ Scan the entire episode and return the deduplicated action sequence:
106
+ - joint_angle / ee_pose: actions of all non-video-demo steps (sequential, not deduplicated)
107
+ - waypoint: remove adjacent duplicate waypoint_action (like EpisodeDatasetResolver)
108
+ - multi_choice: choice_action (JSON dict) only for steps where is_subgoal_boundary=True
109
+ """
110
+ timestep_keys = sorted(
111
+ (k for k in episode_data.keys() if k.startswith("timestep_")),
112
+ key=lambda k: int(k.split("_")[1]),
113
+ )
114
+
115
+ actions: list[Union[np.ndarray, Dict[str, Any]]] = []
116
+ prev_waypoint: np.ndarray | None = None
117
+
118
+ for key in timestep_keys:
119
+ ts = episode_data[key]
120
+ if _is_video_demo(ts):
121
+ continue
122
+
123
+ action_grp = ts.get("action")
124
+ if action_grp is None:
125
+ continue
126
+
127
+ if action_space_type == "joint_angle":
128
+ if "joint_action" not in action_grp:
129
+ continue
130
+ actions.append(np.asarray(action_grp["joint_action"][()], dtype=np.float32))
131
+
132
+ elif action_space_type == "ee_pose":
133
+ if "eef_action" not in action_grp:
134
+ continue
135
+ actions.append(np.asarray(action_grp["eef_action"][()], dtype=np.float32))
136
+
137
+ elif action_space_type == "waypoint":
138
+ if "waypoint_action" not in action_grp:
139
+ continue
140
+ wa = np.asarray(action_grp["waypoint_action"][()], dtype=np.float32).flatten()
141
+ if wa.shape != (7,) or not np.all(np.isfinite(wa)):
142
+ continue
143
+ # Remove adjacent duplicates
144
+ if prev_waypoint is None or not np.array_equal(wa, prev_waypoint):
145
+ actions.append(wa)
146
+ prev_waypoint = wa.copy()
147
+
148
+ elif action_space_type == "multi_choice":
149
+ if not _is_subgoal_boundary(ts):
150
+ continue
151
+ if "choice_action" not in action_grp:
152
+ continue
153
+ raw = _decode_h5_str(action_grp["choice_action"][()])
154
+ try:
155
+ payload = json.loads(raw)
156
+ except (TypeError, ValueError, json.JSONDecodeError):
157
+ continue
158
+ if not isinstance(payload, dict):
159
+ continue
160
+ choice = payload.get("choice")
161
+ if not isinstance(choice, str) or not choice.strip():
162
+ continue
163
+ if "point" not in payload:
164
+ continue
165
+ actions.append({"choice": choice, "point": payload.get("point")})
166
+
167
+ else:
168
+ raise ValueError(f"Unknown action space type: {action_space_type}")
169
+
170
+ return actions
171
+
172
+
173
+ def _save_video(
174
+ frames: list[np.ndarray],
175
+ task_id: str,
176
+ episode_idx: int,
177
+ task_goal: str,
178
+ outcome: str,
179
+ action_space_type: str,
180
+ ) -> Path:
181
+ video_dir = Path(REPLAY_VIDEO_DIR) / action_space_type
182
+ video_dir.mkdir(parents=True, exist_ok=True)
183
+ name = f"{outcome}_{task_id}_ep{episode_idx}_{task_goal}.mp4"
184
+ path = video_dir / name
185
+ imageio.mimsave(str(path), frames, fps=VIDEO_FPS)
186
+ return path
187
+
188
+
189
+ def _get_episode_indices(data: h5py.File) -> list[int]:
190
+ return sorted(
191
+ int(key.split("_")[1])
192
+ for key in data.keys()
193
+ if key.startswith("episode_")
194
+ )
195
+
196
+
197
+ def process_episode(
198
+ env_data: h5py.File,
199
+ episode_idx: int,
200
+ task_id: str,
201
+ action_space_type: ActionSpaceType,
202
+ ) -> None:
203
+ """Replay one episode from HDF5 data, record frames, and save a video."""
204
+ episode_data = env_data[f"episode_{episode_idx}"]
205
+ task_goal = episode_data["setup"]["task_goal"][()][0].decode()
206
+ action_sequence = _build_action_sequence(episode_data, action_space_type)
207
+
208
+ env = BenchmarkEnvBuilder(
209
+ env_id=task_id,
210
+ dataset="train",
211
+ action_space=action_space_type,
212
+ gui_render=GUI_RENDER,
213
+ ).make_env_for_episode(episode_idx)
214
+
215
+ print(f"\nTask: {task_id}, Episode: {episode_idx}, ",
216
+ f"Seed: {env.unwrapped.seed}, Difficulty: {env.unwrapped.difficulty}")
217
+ print(f"Task goal: {task_goal}")
218
+ print(f"Total actions after dedup: {len(action_sequence)}")
219
+
220
+ obs, _ = env.reset()
221
+ frames = _extract_frames(
222
+ obs, is_video_demo_fn=lambda i, n=len(obs["front_rgb_list"]): i < n - 1
223
+ )
224
+
225
+ outcome = "unknown"
226
+ for seq_idx, action in enumerate(action_sequence):
227
+ try:
228
+ obs, _, terminated, truncated, info = env.step(action)
229
+ frames.extend(_extract_frames(obs))
230
+ except Exception as e:
231
+ print(f"Error at seq_idx {seq_idx}: {e}")
232
+ break
233
+
234
+ if GUI_RENDER:
235
+ env.render()
236
+ if terminated or truncated:
237
+ outcome = info.get("status", "unknown")
238
+ print(f"Outcome: {outcome}")
239
+ break
240
+
241
+ env.close()
242
+ path = _save_video(frames, task_id, episode_idx, task_goal, outcome, action_space_type)
243
+ print(f"Saved video to {path}\n")
244
+
245
+
246
+ def replay(
247
+ h5_data_dir: str = "/data/hongzefu/data_0226",
248
+ action_space_type: ActionSpaceType = "ee_pose",
249
+ replay_number: int = 10,
250
+ ) -> None:
251
+ """Replay episodes from HDF5 dataset files and save rollout videos."""
252
+ #for task_id in BenchmarkEnvBuilder.get_task_list():
253
+ for task_id in ["VideoUnmaskSwap"]:
254
+ file_path = Path(h5_data_dir) / f"record_dataset_{task_id}.h5"
255
+
256
+ if not file_path.exists():
257
+ print(f"Skipping {task_id}: file not found: {file_path}")
258
+ continue
259
+
260
+ with h5py.File(file_path, "r") as data:
261
+ episode_indices = _get_episode_indices(data)
262
+ for episode_idx in episode_indices[:min(replay_number, len(episode_indices))]:
263
+ process_episode(data, episode_idx, task_id, action_space_type)
264
+
265
+
266
+ if __name__ == "__main__":
267
+ import tyro
268
+ tyro.cli(replay)
scripts/dev/compare_multi_choice_readers.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Compare how v3 and v4 replay pipelines read multi_choice actions.
3
+
4
+ v3 source:
5
+ - EpisodeDatasetResolver.get_step("multi_choice", step)
6
+
7
+ v4-noresolver source:
8
+ - scripts.dataset_replay._build_action_sequence(..., "multi_choice")
9
+ - then _parse_oracle_command() in replay loop
10
+ """
11
+
12
+ import argparse
13
+ import importlib.util
14
+ import json
15
+ import re
16
+ import sys
17
+ from pathlib import Path
18
+ from typing import Any, Optional
19
+
20
+ import h5py
21
+ import numpy as np
22
+
23
+ REPO_ROOT = Path(__file__).resolve().parents[2]
24
+ if str(REPO_ROOT) not in sys.path:
25
+ sys.path.insert(0, str(REPO_ROOT))
26
+ SRC_ROOT = REPO_ROOT / "src"
27
+ if str(SRC_ROOT) not in sys.path:
28
+ sys.path.insert(0, str(SRC_ROOT))
29
+
30
+
31
+ def _load_episode_dataset_resolver_cls():
32
+ resolver_path = SRC_ROOT / "robomme" / "env_record_wrapper" / "episode_dataset_resolver.py"
33
+ spec = importlib.util.spec_from_file_location(
34
+ "episode_dataset_resolver_direct",
35
+ resolver_path,
36
+ )
37
+ if spec is None or spec.loader is None:
38
+ raise RuntimeError(f"Failed to load resolver module from {resolver_path}")
39
+ module = importlib.util.module_from_spec(spec)
40
+ spec.loader.exec_module(module)
41
+ resolver_cls = getattr(module, "EpisodeDatasetResolver", None)
42
+ if resolver_cls is None:
43
+ raise RuntimeError(f"EpisodeDatasetResolver not found in {resolver_path}")
44
+ return resolver_cls
45
+
46
+
47
+ EpisodeDatasetResolver = _load_episode_dataset_resolver_cls()
48
+
49
+ DEFAULT_ENV_ID = "PatternLock"
50
+ DEFAULT_DATASET_ROOT = "/data/hongzefu/data_0226-test"
51
+
52
+
53
+ def _parse_oracle_command_v4(choice_action: Optional[Any]) -> Optional[dict[str, Any]]:
54
+ """Exact validation logic used in evaluate_dataset_replay-parallelv4-noresolver.py."""
55
+ if not isinstance(choice_action, dict):
56
+ return None
57
+ choice = choice_action.get("choice")
58
+ if not isinstance(choice, str) or not choice.strip():
59
+ return None
60
+ point = choice_action.get("point")
61
+ if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2:
62
+ return None
63
+ return choice_action
64
+
65
+
66
+ def _is_video_demo_v4(ts: h5py.Group) -> bool:
67
+ info = ts.get("info")
68
+ if info is None or "is_video_demo" not in info:
69
+ return False
70
+ return bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0])
71
+
72
+
73
+ def _is_subgoal_boundary_v4(ts: h5py.Group) -> bool:
74
+ info = ts.get("info")
75
+ if info is None or "is_subgoal_boundary" not in info:
76
+ return False
77
+ return bool(np.reshape(np.asarray(info["is_subgoal_boundary"][()]), -1)[0])
78
+
79
+
80
+ def _decode_h5_str_v4(raw: Any) -> str:
81
+ if isinstance(raw, np.ndarray):
82
+ raw = raw.flatten()[0]
83
+ if isinstance(raw, (bytes, np.bytes_)):
84
+ raw = raw.decode("utf-8")
85
+ return raw
86
+
87
+
88
+ def _build_multi_choice_sequence_v4(episode_data: h5py.Group) -> list[Any]:
89
+ """
90
+ Re-implementation of dataset_replay._build_action_sequence(..., \"multi_choice\")
91
+ without importing cv2/imageio/torch dependencies.
92
+ """
93
+ timestep_keys = sorted(
94
+ (k for k in episode_data.keys() if k.startswith("timestep_")),
95
+ key=lambda k: int(k.split("_")[1]),
96
+ )
97
+
98
+ out: list[Any] = []
99
+ for key in timestep_keys:
100
+ ts = episode_data[key]
101
+ if _is_video_demo_v4(ts):
102
+ continue
103
+
104
+ action_grp = ts.get("action")
105
+ if action_grp is None:
106
+ continue
107
+ if not _is_subgoal_boundary_v4(ts):
108
+ continue
109
+ if "choice_action" not in action_grp:
110
+ continue
111
+
112
+ raw = _decode_h5_str_v4(action_grp["choice_action"][()])
113
+ try:
114
+ out.append(json.loads(raw))
115
+ except (TypeError, ValueError, json.JSONDecodeError):
116
+ continue
117
+ return out
118
+
119
+
120
+ def _resolve_h5_path(env_id: str, dataset_root: Optional[str], h5_path: Optional[str]) -> Path:
121
+ if h5_path:
122
+ return Path(h5_path)
123
+ if not dataset_root:
124
+ raise ValueError("Either --h5_path or --dataset_root must be provided")
125
+ return Path(dataset_root) / f"record_dataset_{env_id}.h5"
126
+
127
+
128
+ def _episode_indices(data: h5py.File) -> list[int]:
129
+ return sorted(
130
+ int(m.group(1))
131
+ for key in data.keys()
132
+ for m in [re.match(r"episode_(\d+)$", key)]
133
+ if m
134
+ )
135
+
136
+
137
+ def _parse_episode_filter(raw: Optional[str], all_eps: list[int]) -> list[int]:
138
+ if not raw:
139
+ return all_eps
140
+
141
+ selected: set[int] = set()
142
+ for token in [x.strip() for x in raw.split(",") if x.strip()]:
143
+ if "-" in token:
144
+ lo_s, hi_s = token.split("-", 1)
145
+ lo = int(lo_s)
146
+ hi = int(hi_s)
147
+ if lo > hi:
148
+ lo, hi = hi, lo
149
+ selected.update(range(lo, hi + 1))
150
+ else:
151
+ selected.add(int(token))
152
+
153
+ return [ep for ep in all_eps if ep in selected]
154
+
155
+
156
+ def _canonical_command(cmd: Any) -> str:
157
+ """Stable string form for diffing and readable output."""
158
+ try:
159
+ return json.dumps(cmd, ensure_ascii=False, sort_keys=True)
160
+ except TypeError:
161
+ if isinstance(cmd, dict):
162
+ safe = {
163
+ str(k): (v.tolist() if isinstance(v, np.ndarray) else v)
164
+ for k, v in cmd.items()
165
+ }
166
+ return json.dumps(safe, ensure_ascii=False, sort_keys=True)
167
+ return repr(cmd)
168
+
169
+
170
+ def _read_v4_commands(episode_group: h5py.Group) -> tuple[list[Any], list[dict[str, Any]], int]:
171
+ raw_list = _build_multi_choice_sequence_v4(episode_group)
172
+ parsed_list: list[dict[str, Any]] = []
173
+ skipped = 0
174
+
175
+ for item in raw_list:
176
+ parsed = _parse_oracle_command_v4(item)
177
+ if parsed is None:
178
+ skipped += 1
179
+ continue
180
+ parsed_list.append(parsed)
181
+
182
+ return raw_list, parsed_list, skipped
183
+
184
+
185
+ def _read_v3_commands(env_id: str, episode: int, dataset_ref: str) -> list[dict[str, Any]]:
186
+ out: list[dict[str, Any]] = []
187
+ with EpisodeDatasetResolver(
188
+ env_id=env_id,
189
+ episode=episode,
190
+ dataset_directory=dataset_ref,
191
+ ) as resolver:
192
+ step = 0
193
+ while True:
194
+ cmd = resolver.get_step("multi_choice", step)
195
+ if cmd is None:
196
+ break
197
+ if isinstance(cmd, dict):
198
+ out.append(cmd)
199
+ step += 1
200
+ return out
201
+
202
+
203
+ def compare_episode(
204
+ env_id: str,
205
+ episode: int,
206
+ episode_group: h5py.Group,
207
+ dataset_ref: str,
208
+ max_show: int,
209
+ ) -> None:
210
+ v4_raw, v4_effective, v4_skipped = _read_v4_commands(episode_group)
211
+ v3_resolver = _read_v3_commands(env_id=env_id, episode=episode, dataset_ref=dataset_ref)
212
+
213
+ print(f"\n=== episode_{episode} ===")
214
+ print(
215
+ "counts: "
216
+ f"v4_raw={len(v4_raw)}, "
217
+ f"v4_effective={len(v4_effective)} (skipped_by_parse={v4_skipped}), "
218
+ f"v3_resolver={len(v3_resolver)}"
219
+ )
220
+
221
+ v4_effective_c = [_canonical_command(x) for x in v4_effective]
222
+ v3_c = [_canonical_command(x) for x in v3_resolver]
223
+
224
+ if v4_effective_c == v3_c:
225
+ print("effective sequence compare: SAME")
226
+ else:
227
+ print("effective sequence compare: DIFFERENT")
228
+ max_len = max(len(v4_effective_c), len(v3_c))
229
+ shown = 0
230
+ for idx in range(max_len):
231
+ left = v4_effective_c[idx] if idx < len(v4_effective_c) else "<MISSING>"
232
+ right = v3_c[idx] if idx < len(v3_c) else "<MISSING>"
233
+ if left == right:
234
+ continue
235
+ print(f" idx={idx}")
236
+ print(f" v4_effective: {left}")
237
+ print(f" v3_resolver : {right}")
238
+ shown += 1
239
+ if shown >= max_show:
240
+ remaining = max_len - idx - 1
241
+ if remaining > 0:
242
+ print(f" ... more differences omitted ({remaining} remaining positions)")
243
+ break
244
+
245
+ print(f"sample v4_raw (first {max_show}):")
246
+ for i, item in enumerate(v4_raw[:max_show]):
247
+ print(f" [{i}] {_canonical_command(item)}")
248
+
249
+ print(f"sample v4_effective (first {max_show}):")
250
+ for i, item in enumerate(v4_effective[:max_show]):
251
+ print(f" [{i}] {_canonical_command(item)}")
252
+
253
+ print(f"sample v3_resolver (first {max_show}):")
254
+ for i, item in enumerate(v3_resolver[:max_show]):
255
+ print(f" [{i}] {_canonical_command(item)}")
256
+
257
+
258
+ def main() -> None:
259
+ parser = argparse.ArgumentParser(
260
+ description=(
261
+ "Compare multi_choice read results between "
262
+ "evaluate_dataset_replay-parallelv3 and parallelv4-noresolver."
263
+ )
264
+ )
265
+ parser.add_argument(
266
+ "--env_id",
267
+ type=str,
268
+ default=DEFAULT_ENV_ID,
269
+ help=f"Task/env id. Default: {DEFAULT_ENV_ID}",
270
+ )
271
+ parser.add_argument(
272
+ "--dataset_root",
273
+ type=str,
274
+ default=DEFAULT_DATASET_ROOT,
275
+ help=(
276
+ "Directory that contains record_dataset_<env_id>.h5. "
277
+ f"Default: {DEFAULT_DATASET_ROOT}"
278
+ ),
279
+ )
280
+ parser.add_argument(
281
+ "--h5_path",
282
+ type=str,
283
+ default=None,
284
+ help="Direct path to .h5 file (overrides --dataset_root)",
285
+ )
286
+ parser.add_argument(
287
+ "--episodes",
288
+ type=str,
289
+ default=0,
290
+ help="Episode filter, e.g. '0,3,8-10'. Default: all episodes in h5",
291
+ )
292
+ parser.add_argument(
293
+ "--max_show",
294
+ type=int,
295
+ default=50,
296
+ help="Max number of diff/sample rows per episode",
297
+ )
298
+ args = parser.parse_args()
299
+
300
+ h5_file = _resolve_h5_path(args.env_id, args.dataset_root, args.h5_path)
301
+ if not h5_file.exists():
302
+ raise FileNotFoundError(f"h5 file not found: {h5_file}")
303
+
304
+ dataset_ref = str(h5_file) if h5_file.suffix == ".h5" else str(h5_file.parent)
305
+
306
+ print(f"env_id={args.env_id}")
307
+ print(f"h5={h5_file}")
308
+
309
+ with h5py.File(h5_file, "r") as data:
310
+ all_eps = _episode_indices(data)
311
+ selected_eps = _parse_episode_filter(args.episodes, all_eps)
312
+
313
+ if not selected_eps:
314
+ print("No episodes selected.")
315
+ return
316
+
317
+ print(f"episodes={selected_eps}")
318
+ for ep in selected_eps:
319
+ key = f"episode_{ep}"
320
+ if key not in data:
321
+ print(f"\n=== episode_{ep} ===")
322
+ print("missing in h5, skip")
323
+ continue
324
+ compare_episode(
325
+ env_id=args.env_id,
326
+ episode=ep,
327
+ episode_group=data[key],
328
+ dataset_ref=dataset_ref,
329
+ max_show=args.max_show,
330
+ )
331
+
332
+
333
+ if __name__ == "__main__":
334
+ main()
scripts/dev/dataset_replay_printType.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Script function: Unified dataset replay entry point, supporting 4 action spaces: joint_angle / ee_pose / waypoint / multi_choice.
3
+ # Consistent with subgoal_evaluate_func.py main loop; difference is actions come from EpisodeDatasetResolver.
4
+
5
+ import os
6
+ from typing import Any, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from robomme.robomme_env import *
12
+ from robomme.robomme_env.utils import *
13
+ from robomme.env_record_wrapper import (
14
+ BenchmarkEnvBuilder,
15
+ EpisodeDatasetResolver,
16
+ )
17
+
18
+ # Only enable one ACTION_SPACE; others are commented out for manual switching
19
+ #ACTION_SPACE = "joint_angle"
20
+ ACTION_SPACE = "waypoint"
21
+
22
+
23
+ GUI_RENDER = False
24
+
25
+ DATASET_ROOT = "/data/hongzefu/data_0225"
26
+
27
+ DEFAULT_ENV_IDS = [
28
+ # "PickXtimes",
29
+ # "StopCube",
30
+ # "SwingXtimes",
31
+ #"BinFill",
32
+ "VideoUnmaskSwap",
33
+ # "VideoUnmask",
34
+ # "ButtonUnmaskSwap",
35
+ # "ButtonUnmask",
36
+ #"VideoRepick",
37
+ # "VideoPlaceButton",
38
+ # "VideoPlaceOrder",
39
+ # "PickHighlight",
40
+ # "InsertPeg",
41
+ # "MoveCube",
42
+ # "PatternLock",
43
+ # "RouteStick",
44
+ ]
45
+
46
+ MAX_STEPS = 1000
47
+
48
+
49
+ def _describe(value: Any, indent: int = 0) -> str:
50
+ """Recursively describe a value's type, shape, and content summary."""
51
+ prefix = " " * indent
52
+ if isinstance(value, torch.Tensor):
53
+ return f"{prefix}Tensor dtype={value.dtype} shape={tuple(value.shape)} device={value.device}"
54
+ elif isinstance(value, np.ndarray):
55
+ return f"{prefix}ndarray dtype={value.dtype} shape={value.shape}"
56
+ elif isinstance(value, list):
57
+ if len(value) == 0:
58
+ return f"{prefix}list[] (empty)"
59
+ lines = [f"{prefix}list[{len(value)}]"]
60
+ for i, item in enumerate(value):
61
+ lines.append(f"{prefix} [{i}]: {_describe(item, 0)}")
62
+ if i >= 2:
63
+ lines.append(f"{prefix} ... (only first 3 shown)")
64
+ break
65
+ return "\n".join(lines)
66
+ elif isinstance(value, dict):
67
+ lines = [f"{prefix}dict keys={list(value.keys())}"]
68
+ for k, v in value.items():
69
+ lines.append(f"{prefix} '{k}': {_describe(v, 0)}")
70
+ return "\n".join(lines)
71
+ elif isinstance(value, (int, float, bool, str)):
72
+ return f"{prefix}{type(value).__name__} value={repr(value)}"
73
+ elif value is None:
74
+ return f"{prefix}None"
75
+ else:
76
+ return f"{prefix}{type(value).__name__} repr={repr(value)[:80]}"
77
+
78
+
79
+ def _print_obs(obs: dict, tag: str):
80
+ """Print data formats of all fields in the obs dict."""
81
+ print(f"\n{'='*60}")
82
+ print(f"[{tag}] obs fields:")
83
+ print(f"{'='*60}")
84
+ # maniskill_obs not printed (data volume is large)
85
+ _ = obs["maniskill_obs"]
86
+ front_rgb_list = obs["front_rgb_list"]
87
+ wrist_rgb_list = obs["wrist_rgb_list"]
88
+ front_depth_list = obs["front_depth_list"]
89
+ wrist_depth_list = obs["wrist_depth_list"]
90
+ end_effector_pose_raw = obs["end_effector_pose_raw"]
91
+ eef_state_list = obs["eef_state_list"]
92
+ joint_state_list = obs["joint_state_list"]
93
+
94
+ gripper_state_list = obs["gripper_state_list"]
95
+ front_camera_extrinsic_list = obs["front_camera_extrinsic_list"]
96
+ wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"]
97
+
98
+ fields = {
99
+ "front_rgb_list": front_rgb_list,
100
+ "wrist_rgb_list": wrist_rgb_list,
101
+ "front_depth_list": front_depth_list,
102
+ "wrist_depth_list": wrist_depth_list,
103
+ "end_effector_pose_raw": end_effector_pose_raw,
104
+ "eef_state_list": eef_state_list,
105
+ "joint_state_list": joint_state_list,
106
+
107
+ "gripper_state_list": gripper_state_list,
108
+ "front_camera_extrinsic_list": front_camera_extrinsic_list,
109
+ "wrist_camera_extrinsic_list": wrist_camera_extrinsic_list,
110
+ }
111
+ for name, val in fields.items():
112
+ print(f" obs['{name}']:")
113
+ print(_describe(val, indent=2))
114
+ return fields
115
+
116
+
117
+ def _print_info(info: dict, tag: str):
118
+ """Print data formats of all fields in the info dict."""
119
+ print(f"\n[{tag}] info fields:")
120
+ print(f"{'-'*60}")
121
+ task_goal = info["task_goal"]
122
+ simple_subgoal_online = info["simple_subgoal_online"]
123
+ grounded_subgoal_online = info["grounded_subgoal_online"]
124
+ available_multi_choices = info.get("available_multi_choices")
125
+ front_camera_intrinsic = info["front_camera_intrinsic"]
126
+ wrist_camera_intrinsic = info["wrist_camera_intrinsic"]
127
+ status = info.get("status")
128
+
129
+ fields = {
130
+ "task_goal": task_goal,
131
+ "simple_subgoal_online": simple_subgoal_online,
132
+ "grounded_subgoal_online": grounded_subgoal_online,
133
+ "available_multi_choices": available_multi_choices,
134
+ "front_camera_intrinsic": front_camera_intrinsic,
135
+ "wrist_camera_intrinsic": wrist_camera_intrinsic,
136
+ "status": status,
137
+ }
138
+ for name, val in fields.items():
139
+ print(f" info['{name}']:")
140
+ print(_describe(val, indent=2))
141
+ return fields
142
+
143
+
144
+ def _print_step_extras(reward, terminated, truncated, tag: str):
145
+ """Print data formats of reward / terminated / truncated."""
146
+ print(f"\n[{tag}] reward / terminated / truncated:")
147
+ print(f"{'-'*60}")
148
+ print(f" reward: {_describe(reward, 0)}")
149
+ print(f" terminated: {_describe(terminated, 0)}")
150
+ print(f" truncated: {_describe(truncated, 0)}")
151
+
152
+
153
+ def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict[str, Any]]:
154
+ if not isinstance(choice_action, dict):
155
+ return None
156
+ choice = choice_action.get("choice")
157
+ if not isinstance(choice, str) or not choice.strip():
158
+ return None
159
+ point = choice_action.get("point")
160
+ if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2:
161
+ return None
162
+ return choice_action
163
+
164
+
165
+ def main():
166
+ env_id_list = BenchmarkEnvBuilder.get_task_list()
167
+ print(f"Running envs: {env_id_list}")
168
+ print(f"Using action_space: {ACTION_SPACE}")
169
+
170
+ #for env_id in env_id_list:
171
+ for env_id in DEFAULT_ENV_IDS:
172
+ env_builder = BenchmarkEnvBuilder(
173
+ env_id=env_id,
174
+ dataset="train",
175
+ action_space=ACTION_SPACE,
176
+ gui_render=GUI_RENDER,
177
+ )
178
+ episode_count = env_builder.get_episode_num()
179
+ print(f"[{env_id}] episode_count from metadata: {episode_count}")
180
+
181
+ env = None
182
+ for episode in range(episode_count):
183
+
184
+ env = env_builder.make_env_for_episode(
185
+ episode,
186
+ max_steps=MAX_STEPS,
187
+ include_maniskill_obs=True,
188
+ include_front_depth=True,
189
+ include_wrist_depth=True,
190
+ include_front_camera_extrinsic=True,
191
+ include_wrist_camera_extrinsic=True,
192
+ include_available_multi_choices=True,
193
+ include_front_camera_intrinsic=True,
194
+ include_wrist_camera_intrinsic=True,
195
+ )
196
+ dataset_resolver = EpisodeDatasetResolver(
197
+ env_id=env_id,
198
+ episode=episode,
199
+ dataset_directory=DATASET_ROOT,
200
+ )
201
+
202
+ # obs: dict-of-lists (columnar batch, list length = number of demo frames)
203
+ # info: flat dict (last frame values only)
204
+ obs, info = env.reset()
205
+
206
+ # --- Print all obs / info field types (reset) ---
207
+ _print_obs(obs, tag=f"{env_id} ep{episode} RESET")
208
+ _print_info(info, tag=f"{env_id} ep{episode} RESET")
209
+
210
+ step = 0
211
+ episode_success = False
212
+
213
+ # ======== Step loop ========
214
+ while True:
215
+ replay_key = ACTION_SPACE
216
+ action = dataset_resolver.get_step(replay_key, step)
217
+ if ACTION_SPACE == "multi_choice":
218
+ action = _parse_oracle_command(action)
219
+ if action is None:
220
+ break
221
+
222
+ # step returns: obs (dict-of-lists), reward (scalar tensor),
223
+ # terminated (scalar tensor), truncated (scalar tensor), info (flat dict)
224
+ obs, reward, terminated, truncated, info = env.step(action)
225
+
226
+ # --- Print all obs / info / reward / terminated / truncated field types (step) ---
227
+ _print_obs(obs, tag=f"{env_id} ep{episode} STEP{step}")
228
+ _print_info(info, tag=f"{env_id} ep{episode} STEP{step}")
229
+ _print_step_extras(reward, terminated, truncated, tag=f"{env_id} ep{episode} STEP{step}")
230
+
231
+ terminated_flag = bool(terminated.item())
232
+ truncated_flag = bool(truncated.item())
233
+
234
+ step += 1
235
+ if GUI_RENDER:
236
+ env.render()
237
+ if truncated_flag:
238
+ print(f"[{env_id}] episode {episode} steps exceeded, step {step}.")
239
+ break
240
+ if terminated_flag:
241
+ status = info.get("status")
242
+ if status == "success":
243
+ print(f"[{env_id}] episode {episode} success.")
244
+ episode_success = True
245
+ elif status == "fail":
246
+ print(f"[{env_id}] episode {episode} failed.")
247
+ break
248
+
249
+ if env is not None:
250
+ env.close()
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
scripts/dev/deprecated/dataset_replay-FK-parallel.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Replay episodes from an HDF5 dataset and save videos.
3
+
4
+ Read recorded joint actions (joint_action) from record_dataset_<Task>.h5,
5
+ convert them to end-effector pose actions (EE pose actions) via forward kinematics (FK),
6
+ replay them in an environment wrapped by EE_POSE_ACTION_SPACE,
7
+ and finally save side-by-side front/wrist camera videos to disk.
8
+ """
9
+
10
+ import os
11
+ from typing import Optional, Tuple
12
+
13
+ import cv2
14
+ import h5py
15
+ import imageio
16
+ import numpy as np
17
+ import sapien
18
+ import torch
19
+
20
+ from mani_skill.examples.motionplanning.panda.motionplanner import (
21
+ PandaArmMotionPlanningSolver,
22
+ )
23
+
24
+ from robomme.robomme_env import *
25
+ from robomme.robomme_env.utils import *
26
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
27
+ from robomme.robomme_env.utils import EE_POSE_ACTION_SPACE
28
+ from robomme.robomme_env.utils.rpy_util import build_endeffector_pose_dict
29
+
30
+ # --- Configuration ---
31
+ GUI_RENDER = False
32
+ REPLAY_VIDEO_DIR = "replay_videos"
33
+ VIDEO_FPS = 30
34
+ MAX_STEPS = 1000
35
+
36
+
37
+ def _init_fk_planner(env) -> Tuple:
38
+ """Create PandaArmMotionPlanningSolver and return helper objects needed for FK.
39
+
40
+ Returns:
41
+ (mplib_planner, ee_link_idx, robot_base_pose)
42
+ - mplib_planner: mplib.Planner instance used for FK computation
43
+ - ee_link_idx: end-effector link index in the pinocchio model
44
+ - robot_base_pose: robot base pose in world coordinates
45
+ """
46
+ solver = PandaArmMotionPlanningSolver(
47
+ env,
48
+ debug=False,
49
+ vis=False,
50
+ base_pose=env.unwrapped.agent.robot.pose,
51
+ visualize_target_grasp_pose=False,
52
+ print_env_info=False,
53
+ )
54
+ mplib_planner = solver.planner
55
+ ee_link_idx = mplib_planner.link_name_2_idx[mplib_planner.move_group]
56
+ robot_base_pose = env.unwrapped.agent.robot.pose
57
+
58
+ print(f"[FK] move_group: {mplib_planner.move_group}, "
59
+ f"ee_link_idx: {ee_link_idx}, "
60
+ f"link_names: {mplib_planner.user_link_names}")
61
+ return mplib_planner, ee_link_idx, robot_base_pose
62
+
63
+
64
+ def _joint_action_to_ee_pose(
65
+ mplib_planner,
66
+ joint_action: np.ndarray,
67
+ robot_base_pose: sapien.Pose,
68
+ ee_link_idx: int,
69
+ prev_ee_quat_wxyz: Optional[torch.Tensor] = None,
70
+ prev_ee_rpy_xyz: Optional[torch.Tensor] = None,
71
+ ) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]:
72
+ """Convert 8D joint action to 7D end-effector pose action via forward kinematics (FK).
73
+
74
+ Args:
75
+ mplib_planner: mplib.Planner instance (from PandaArmMotionPlanningSolver).
76
+ joint_action: 8D array [q1..q7, gripper].
77
+ robot_base_pose: robot base pose as a Sapien Pose.
78
+ ee_link_idx: end-effector link index in the pinocchio model.
79
+ prev_ee_quat_wxyz: previous-frame quaternion cache (for sign alignment).
80
+ prev_ee_rpy_xyz: previous-frame RPY cache (for continuity unwrapping).
81
+
82
+ Returns:
83
+ ee_action: 7D [x, y, z, roll, pitch, yaw, gripper].
84
+ new_prev_quat: updated quaternion cache.
85
+ new_prev_rpy: updated RPY cache.
86
+ """
87
+ action = np.asarray(joint_action, dtype=np.float64).flatten()
88
+ arm_qpos = action[:7]
89
+ gripper = float(action[7]) if action.size > 7 else -1.0
90
+
91
+ # Build full qpos: 7 arm joints + 2 gripper finger joints
92
+ finger_pos = max(gripper, 0.0) if gripper >= 0 else 0.04
93
+ full_qpos = np.concatenate([arm_qpos, [finger_pos, finger_pos]])
94
+
95
+ # Compute forward kinematics in the robot-base coordinate frame
96
+ pmodel = mplib_planner.pinocchio_model
97
+ pmodel.compute_forward_kinematics(full_qpos)
98
+ fk_result = pmodel.get_link_pose(ee_link_idx) # 7D [x,y,z, qw,qx,qy,qz]
99
+
100
+ p_base = fk_result[:3]
101
+ q_base_wxyz = fk_result[3:] # wxyz quaternion format
102
+
103
+ # base frame -> world frame transform
104
+ pose_in_base = sapien.Pose(p_base, q_base_wxyz)
105
+ world_pose = robot_base_pose * pose_in_base
106
+
107
+ # Use shared utilities to build continuous RPY (quaternion normalization, sign alignment, RPY unwrapping)
108
+ position_t = torch.as_tensor(
109
+ np.asarray(world_pose.p, dtype=np.float64), dtype=torch.float64
110
+ )
111
+ quat_wxyz_t = torch.as_tensor(
112
+ np.asarray(world_pose.q, dtype=np.float64), dtype=torch.float64
113
+ )
114
+ pose_dict, new_prev_quat, new_prev_rpy = build_endeffector_pose_dict(
115
+ position_t, quat_wxyz_t,
116
+ prev_ee_quat_wxyz, prev_ee_rpy_xyz,
117
+ )
118
+
119
+ # Concatenate into 7D EE pose action: [position(3), RPY(3), gripper(1)]
120
+ pos_np = pose_dict["pose"].detach().cpu().numpy().flatten()[:3]
121
+ rpy_np = pose_dict["rpy"].detach().cpu().numpy().flatten()[:3]
122
+ ee_action = np.concatenate([pos_np, rpy_np, [gripper]]).astype(np.float64)
123
+
124
+ return ee_action, new_prev_quat, new_prev_rpy
125
+
126
+
127
+ def _frame_from_obs(obs: dict, is_video_frame: bool = False) -> np.ndarray:
128
+ """Build one side-by-side frame from front and wrist camera observations."""
129
+ front = obs["front_camera"][0].cpu().numpy()
130
+ wrist = obs["wrist_camera"][0].cpu().numpy()
131
+ frame = np.concatenate([front, wrist], axis=1).astype(np.uint8)
132
+ if is_video_frame:
133
+ # Mark video-demo frames with a red border
134
+ frame = cv2.rectangle(
135
+ frame, (0, 0), (frame.shape[1], frame.shape[0]), (255, 0, 0), 10
136
+ )
137
+ return frame
138
+
139
+
140
+ def _first_execution_step(episode_data) -> int:
141
+ """Return the first non-video-demo step index (actual execution start step)."""
142
+ step_idx = 0
143
+ while episode_data[f"timestep_{step_idx}"]["info"]["is_video_demo"][()]:
144
+ step_idx += 1
145
+ return step_idx
146
+
147
+
148
+ def process_episode(
149
+ h5_file_path: str, episode_idx: int, env_id: str, gui_render: bool = False,
150
+ ) -> None:
151
+ """Replay one episode in HDF5: read joint actions, run FK conversion, execute the environment, and save video.
152
+
153
+ Each worker process opens the HDF5 file independently to avoid cross-process shared file handles.
154
+ """
155
+ with h5py.File(h5_file_path, "r") as env_data:
156
+ episode_data = env_data[f"episode_{episode_idx}"]
157
+ task_goal = episode_data["setup"]["task_goal"][()].decode()
158
+ total_steps = sum(1 for k in episode_data.keys() if k.startswith("timestep_"))
159
+
160
+ step_idx = _first_execution_step(episode_data)
161
+ print(f"[ep{episode_idx}] execution start step index: {step_idx}")
162
+
163
+ # Create environment with EE_POSE_ACTION_SPACE (wrapped by EndeffectorDemonstrationWrapper)
164
+ env_builder = BenchmarkEnvBuilder(
165
+ env_id=env_id,
166
+ dataset="train",
167
+ action_space=EE_POSE_ACTION_SPACE,
168
+ gui_render=gui_render,
169
+ )
170
+ env = env_builder.make_env_for_episode(
171
+ episode_idx,
172
+ max_steps=MAX_STEPS,
173
+ include_maniskill_obs=True,
174
+ include_front_depth=True,
175
+ include_wrist_depth=True,
176
+ include_front_camera_extrinsic=True,
177
+ include_wrist_camera_extrinsic=True,
178
+ include_available_multi_choices=True,
179
+ include_front_camera_intrinsic=True,
180
+ include_wrist_camera_intrinsic=True,
181
+ )
182
+ print(f"[ep{episode_idx}] task: {env_id}, goal: {task_goal}")
183
+
184
+ obs, info = env.reset()
185
+
186
+ # Initialize FK planner (must be called after env.reset())
187
+ mplib_planner, ee_link_idx, robot_base_pose = _init_fk_planner(env)
188
+
189
+ # Observation list: length 1 means no demo video, length >1 means includes demo video; last element is current frame
190
+ frames = []
191
+ n_obs = len(obs["front_camera"])
192
+ for i in range(n_obs):
193
+ single_obs = {k: [v[i]] for k, v in obs.items()}
194
+ frames.append(_frame_from_obs(single_obs, is_video_frame=(i < n_obs - 1)))
195
+ print(f"[ep{episode_idx}] initial frame count (demo video + current frame): {len(frames)}")
196
+
197
+ outcome = "unknown"
198
+ prev_quat: Optional[torch.Tensor] = None
199
+ prev_rpy: Optional[torch.Tensor] = None
200
+ try:
201
+ while step_idx < total_steps:
202
+ # Read joint action from HDF5
203
+ joint_action = np.asarray(
204
+ episode_data[f"timestep_{step_idx}"]["action"]["joint_action"][()],
205
+ dtype=np.float64,
206
+ )
207
+
208
+ # Forward kinematics: joint_action -> ee_pose action
209
+ ee_action, prev_quat, prev_rpy = _joint_action_to_ee_pose(
210
+ mplib_planner, joint_action, robot_base_pose, ee_link_idx,
211
+ prev_ee_quat_wxyz=prev_quat,
212
+ prev_ee_rpy_xyz=prev_rpy,
213
+ )
214
+
215
+ # Print debug info on the first step to verify FK conversion
216
+ if step_idx == _first_execution_step(episode_data):
217
+ print(f"[ep{episode_idx}][FK] first step joint_action: {joint_action}")
218
+ print(f"[ep{episode_idx}][FK] first step ee_action: {ee_action}")
219
+
220
+ # Execute EE pose action in the environment
221
+ obs, _, terminated, _, info = env.step(ee_action)
222
+ frames.append(_frame_from_obs(obs))
223
+
224
+ if gui_render:
225
+ env.render()
226
+
227
+ # TODO: hongze fix nested-list handling
228
+ if terminated:
229
+ if info.get("success", False)[-1][-1]:
230
+ outcome = "success"
231
+ if info.get("fail", False)[-1][-1]:
232
+ outcome = "fail"
233
+ break
234
+ step_idx += 1
235
+ finally:
236
+ env.close()
237
+
238
+ # Save replay video
239
+ safe_goal = task_goal.replace(" ", "_").replace("/", "_")
240
+ os.makedirs(REPLAY_VIDEO_DIR, exist_ok=True)
241
+ video_name = f"{outcome}_{env_id}_ep{episode_idx}_{safe_goal}_step-{len(frames)}.mp4"
242
+ video_path = os.path.join(REPLAY_VIDEO_DIR, video_name)
243
+ imageio.mimsave(video_path, frames, fps=VIDEO_FPS)
244
+ print(f"[ep{episode_idx}] Video saved to {video_path}")
245
+
246
+
247
+ def _worker_init(gpu_id_queue) -> None:
248
+ """Pool worker initializer that binds a GPU before CUDA initialization.
249
+
250
+ When each worker starts, it takes one GPU ID from the queue and sets env vars,
251
+ ensuring all later CUDA ops in that process run on the assigned GPU.
252
+ """
253
+ gpu_id = gpu_id_queue.get()
254
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
255
+ print(f"[Worker PID {os.getpid()}] bind GPU {gpu_id}")
256
+
257
+
258
+ def _process_episode_worker(args: Tuple[str, int, str, bool]) -> str:
259
+ """multiprocessing worker entrypoint: unpack args and call process_episode."""
260
+ h5_file_path, episode_idx, env_id, gui_render = args
261
+ gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
262
+ try:
263
+ process_episode(h5_file_path, episode_idx, env_id, gui_render=gui_render)
264
+ return f"OK: {env_id} ep{episode_idx} (GPU {gpu_id})"
265
+ except Exception as e:
266
+ import traceback
267
+ traceback.print_exc()
268
+ return f"FAIL: {env_id} ep{episode_idx} (GPU {gpu_id}): {e}"
269
+
270
+
271
+ def replay(
272
+ h5_data_dir: str = "/data/hongzefu/data_0214",
273
+ num_workers: int = 20,
274
+ gui_render: bool = False,
275
+ gpu_ids: str = "0,1",
276
+ ) -> None:
277
+ """Iterate through all task HDF5 files in the given directory and replay multiple episodes per env in parallel.
278
+
279
+ Args:
280
+ h5_data_dir: Directory containing HDF5 datasets.
281
+ num_workers: Number of parallel workers per env.
282
+ gui_render: Whether to enable GUI rendering (recommended off in multiprocessing).
283
+ gpu_ids: Comma-separated GPU ID list; workers use them in round-robin order.
284
+ For example, "0,1" alternates assignment between GPU 0 and GPU 1.
285
+ """
286
+ import multiprocessing as mp
287
+ ctx = mp.get_context("spawn")
288
+
289
+ gpu_id_list = [int(g.strip()) for g in gpu_ids.split(",")]
290
+ print(f"Using GPUs: {gpu_id_list}, workers: {num_workers}")
291
+
292
+ env_id_list = BenchmarkEnvBuilder.get_task_list()
293
+ for env_id in env_id_list:
294
+ file_name = f"record_dataset_{env_id}.h5"
295
+ file_path = os.path.join(h5_data_dir, file_name)
296
+ if not os.path.exists(file_path):
297
+ print(f"Skip {env_id}: file does not exist: {file_path}")
298
+ continue
299
+
300
+ # Quickly read episode list and close file
301
+ with h5py.File(file_path, "r") as data:
302
+ episode_indices = sorted(
303
+ int(k.split("_")[1])
304
+ for k in data.keys()
305
+ if k.startswith("episode_")
306
+ )
307
+ print(f"task: {env_id}, total {len(episode_indices)} episodes, "
308
+ f"workers: {num_workers}, GPUs: {gpu_id_list}")
309
+
310
+ # Build worker argument list
311
+ worker_args = [
312
+ (file_path, ep_idx, env_id, gui_render)
313
+ for ep_idx in episode_indices
314
+ ]
315
+
316
+ # Create a new GPU assignment queue for each round; each worker grabs one GPU ID at startup
317
+ gpu_id_queue = ctx.Queue()
318
+ for i in range(num_workers):
319
+ gpu_id_queue.put(gpu_id_list[i % len(gpu_id_list)])
320
+
321
+ # Parallel replay (initializer binds GPU when each worker starts)
322
+ with ctx.Pool(
323
+ processes=num_workers,
324
+ initializer=_worker_init,
325
+ initargs=(gpu_id_queue,),
326
+ ) as pool:
327
+ results = pool.map(_process_episode_worker, worker_args)
328
+
329
+ for r in results:
330
+ print(r)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ import tyro
335
+ tyro.cli(replay)
scripts/dev/deprecated/dataset_replay-FK.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Replay episodes from an HDF5 dataset and save videos.
3
+
4
+ Read recorded joint actions (joint_action) from record_dataset_<Task>.h5,
5
+ convert them to end-effector pose actions (EE pose actions) via forward kinematics (FK),
6
+ replay them in an environment wrapped by EE_POSE_ACTION_SPACE,
7
+ and finally save side-by-side front/wrist camera videos to disk.
8
+ """
9
+
10
+ import os
11
+ from typing import Optional, Tuple
12
+
13
+ import cv2
14
+ import h5py
15
+ import imageio
16
+ import numpy as np
17
+ import sapien
18
+ import torch
19
+
20
+ from mani_skill.examples.motionplanning.panda.motionplanner import (
21
+ PandaArmMotionPlanningSolver,
22
+ )
23
+
24
+ from robomme.robomme_env import *
25
+ from robomme.robomme_env.utils import *
26
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
27
+ from robomme.robomme_env.utils import EE_POSE_ACTION_SPACE
28
+ from robomme.robomme_env.utils.rpy_util import build_endeffector_pose_dict
29
+
30
+ # --- Configuration ---
31
+ GUI_RENDER = True
32
+ REPLAY_VIDEO_DIR = "replay_videos"
33
+ VIDEO_FPS = 30
34
+ MAX_STEPS = 1000
35
+
36
+
37
+ def _init_fk_planner(env) -> Tuple:
38
+ """Create PandaArmMotionPlanningSolver and return helper objects needed for FK.
39
+
40
+ Returns:
41
+ (mplib_planner, ee_link_idx, robot_base_pose)
42
+ - mplib_planner: mplib.Planner instance used for FK computation
43
+ - ee_link_idx: end-effector link index in the pinocchio model
44
+ - robot_base_pose: robot base pose in world coordinates
45
+ """
46
+ solver = PandaArmMotionPlanningSolver(
47
+ env,
48
+ debug=False,
49
+ vis=False,
50
+ base_pose=env.unwrapped.agent.robot.pose,
51
+ visualize_target_grasp_pose=False,
52
+ print_env_info=False,
53
+ )
54
+ mplib_planner = solver.planner
55
+ ee_link_idx = mplib_planner.link_name_2_idx[mplib_planner.move_group]
56
+ robot_base_pose = env.unwrapped.agent.robot.pose
57
+
58
+ print(f"[FK] move_group: {mplib_planner.move_group}, "
59
+ f"ee_link_idx: {ee_link_idx}, "
60
+ f"link_names: {mplib_planner.user_link_names}")
61
+ return mplib_planner, ee_link_idx, robot_base_pose
62
+
63
+
64
+ def _joint_action_to_ee_pose(
65
+ mplib_planner,
66
+ joint_action: np.ndarray,
67
+ robot_base_pose: sapien.Pose,
68
+ ee_link_idx: int,
69
+ prev_ee_quat_wxyz: Optional[torch.Tensor] = None,
70
+ prev_ee_rpy_xyz: Optional[torch.Tensor] = None,
71
+ ) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]:
72
+ """Convert 8D joint action to 7D end-effector pose action via forward kinematics (FK).
73
+
74
+ Args:
75
+ mplib_planner: mplib.Planner instance (from PandaArmMotionPlanningSolver).
76
+ joint_action: 8D array [q1..q7, gripper].
77
+ robot_base_pose: robot base pose as a Sapien Pose.
78
+ ee_link_idx: end-effector link index in the pinocchio model.
79
+ prev_ee_quat_wxyz: previous-frame quaternion cache (for sign alignment).
80
+ prev_ee_rpy_xyz: previous-frame RPY cache (for continuity unwrapping).
81
+
82
+ Returns:
83
+ ee_action: 7D [x, y, z, roll, pitch, yaw, gripper].
84
+ new_prev_quat: updated quaternion cache.
85
+ new_prev_rpy: updated RPY cache.
86
+ """
87
+ action = np.asarray(joint_action, dtype=np.float64).flatten()
88
+ arm_qpos = action[:7]
89
+ gripper = float(action[7]) if action.size > 7 else -1.0
90
+
91
+ # Build full qpos: 7 arm joints + 2 gripper finger joints
92
+ finger_pos = max(gripper, 0.0) if gripper >= 0 else 0.04
93
+ full_qpos = np.concatenate([arm_qpos, [finger_pos, finger_pos]])
94
+
95
+ # Compute forward kinematics in the robot-base coordinate frame
96
+ pmodel = mplib_planner.pinocchio_model
97
+ pmodel.compute_forward_kinematics(full_qpos)
98
+ fk_result = pmodel.get_link_pose(ee_link_idx) # 7D [x,y,z, qw,qx,qy,qz]
99
+
100
+ p_base = fk_result[:3]
101
+ q_base_wxyz = fk_result[3:] # wxyz quaternion format
102
+
103
+ # base frame -> world frame transform
104
+ pose_in_base = sapien.Pose(p_base, q_base_wxyz)
105
+ world_pose = robot_base_pose * pose_in_base
106
+
107
+ # Use shared utilities to build continuous RPY (quaternion normalization, sign alignment, RPY unwrapping)
108
+ position_t = torch.as_tensor(
109
+ np.asarray(world_pose.p, dtype=np.float64), dtype=torch.float64
110
+ )
111
+ quat_wxyz_t = torch.as_tensor(
112
+ np.asarray(world_pose.q, dtype=np.float64), dtype=torch.float64
113
+ )
114
+ pose_dict, new_prev_quat, new_prev_rpy = build_endeffector_pose_dict(
115
+ position_t, quat_wxyz_t,
116
+ prev_ee_quat_wxyz, prev_ee_rpy_xyz,
117
+ )
118
+
119
+ # Concatenate into 7D EE pose action: [position(3), RPY(3), gripper(1)]
120
+ pos_np = pose_dict["pose"].detach().cpu().numpy().flatten()[:3]
121
+ rpy_np = pose_dict["rpy"].detach().cpu().numpy().flatten()[:3]
122
+ ee_action = np.concatenate([pos_np, rpy_np, [gripper]]).astype(np.float64)
123
+
124
+ return ee_action, new_prev_quat, new_prev_rpy
125
+
126
+
127
+ def _frame_from_obs(obs: dict, is_video_frame: bool = False) -> np.ndarray:
128
+ """Build one side-by-side frame from front and wrist camera observations."""
129
+ front = obs["front_camera"][0].cpu().numpy()
130
+ wrist = obs["wrist_camera"][0].cpu().numpy()
131
+ frame = np.concatenate([front, wrist], axis=1).astype(np.uint8)
132
+ if is_video_frame:
133
+ # Mark video-demo frames with a red border
134
+ frame = cv2.rectangle(
135
+ frame, (0, 0), (frame.shape[1], frame.shape[0]), (255, 0, 0), 10
136
+ )
137
+ return frame
138
+
139
+
140
+ def _first_execution_step(episode_data) -> int:
141
+ """Return the first non-video-demo step index (actual execution start step)."""
142
+ step_idx = 0
143
+ while episode_data[f"timestep_{step_idx}"]["info"]["is_video_demo"][()]:
144
+ step_idx += 1
145
+ return step_idx
146
+
147
+
148
+ def process_episode(env_data: h5py.File, episode_idx: int, env_id: str) -> None:
149
+ """Replay one episode in HDF5: read joint actions, run FK conversion, execute the environment, and save video."""
150
+ episode_data = env_data[f"episode_{episode_idx}"]
151
+ task_goal = episode_data["setup"]["task_goal"][()].decode()
152
+ total_steps = sum(1 for k in episode_data.keys() if k.startswith("timestep_"))
153
+
154
+ step_idx = _first_execution_step(episode_data)
155
+ print(f"execution start step index: {step_idx}")
156
+
157
+ # Create environment with EE_POSE_ACTION_SPACE (wrapped by EndeffectorDemonstrationWrapper)
158
+ env_builder = BenchmarkEnvBuilder(
159
+ env_id=env_id,
160
+ dataset="train",
161
+ action_space=EE_POSE_ACTION_SPACE,
162
+ gui_render=GUI_RENDER,
163
+ )
164
+ env = env_builder.make_env_for_episode(
165
+ episode_idx,
166
+ max_steps=MAX_STEPS,
167
+ include_maniskill_obs=True,
168
+ include_front_depth=True,
169
+ include_wrist_depth=True,
170
+ include_front_camera_extrinsic=True,
171
+ include_wrist_camera_extrinsic=True,
172
+ include_available_multi_choices=True,
173
+ include_front_camera_intrinsic=True,
174
+ include_wrist_camera_intrinsic=True,
175
+ )
176
+ print(f"task: {env_id}, episode: {episode_idx}, goal: {task_goal}")
177
+
178
+ obs, info = env.reset()
179
+
180
+ # Initialize FK planner (must be called after env.reset())
181
+ mplib_planner, ee_link_idx, robot_base_pose = _init_fk_planner(env)
182
+
183
+ # Observation list: length 1 means no demo video, length >1 means includes demo video; last element is current frame
184
+ frames = []
185
+ n_obs = len(obs["front_camera"])
186
+ for i in range(n_obs):
187
+ single_obs = {k: [v[i]] for k, v in obs.items()}
188
+ frames.append(_frame_from_obs(single_obs, is_video_frame=(i < n_obs - 1)))
189
+ print(f"initial frame count (demo video + current frame): {len(frames)}")
190
+
191
+ outcome = "unknown"
192
+ prev_quat: Optional[torch.Tensor] = None
193
+ prev_rpy: Optional[torch.Tensor] = None
194
+ try:
195
+ while step_idx < total_steps:
196
+ # Read joint action from HDF5
197
+ joint_action = np.asarray(
198
+ episode_data[f"timestep_{step_idx}"]["action"]["joint_action"][()],
199
+ dtype=np.float64,
200
+ )
201
+
202
+ # Forward kinematics: joint_action -> ee_pose action
203
+ ee_action, prev_quat, prev_rpy = _joint_action_to_ee_pose(
204
+ mplib_planner, joint_action, robot_base_pose, ee_link_idx,
205
+ prev_ee_quat_wxyz=prev_quat,
206
+ prev_ee_rpy_xyz=prev_rpy,
207
+ )
208
+
209
+ # Print debug info on the first step to verify FK conversion
210
+ if step_idx == _first_execution_step(episode_data):
211
+ print(f"[FK] first step joint_action: {joint_action}")
212
+ print(f"[FK] first step ee_action: {ee_action}")
213
+
214
+ # Execute EE pose action in the environment
215
+ obs, _, terminated, _, info = env.step(ee_action)
216
+ frames.append(_frame_from_obs(obs))
217
+
218
+ if GUI_RENDER:
219
+ env.render()
220
+
221
+ # TODO: hongze fix nested-list handling
222
+ if terminated:
223
+ if info.get("success", False)[-1][-1]:
224
+ outcome = "success"
225
+ if info.get("fail", False)[-1][-1]:
226
+ outcome = "fail"
227
+ break
228
+ step_idx += 1
229
+ finally:
230
+ env.close()
231
+
232
+ # Save replay video
233
+ safe_goal = task_goal.replace(" ", "_").replace("/", "_")
234
+ os.makedirs(REPLAY_VIDEO_DIR, exist_ok=True)
235
+ video_name = f"{outcome}_{env_id}_ep{episode_idx}_{safe_goal}_step-{len(frames)}.mp4"
236
+ video_path = os.path.join(REPLAY_VIDEO_DIR, video_name)
237
+ imageio.mimsave(video_path, frames, fps=VIDEO_FPS)
238
+ print(f"Video saved to {video_path}")
239
+
240
+
241
+ def replay(h5_data_dir: str = "/data/hongzefu/data_0214") -> None:
242
+ """Iterate through all task HDF5 files in the given directory and replay episodes one by one."""
243
+ env_id_list = BenchmarkEnvBuilder.get_task_list()
244
+ for env_id in env_id_list:
245
+ file_name = f"record_dataset_{env_id}.h5"
246
+ file_path = os.path.join(h5_data_dir, file_name)
247
+ if not os.path.exists(file_path):
248
+ print(f"Skip {env_id}: file does not exist: {file_path}")
249
+ continue
250
+
251
+ with h5py.File(file_path, "r") as data:
252
+ episode_indices = sorted(
253
+ int(k.split("_")[1])
254
+ for k in data.keys()
255
+ if k.startswith("episode_")
256
+ )
257
+ print(f"task: {env_id}, total {len(episode_indices)} episodes")
258
+ for episode_idx in episode_indices:
259
+ process_episode(data, episode_idx, env_id)
260
+
261
+
262
+ if __name__ == "__main__":
263
+ import tyro
264
+ tyro.cli(replay)
scripts/dev/deprecated/dataset_replay-ee-parallel.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Replay episodes from an HDF5 dataset and save videos.
3
+
4
+ Read recorded end-effector pose actions (eef_action) from record_dataset_<Task>.h5,
5
+ replay them in an environment wrapped by EE_POSE_ACTION_SPACE,
6
+ and finally save side-by-side front/wrist camera videos to disk.
7
+ """
8
+
9
+ import os
10
+ from typing import Tuple
11
+
12
+ import cv2
13
+ import h5py
14
+ import imageio
15
+ import numpy as np
16
+
17
+ from robomme.robomme_env import *
18
+ from robomme.robomme_env.utils import *
19
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
20
+ from robomme.robomme_env.utils import EE_POSE_ACTION_SPACE
21
+
22
+ # --- Config ---
23
+ GUI_RENDER = False
24
+ REPLAY_VIDEO_DIR = "replay_videos"
25
+ VIDEO_FPS = 30
26
+ MAX_STEPS = 1000
27
+
28
+
29
+ def _frame_from_obs(obs: dict, is_video_frame: bool = False) -> np.ndarray:
30
+ """Build a single side-by-side frame from front and wrist camera obs."""
31
+ front = obs["front_camera"][0].cpu().numpy()
32
+ wrist = obs["wrist_camera"][0].cpu().numpy()
33
+ frame = np.concatenate([front, wrist], axis=1).astype(np.uint8)
34
+ if is_video_frame:
35
+ frame = cv2.rectangle(
36
+ frame, (0, 0), (frame.shape[1], frame.shape[0]), (255, 0, 0), 10
37
+ )
38
+ return frame
39
+
40
+
41
+ def _first_execution_step(episode_data) -> int:
42
+ """Return the first step index that is not a video-demo step."""
43
+ step_idx = 0
44
+ while episode_data[f"timestep_{step_idx}"]["info"]["is_video_demo"][()]:
45
+ step_idx += 1
46
+ return step_idx
47
+
48
+
49
+ def process_episode(
50
+ h5_file_path: str, episode_idx: int, env_id: str, gui_render: bool = False,
51
+ ) -> None:
52
+ """Replay one episode in HDF5: read EE pose actions, run the environment, and save video.
53
+
54
+ Each worker process opens the HDF5 file independently to avoid cross-process shared file handles.
55
+ """
56
+ with h5py.File(h5_file_path, "r") as env_data:
57
+ episode_data = env_data[f"episode_{episode_idx}"]
58
+ task_goal = episode_data["setup"]["task_goal"][()].decode()
59
+ total_steps = sum(1 for k in episode_data.keys() if k.startswith("timestep_"))
60
+
61
+ step_idx = _first_execution_step(episode_data)
62
+ print(f"[ep{episode_idx}] execution start step index: {step_idx}")
63
+
64
+ env_builder = BenchmarkEnvBuilder(
65
+ env_id=env_id,
66
+ dataset="train",
67
+ action_space=EE_POSE_ACTION_SPACE,
68
+ gui_render=gui_render,
69
+ )
70
+ env = env_builder.make_env_for_episode(
71
+ episode_idx,
72
+ max_steps=MAX_STEPS,
73
+ include_maniskill_obs=True,
74
+ include_front_depth=True,
75
+ include_wrist_depth=True,
76
+ include_front_camera_extrinsic=True,
77
+ include_wrist_camera_extrinsic=True,
78
+ include_available_multi_choices=True,
79
+ include_front_camera_intrinsic=True,
80
+ include_wrist_camera_intrinsic=True,
81
+ )
82
+ print(f"[ep{episode_idx}] task: {env_id}, goal: {task_goal}")
83
+
84
+ obs, info = env.reset()
85
+ frames = []
86
+ n_obs = len(obs["front_camera"])
87
+ for i in range(n_obs):
88
+ single_obs = {k: [v[i]] for k, v in obs.items()}
89
+ frames.append(_frame_from_obs(single_obs, is_video_frame=(i < n_obs - 1)))
90
+ print(f"[ep{episode_idx}] initial frame count (demo video + current frame): {len(frames)}")
91
+
92
+ outcome = "unknown"
93
+ try:
94
+ while step_idx < total_steps:
95
+ action = np.asarray(
96
+ episode_data[f"timestep_{step_idx}"]["action"]["eef_action"][()],
97
+ dtype=np.float32,
98
+ )
99
+ obs, _, terminated, _, info = env.step(action)
100
+ frames.append(_frame_from_obs(obs))
101
+
102
+ if gui_render:
103
+ env.render()
104
+
105
+ # TODO: hongze makes this correct
106
+ # there are two many nested lists here, need to flatten them
107
+ if terminated:
108
+ if info.get("success", False)[-1][-1]:
109
+ outcome = "success"
110
+ if info.get("fail", False)[-1][-1]:
111
+ outcome = "fail"
112
+ break
113
+ step_idx += 1
114
+ finally:
115
+ env.close()
116
+
117
+ # Save replay video
118
+ safe_goal = task_goal.replace(" ", "_").replace("/", "_")
119
+ os.makedirs(REPLAY_VIDEO_DIR, exist_ok=True)
120
+ video_name = f"{outcome}_{env_id}_ep{episode_idx}_{safe_goal}_step-{len(frames)}.mp4"
121
+ video_path = os.path.join(REPLAY_VIDEO_DIR, video_name)
122
+ imageio.mimsave(video_path, frames, fps=VIDEO_FPS)
123
+ print(f"[ep{episode_idx}] Video saved to {video_path}")
124
+
125
+
126
+ def _worker_init(gpu_id_queue) -> None:
127
+ """Pool worker initializer that binds a GPU before CUDA initialization.
128
+
129
+ When each worker starts, it takes one GPU ID from the queue and sets env vars,
130
+ ensuring all later CUDA ops in that process run on the assigned GPU.
131
+ """
132
+ gpu_id = gpu_id_queue.get()
133
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
134
+ print(f"[Worker PID {os.getpid()}] bind GPU {gpu_id}")
135
+
136
+
137
+ def _process_episode_worker(args: Tuple[str, int, str, bool]) -> str:
138
+ """multiprocessing worker entrypoint: unpack args and call process_episode."""
139
+ h5_file_path, episode_idx, env_id, gui_render = args
140
+ gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
141
+ try:
142
+ process_episode(h5_file_path, episode_idx, env_id, gui_render=gui_render)
143
+ return f"OK: {env_id} ep{episode_idx} (GPU {gpu_id})"
144
+ except Exception as e:
145
+ import traceback
146
+ traceback.print_exc()
147
+ return f"FAIL: {env_id} ep{episode_idx} (GPU {gpu_id}): {e}"
148
+
149
+
150
+ def replay(
151
+ h5_data_dir: str = "/data/hongzefu/dataset_generate",
152
+ num_workers: int = 20,
153
+ gui_render: bool = False,
154
+ gpu_ids: str = "0,1",
155
+ ) -> None:
156
+ """Iterate through all task HDF5 files in the given directory and replay multiple episodes per env in parallel.
157
+
158
+ Args:
159
+ h5_data_dir: Directory containing HDF5 datasets.
160
+ num_workers: Number of parallel workers per env.
161
+ gui_render: Whether to enable GUI rendering (recommended off in multiprocessing).
162
+ gpu_ids: Comma-separated GPU ID list; workers use them in round-robin order.
163
+ For example, "0,1" alternates assignment between GPU 0 and GPU 1.
164
+ """
165
+ import multiprocessing as mp
166
+ ctx = mp.get_context("spawn")
167
+
168
+ gpu_id_list = [int(g.strip()) for g in gpu_ids.split(",")]
169
+ print(f"Using GPUs: {gpu_id_list}, workers: {num_workers}")
170
+
171
+ env_id_list = BenchmarkEnvBuilder.get_task_list()
172
+ for env_id in env_id_list:
173
+ file_name = f"record_dataset_{env_id}.h5"
174
+ file_path = os.path.join(h5_data_dir, file_name)
175
+ if not os.path.exists(file_path):
176
+ print(f"Skip {env_id}: file does not exist: {file_path}")
177
+ continue
178
+
179
+ # Quickly read episode list and close file
180
+ with h5py.File(file_path, "r") as data:
181
+ episode_indices = sorted(
182
+ int(k.split("_")[1])
183
+ for k in data.keys()
184
+ if k.startswith("episode_")
185
+ )
186
+ print(f"task: {env_id}, total {len(episode_indices)} episodes, "
187
+ f"workers: {num_workers}, GPUs: {gpu_id_list}")
188
+
189
+ # Build worker argument list
190
+ worker_args = [
191
+ (file_path, ep_idx, env_id, gui_render)
192
+ for ep_idx in episode_indices
193
+ ]
194
+
195
+ # Create a new GPU assignment queue for each round; each worker grabs one GPU ID at startup
196
+ gpu_id_queue = ctx.Queue()
197
+ for i in range(num_workers):
198
+ gpu_id_queue.put(gpu_id_list[i % len(gpu_id_list)])
199
+
200
+ # Parallel replay (initializer binds GPU when each worker starts)
201
+ with ctx.Pool(
202
+ processes=num_workers,
203
+ initializer=_worker_init,
204
+ initargs=(gpu_id_queue,),
205
+ ) as pool:
206
+ results = pool.map(_process_episode_worker, worker_args)
207
+
208
+ for r in results:
209
+ print(r)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ import tyro
214
+ tyro.cli(replay)
scripts/dev/deprecated/dataset_replay-ee.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Replay episodes from HDF5 datasets and save rollout videos.
3
+
4
+ Loads recorded joint actions from record_dataset_<Task>.h5, steps the environment,
5
+ and writes side-by-side front/wrist camera videos to disk.
6
+ """
7
+
8
+ import os
9
+
10
+ import cv2
11
+ import h5py
12
+ import imageio
13
+ import numpy as np
14
+
15
+ from robomme.robomme_env import *
16
+ from robomme.robomme_env.utils import *
17
+ from robomme.env_record_wrapper import BenchmarkEnvBuilder
18
+ from robomme.robomme_env.utils import EE_POSE_ACTION_SPACE
19
+
20
+ # --- Config ---
21
+ GUI_RENDER = False
22
+ REPLAY_VIDEO_DIR = "replay_videos"
23
+ VIDEO_FPS = 30
24
+ MAX_STEPS = 1000
25
+
26
+
27
+ def _frame_from_obs(obs: dict, is_video_frame: bool = False) -> np.ndarray:
28
+ """Build a single side-by-side frame from front and wrist camera obs."""
29
+ front = obs["front_camera"][0].cpu().numpy()
30
+ wrist = obs["wrist_camera"][0].cpu().numpy()
31
+ frame = np.concatenate([front, wrist], axis=1).astype(np.uint8)
32
+ if is_video_frame:
33
+ frame = cv2.rectangle(
34
+ frame, (0, 0), (frame.shape[1], frame.shape[0]), (255, 0, 0), 10
35
+ )
36
+ return frame
37
+
38
+
39
+ def _first_execution_step(episode_data) -> int:
40
+ """Return the first step index that is not a video-demo step."""
41
+ step_idx = 0
42
+ while episode_data[f"timestep_{step_idx}"]["info"]["is_video_demo"][()]:
43
+ step_idx += 1
44
+ return step_idx
45
+
46
+
47
+ def process_episode(env_data: h5py.File, episode_idx: int, env_id: str) -> None:
48
+ """Replay one episode from HDF5 data, record frames, and save a video."""
49
+ episode_data = env_data[f"episode_{episode_idx}"]
50
+ task_goal = episode_data["setup"]["task_goal"][()].decode()
51
+ total_steps = sum(1 for k in episode_data.keys() if k.startswith("timestep_"))
52
+
53
+ step_idx = _first_execution_step(episode_data)
54
+ print(f"Execution start step index: {step_idx}")
55
+
56
+ env_builder = BenchmarkEnvBuilder(
57
+ env_id=env_id,
58
+ dataset="test",
59
+ action_space=EE_POSE_ACTION_SPACE,
60
+ gui_render=GUI_RENDER,
61
+ )
62
+ env = env_builder.make_env_for_episode(
63
+ episode_idx,
64
+ max_steps=MAX_STEPS,
65
+ include_maniskill_obs=True,
66
+ include_front_depth=True,
67
+ include_wrist_depth=True,
68
+ include_front_camera_extrinsic=True,
69
+ include_wrist_camera_extrinsic=True,
70
+ include_available_multi_choices=True,
71
+ include_front_camera_intrinsic=True,
72
+ include_wrist_camera_intrinsic=True,
73
+ )
74
+ print(f"task_name: {env_id}, episode_idx: {episode_idx}, task_goal: {task_goal}")
75
+
76
+ obs, info = env.reset()
77
+ # Obs lists: length 1 = no video, length > 1 = video; last element is current.
78
+ frames = []
79
+ n_obs = len(obs["front_camera"])
80
+ for i in range(n_obs):
81
+ single_obs = {k: [v[i]] for k, v in obs.items()}
82
+ frames.append(_frame_from_obs(single_obs, is_video_frame=(i < n_obs - 1)))
83
+ print(f"Initial frames (video + current): {len(frames)}")
84
+
85
+ outcome = "unknown"
86
+ try:
87
+ while step_idx < total_steps:
88
+ action = np.asarray(
89
+ episode_data[f"timestep_{step_idx}"]["action"]["eef_action"][()],
90
+ dtype=np.float32,
91
+ )
92
+ obs, _, terminated, _, info = env.step(action)
93
+ frames.append(_frame_from_obs(obs))
94
+
95
+ if GUI_RENDER:
96
+ env.render()
97
+
98
+ # TODO: hongze makes this correct
99
+ # there are two many nested lists here, need to flatten them
100
+ if terminated:
101
+ if info.get("success", False)[-1][-1]:
102
+ outcome = "success"
103
+ if info.get("fail", False)[-1][-1]:
104
+ outcome = "fail"
105
+ break
106
+ step_idx += 1
107
+ finally:
108
+ env.close()
109
+
110
+ safe_goal = task_goal.replace(" ", "_").replace("/", "_")
111
+ os.makedirs(REPLAY_VIDEO_DIR, exist_ok=True)
112
+ video_name = f"{outcome}_{env_id}_ep{episode_idx}_{safe_goal}_step-{len(frames)}.mp4"
113
+ video_path = os.path.join(REPLAY_VIDEO_DIR, video_name)
114
+ imageio.mimsave(video_path, frames, fps=VIDEO_FPS)
115
+ print(f"Saved video to {video_path}")
116
+
117
+
118
+ def replay(h5_data_dir: str = "/data/hongzefu/dataset_generate") -> None:
119
+ """Replay all episodes from all task HDF5 files in the given directory."""
120
+ env_id_list = BenchmarkEnvBuilder.get_task_list()
121
+ env_id_list =[
122
+ "PickXtimes",
123
+ # "StopCube",
124
+ # "SwingXtimes",
125
+ # "BinFill",
126
+
127
+ # "VideoUnmaskSwap",
128
+ # "VideoUnmask",
129
+ # "ButtonUnmaskSwap",
130
+ # "ButtonUnmask",
131
+
132
+ # "VideoRepick",
133
+ # "VideoPlaceButton",
134
+ # "VideoPlaceOrder",
135
+ # "PickHighlight",
136
+
137
+ # "InsertPeg",
138
+ # 'MoveCube',
139
+ # "PatternLock",
140
+ # "RouteStick"
141
+ ]
142
+
143
+ for env_id in env_id_list:
144
+ file_name = f"record_dataset_{env_id}.h5"
145
+ file_path = os.path.join(h5_data_dir, file_name)
146
+ if not os.path.exists(file_path):
147
+ print(f"Skipping {env_id}: file not found: {file_path}")
148
+ continue
149
+
150
+ with h5py.File(file_path, "r") as data:
151
+ episode_indices = sorted(
152
+ int(k.split("_")[1])
153
+ for k in data.keys()
154
+ if k.startswith("episode_")
155
+ )
156
+ print(f"Task: {env_id}, has {len(episode_indices)} episodes")
157
+ for episode_idx in episode_indices[:1]:
158
+ process_episode(data, episode_idx, env_id)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ import tyro
163
+ tyro.cli(replay)
scripts/dev/eval-dataset-offline-rpy.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import h5py
5
+ import numpy as np
6
+ import argparse
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from robomme.robomme_env.utils.rpy_util import summarize_and_print_rpy_sequence
11
+
12
+ def _write_split_rpy_summaries_json(
13
+ path: str,
14
+ demo_summaries: list[dict[str, Any]],
15
+ non_demo_summaries: list[dict[str, Any]],
16
+ ) -> None:
17
+ """
18
+ Summarize both demo and non-demo parts and write to JSON.
19
+ """
20
+ if os.path.dirname(path):
21
+ os.makedirs(os.path.dirname(path), exist_ok=True)
22
+ payload = {
23
+ "demo_summaries": demo_summaries,
24
+ "non_demo_summaries": non_demo_summaries,
25
+ }
26
+ with open(path, "w", encoding="utf-8") as f:
27
+ json.dump(payload, f, ensure_ascii=False, indent=2)
28
+
29
+
30
+ def _read_is_video_demo(ts_group: h5py.Group) -> bool:
31
+ """Read info/is_video_demo from timestep group, default to False if missing."""
32
+ info_grp = ts_group.get("info")
33
+ if info_grp is not None and "is_video_demo" in info_grp:
34
+ val = info_grp["is_video_demo"][()]
35
+ if isinstance(val, (bytes, np.bytes_)):
36
+ return val in (b"True", b"true", b"1")
37
+ return bool(val)
38
+ return False
39
+
40
+
41
+ def _extract_rpy_from_timestep(ts_group: h5py.Group) -> list[np.ndarray]:
42
+ """Extract RPY vector list from a single timestep."""
43
+ if (
44
+ "action" in ts_group
45
+ and "eef_action_raw" in ts_group["action"]
46
+ and "rpy" in ts_group["action"]["eef_action_raw"]
47
+ ):
48
+ rpy_data = ts_group["action"]["eef_action_raw"]["rpy"][()]
49
+ rpy_arr = np.asarray(rpy_data, dtype=np.float64)
50
+ if rpy_arr.ndim == 1:
51
+ rpy_arr = rpy_arr.reshape(1, -1)
52
+ else:
53
+ rpy_arr = rpy_arr.reshape(-1, rpy_arr.shape[-1])
54
+ if rpy_arr.shape[-1] == 3:
55
+ return [row.copy() for row in rpy_arr]
56
+ return []
57
+
58
+
59
+ def main():
60
+ # Hardcoded dataset directory as requested
61
+ DATASET_DIR = Path("/data/hongzefu/dataset_generate")
62
+
63
+ parser = argparse.ArgumentParser(description="Read generated HDF5 dataset and verify RPY consistency.")
64
+ parser.add_argument("--dataset_path", type=str, default=str(DATASET_DIR), help="Path to the HDF5 file or directory to verify.")
65
+ args = parser.parse_args()
66
+
67
+ input_path = Path(args.dataset_path).resolve()
68
+
69
+ if not input_path.exists():
70
+ print(f"Error: Path not found: {input_path}")
71
+ sys.exit(1)
72
+
73
+ # Determine files to process
74
+ files_to_process = []
75
+ if input_path.is_file():
76
+ if input_path.suffix in ['.h5', '.hdf5']:
77
+ files_to_process.append(input_path)
78
+ elif input_path.is_dir():
79
+ files_to_process.extend(sorted(input_path.glob("*.h5")))
80
+ files_to_process.extend(sorted(input_path.glob("*.hdf5")))
81
+
82
+ if not files_to_process:
83
+ print(f"No HDF5 files found in {input_path}")
84
+ sys.exit(0)
85
+
86
+ print(f"Found {len(files_to_process)} files to process in {input_path}")
87
+
88
+ for dataset_path in files_to_process:
89
+ print(f"\n{'='*50}")
90
+ print(f"Processing dataset: {dataset_path}")
91
+ print(f"{'='*50}")
92
+
93
+ # Generate output JSON path
94
+ output_json_path = Path("/data/hongzefu/dataset_replay") / f"{dataset_path.stem}_rpy_summary.json"
95
+
96
+ demo_summaries: list[dict[str, Any]] = []
97
+ non_demo_summaries: list[dict[str, Any]] = []
98
+
99
+ try:
100
+ with h5py.File(dataset_path, "r") as f:
101
+ # Iterate through environments (e.g., env_PickXtimes...)
102
+ env_groups = [key for key in f.keys() if key.startswith("env_")]
103
+ env_groups.sort()
104
+
105
+ if not env_groups:
106
+ print(f"Warning: No 'env_*' groups found in {dataset_path.name}")
107
+
108
+ for env_group_name in env_groups:
109
+ env_group = f[env_group_name]
110
+ print(f"Processing environment group: {env_group_name}")
111
+
112
+ # Extract env_id from group name (remove 'env_' prefix)
113
+ env_id = env_group_name[4:]
114
+
115
+ # Iterate through episodes
116
+ episode_keys = [key for key in env_group.keys() if key.startswith("episode_")]
117
+ # Sort numerically by episode ID
118
+ episode_keys.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else x)
119
+
120
+ for episode_key in episode_keys:
121
+ print(f" Processing {episode_key}...")
122
+ episode_group = env_group[episode_key]
123
+ try:
124
+ episode_idx = int(episode_key.split('_')[1])
125
+ except (IndexError, ValueError):
126
+ episode_idx = -1
127
+
128
+ # Iterate through timesteps to reconstruct sequence
129
+ timestep_keys = [key for key in episode_group.keys() if key.startswith("record_timestep_")]
130
+
131
+ def get_timestep_idx(key):
132
+ parts = key.split('_')
133
+ try:
134
+ return int(parts[2])
135
+ except (IndexError, ValueError):
136
+ return -1
137
+
138
+ timestep_keys.sort(key=get_timestep_idx)
139
+
140
+ # Separate RPY sequences by is_video_demo flag
141
+ demo_rpy_seq: list[np.ndarray] = []
142
+ non_demo_rpy_seq: list[np.ndarray] = []
143
+
144
+ for ts_key in timestep_keys:
145
+ ts_group = episode_group[ts_key]
146
+ rpy_rows = _extract_rpy_from_timestep(ts_group)
147
+ if rpy_rows:
148
+ if _read_is_video_demo(ts_group):
149
+ demo_rpy_seq.extend(rpy_rows)
150
+ else:
151
+ non_demo_rpy_seq.extend(rpy_rows)
152
+
153
+ # Summarize demo portion
154
+ if demo_rpy_seq:
155
+ demo_summary = summarize_and_print_rpy_sequence(
156
+ demo_rpy_seq,
157
+ label=f"[{env_id}] episode {episode_idx} (demo)",
158
+ )
159
+ demo_summaries.append({
160
+ "order_index": len(demo_summaries),
161
+ "env_id": env_id,
162
+ "episode": episode_idx,
163
+ "action_space": "eef_pose",
164
+ "summary": demo_summary,
165
+ })
166
+
167
+ # Summarize non-demo portion
168
+ if non_demo_rpy_seq:
169
+ non_demo_summary = summarize_and_print_rpy_sequence(
170
+ non_demo_rpy_seq,
171
+ label=f"[{env_id}] episode {episode_idx} (non-demo)",
172
+ )
173
+ non_demo_summaries.append({
174
+ "order_index": len(non_demo_summaries),
175
+ "env_id": env_id,
176
+ "episode": episode_idx,
177
+ "action_space": "eef_pose",
178
+ "summary": non_demo_summary,
179
+ })
180
+
181
+ except Exception as e:
182
+ print(f"An error occurred while reading {dataset_path.name}: {e}")
183
+ import traceback
184
+ traceback.print_exc()
185
+
186
+ # Write summary to JSON
187
+ if demo_summaries or non_demo_summaries:
188
+ _write_split_rpy_summaries_json(str(output_json_path), demo_summaries, non_demo_summaries)
189
+ print(f"Saved split RPY summaries to: {output_json_path}")
190
+ print(f" demo entries: {len(demo_summaries)}, non-demo entries: {len(non_demo_summaries)}")
191
+ else:
192
+ print(f"No summaries generated for {dataset_path.name}")
193
+
194
+ if __name__ == "__main__":
195
+ main()
scripts/dev/eval_dataset_replay.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Script function: Unified dataset replay entry point, supporting 4 action spaces: joint_angle / ee_pose / waypoint / multi_choice.
3
+ # Consistent with subgoal_evaluate_func.py main loop; difference is actions come from EpisodeDatasetResolver.
4
+
5
+ import os
6
+ from typing import Any, Optional
7
+
8
+
9
+
10
+ import os
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
12
+
13
+
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+
19
+ from robomme.robomme_env import *
20
+ from robomme.robomme_env.utils import *
21
+ from robomme.env_record_wrapper import (
22
+ BenchmarkEnvBuilder,
23
+ EpisodeDatasetResolver,
24
+ )
25
+ from robomme.env_record_wrapper.OraclePlannerDemonstrationWrapper import (
26
+ OraclePlannerDemonstrationWrapper,
27
+ )
28
+ from robomme.robomme_env.utils.choice_action_mapping import (
29
+ _unique_candidates,
30
+ extract_actor_position_xyz,
31
+ project_world_to_pixel,
32
+ select_target_with_pixel,
33
+ )
34
+ from robomme.robomme_env.utils.save_reset_video import save_robomme_video
35
+
36
+ # Only enable one ACTION_SPACE; others are commented out for manual switching
37
+ ACTION_SPACE = "joint_angle"
38
+
39
+
40
+ GUI_RENDER = False
41
+
42
+ DATASET_ROOT = "/data/hongzefu/data_0226"
43
+
44
+ DEFAULT_ENV_IDS = [
45
+ #"PickXtimes",
46
+ #"StopCube",
47
+ #"SwingXtimes",
48
+ "BinFill",
49
+ # "VideoUnmaskSwap",
50
+ # "VideoUnmask",
51
+ # "ButtonUnmaskSwap",
52
+ # "ButtonUnmask",
53
+ # "VideoRepick",
54
+ # "VideoPlaceButton",
55
+ # "VideoPlaceOrder",
56
+ #"PickHighlight",
57
+ #"InsertPeg",
58
+ #"MoveCube",
59
+ #"PatternLock",
60
+ #"RouteStick",
61
+ ]
62
+
63
+ OUT_VIDEO_DIR = "/data/hongzefu/dataset_replay"
64
+ MAX_STEPS = 1000
65
+
66
+
67
+ def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict[str, Any]]:
68
+ if not isinstance(choice_action, dict):
69
+ return None
70
+ choice = choice_action.get("choice")
71
+ if not isinstance(choice, str) or not choice.strip():
72
+ return None
73
+ point = choice_action.get("point")
74
+ if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2:
75
+ return None
76
+ return choice_action
77
+
78
+
79
+ def _to_numpy_copy(value: Any) -> np.ndarray:
80
+ if isinstance(value, torch.Tensor):
81
+ value = value.detach().cpu().numpy()
82
+ else:
83
+ value = np.asarray(value)
84
+ return np.array(value, copy=True)
85
+
86
+
87
+ def _to_frame_list(frames_like: Any) -> list[np.ndarray]:
88
+ if frames_like is None:
89
+ return []
90
+ if isinstance(frames_like, torch.Tensor):
91
+ arr = frames_like.detach().cpu().numpy()
92
+ if arr.ndim == 3:
93
+ return [np.array(arr, copy=True)]
94
+ if arr.ndim == 4:
95
+ return [np.array(x, copy=True) for x in arr]
96
+ return []
97
+ if isinstance(frames_like, np.ndarray):
98
+ if frames_like.ndim == 3:
99
+ return [np.array(frames_like, copy=True)]
100
+ if frames_like.ndim == 4:
101
+ return [np.array(x, copy=True) for x in frames_like]
102
+ return []
103
+ if isinstance(frames_like, (list, tuple)):
104
+ out = []
105
+ for frame in frames_like:
106
+ if frame is None:
107
+ continue
108
+ out.append(_to_numpy_copy(frame))
109
+ return out
110
+ try:
111
+ arr = np.asarray(frames_like)
112
+ except Exception:
113
+ return []
114
+ if arr.ndim == 3:
115
+ return [np.array(arr, copy=True)]
116
+ if arr.ndim == 4:
117
+ return [np.array(x, copy=True) for x in arr]
118
+ return []
119
+
120
+
121
+ def _normalize_pixel_xy(pixel_like: Any) -> Optional[list[int]]:
122
+ if not isinstance(pixel_like, (list, tuple, np.ndarray)):
123
+ return None
124
+ if len(pixel_like) < 2:
125
+ return None
126
+ try:
127
+ x = float(pixel_like[0])
128
+ y = float(pixel_like[1])
129
+ except (TypeError, ValueError):
130
+ return None
131
+ if not np.isfinite(x) or not np.isfinite(y):
132
+ return None
133
+ return [int(np.rint(x)), int(np.rint(y))]
134
+
135
+
136
+ def _normalize_point_yx_to_pixel_xy(point_like: Any) -> Optional[list[int]]:
137
+ if not isinstance(point_like, (list, tuple, np.ndarray)):
138
+ return None
139
+ if len(point_like) < 2:
140
+ return None
141
+ try:
142
+ y = float(point_like[0])
143
+ x = float(point_like[1])
144
+ except (TypeError, ValueError):
145
+ return None
146
+ if not np.isfinite(x) or not np.isfinite(y):
147
+ return None
148
+ return [int(np.rint(x)), int(np.rint(y))]
149
+
150
+
151
+ def _find_oracle_wrapper(env_like: Any) -> Optional[OraclePlannerDemonstrationWrapper]:
152
+ current = env_like
153
+ visited: set[int] = set()
154
+ for _ in range(16):
155
+ if current is None:
156
+ return None
157
+ if isinstance(current, OraclePlannerDemonstrationWrapper):
158
+ return current
159
+ obj_id = id(current)
160
+ if obj_id in visited:
161
+ return None
162
+ visited.add(obj_id)
163
+ current = getattr(current, "env", None)
164
+ return None
165
+
166
+
167
+ def _collect_multi_choice_visualization(
168
+ env_like: Any,
169
+ command: dict[str, Any],
170
+ ) -> tuple[list[list[int]], Optional[list[int]], Optional[list[int]]]:
171
+ clicked_pixel = _normalize_point_yx_to_pixel_xy(command.get("point"))
172
+ oracle_wrapper = _find_oracle_wrapper(env_like)
173
+ if oracle_wrapper is None:
174
+ return [], clicked_pixel, None
175
+
176
+ try:
177
+ _selected_target, solve_options = oracle_wrapper._build_step_options()
178
+ found_idx, _ = oracle_wrapper._resolve_command(command, solve_options)
179
+ except Exception:
180
+ return [], clicked_pixel, None
181
+
182
+ if found_idx is None or found_idx < 0 or found_idx >= len(solve_options):
183
+ return [], clicked_pixel, None
184
+
185
+ option = solve_options[found_idx]
186
+ available = option.get("available")
187
+ intrinsic_cv = getattr(oracle_wrapper, "_front_camera_intrinsic_cv", None)
188
+ extrinsic_cv = getattr(oracle_wrapper, "_front_camera_extrinsic_cv", None)
189
+ image_shape = getattr(oracle_wrapper, "_front_rgb_shape", None)
190
+
191
+ candidate_pixels: list[list[int]] = []
192
+ if available is not None:
193
+ for actor in _unique_candidates(available):
194
+ actor_pos = extract_actor_position_xyz(actor)
195
+ if actor_pos is None:
196
+ continue
197
+ projected = project_world_to_pixel(
198
+ actor_pos,
199
+ intrinsic_cv=intrinsic_cv,
200
+ extrinsic_cv=extrinsic_cv,
201
+ image_shape=image_shape,
202
+ )
203
+ if projected is None:
204
+ continue
205
+ candidate_pixels.append([int(projected[0]), int(projected[1])])
206
+
207
+ matched_pixel: Optional[list[int]] = None
208
+ if available is not None and clicked_pixel is not None:
209
+ matched = select_target_with_pixel(
210
+ available=available,
211
+ pixel_like=clicked_pixel,
212
+ intrinsic_cv=intrinsic_cv,
213
+ extrinsic_cv=extrinsic_cv,
214
+ image_shape=image_shape,
215
+ )
216
+ if isinstance(matched, dict):
217
+ matched_pixel = _normalize_pixel_xy(matched.get("projected_pixel"))
218
+
219
+ return candidate_pixels, clicked_pixel, matched_pixel
220
+
221
+
222
+ def _make_blackboard(frame_like: Any) -> np.ndarray:
223
+ frame = _to_numpy_copy(frame_like)
224
+ if frame.ndim < 2:
225
+ return np.zeros((1, 1, 3), dtype=np.uint8)
226
+ h, w = int(frame.shape[0]), int(frame.shape[1])
227
+ if h <= 0 or w <= 0:
228
+ return np.zeros((1, 1, 3), dtype=np.uint8)
229
+ return np.zeros((h, w, 3), dtype=np.uint8)
230
+
231
+
232
+ def _draw_candidate_blackboard(
233
+ frame_like: Any,
234
+ candidate_pixels: list[list[int]],
235
+ ) -> np.ndarray:
236
+ board = _make_blackboard(frame_like)
237
+ for pixel in candidate_pixels:
238
+ if len(pixel) < 2:
239
+ continue
240
+ cv2.circle(board, (int(pixel[0]), int(pixel[1])), 4, (0, 255, 255), 1)
241
+ return board
242
+
243
+
244
+ def _draw_selection_blackboard(
245
+ frame_like: Any,
246
+ clicked_pixel: Optional[list[int]],
247
+ matched_pixel: Optional[list[int]],
248
+ ) -> np.ndarray:
249
+ board = _make_blackboard(frame_like)
250
+ if clicked_pixel is not None:
251
+ cv2.drawMarker(
252
+ board,
253
+ (int(clicked_pixel[0]), int(clicked_pixel[1])),
254
+ (255, 255, 0),
255
+ markerType=cv2.MARKER_TILTED_CROSS,
256
+ markerSize=10,
257
+ thickness=1,
258
+ )
259
+ if matched_pixel is not None:
260
+ cv2.circle(board, (int(matched_pixel[0]), int(matched_pixel[1])), 5, (255, 0, 0), 2)
261
+ return board
262
+
263
+
264
+
265
+
266
+ def main():
267
+ from robomme.logging_utils import setup_logging
268
+ setup_logging(level="DEBUG")
269
+ env_id_list = BenchmarkEnvBuilder.get_task_list()
270
+ print(f"Running envs: {env_id_list}")
271
+ print(f"Using action_space: {ACTION_SPACE}")
272
+
273
+ #for env_id in env_id_list:
274
+ for env_id in DEFAULT_ENV_IDS:
275
+ env_builder = BenchmarkEnvBuilder(
276
+ env_id=env_id,
277
+ dataset="train",
278
+ action_space=ACTION_SPACE,
279
+ gui_render=GUI_RENDER,
280
+ )
281
+ episode_count = env_builder.get_episode_num()
282
+ print(f"[{env_id}] episode_count from metadata: {episode_count}")
283
+
284
+ env = None
285
+ for episode in range(episode_count):
286
+ if episode !=15:
287
+ continue
288
+
289
+ env = env_builder.make_env_for_episode(
290
+ episode,
291
+ max_steps=MAX_STEPS,
292
+ include_maniskill_obs=True,
293
+ include_front_depth=True,
294
+ include_wrist_depth=True,
295
+ include_front_camera_extrinsic=True,
296
+ include_wrist_camera_extrinsic=True,
297
+ include_available_multi_choices=True,
298
+ include_front_camera_intrinsic=True,
299
+ include_wrist_camera_intrinsic=True,
300
+ )
301
+ try:
302
+ dataset_resolver = EpisodeDatasetResolver(
303
+ env_id=env_id,
304
+ episode=episode,
305
+ dataset_directory=DATASET_ROOT,
306
+ )
307
+ except KeyError as e:
308
+ print(f"[{env_id}] Episode {episode} missing in H5, skipping. ({e})")
309
+ if env is not None:
310
+ env.close()
311
+ continue
312
+
313
+ # ======== Reset ========
314
+ # obs: dict-of-lists (columnar batch, list length = number of demo frames)
315
+ # info: flat dict (last frame values only)
316
+ obs, info = env.reset()
317
+
318
+ # --- Explicitly read all obs fields (each is a list) ---
319
+ maniskill_obs = obs["maniskill_obs"]
320
+ front_rgb_list = _to_frame_list(obs["front_rgb_list"])
321
+ wrist_rgb_list = _to_frame_list(obs["wrist_rgb_list"])
322
+ front_depth_list = obs["front_depth_list"]
323
+ wrist_depth_list = obs["wrist_depth_list"]
324
+ end_effector_pose_raw = obs["end_effector_pose_raw"]
325
+ eef_state_list = obs["eef_state_list"]
326
+ joint_state_list = obs["joint_state_list"]
327
+ # velocity = obs["velocity"]
328
+ gripper_state_list = obs["gripper_state_list"]
329
+ front_camera_extrinsic_list = obs["front_camera_extrinsic_list"]
330
+ wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"]
331
+
332
+ # --- Explicitly read all info fields (flat dict, last frame values) ---
333
+ task_goal = info["task_goal"]
334
+ simple_subgoal_online = info["simple_subgoal_online"]
335
+ grounded_subgoal_online = info["grounded_subgoal_online"]
336
+ available_multi_choices = info.get("available_multi_choices")
337
+ front_camera_intrinsic = info["front_camera_intrinsic"]
338
+ wrist_camera_intrinsic = info["wrist_camera_intrinsic"]
339
+ status = info.get("status")
340
+
341
+
342
+ # --- Video saving variable preparation (reset phase) ---
343
+ reset_base_frames = [_to_numpy_copy(f) for f in front_rgb_list]
344
+ reset_wrist_frames = [_to_numpy_copy(f) for f in wrist_rgb_list]
345
+ reset_right_frames = (
346
+ [_make_blackboard(f) for f in reset_base_frames]
347
+ if ACTION_SPACE == "multi_choice"
348
+ else None
349
+ )
350
+ reset_far_right_frames = (
351
+ [_make_blackboard(f) for f in reset_base_frames]
352
+ if ACTION_SPACE == "multi_choice"
353
+ else None
354
+ )
355
+ reset_subgoal_grounded = [grounded_subgoal_online] * len(front_rgb_list)
356
+
357
+ step = 0
358
+ episode_success = False
359
+ rollout_base_frames: list[np.ndarray] = []
360
+ rollout_wrist_frames: list[np.ndarray] = []
361
+ rollout_right_frames: list[np.ndarray] = []
362
+ rollout_far_right_frames: list[np.ndarray] = []
363
+ rollout_subgoal_grounded: list[Any] = []
364
+
365
+ # ======== Step loop ========
366
+ while True:
367
+ replay_key = ACTION_SPACE
368
+ action = dataset_resolver.get_step(replay_key, step)
369
+ if ACTION_SPACE == "multi_choice":
370
+ action = _parse_oracle_command(action)
371
+ if action is None:
372
+ break
373
+
374
+ candidate_pixels: list[list[int]] = []
375
+ clicked_pixel: Optional[list[int]] = None
376
+ matched_pixel: Optional[list[int]] = None
377
+ if ACTION_SPACE == "multi_choice":
378
+ candidate_pixels, clicked_pixel, matched_pixel = _collect_multi_choice_visualization(
379
+ env, action
380
+ )
381
+
382
+ # step returns: obs (dict-of-lists), reward (scalar tensor),
383
+ # terminated (scalar tensor), truncated (scalar tensor), info (flat dict)
384
+ obs, reward, terminated, truncated, info = env.step(action)
385
+
386
+ # --- Explicitly read all obs fields (dict-of-lists, typically 1 element per list) -- maniskill_obs = obs["maniskill_obs"]
387
+ front_rgb_list = _to_frame_list(obs["front_rgb_list"])
388
+ wrist_rgb_list = _to_frame_list(obs["wrist_rgb_list"])
389
+ front_depth_list = obs["front_depth_list"]
390
+ wrist_depth_list = obs["wrist_depth_list"]
391
+ end_effector_pose_raw = obs["end_effector_pose_raw"]
392
+ eef_state_list = obs["eef_state_list"]
393
+ joint_state_list = obs["joint_state_list"]
394
+ gripper_state_list = obs["gripper_state_list"]
395
+ front_camera_extrinsic_list = obs["front_camera_extrinsic_list"]
396
+ wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"]
397
+
398
+ # --- Explicitly read all info fields (flat dict) ---
399
+ task_goal = info["task_goal"]
400
+ simple_subgoal_online = info["simple_subgoal_online"]
401
+ grounded_subgoal_online = info["grounded_subgoal_online"]
402
+ available_multi_choices = info.get("available_multi_choices")
403
+ front_camera_intrinsic = info["front_camera_intrinsic"]
404
+ wrist_camera_intrinsic = info["wrist_camera_intrinsic"]
405
+ status = info.get("status")
406
+
407
+ # --- Video saving variable preparation (replay phase) ---
408
+ rollout_base_frames.extend(
409
+ _to_numpy_copy(f) for f in front_rgb_list
410
+ )
411
+ rollout_wrist_frames.extend(
412
+ _to_numpy_copy(f) for f in wrist_rgb_list
413
+ )
414
+ if ACTION_SPACE == "multi_choice":
415
+ for base_frame in front_rgb_list:
416
+ rollout_right_frames.append(
417
+ _draw_candidate_blackboard(
418
+ base_frame,
419
+ candidate_pixels=candidate_pixels,
420
+ )
421
+ )
422
+ rollout_far_right_frames.append(
423
+ _draw_selection_blackboard(
424
+ base_frame,
425
+ clicked_pixel=clicked_pixel,
426
+ matched_pixel=matched_pixel,
427
+ )
428
+ )
429
+ rollout_subgoal_grounded.extend([grounded_subgoal_online] * len(front_rgb_list))
430
+
431
+ terminated_flag = bool(terminated.item())
432
+ truncated_flag = bool(truncated.item())
433
+
434
+ step += 1
435
+ if GUI_RENDER:
436
+ env.render()
437
+ if truncated_flag:
438
+ print(f"[{env_id}] episode {episode} steps exceeded, step {step}.")
439
+ break
440
+ if terminated_flag:
441
+ if status == "success":
442
+ print(f"[{env_id}] episode {episode} success.")
443
+ episode_success = True
444
+ elif status == "fail":
445
+ print(f"[{env_id}] episode {episode} failed.")
446
+ break
447
+
448
+ # ======== Video saving ========
449
+ save_robomme_video(
450
+ reset_base_frames=reset_base_frames,
451
+ reset_wrist_frames=reset_wrist_frames,
452
+ rollout_base_frames=rollout_base_frames,
453
+ rollout_wrist_frames=rollout_wrist_frames,
454
+ reset_subgoal_grounded=reset_subgoal_grounded,
455
+ rollout_subgoal_grounded=rollout_subgoal_grounded,
456
+ out_video_dir=OUT_VIDEO_DIR,
457
+ action_space=ACTION_SPACE,
458
+ env_id=env_id,
459
+ episode=episode,
460
+ episode_success=episode_success,
461
+ reset_right_frames=reset_right_frames if ACTION_SPACE == "multi_choice" else None,
462
+ rollout_right_frames=rollout_right_frames if ACTION_SPACE == "multi_choice" else None,
463
+ reset_far_right_frames=(
464
+ reset_far_right_frames if ACTION_SPACE == "multi_choice" else None
465
+ ),
466
+ rollout_far_right_frames=(
467
+ rollout_far_right_frames if ACTION_SPACE == "multi_choice" else None
468
+ ),
469
+ )
470
+
471
+ if env is not None:
472
+ env.close()
473
+
474
+
475
+ if __name__ == "__main__":
476
+ main()
scripts/dev/evaluate_dataset_replay-parallelv3.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Script function: Unified dataset replay entry point, supports four action_spaces: joint_angle / ee_pose / waypoint / multi_choice.
3
+ # Consistent with subgoal_evaluate_func.py's main loop and debug fields; the difference is that actions come from EpisodeDatasetResolver.
4
+ # [New] Support parallel multi-process replay and alternate task assignment between two GPUs.
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import concurrent.futures
10
+ import multiprocessing as mp
11
+ from typing import Any, Optional
12
+
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+
17
+ from robomme.robomme_env import *
18
+ from robomme.robomme_env.utils import *
19
+ from robomme.env_record_wrapper import (
20
+ BenchmarkEnvBuilder,
21
+ EpisodeDatasetResolver,
22
+ )
23
+ from robomme.env_record_wrapper.OraclePlannerDemonstrationWrapper import (
24
+ OraclePlannerDemonstrationWrapper,
25
+ )
26
+ from robomme.robomme_env.utils.choice_action_mapping import (
27
+ _unique_candidates,
28
+ extract_actor_position_xyz,
29
+ project_world_to_pixel,
30
+ select_target_with_pixel,
31
+ )
32
+ from robomme.robomme_env.utils.save_reset_video import save_robomme_video
33
+
34
+ AVAILABLE_ACTION_SPACES = [
35
+ "joint_angle",
36
+ "ee_pose",
37
+ "waypoint",
38
+ "multi_choice",
39
+ ]
40
+
41
+ GUI_RENDER = False
42
+
43
+ DATASET_ROOT = "/data/hongzefu/data_0226-test"
44
+ OVERRIDE_METADATA_PATH = "/data/hongzefu/data_0226-test"
45
+
46
+ # ######## Video saving variables (output directory) start ########
47
+ # Video output directory: Independently hardcoded, not aligned with h5 path or env_id
48
+ OUT_VIDEO_DIR = "/data/hongzefu/dataset_replay-0226-test"
49
+ # ######## Video saving variables (output directory) end ########
50
+ MAX_STEPS = 2000
51
+
52
+ DEFAULT_ENV_IDS = [
53
+ # "PickXtimes",
54
+ # "StopCube",
55
+ # "SwingXtimes",
56
+ # "BinFill",
57
+ # "VideoUnmaskSwap",
58
+ # "VideoUnmask",
59
+ # "ButtonUnmaskSwap",
60
+ # "ButtonUnmask",
61
+ # "VideoRepick",
62
+ # "VideoPlaceButton",
63
+ # "VideoPlaceOrder",
64
+ # "PickHighlight",
65
+ # "InsertPeg",
66
+ # "MoveCube",
67
+ "PatternLock",
68
+ # "RouteStick",
69
+ ]
70
+
71
+ def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict[str, Any]]:
72
+ if not isinstance(choice_action, dict):
73
+ return None
74
+ choice = choice_action.get("choice")
75
+ if not isinstance(choice, str) or not choice.strip():
76
+ return None
77
+ if "point" not in choice_action:
78
+ return None
79
+ return {
80
+ "choice": choice_action.get("choice"),
81
+ "point": choice_action.get("point"),
82
+ }
83
+
84
+
85
+ def _to_numpy_copy(value: Any) -> np.ndarray:
86
+ if isinstance(value, torch.Tensor):
87
+ value = value.detach().cpu().numpy()
88
+ else:
89
+ value = np.asarray(value)
90
+ return np.array(value, copy=True)
91
+
92
+
93
+ def _to_frame_list(frames_like: Any) -> list[np.ndarray]:
94
+ if frames_like is None:
95
+ return []
96
+ if isinstance(frames_like, torch.Tensor):
97
+ arr = frames_like.detach().cpu().numpy()
98
+ if arr.ndim == 3:
99
+ return [np.array(arr, copy=True)]
100
+ if arr.ndim == 4:
101
+ return [np.array(x, copy=True) for x in arr]
102
+ return []
103
+ if isinstance(frames_like, np.ndarray):
104
+ if frames_like.ndim == 3:
105
+ return [np.array(frames_like, copy=True)]
106
+ if frames_like.ndim == 4:
107
+ return [np.array(x, copy=True) for x in frames_like]
108
+ return []
109
+ if isinstance(frames_like, (list, tuple)):
110
+ out = []
111
+ for frame in frames_like:
112
+ if frame is None:
113
+ continue
114
+ out.append(_to_numpy_copy(frame))
115
+ return out
116
+ try:
117
+ arr = np.asarray(frames_like)
118
+ except Exception:
119
+ return []
120
+ if arr.ndim == 3:
121
+ return [np.array(arr, copy=True)]
122
+ if arr.ndim == 4:
123
+ return [np.array(x, copy=True) for x in arr]
124
+ return []
125
+
126
+
127
+ def _normalize_pixel_xy(pixel_like: Any) -> Optional[list[int]]:
128
+ if not isinstance(pixel_like, (list, tuple, np.ndarray)):
129
+ return None
130
+ if len(pixel_like) < 2:
131
+ return None
132
+ try:
133
+ x = float(pixel_like[0])
134
+ y = float(pixel_like[1])
135
+ except (TypeError, ValueError):
136
+ return None
137
+ if not np.isfinite(x) or not np.isfinite(y):
138
+ return None
139
+ return [int(np.rint(x)), int(np.rint(y))]
140
+
141
+
142
+ def _normalize_point_yx_to_pixel_xy(point_like: Any) -> Optional[list[int]]:
143
+ if not isinstance(point_like, (list, tuple, np.ndarray)):
144
+ return None
145
+ if len(point_like) < 2:
146
+ return None
147
+ try:
148
+ y = float(point_like[0])
149
+ x = float(point_like[1])
150
+ except (TypeError, ValueError):
151
+ return None
152
+ if not np.isfinite(x) or not np.isfinite(y):
153
+ return None
154
+ return [int(np.rint(x)), int(np.rint(y))]
155
+
156
+
157
+ def _find_oracle_wrapper(env_like: Any) -> Optional[OraclePlannerDemonstrationWrapper]:
158
+ current = env_like
159
+ visited: set[int] = set()
160
+ for _ in range(16):
161
+ if current is None:
162
+ return None
163
+ if isinstance(current, OraclePlannerDemonstrationWrapper):
164
+ return current
165
+ obj_id = id(current)
166
+ if obj_id in visited:
167
+ return None
168
+ visited.add(obj_id)
169
+ current = getattr(current, "env", None)
170
+ return None
171
+
172
+
173
+ def _collect_multi_choice_visualization(
174
+ env_like: Any,
175
+ command: dict[str, Any],
176
+ ) -> tuple[list[list[int]], Optional[list[int]], Optional[list[int]]]:
177
+ clicked_pixel = _normalize_point_yx_to_pixel_xy(command.get("point"))
178
+ oracle_wrapper = _find_oracle_wrapper(env_like)
179
+ if oracle_wrapper is None:
180
+ return [], clicked_pixel, None
181
+
182
+ try:
183
+ _selected_target, solve_options = oracle_wrapper._build_step_options()
184
+ found_idx, _ = oracle_wrapper._resolve_command(command, solve_options)
185
+ except Exception:
186
+ return [], clicked_pixel, None
187
+
188
+ if found_idx is None or found_idx < 0 or found_idx >= len(solve_options):
189
+ return [], clicked_pixel, None
190
+
191
+ option = solve_options[found_idx]
192
+ available = option.get("available")
193
+ intrinsic_cv = getattr(oracle_wrapper, "_front_camera_intrinsic_cv", None)
194
+ extrinsic_cv = getattr(oracle_wrapper, "_front_camera_extrinsic_cv", None)
195
+ image_shape = getattr(oracle_wrapper, "_front_rgb_shape", None)
196
+
197
+ candidate_pixels: list[list[int]] = []
198
+ if available is not None:
199
+ for actor in _unique_candidates(available):
200
+ actor_pos = extract_actor_position_xyz(actor)
201
+ if actor_pos is None:
202
+ continue
203
+ projected = project_world_to_pixel(
204
+ actor_pos,
205
+ intrinsic_cv=intrinsic_cv,
206
+ extrinsic_cv=extrinsic_cv,
207
+ image_shape=image_shape,
208
+ )
209
+ if projected is None:
210
+ continue
211
+ candidate_pixels.append([int(projected[0]), int(projected[1])])
212
+
213
+ matched_pixel: Optional[list[int]] = None
214
+ if available is not None and clicked_pixel is not None:
215
+ matched = select_target_with_pixel(
216
+ available=available,
217
+ pixel_like=clicked_pixel,
218
+ intrinsic_cv=intrinsic_cv,
219
+ extrinsic_cv=extrinsic_cv,
220
+ image_shape=image_shape,
221
+ )
222
+ if isinstance(matched, dict):
223
+ matched_pixel = _normalize_pixel_xy(matched.get("projected_pixel"))
224
+
225
+ return candidate_pixels, clicked_pixel, matched_pixel
226
+
227
+
228
+ def _make_blackboard(frame_like: Any) -> np.ndarray:
229
+ frame = _to_numpy_copy(frame_like)
230
+ if frame.ndim < 2:
231
+ return np.zeros((1, 1, 3), dtype=np.uint8)
232
+ h, w = int(frame.shape[0]), int(frame.shape[1])
233
+ if h <= 0 or w <= 0:
234
+ return np.zeros((1, 1, 3), dtype=np.uint8)
235
+ return np.zeros((h, w, 3), dtype=np.uint8)
236
+
237
+
238
+ def _draw_candidate_blackboard(
239
+ frame_like: Any,
240
+ candidate_pixels: list[list[int]],
241
+ ) -> np.ndarray:
242
+ board = _make_blackboard(frame_like)
243
+ for pixel in candidate_pixels:
244
+ if len(pixel) < 2:
245
+ continue
246
+ cv2.circle(board, (int(pixel[0]), int(pixel[1])), 4, (0, 255, 255), 1)
247
+ return board
248
+
249
+
250
+ def _draw_selection_blackboard(
251
+ frame_like: Any,
252
+ clicked_pixel: Optional[list[int]],
253
+ matched_pixel: Optional[list[int]],
254
+ ) -> np.ndarray:
255
+ board = _make_blackboard(frame_like)
256
+ if clicked_pixel is not None:
257
+ cv2.drawMarker(
258
+ board,
259
+ (int(clicked_pixel[0]), int(clicked_pixel[1])),
260
+ (255, 255, 0),
261
+ markerType=cv2.MARKER_TILTED_CROSS,
262
+ markerSize=10,
263
+ thickness=1,
264
+ )
265
+ if matched_pixel is not None:
266
+ cv2.circle(board, (int(matched_pixel[0]), int(matched_pixel[1])), 5, (255, 0, 0), 2)
267
+ return board
268
+
269
+
270
+ def init_worker(gpu_id: int):
271
+ """
272
+ Worker process initialization function, sets CUDA_VISIBLE_DEVICES.
273
+ """
274
+ from robomme.logging_utils import setup_logging
275
+ setup_logging(level="DEBUG")
276
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
277
+ # print(f"[Worker] Initialized on GPU {gpu_id} (PID: {os.getpid()})")
278
+
279
+ def evaluate_episode(
280
+ env_id: str,
281
+ episode: int,
282
+ dataset_root: str,
283
+ override_metadata_path: str,
284
+ action_space: str,
285
+ out_video_dir: str,
286
+ gui_render: bool
287
+ ) -> str:
288
+ """
289
+ Evaluation logic for a single Episode.
290
+ """
291
+ # Reconstruct Envs and Resolver (avoid passing complex objects across processes)
292
+ env_builder = BenchmarkEnvBuilder(
293
+ env_id=env_id,
294
+ dataset="train",
295
+ action_space=action_space,
296
+ gui_render=gui_render,
297
+ override_metadata_path=override_metadata_path,
298
+ )
299
+
300
+ env = None
301
+ dataset_resolver = None
302
+
303
+ try:
304
+ env = env_builder.make_env_for_episode(
305
+ episode,
306
+ max_steps=MAX_STEPS,
307
+ include_maniskill_obs=True,
308
+ include_front_depth=True,
309
+ include_wrist_depth=True,
310
+ include_front_camera_extrinsic=True,
311
+ include_wrist_camera_extrinsic=True,
312
+ include_available_multi_choices=True,
313
+ include_front_camera_intrinsic=True,
314
+ include_wrist_camera_intrinsic=True,
315
+ )
316
+ dataset_resolver = EpisodeDatasetResolver(
317
+ env_id=env_id,
318
+ episode=episode,
319
+ dataset_directory=dataset_root,
320
+ )
321
+
322
+ # obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = env.reset()
323
+ obs_batch, info_batch = env.reset()
324
+
325
+ # Maintain debug variable semantics from subgoal_evaluate_func.py
326
+ # Note: These local variables in multi-processing can be simplified if printing is not needed, but unpacking logic is retained for consistency.
327
+ maniskill_obs = obs_batch["maniskill_obs"]
328
+ front_camera = _to_frame_list(obs_batch["front_rgb_list"])
329
+ wrist_camera = _to_frame_list(obs_batch["wrist_rgb_list"])
330
+ # Other variables unpacking skipped unless used downstream
331
+
332
+ task_goal_list = info_batch["task_goal"]
333
+ # task_goal = task_goal_list[0] if task_goal_list else None
334
+
335
+ info = {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
336
+ # terminated = bool(terminated_batch[-1].item())
337
+ # truncated = bool(truncated_batch[-1].item())
338
+
339
+ # ######## Video saving variable preparation (reset phase) start ########
340
+ reset_base_frames = [_to_numpy_copy(f) for f in front_camera]
341
+ reset_wrist_frames = [_to_numpy_copy(f) for f in wrist_camera]
342
+ reset_right_frames = (
343
+ [_make_blackboard(f) for f in reset_base_frames]
344
+ if action_space == "multi_choice"
345
+ else None
346
+ )
347
+ reset_far_right_frames = (
348
+ [_make_blackboard(f) for f in reset_base_frames]
349
+ if action_space == "multi_choice"
350
+ else None
351
+ )
352
+ _subgoal = info_batch.get("grounded_subgoal_online", "")
353
+ reset_subgoal_grounded = _subgoal if isinstance(_subgoal, list) else [_subgoal] * len(reset_base_frames)
354
+ # ######## Video saving variable preparation (reset phase) end ########
355
+
356
+ # ######## Video saving variable initialization start ########
357
+ step = 0
358
+ read_step = 0
359
+ episode_success = False
360
+ rollout_base_frames: list[np.ndarray] = []
361
+ rollout_wrist_frames: list[np.ndarray] = []
362
+ rollout_right_frames: list[np.ndarray] = []
363
+ rollout_far_right_frames: list[np.ndarray] = []
364
+ rollout_subgoal_grounded: list[Any] = []
365
+ # ######## Video saving variable initialization end ########
366
+
367
+ while True:
368
+ replay_key = action_space
369
+ action = dataset_resolver.get_step(replay_key, read_step)
370
+ read_step += 1
371
+ if action is None:
372
+ break
373
+ if action_space == "multi_choice":
374
+ action = _parse_oracle_command(action)
375
+ if action is None:
376
+ continue
377
+
378
+ candidate_pixels: list[list[int]] = []
379
+ clicked_pixel: Optional[list[int]] = None
380
+ matched_pixel: Optional[list[int]] = None
381
+ if action_space == "multi_choice":
382
+ candidate_pixels, clicked_pixel, matched_pixel = _collect_multi_choice_visualization(
383
+ env, action
384
+ )
385
+
386
+ obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = env.step(action)
387
+
388
+ # Maintain debug variable semantics from subgoal_evaluate_func.py
389
+ front_camera = _to_frame_list(obs_batch["front_rgb_list"])
390
+ wrist_camera = _to_frame_list(obs_batch["wrist_rgb_list"])
391
+
392
+ subgoal_grounded = info_batch["grounded_subgoal_online"]
393
+
394
+ # ######## Video saving variable preparation (replay phase) start ########
395
+ rollout_base_frames.extend(_to_numpy_copy(f) for f in front_camera)
396
+ rollout_wrist_frames.extend(_to_numpy_copy(f) for f in wrist_camera)
397
+ if action_space == "multi_choice":
398
+ for base_frame in front_camera:
399
+ rollout_right_frames.append(
400
+ _draw_candidate_blackboard(
401
+ base_frame,
402
+ candidate_pixels=candidate_pixels,
403
+ )
404
+ )
405
+ rollout_far_right_frames.append(
406
+ _draw_selection_blackboard(
407
+ base_frame,
408
+ clicked_pixel=clicked_pixel,
409
+ matched_pixel=matched_pixel,
410
+ )
411
+ )
412
+ if isinstance(subgoal_grounded, list):
413
+ rollout_subgoal_grounded.extend(subgoal_grounded)
414
+ else:
415
+ rollout_subgoal_grounded.extend([subgoal_grounded] * len(front_camera))
416
+ # ######## Video saving variable preparation (replay phase) end ########
417
+
418
+ info = {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
419
+ terminated = bool(terminated_batch.item())
420
+ truncated = bool(truncated_batch.item())
421
+
422
+ step += 1
423
+ if gui_render:
424
+ env.render()
425
+
426
+ if truncated:
427
+ # print(f"[{env_id}] episode {episode} step limit exceeded, step {step}.")
428
+ break
429
+ if terminated:
430
+ succ = info.get("success")
431
+ if succ == torch.tensor([True]) or (
432
+ isinstance(succ, torch.Tensor) and succ.item()
433
+ ):
434
+ # print(f"[{env_id}] episode {episode} success.")
435
+ episode_success = True
436
+ elif info.get("fail", False):
437
+ # print(f"[{env_id}] episode {episode} failed.")
438
+ pass
439
+ break
440
+
441
+ # ######## Video saving section start ########
442
+ save_robomme_video(
443
+ reset_base_frames=reset_base_frames,
444
+ reset_wrist_frames=reset_wrist_frames,
445
+ rollout_base_frames=rollout_base_frames,
446
+ rollout_wrist_frames=rollout_wrist_frames,
447
+ reset_subgoal_grounded=reset_subgoal_grounded,
448
+ rollout_subgoal_grounded=rollout_subgoal_grounded,
449
+ out_video_dir=out_video_dir,
450
+ action_space=action_space,
451
+ env_id=env_id,
452
+ episode=episode,
453
+ episode_success=episode_success,
454
+ reset_right_frames=reset_right_frames if action_space == "multi_choice" else None,
455
+ rollout_right_frames=rollout_right_frames if action_space == "multi_choice" else None,
456
+ reset_far_right_frames=(
457
+ reset_far_right_frames if action_space == "multi_choice" else None
458
+ ),
459
+ rollout_far_right_frames=(
460
+ rollout_far_right_frames if action_space == "multi_choice" else None
461
+ ),
462
+ )
463
+ # ######## Video saving section end ########
464
+
465
+ status = "Success" if episode_success else "Ended"
466
+ if not episode_success and info.get("fail", False):
467
+ status = "Failed"
468
+ return f"[{env_id}] episode {episode} {status} (step {step})"
469
+
470
+ except (FileNotFoundError, KeyError) as exc:
471
+ return f"[{env_id}] episode {episode} data missing, skip. {exc}"
472
+ except Exception as exc:
473
+ # import traceback
474
+ # traceback.print_exc()
475
+ return f"[{env_id}] episode {episode} replay exception, skip. {exc}"
476
+ finally:
477
+ if dataset_resolver is not None:
478
+ dataset_resolver.close()
479
+ if env is not None:
480
+ env.close()
481
+
482
+ def _parse_gpus(s: str) -> list[int]:
483
+ """Parse --gpus: '0' -> [0], '1' -> [1], '0,1' -> [0, 1]."""
484
+ allowed = {"0", "1", "0,1", "1,0"}
485
+ v = s.strip()
486
+ if v not in allowed:
487
+ raise argparse.ArgumentTypeError(
488
+ f"--gpus must be one of: 0, 1, 0,1 (got {s!r})"
489
+ )
490
+ if "," in v:
491
+ return [int(x) for x in v.split(",")]
492
+ return [int(v)]
493
+
494
+ def _parse_action_spaces(s: str) -> list[str]:
495
+ tokens = [x.strip() for x in s.split(",") if x.strip()]
496
+ if not tokens:
497
+ raise argparse.ArgumentTypeError(
498
+ "--action_spaces cannot be empty. "
499
+ f"Allowed action spaces: {AVAILABLE_ACTION_SPACES}"
500
+ )
501
+
502
+ selected: list[str] = []
503
+ seen: set[str] = set()
504
+ invalid: list[str] = []
505
+
506
+ for token in tokens:
507
+ if token not in AVAILABLE_ACTION_SPACES:
508
+ invalid.append(token)
509
+ continue
510
+ if token in seen:
511
+ continue
512
+ seen.add(token)
513
+ selected.append(token)
514
+
515
+ if invalid:
516
+ raise argparse.ArgumentTypeError(
517
+ f"Invalid action space(s): {invalid}. "
518
+ f"Allowed action spaces: {AVAILABLE_ACTION_SPACES}"
519
+ )
520
+ if not selected:
521
+ raise argparse.ArgumentTypeError(
522
+ "--action_spaces has no valid value after parsing. "
523
+ f"Allowed action spaces: {AVAILABLE_ACTION_SPACES}"
524
+ )
525
+ return selected
526
+
527
+ def _parse_args() -> argparse.Namespace:
528
+ parser = argparse.ArgumentParser(description="Replay dataset for one env_id in parallel.")
529
+ parser.add_argument(
530
+ "--envid",
531
+ required=False,
532
+ type=str,
533
+ default=None,
534
+ help="Single environment id to replay.",
535
+ )
536
+ parser.add_argument(
537
+ "--max_workers",
538
+ type=int,
539
+ default=20,
540
+ help="Total max workers (split across GPUs when using 2 GPUs).",
541
+ )
542
+ parser.add_argument(
543
+ "--gpus",
544
+ type=_parse_gpus,
545
+ default=[1],
546
+ help="GPUs to use: '0' (GPU 0 only), '1' (GPU 1 only), '0,1' (both). Default: 0.",
547
+ )
548
+ parser.add_argument(
549
+ "--action_spaces",
550
+ type=_parse_action_spaces,
551
+ #default=AVAILABLE_ACTION_SPACES.copy(),
552
+ default=["multi_choice",],
553
+ help=(
554
+ "Comma-separated action spaces to replay in order. "
555
+ "Available: joint_angle,ee_pose,waypoint,multi_choice. "
556
+ "Default: joint_angle,ee_pose,waypoint,multi_choice."
557
+ ),
558
+ )
559
+ return parser.parse_args()
560
+
561
+ def process_env_id(
562
+ env_id: str,
563
+ max_workers_total: int,
564
+ gpu_ids: list[int],
565
+ action_spaces: list[str],
566
+ ):
567
+ # Simple calculation of episode count (do not instantiate env_builder to avoid overhead, or lightweight instantiation)
568
+ # To get episode_count, we need to instantiate env_builder once
569
+ # But we only need the metadata parsing part
570
+ temp_builder = BenchmarkEnvBuilder(
571
+ env_id=env_id,
572
+ dataset="train",
573
+ action_space=action_spaces[0],
574
+ gui_render=False, # Just to read metadata
575
+ override_metadata_path=OVERRIDE_METADATA_PATH,
576
+ )
577
+ episode_count = temp_builder.get_episode_num()
578
+ print(f"[{env_id}] episodes={episode_count}")
579
+ print(f"Parallel execution with max_workers={max_workers_total} on GPU(s) {gpu_ids}")
580
+
581
+ if episode_count == 0:
582
+ print(f"[{env_id}] No episodes to replay, skip.")
583
+ return
584
+
585
+ n_gpus = len(gpu_ids)
586
+ if n_gpus == 1:
587
+ mw0 = max(max_workers_total, 1)
588
+ mw1 = 0
589
+ print(f"Pool (GPU {gpu_ids[0]}): {mw0} workers")
590
+ else:
591
+ mw0 = (max_workers_total + 1) // 2
592
+ mw1 = max_workers_total // 2
593
+ if mw0 == 0:
594
+ mw0 = 1
595
+ if mw1 == 0 and max_workers_total > 1:
596
+ mw1 = 1
597
+ print(f"Pool 0 (GPU {gpu_ids[0]}): {mw0} workers")
598
+ print(f"Pool 1 (GPU {gpu_ids[1]}): {mw1} workers")
599
+
600
+ for action_space in action_spaces:
601
+ print(f"[{env_id}] >>> action_space={action_space}")
602
+ futures = []
603
+
604
+ if n_gpus == 1:
605
+ g0 = gpu_ids[0]
606
+ with concurrent.futures.ProcessPoolExecutor(max_workers=mw0, initializer=init_worker, initargs=(g0,)) as executor0:
607
+ for episode in range(episode_count):
608
+ future = executor0.submit(
609
+ evaluate_episode,
610
+ env_id=env_id,
611
+ episode=episode,
612
+ dataset_root=DATASET_ROOT,
613
+ override_metadata_path=OVERRIDE_METADATA_PATH,
614
+ action_space=action_space,
615
+ out_video_dir=OUT_VIDEO_DIR,
616
+ gui_render=GUI_RENDER
617
+ )
618
+ futures.append(future)
619
+ for future in concurrent.futures.as_completed(futures):
620
+ res = future.result()
621
+ print(res)
622
+ else:
623
+ g0, g1 = gpu_ids[0], gpu_ids[1]
624
+ with concurrent.futures.ProcessPoolExecutor(max_workers=mw0, initializer=init_worker, initargs=(g0,)) as executor0, \
625
+ concurrent.futures.ProcessPoolExecutor(max_workers=mw1, initializer=init_worker, initargs=(g1,)) as executor1:
626
+ for episode in range(episode_count):
627
+ if episode % 2 == 0:
628
+ executor = executor0
629
+ else:
630
+ executor = executor1
631
+ if mw1 == 0:
632
+ executor = executor0
633
+ future = executor.submit(
634
+ evaluate_episode,
635
+ env_id=env_id,
636
+ episode=episode,
637
+ dataset_root=DATASET_ROOT,
638
+ override_metadata_path=OVERRIDE_METADATA_PATH,
639
+ action_space=action_space,
640
+ out_video_dir=OUT_VIDEO_DIR,
641
+ gui_render=GUI_RENDER
642
+ )
643
+ futures.append(future)
644
+ for future in concurrent.futures.as_completed(futures):
645
+ res = future.result()
646
+ print(res)
647
+ print(f"[{env_id}] <<< action_space={action_space} done")
648
+
649
+ def main():
650
+ from robomme.logging_utils import setup_logging
651
+ setup_logging(level="DEBUG")
652
+ # Force use of spawn to avoid PyTorch/CUDA fork issues
653
+ mp.set_start_method("spawn", force=True)
654
+
655
+ args = _parse_args()
656
+ env_ids = [args.envid] if args.envid else DEFAULT_ENV_IDS
657
+ max_workers_total = args.max_workers
658
+ gpu_ids = args.gpus
659
+ action_spaces = args.action_spaces
660
+
661
+ print(f"Plan to replay envs: {env_ids} (gpus={gpu_ids})")
662
+ print(f"Available action spaces: {AVAILABLE_ACTION_SPACES}")
663
+ print(f"Selected action spaces: {action_spaces}")
664
+ for env_id in env_ids:
665
+ print(f"=== Processing {env_id} ===")
666
+ process_env_id(env_id, max_workers_total, gpu_ids, action_spaces)
667
+
668
+ if __name__ == "__main__":
669
+ main()
scripts/dev/evaluate_dataset_replay-parallelv4-noresolver.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Script function: Unified dataset replay entry point, supports four action_spaces: joint_angle / ee_pose / waypoint / multi_choice.
3
+ # Consistent with subgoal_evaluate_func.py's main loop and debug fields; actions are read directly from HDF5 dataset files.
4
+ # [New] Support parallel multi-process replay and alternate task assignment between two GPUs.
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import concurrent.futures
10
+ import multiprocessing as mp
11
+ from pathlib import Path
12
+ from typing import Any, Optional
13
+
14
+ import cv2
15
+ import h5py
16
+ import numpy as np
17
+ import torch
18
+
19
+ # Support running this file directly: python scripts/dev/<script>.py
20
+ REPO_ROOT = Path(__file__).resolve().parents[2]
21
+ if str(REPO_ROOT) not in sys.path:
22
+ sys.path.insert(0, str(REPO_ROOT))
23
+
24
+ from scripts.dataset_replay import _build_action_sequence
25
+ from robomme.robomme_env import *
26
+ from robomme.robomme_env.utils import *
27
+ from robomme.env_record_wrapper import (
28
+ BenchmarkEnvBuilder,
29
+ )
30
+ from robomme.env_record_wrapper.OraclePlannerDemonstrationWrapper import (
31
+ OraclePlannerDemonstrationWrapper,
32
+ )
33
+ from robomme.robomme_env.utils.choice_action_mapping import (
34
+ _unique_candidates,
35
+ extract_actor_position_xyz,
36
+ project_world_to_pixel,
37
+ select_target_with_pixel,
38
+ )
39
+ from robomme.robomme_env.utils.save_reset_video import save_robomme_video
40
+
41
+ AVAILABLE_ACTION_SPACES = [
42
+ "joint_angle",
43
+ "ee_pose",
44
+ "waypoint",
45
+ "multi_choice",
46
+ ]
47
+
48
+ GUI_RENDER = False
49
+
50
+ DATASET_ROOT = "/data/hongzefu/data_0226-test"
51
+ OVERRIDE_METADATA_PATH = "/data/hongzefu/data_0226-test"
52
+
53
+ # ######## Video saving variables (output directory) start ########
54
+ # Video output directory: Independently hardcoded, not aligned with h5 path or env_id
55
+ OUT_VIDEO_DIR = "/data/hongzefu/dataset_replay-0226-test"
56
+ # ######## Video saving variables (output directory) end ########
57
+ MAX_STEPS = 2000
58
+
59
+ DEFAULT_ENV_IDS = [
60
+ "PickXtimes",
61
+ "StopCube",
62
+ "SwingXtimes",
63
+ "BinFill",
64
+ "VideoUnmaskSwap",
65
+ "VideoUnmask",
66
+ "ButtonUnmaskSwap",
67
+ "ButtonUnmask",
68
+ "VideoRepick",
69
+ "VideoPlaceButton",
70
+ "VideoPlaceOrder",
71
+ "PickHighlight",
72
+ "InsertPeg",
73
+ "MoveCube",
74
+ "PatternLock",
75
+ "RouteStick",
76
+ ]
77
+
78
+ def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict[str, Any]]:
79
+ if not isinstance(choice_action, dict):
80
+ return None
81
+ choice = choice_action.get("choice")
82
+ if not isinstance(choice, str) or not choice.strip():
83
+ return None
84
+ if "point" not in choice_action:
85
+ return None
86
+ return {
87
+ "choice": choice_action.get("choice"),
88
+ "point": choice_action.get("point"),
89
+ }
90
+
91
+
92
+ def _to_numpy_copy(value: Any) -> np.ndarray:
93
+ if isinstance(value, torch.Tensor):
94
+ value = value.detach().cpu().numpy()
95
+ else:
96
+ value = np.asarray(value)
97
+ return np.array(value, copy=True)
98
+
99
+
100
+ def _to_frame_list(frames_like: Any) -> list[np.ndarray]:
101
+ if frames_like is None:
102
+ return []
103
+ if isinstance(frames_like, torch.Tensor):
104
+ arr = frames_like.detach().cpu().numpy()
105
+ if arr.ndim == 3:
106
+ return [np.array(arr, copy=True)]
107
+ if arr.ndim == 4:
108
+ return [np.array(x, copy=True) for x in arr]
109
+ return []
110
+ if isinstance(frames_like, np.ndarray):
111
+ if frames_like.ndim == 3:
112
+ return [np.array(frames_like, copy=True)]
113
+ if frames_like.ndim == 4:
114
+ return [np.array(x, copy=True) for x in frames_like]
115
+ return []
116
+ if isinstance(frames_like, (list, tuple)):
117
+ out = []
118
+ for frame in frames_like:
119
+ if frame is None:
120
+ continue
121
+ out.append(_to_numpy_copy(frame))
122
+ return out
123
+ try:
124
+ arr = np.asarray(frames_like)
125
+ except Exception:
126
+ return []
127
+ if arr.ndim == 3:
128
+ return [np.array(arr, copy=True)]
129
+ if arr.ndim == 4:
130
+ return [np.array(x, copy=True) for x in arr]
131
+ return []
132
+
133
+
134
+ def _normalize_pixel_xy(pixel_like: Any) -> Optional[list[int]]:
135
+ if not isinstance(pixel_like, (list, tuple, np.ndarray)):
136
+ return None
137
+ if len(pixel_like) < 2:
138
+ return None
139
+ try:
140
+ x = float(pixel_like[0])
141
+ y = float(pixel_like[1])
142
+ except (TypeError, ValueError):
143
+ return None
144
+ if not np.isfinite(x) or not np.isfinite(y):
145
+ return None
146
+ return [int(np.rint(x)), int(np.rint(y))]
147
+
148
+
149
+ def _normalize_point_yx_to_pixel_xy(point_like: Any) -> Optional[list[int]]:
150
+ if not isinstance(point_like, (list, tuple, np.ndarray)):
151
+ return None
152
+ if len(point_like) < 2:
153
+ return None
154
+ try:
155
+ y = float(point_like[0])
156
+ x = float(point_like[1])
157
+ except (TypeError, ValueError):
158
+ return None
159
+ if not np.isfinite(x) or not np.isfinite(y):
160
+ return None
161
+ return [int(np.rint(x)), int(np.rint(y))]
162
+
163
+
164
+ def _find_oracle_wrapper(env_like: Any) -> Optional[OraclePlannerDemonstrationWrapper]:
165
+ current = env_like
166
+ visited: set[int] = set()
167
+ for _ in range(16):
168
+ if current is None:
169
+ return None
170
+ if isinstance(current, OraclePlannerDemonstrationWrapper):
171
+ return current
172
+ obj_id = id(current)
173
+ if obj_id in visited:
174
+ return None
175
+ visited.add(obj_id)
176
+ current = getattr(current, "env", None)
177
+ return None
178
+
179
+
180
+ def _collect_multi_choice_visualization(
181
+ env_like: Any,
182
+ command: dict[str, Any],
183
+ ) -> tuple[list[list[int]], Optional[list[int]], Optional[list[int]]]:
184
+ clicked_pixel = _normalize_point_yx_to_pixel_xy(command.get("point"))
185
+ oracle_wrapper = _find_oracle_wrapper(env_like)
186
+ if oracle_wrapper is None:
187
+ return [], clicked_pixel, None
188
+
189
+ try:
190
+ _selected_target, solve_options = oracle_wrapper._build_step_options()
191
+ found_idx, _ = oracle_wrapper._resolve_command(command, solve_options)
192
+ except Exception:
193
+ return [], clicked_pixel, None
194
+
195
+ if found_idx is None or found_idx < 0 or found_idx >= len(solve_options):
196
+ return [], clicked_pixel, None
197
+
198
+ option = solve_options[found_idx]
199
+ available = option.get("available")
200
+ intrinsic_cv = getattr(oracle_wrapper, "_front_camera_intrinsic_cv", None)
201
+ extrinsic_cv = getattr(oracle_wrapper, "_front_camera_extrinsic_cv", None)
202
+ image_shape = getattr(oracle_wrapper, "_front_rgb_shape", None)
203
+
204
+ candidate_pixels: list[list[int]] = []
205
+ if available is not None:
206
+ for actor in _unique_candidates(available):
207
+ actor_pos = extract_actor_position_xyz(actor)
208
+ if actor_pos is None:
209
+ continue
210
+ projected = project_world_to_pixel(
211
+ actor_pos,
212
+ intrinsic_cv=intrinsic_cv,
213
+ extrinsic_cv=extrinsic_cv,
214
+ image_shape=image_shape,
215
+ )
216
+ if projected is None:
217
+ continue
218
+ candidate_pixels.append([int(projected[0]), int(projected[1])])
219
+
220
+ matched_pixel: Optional[list[int]] = None
221
+ if available is not None and clicked_pixel is not None:
222
+ matched = select_target_with_pixel(
223
+ available=available,
224
+ pixel_like=clicked_pixel,
225
+ intrinsic_cv=intrinsic_cv,
226
+ extrinsic_cv=extrinsic_cv,
227
+ image_shape=image_shape,
228
+ )
229
+ if isinstance(matched, dict):
230
+ matched_pixel = _normalize_pixel_xy(matched.get("projected_pixel"))
231
+
232
+ return candidate_pixels, clicked_pixel, matched_pixel
233
+
234
+
235
+ def _make_blackboard(frame_like: Any) -> np.ndarray:
236
+ frame = _to_numpy_copy(frame_like)
237
+ if frame.ndim < 2:
238
+ return np.zeros((1, 1, 3), dtype=np.uint8)
239
+ h, w = int(frame.shape[0]), int(frame.shape[1])
240
+ if h <= 0 or w <= 0:
241
+ return np.zeros((1, 1, 3), dtype=np.uint8)
242
+ return np.zeros((h, w, 3), dtype=np.uint8)
243
+
244
+
245
+ def _draw_candidate_blackboard(
246
+ frame_like: Any,
247
+ candidate_pixels: list[list[int]],
248
+ ) -> np.ndarray:
249
+ board = _make_blackboard(frame_like)
250
+ for pixel in candidate_pixels:
251
+ if len(pixel) < 2:
252
+ continue
253
+ cv2.circle(board, (int(pixel[0]), int(pixel[1])), 4, (0, 255, 255), 1)
254
+ return board
255
+
256
+
257
+ def _draw_selection_blackboard(
258
+ frame_like: Any,
259
+ clicked_pixel: Optional[list[int]],
260
+ matched_pixel: Optional[list[int]],
261
+ ) -> np.ndarray:
262
+ board = _make_blackboard(frame_like)
263
+ if clicked_pixel is not None:
264
+ cv2.drawMarker(
265
+ board,
266
+ (int(clicked_pixel[0]), int(clicked_pixel[1])),
267
+ (255, 255, 0),
268
+ markerType=cv2.MARKER_TILTED_CROSS,
269
+ markerSize=10,
270
+ thickness=1,
271
+ )
272
+ if matched_pixel is not None:
273
+ cv2.circle(board, (int(matched_pixel[0]), int(matched_pixel[1])), 5, (255, 0, 0), 2)
274
+ return board
275
+
276
+
277
+ def init_worker(gpu_id: int):
278
+ """
279
+ Worker process initialization function, sets CUDA_VISIBLE_DEVICES.
280
+ """
281
+ from robomme.logging_utils import setup_logging
282
+ setup_logging(level="DEBUG")
283
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
284
+ # print(f"[Worker] Initialized on GPU {gpu_id} (PID: {os.getpid()})")
285
+
286
+ def evaluate_episode(
287
+ env_id: str,
288
+ episode: int,
289
+ dataset_root: str,
290
+ override_metadata_path: str,
291
+ action_space: str,
292
+ out_video_dir: str,
293
+ gui_render: bool
294
+ ) -> str:
295
+ """
296
+ Evaluation logic for a single Episode.
297
+ """
298
+ # Reconstruct envs in worker process (avoid passing complex objects across processes)
299
+ env_builder = BenchmarkEnvBuilder(
300
+ env_id=env_id,
301
+ dataset="train",
302
+ action_space=action_space,
303
+ gui_render=gui_render,
304
+ override_metadata_path=override_metadata_path,
305
+ )
306
+
307
+ env = None
308
+
309
+ try:
310
+ env = env_builder.make_env_for_episode(
311
+ episode,
312
+ max_steps=MAX_STEPS,
313
+ include_maniskill_obs=True,
314
+ include_front_depth=True,
315
+ include_wrist_depth=True,
316
+ include_front_camera_extrinsic=True,
317
+ include_wrist_camera_extrinsic=True,
318
+ include_available_multi_choices=True,
319
+ include_front_camera_intrinsic=True,
320
+ include_wrist_camera_intrinsic=True,
321
+ )
322
+
323
+ file_path = Path(dataset_root) / f"record_dataset_{env_id}.h5"
324
+ if not file_path.exists():
325
+ raise FileNotFoundError(f"dataset file not found: {file_path}")
326
+ episode_key = f"episode_{episode}"
327
+ with h5py.File(file_path, "r") as data:
328
+ if episode_key not in data:
329
+ raise KeyError(f"missing key '{episode_key}' in {file_path}")
330
+ action_sequence = _build_action_sequence(data[episode_key], action_space)
331
+ print(
332
+ f"[{env_id}] episode={episode} h5={file_path} "
333
+ f"episode_key={episode_key} action_space={action_space} "
334
+ f"action_count={len(action_sequence)}"
335
+ )
336
+
337
+ # obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = env.reset()
338
+ obs_batch, info_batch = env.reset()
339
+
340
+ # Maintain debug variable semantics from subgoal_evaluate_func.py
341
+ # Note: These local variables in multi-processing can be simplified if printing is not needed, but unpacking logic is retained for consistency.
342
+ maniskill_obs = obs_batch["maniskill_obs"]
343
+ front_camera = _to_frame_list(obs_batch["front_rgb_list"])
344
+ wrist_camera = _to_frame_list(obs_batch["wrist_rgb_list"])
345
+ # Other variables unpacking skipped unless used downstream
346
+
347
+ task_goal_list = info_batch["task_goal"]
348
+ # task_goal = task_goal_list[0] if task_goal_list else None
349
+
350
+ info = {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
351
+ # terminated = bool(terminated_batch[-1].item())
352
+ # truncated = bool(truncated_batch[-1].item())
353
+
354
+ # ######## Video saving variable preparation (reset phase) start ########
355
+ reset_base_frames = [_to_numpy_copy(f) for f in front_camera]
356
+ reset_wrist_frames = [_to_numpy_copy(f) for f in wrist_camera]
357
+ reset_right_frames = (
358
+ [_make_blackboard(f) for f in reset_base_frames]
359
+ if action_space == "multi_choice"
360
+ else None
361
+ )
362
+ reset_far_right_frames = (
363
+ [_make_blackboard(f) for f in reset_base_frames]
364
+ if action_space == "multi_choice"
365
+ else None
366
+ )
367
+ _subgoal = info_batch.get("grounded_subgoal_online", "")
368
+ reset_subgoal_grounded = _subgoal if isinstance(_subgoal, list) else [_subgoal] * len(reset_base_frames)
369
+ # ######## Video saving variable preparation (reset phase) end ########
370
+
371
+ # ######## Video saving variable initialization start ########
372
+ step = 0
373
+ episode_success = False
374
+ rollout_base_frames: list[np.ndarray] = []
375
+ rollout_wrist_frames: list[np.ndarray] = []
376
+ rollout_right_frames: list[np.ndarray] = []
377
+ rollout_far_right_frames: list[np.ndarray] = []
378
+ rollout_subgoal_grounded: list[Any] = []
379
+ # ######## Video saving variable initialization end ########
380
+
381
+ for _, action in enumerate(action_sequence):
382
+ if action_space == "multi_choice":
383
+ action = _parse_oracle_command(action)
384
+ if action is None:
385
+ continue
386
+
387
+ candidate_pixels: list[list[int]] = []
388
+ clicked_pixel: Optional[list[int]] = None
389
+ matched_pixel: Optional[list[int]] = None
390
+ if action_space == "multi_choice":
391
+ candidate_pixels, clicked_pixel, matched_pixel = _collect_multi_choice_visualization(
392
+ env, action
393
+ )
394
+
395
+ obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = env.step(action)
396
+
397
+ # Maintain debug variable semantics from subgoal_evaluate_func.py
398
+ front_camera = _to_frame_list(obs_batch["front_rgb_list"])
399
+ wrist_camera = _to_frame_list(obs_batch["wrist_rgb_list"])
400
+
401
+ subgoal_grounded = info_batch["grounded_subgoal_online"]
402
+
403
+ # ######## Video saving variable preparation (replay phase) start ########
404
+ rollout_base_frames.extend(_to_numpy_copy(f) for f in front_camera)
405
+ rollout_wrist_frames.extend(_to_numpy_copy(f) for f in wrist_camera)
406
+ if action_space == "multi_choice":
407
+ for base_frame in front_camera:
408
+ rollout_right_frames.append(
409
+ _draw_candidate_blackboard(
410
+ base_frame,
411
+ candidate_pixels=candidate_pixels,
412
+ )
413
+ )
414
+ rollout_far_right_frames.append(
415
+ _draw_selection_blackboard(
416
+ base_frame,
417
+ clicked_pixel=clicked_pixel,
418
+ matched_pixel=matched_pixel,
419
+ )
420
+ )
421
+ if isinstance(subgoal_grounded, list):
422
+ rollout_subgoal_grounded.extend(subgoal_grounded)
423
+ else:
424
+ rollout_subgoal_grounded.extend([subgoal_grounded] * len(front_camera))
425
+ # ######## Video saving variable preparation (replay phase) end ########
426
+
427
+ info = {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
428
+ terminated = bool(terminated_batch.item())
429
+ truncated = bool(truncated_batch.item())
430
+
431
+ step += 1
432
+ if gui_render:
433
+ env.render()
434
+
435
+ if truncated:
436
+ # print(f"[{env_id}] episode {episode} step limit exceeded, step {step}.")
437
+ break
438
+ if terminated:
439
+ succ = info.get("success")
440
+ if succ == torch.tensor([True]) or (
441
+ isinstance(succ, torch.Tensor) and succ.item()
442
+ ):
443
+ # print(f"[{env_id}] episode {episode} success.")
444
+ episode_success = True
445
+ elif info.get("fail", False):
446
+ # print(f"[{env_id}] episode {episode} failed.")
447
+ pass
448
+ break
449
+
450
+ # ######## Video saving section start ########
451
+ save_robomme_video(
452
+ reset_base_frames=reset_base_frames,
453
+ reset_wrist_frames=reset_wrist_frames,
454
+ rollout_base_frames=rollout_base_frames,
455
+ rollout_wrist_frames=rollout_wrist_frames,
456
+ reset_subgoal_grounded=reset_subgoal_grounded,
457
+ rollout_subgoal_grounded=rollout_subgoal_grounded,
458
+ out_video_dir=out_video_dir,
459
+ action_space=action_space,
460
+ env_id=env_id,
461
+ episode=episode,
462
+ episode_success=episode_success,
463
+ reset_right_frames=reset_right_frames if action_space == "multi_choice" else None,
464
+ rollout_right_frames=rollout_right_frames if action_space == "multi_choice" else None,
465
+ reset_far_right_frames=(
466
+ reset_far_right_frames if action_space == "multi_choice" else None
467
+ ),
468
+ rollout_far_right_frames=(
469
+ rollout_far_right_frames if action_space == "multi_choice" else None
470
+ ),
471
+ )
472
+ # ######## Video saving section end ########
473
+
474
+ status = "Success" if episode_success else "Ended"
475
+ if not episode_success and info.get("fail", False):
476
+ status = "Failed"
477
+ return f"[{env_id}] episode {episode} {status} (step {step})"
478
+
479
+ except (FileNotFoundError, KeyError) as exc:
480
+ return f"[{env_id}] episode {episode} data missing, skip. {exc}"
481
+ except Exception as exc:
482
+ # import traceback
483
+ # traceback.print_exc()
484
+ return f"[{env_id}] episode {episode} replay exception, skip. {exc}"
485
+ finally:
486
+ if env is not None:
487
+ env.close()
488
+
489
+ def _parse_gpus(s: str) -> list[int]:
490
+ """Parse --gpus: '0' -> [0], '1' -> [1], '0,1' -> [0, 1]."""
491
+ allowed = {"0", "1", "0,1", "1,0"}
492
+ v = s.strip()
493
+ if v not in allowed:
494
+ raise argparse.ArgumentTypeError(
495
+ f"--gpus must be one of: 0, 1, 0,1 (got {s!r})"
496
+ )
497
+ if "," in v:
498
+ return [int(x) for x in v.split(",")]
499
+ return [int(v)]
500
+
501
+ def _parse_action_spaces(s: str) -> list[str]:
502
+ tokens = [x.strip() for x in s.split(",") if x.strip()]
503
+ if not tokens:
504
+ raise argparse.ArgumentTypeError(
505
+ "--action_spaces cannot be empty. "
506
+ f"Allowed action spaces: {AVAILABLE_ACTION_SPACES}"
507
+ )
508
+
509
+ selected: list[str] = []
510
+ seen: set[str] = set()
511
+ invalid: list[str] = []
512
+
513
+ for token in tokens:
514
+ if token not in AVAILABLE_ACTION_SPACES:
515
+ invalid.append(token)
516
+ continue
517
+ if token in seen:
518
+ continue
519
+ seen.add(token)
520
+ selected.append(token)
521
+
522
+ if invalid:
523
+ raise argparse.ArgumentTypeError(
524
+ f"Invalid action space(s): {invalid}. "
525
+ f"Allowed action spaces: {AVAILABLE_ACTION_SPACES}"
526
+ )
527
+ if not selected:
528
+ raise argparse.ArgumentTypeError(
529
+ "--action_spaces has no valid value after parsing. "
530
+ f"Allowed action spaces: {AVAILABLE_ACTION_SPACES}"
531
+ )
532
+ return selected
533
+
534
+ def _parse_args() -> argparse.Namespace:
535
+ parser = argparse.ArgumentParser(description="Replay dataset for one env_id in parallel.")
536
+ parser.add_argument(
537
+ "--envid",
538
+ required=False,
539
+ type=str,
540
+ default=None,
541
+ help="Single environment id to replay.",
542
+ )
543
+ parser.add_argument(
544
+ "--max_workers",
545
+ type=int,
546
+ default=20,
547
+ help="Total max workers (split across GPUs when using 2 GPUs).",
548
+ )
549
+ parser.add_argument(
550
+ "--gpus",
551
+ type=_parse_gpus,
552
+ default=[1],
553
+ help="GPUs to use: '0' (GPU 0 only), '1' (GPU 1 only), '0,1' (both). Default: 0.",
554
+ )
555
+ parser.add_argument(
556
+ "--action_spaces",
557
+ type=_parse_action_spaces,
558
+ #default=AVAILABLE_ACTION_SPACES.copy(),
559
+ default=["multi_choice",],
560
+ help=(
561
+ "Comma-separated action spaces to replay in order. "
562
+ "Available: joint_angle,ee_pose,waypoint,multi_choice. "
563
+ "Default: joint_angle,ee_pose,waypoint,multi_choice."
564
+ ),
565
+ )
566
+ return parser.parse_args()
567
+
568
+ def process_env_id(
569
+ env_id: str,
570
+ max_workers_total: int,
571
+ gpu_ids: list[int],
572
+ action_spaces: list[str],
573
+ ):
574
+ # Simple calculation of episode count (do not instantiate env_builder to avoid overhead, or lightweight instantiation)
575
+ # To get episode_count, we need to instantiate env_builder once
576
+ # But we only need the metadata parsing part
577
+ temp_builder = BenchmarkEnvBuilder(
578
+ env_id=env_id,
579
+ dataset="train",
580
+ action_space=action_spaces[0],
581
+ gui_render=False, # Just to read metadata
582
+ override_metadata_path=OVERRIDE_METADATA_PATH,
583
+ )
584
+ episode_count = temp_builder.get_episode_num()
585
+ print(f"[{env_id}] episodes={episode_count}")
586
+ print(f"Parallel execution with max_workers={max_workers_total} on GPU(s) {gpu_ids}")
587
+
588
+ if episode_count == 0:
589
+ print(f"[{env_id}] No episodes to replay, skip.")
590
+ return
591
+
592
+ n_gpus = len(gpu_ids)
593
+ if n_gpus == 1:
594
+ mw0 = max(max_workers_total, 1)
595
+ mw1 = 0
596
+ print(f"Pool (GPU {gpu_ids[0]}): {mw0} workers")
597
+ else:
598
+ mw0 = (max_workers_total + 1) // 2
599
+ mw1 = max_workers_total // 2
600
+ if mw0 == 0:
601
+ mw0 = 1
602
+ if mw1 == 0 and max_workers_total > 1:
603
+ mw1 = 1
604
+ print(f"Pool 0 (GPU {gpu_ids[0]}): {mw0} workers")
605
+ print(f"Pool 1 (GPU {gpu_ids[1]}): {mw1} workers")
606
+
607
+ for action_space in action_spaces:
608
+ print(f"[{env_id}] >>> action_space={action_space}")
609
+ futures = []
610
+
611
+ if n_gpus == 1:
612
+ g0 = gpu_ids[0]
613
+ with concurrent.futures.ProcessPoolExecutor(max_workers=mw0, initializer=init_worker, initargs=(g0,)) as executor0:
614
+ for episode in range(episode_count):
615
+ future = executor0.submit(
616
+ evaluate_episode,
617
+ env_id=env_id,
618
+ episode=episode,
619
+ dataset_root=DATASET_ROOT,
620
+ override_metadata_path=OVERRIDE_METADATA_PATH,
621
+ action_space=action_space,
622
+ out_video_dir=OUT_VIDEO_DIR,
623
+ gui_render=GUI_RENDER
624
+ )
625
+ futures.append(future)
626
+ for future in concurrent.futures.as_completed(futures):
627
+ res = future.result()
628
+ print(res)
629
+ else:
630
+ g0, g1 = gpu_ids[0], gpu_ids[1]
631
+ with concurrent.futures.ProcessPoolExecutor(max_workers=mw0, initializer=init_worker, initargs=(g0,)) as executor0, \
632
+ concurrent.futures.ProcessPoolExecutor(max_workers=mw1, initializer=init_worker, initargs=(g1,)) as executor1:
633
+ for episode in range(episode_count):
634
+ if episode % 2 == 0:
635
+ executor = executor0
636
+ else:
637
+ executor = executor1
638
+ if mw1 == 0:
639
+ executor = executor0
640
+ future = executor.submit(
641
+ evaluate_episode,
642
+ env_id=env_id,
643
+ episode=episode,
644
+ dataset_root=DATASET_ROOT,
645
+ override_metadata_path=OVERRIDE_METADATA_PATH,
646
+ action_space=action_space,
647
+ out_video_dir=OUT_VIDEO_DIR,
648
+ gui_render=GUI_RENDER
649
+ )
650
+ futures.append(future)
651
+ for future in concurrent.futures.as_completed(futures):
652
+ res = future.result()
653
+ print(res)
654
+ print(f"[{env_id}] <<< action_space={action_space} done")
655
+
656
+ def main():
657
+ from robomme.logging_utils import setup_logging
658
+ setup_logging(level="DEBUG")
659
+ # Force use of spawn to avoid PyTorch/CUDA fork issues
660
+ mp.set_start_method("spawn", force=True)
661
+
662
+ args = _parse_args()
663
+ env_ids = [args.envid] if args.envid else DEFAULT_ENV_IDS
664
+ max_workers_total = args.max_workers
665
+ gpu_ids = args.gpus
666
+ action_spaces = args.action_spaces
667
+
668
+ print(f"Plan to replay envs: {env_ids} (gpus={gpu_ids})")
669
+ print(f"Available action spaces: {AVAILABLE_ACTION_SPACES}")
670
+ print(f"Selected action spaces: {action_spaces}")
671
+ for env_id in env_ids:
672
+ print(f"=== Processing {env_id} ===")
673
+ process_env_id(env_id, max_workers_total, gpu_ids, action_spaces)
674
+
675
+ if __name__ == "__main__":
676
+ main()
scripts/dev/generate-dataset-control-seed-readJson-advanceV3.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import argparse
4
+ import json
5
+ import shutil
6
+ from concurrent.futures import ProcessPoolExecutor, as_completed
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from typing import Any, Dict, Iterable, List, Optional, Set
11
+ import h5py
12
+
13
+ import gymnasium as gym
14
+
15
+ # Import Robomme related environment wrappers and exception classes
16
+ from robomme.env_record_wrapper import RobommeRecordWrapper, FailsafeTimeout
17
+ from robomme.robomme_env import *
18
+ from robomme.robomme_env.utils.SceneGenerationError import SceneGenerationError
19
+
20
+ # from util import *
21
+ import torch
22
+
23
+ # Import planner and related exception classes
24
+ from robomme.robomme_env.utils.planner_fail_safe import (
25
+ FailAwarePandaArmMotionPlanningSolver,
26
+ FailAwarePandaStickMotionPlanningSolver,
27
+ ScrewPlanFailure,
28
+ )
29
+
30
+ """
31
+
32
+ Script function: Parallel generation of Robomme environment datasets.
33
+ This script supports multi-process parallel environment simulation, generating HDF5 datasets containing RGB, depth, segmentation, etc.
34
+ Key features include:
35
+ 1. Configure environment list and parameters.
36
+ 2. Parallel execution of multiple episode simulations.
37
+ 3. Use FailAware planner to attempt to solve tasks.
38
+ 4. Record data and save as HDF5 file.
39
+ 5. Merge multiple temporarily generated HDF5 files into a final dataset.
40
+ """
41
+
42
+ # List of all supported environment module names
43
+ DEFAULT_ENVS =[
44
+ "PickXtimes",
45
+ "StopCube",
46
+ "SwingXtimes",
47
+ "BinFill",
48
+
49
+ "VideoUnmaskSwap",
50
+ "VideoUnmask",
51
+ "ButtonUnmaskSwap",
52
+ "ButtonUnmask",
53
+
54
+ "VideoRepick",
55
+ "VideoPlaceButton",
56
+ "VideoPlaceOrder",
57
+ "PickHighlight",
58
+
59
+ "InsertPeg",
60
+ 'MoveCube',
61
+ "PatternLock",
62
+ "RouteStick"
63
+ ]
64
+
65
+ # Reference dataset metadata root directory: used to read difficulty and seed
66
+ SOURCE_METADATA_ROOT = Path("/data/hongzefu/robomme_benchmark/src/robomme/env_metadata/1206")
67
+ VALID_DIFFICULTIES: Set[str] = {"easy", "medium", "hard"}
68
+ DATASET_SCREW_MAX_ATTEMPTS = 3
69
+ DATASET_RRT_MAX_ATTEMPTS = 3
70
+
71
+
72
+ def _load_env_metadata_records(
73
+ env_id: str,
74
+ metadata_root: Path,
75
+ ) -> List[Dict[str, Any]]:
76
+ """
77
+ Read metadata records for an environment from the reference directory to control difficulty and seed.
78
+ """
79
+ metadata_path = metadata_root / f"record_dataset_{env_id}_metadata.json"
80
+ if not metadata_path.exists():
81
+ raise FileNotFoundError(
82
+ f"Metadata file not found for env '{env_id}': {metadata_path}"
83
+ )
84
+
85
+ with metadata_path.open("r", encoding="utf-8") as metadata_file:
86
+ payload = json.load(metadata_file)
87
+
88
+ raw_records = payload.get("records")
89
+ if not isinstance(raw_records, list) or not raw_records:
90
+ raise ValueError(
91
+ f"Metadata file has no valid 'records' list: {metadata_path}"
92
+ )
93
+
94
+ normalized_records: List[Dict[str, Any]] = []
95
+ for idx, raw_record in enumerate(raw_records):
96
+ if not isinstance(raw_record, dict):
97
+ raise ValueError(
98
+ f"Invalid metadata record at index {idx} in {metadata_path}"
99
+ )
100
+ if "episode" not in raw_record or "seed" not in raw_record or "difficulty" not in raw_record:
101
+ raise ValueError(
102
+ f"Metadata record missing episode/seed/difficulty at index {idx} in {metadata_path}"
103
+ )
104
+
105
+ try:
106
+ episode = int(raw_record["episode"])
107
+ seed = int(raw_record["seed"])
108
+ except (TypeError, ValueError) as exc:
109
+ raise ValueError(
110
+ f"Metadata record has non-integer episode/seed at index {idx} in {metadata_path}"
111
+ ) from exc
112
+
113
+ difficulty_raw = str(raw_record["difficulty"]).strip().lower()
114
+ if difficulty_raw not in VALID_DIFFICULTIES:
115
+ raise ValueError(
116
+ f"Metadata record has invalid difficulty '{raw_record['difficulty']}' "
117
+ f"at index {idx} in {metadata_path}. Expected one of {sorted(VALID_DIFFICULTIES)}."
118
+ )
119
+
120
+ normalized_records.append(
121
+ {
122
+ "episode": episode,
123
+ "seed": seed,
124
+ "difficulty": difficulty_raw,
125
+ }
126
+ )
127
+
128
+ normalized_records.sort(key=lambda rec: rec["episode"])
129
+ print(
130
+ f"Loaded {len(normalized_records)} metadata records for {env_id} from {metadata_path}"
131
+ )
132
+ return normalized_records
133
+
134
+
135
+ def _build_seed_candidates_from_metadata(
136
+ episode: int,
137
+ metadata_records: List[Dict[str, Any]],
138
+ ) -> List[Dict[str, Any]]:
139
+ """
140
+ Construct candidate (seed, difficulty) list for current episode.
141
+ Strictly use only the seed from metadata for the same episode, no cross-episode fallback.
142
+ """
143
+ if not metadata_records:
144
+ return []
145
+
146
+ same_episode_records = [rec for rec in metadata_records if rec["episode"] == episode]
147
+ if not same_episode_records:
148
+ return []
149
+ if len(same_episode_records) > 1:
150
+ raise ValueError(
151
+ f"Found duplicated metadata records for episode {episode}. "
152
+ "Strict mode requires exactly one source record per episode."
153
+ )
154
+
155
+ rec = same_episode_records[0]
156
+ return [{"seed": int(rec["seed"]), "difficulty": rec["difficulty"]}]
157
+
158
+ def _tensor_to_bool(value) -> bool:
159
+ """
160
+ Helper function: Convert Tensor or numpy array to Python bool type.
161
+ Used to handle success/failure flags from different sources.
162
+ """
163
+ if value is None:
164
+ return False
165
+ if isinstance(value, torch.Tensor):
166
+ return bool(value.detach().cpu().bool().item())
167
+ if isinstance(value, np.ndarray):
168
+ return bool(np.any(value))
169
+ return bool(value)
170
+
171
+
172
+ def _split_episode_indices(num_episodes: int, max_chunks: int) -> List[List[int]]:
173
+ """
174
+ Helper function: Split total episodes into multiple chunks for parallel processing by different processes.
175
+
176
+ Args:
177
+ num_episodes: Total number of episodes
178
+ max_chunks: Max number of chunks (usually equals number of workers)
179
+
180
+ Returns:
181
+ List containing lists of episode indices
182
+ """
183
+ if num_episodes <= 0:
184
+ return []
185
+
186
+ chunk_count = min(max_chunks, num_episodes)
187
+ base_size, remainder = divmod(num_episodes, chunk_count)
188
+
189
+ chunks: List[List[int]] = []
190
+ start = 0
191
+ for chunk_idx in range(chunk_count):
192
+ # If there is a remainder, allocate one extra episode to the first 'remainder' chunks
193
+ stop = start + base_size + (1 if chunk_idx < remainder else 0)
194
+ chunks.append(list(range(start, stop)))
195
+ start = stop
196
+
197
+ return chunks
198
+
199
+
200
+ def _run_episode_attempt(
201
+ env_id: str,
202
+ episode: int,
203
+ seed: int,
204
+ temp_dataset_path: Path,
205
+ save_video: bool,
206
+ difficulty: Optional[str],
207
+ ) -> bool:
208
+ """
209
+ Run a single episode attempt and report success or failure.
210
+
211
+ Main steps:
212
+ 1. Initialize environment parameters and Gym environment.
213
+ 2. Apply RobommeRecordWrapper for data recording.
214
+ 3. Select appropriate planner based on environment type (PandaStick or PandaArm).
215
+ 4. Get task list and execute tasks one by one.
216
+ 5. Use planner to solve task and handle possible planning failures.
217
+ 6. Check task execution result (fail/success).
218
+ 7. Return whether episode is finally successful.
219
+ """
220
+ print(f"--- Running simulation for episode:{episode}, seed:{seed}, env: {env_id} ---")
221
+
222
+ env: Optional[gym.Env] = None
223
+ try:
224
+ # 1. Environment parameter configuration
225
+ env_kwargs = dict(
226
+ obs_mode="rgb+depth+segmentation", # Observation mode: RGB + Depth + Segmentation
227
+ control_mode="pd_joint_pos", # Control mode: Position control
228
+ render_mode="rgb_array", # Render mode
229
+ reward_mode="dense", # Reward mode
230
+ seed=seed, # Random seed
231
+ difficulty=difficulty, # Difficulty setting
232
+ )
233
+
234
+ # Special failure recovery settings for the first few episodes (for testing or demonstration purposes only)
235
+ if episode <= 5:
236
+ env_kwargs["robomme_failure_recovery"] = True
237
+ if episode <=2:
238
+ env_kwargs["robomme_failure_recovery_mode"] = "z" # z-axis recovery
239
+ else:
240
+ env_kwargs["robomme_failure_recovery_mode"] = "xy" # xy-axis recovery
241
+
242
+
243
+ env = gym.make(env_id, **env_kwargs)
244
+
245
+ # 2. Wrap environment to record data
246
+ env = RobommeRecordWrapper(
247
+ env,
248
+ dataset=str(temp_dataset_path), # Data save path
249
+ env_id=env_id,
250
+ episode=episode,
251
+ seed=seed,
252
+ save_video=save_video,
253
+
254
+ )
255
+
256
+ episode_successful = False
257
+
258
+
259
+ env.reset()
260
+
261
+ # 3. Select planner
262
+ # PatternLock and RouteStick require Stick planner, others use Arm planner
263
+ if env_id == "PatternLock" or env_id == "RouteStick":
264
+ planner = FailAwarePandaStickMotionPlanningSolver(
265
+ env,
266
+ debug=False,
267
+ vis=False,
268
+ base_pose=env.unwrapped.agent.robot.pose,
269
+ visualize_target_grasp_pose=False,
270
+ print_env_info=False,
271
+ joint_vel_limits=0.3,
272
+ )
273
+ else:
274
+ planner = FailAwarePandaArmMotionPlanningSolver(
275
+ env,
276
+ debug=False,
277
+ vis=False,
278
+ base_pose=env.unwrapped.agent.robot.pose,
279
+ visualize_target_grasp_pose=False,
280
+ print_env_info=False,
281
+ )
282
+
283
+ original_move_to_pose_with_screw = planner.move_to_pose_with_screw
284
+ original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar
285
+
286
+ def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs):
287
+ for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1):
288
+ try:
289
+ result = original_move_to_pose_with_screw(*args, **kwargs)
290
+ except ScrewPlanFailure as exc:
291
+ print(
292
+ f"[DatasetGen] screw planning failed "
293
+ f"(attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}"
294
+ )
295
+ continue
296
+
297
+ if isinstance(result, int) and result == -1:
298
+ print(
299
+ f"[DatasetGen] screw planning returned -1 "
300
+ f"(attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})"
301
+ )
302
+ continue
303
+
304
+ return result
305
+
306
+ print(
307
+ "[DatasetGen] screw planning exhausted; "
308
+ f"fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)"
309
+ )
310
+
311
+ for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1):
312
+ try:
313
+ result = original_move_to_pose_with_rrt(*args, **kwargs)
314
+ except Exception as exc:
315
+ print(
316
+ f"[DatasetGen] RRT* planning failed "
317
+ f"(attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}"
318
+ )
319
+ continue
320
+
321
+ if isinstance(result, int) and result == -1:
322
+ print(
323
+ f"[DatasetGen] RRT* planning returned -1 "
324
+ f"(attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})"
325
+ )
326
+ continue
327
+
328
+ return result
329
+
330
+ print("[DatasetGen] screw->RRT* planning exhausted; return -1")
331
+ return -1
332
+
333
+ planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry
334
+
335
+ env.unwrapped.evaluate()
336
+ # Get environment task list
337
+ tasks = list(getattr(env.unwrapped, "task_list", []) or [])
338
+
339
+ print(f"{env_id}: Task list has {len(tasks)} tasks")
340
+
341
+ # 4. Iterate and execute all subtasks
342
+ for idx, task_entry in enumerate(tasks):
343
+ task_name = task_entry.get("name", f"Task {idx}")
344
+ print(f"Executing task {idx + 1}/{len(tasks)}: {task_name}")
345
+
346
+ solve_callable = task_entry.get("solve")
347
+ if not callable(solve_callable):
348
+ raise ValueError(
349
+ f"Task '{task_name}' must supply a callable 'solve'."
350
+ )
351
+
352
+ # Evaluate once before executing solve
353
+ env.unwrapped.evaluate(solve_complete_eval=True)
354
+ screw_failed = False
355
+ try:
356
+ # 5. Call planner to solve current task
357
+ solve_result = solve_callable(env, planner)
358
+ if isinstance(solve_result, int) and solve_result == -1:
359
+ screw_failed = True
360
+ print(f"Screw->RRT* planning exhausted during '{task_name}'")
361
+ env.unwrapped.failureflag = torch.tensor([True])
362
+ env.unwrapped.successflag = torch.tensor([False])
363
+ env.unwrapped.current_task_failure = True
364
+ except ScrewPlanFailure as exc:
365
+ # Planning failure handling
366
+ screw_failed = True
367
+ print(f"Screw plan failure during '{task_name}': {exc}")
368
+ env.unwrapped.failureflag = torch.tensor([True])
369
+ env.unwrapped.successflag = torch.tensor([False])
370
+ env.unwrapped.current_task_failure = True
371
+ except FailsafeTimeout as exc:
372
+ # Timeout handling
373
+ print(f"Failsafe: {exc}")
374
+ break
375
+
376
+ # Evaluation after task execution
377
+ evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
378
+
379
+ fail_flag = evaluation.get("fail", False)
380
+ success_flag = evaluation.get("success", False)
381
+
382
+ # 6. Check success/failure conditions
383
+ if _tensor_to_bool(success_flag):
384
+ print("All tasks completed successfully.")
385
+ episode_successful = True
386
+ break
387
+
388
+ if screw_failed or _tensor_to_bool(fail_flag):
389
+ print("Encountered failure condition; stopping task sequence.")
390
+ break
391
+
392
+ else:
393
+ # If loop ends normally (no break), check success again
394
+ evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
395
+ episode_successful = _tensor_to_bool(evaluation.get("success", False))
396
+
397
+ # 7. Prioritize wrapper's success signal (double check)
398
+ episode_successful = episode_successful or _tensor_to_bool(
399
+ getattr(env, "episode_success", False)
400
+ )
401
+
402
+ except SceneGenerationError as exc:# Scene generation failure may occur in environments like swingxtimes
403
+ print(
404
+ f"Scene generation failed for env {env_id}, episode {episode}, seed {seed}: {exc}"
405
+ )
406
+ episode_successful = False
407
+ finally:
408
+ if env is not None:
409
+ try:
410
+ env.close()
411
+ except Exception as close_exc:
412
+ # Even if close() fails, return success if episode was successful
413
+ # Because HDF5 data was written before close() (in write() method)
414
+ print(f"Warning: Exception during env.close() for episode {episode}, seed {seed}: {close_exc}")
415
+ # If episode was successful, close() exception should not affect return value
416
+ # episode_successful was determined before close()
417
+
418
+ status_text = "SUCCESS" if episode_successful else "FAILED"
419
+ print(
420
+ f"--- Finished Running simulation for episode:{episode}, seed:{seed}, env: {env_id} [{status_text}] ---"
421
+ )
422
+
423
+ return episode_successful
424
+
425
+
426
+ def run_env_dataset(
427
+ env_id: str,
428
+ episode_indices: Iterable[int],
429
+ temp_folder: Path,
430
+ save_video: bool,
431
+ metadata_records: List[Dict[str, Any]],
432
+ gpu_id: int,
433
+ ) -> List[Dict[str, Any]]:
434
+ """
435
+ Run dataset generation for a batch of episodes and save data to temporary folder.
436
+
437
+ Args:
438
+ env_id: Environment ID
439
+ episode_indices: List of episode indices to run
440
+ temp_folder: Temporary folder to save data
441
+ save_video: Whether to save video
442
+ metadata_records: Records from reference dataset metadata
443
+ gpu_id: GPU ID to use
444
+
445
+ Returns:
446
+ Generated episode metadata record list
447
+ """
448
+ # Set GPU used by current process
449
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
450
+
451
+ temp_folder.mkdir(parents=True, exist_ok=True)
452
+ episode_indices = list(episode_indices)
453
+ if not episode_indices:
454
+ return []
455
+
456
+ if env_id not in DEFAULT_ENVS:
457
+ raise ValueError(f"Unsupported environment: {env_id}")
458
+
459
+ # Pass a temporary h5 file path to wrapper
460
+ # Note: wrapper will actually create separate episode files in a subfolder of that path's directory
461
+ temp_dataset_path = temp_folder / f"temp_chunk.h5"
462
+ episode_records: List[Dict[str, Any]] = []
463
+
464
+ for episode in episode_indices:
465
+ candidate_pairs = _build_seed_candidates_from_metadata(episode, metadata_records)
466
+ if not candidate_pairs:
467
+ print(f"Episode {episode}: no metadata candidate seeds found, skipping.")
468
+ continue
469
+
470
+ episode_success = False
471
+ MAX_RETRY_ATTEMPTS = 20
472
+
473
+ for attempt_idx, candidate in enumerate(candidate_pairs, start=1):
474
+ base_seed = int(candidate["seed"])
475
+ difficulty = str(candidate["difficulty"])
476
+
477
+ current_seed = base_seed
478
+ for retry_count in range(MAX_RETRY_ATTEMPTS):
479
+ if retry_count > 0:
480
+ current_seed += 1
481
+
482
+ print(
483
+ f"Episode {episode} attempt {retry_count + 1}/{MAX_RETRY_ATTEMPTS} "
484
+ f"with seed={current_seed} (base={base_seed}, diff={difficulty})"
485
+ )
486
+
487
+ try:
488
+ success = _run_episode_attempt(
489
+ env_id=env_id,
490
+ episode=episode,
491
+ seed=current_seed,
492
+ temp_dataset_path=temp_dataset_path,
493
+ save_video=save_video,
494
+ difficulty=difficulty,
495
+ )
496
+
497
+ if success:
498
+ # Record successful episode information
499
+ episode_records.append(
500
+ {
501
+ "task": env_id,
502
+ "episode": episode,
503
+ "seed": current_seed,
504
+ "difficulty": difficulty,
505
+ }
506
+ )
507
+ episode_success = True
508
+ break # Break retry loop (seed increment loop)
509
+
510
+ print(
511
+ f"Episode {episode} failed with seed {current_seed}; retrying with seed+1..."
512
+ )
513
+ except Exception as exc:
514
+ print(
515
+ f"Episode {episode} exception with seed {current_seed}: {exc}; retrying with seed+1..."
516
+ )
517
+
518
+ if episode_success:
519
+ break # Break candidate loop
520
+
521
+ if not episode_success:
522
+ print(
523
+ f"Episode {episode} failed with strict source metadata seed; "
524
+ "metadata will not be recorded for this episode."
525
+ )
526
+
527
+ return episode_records
528
+
529
+
530
+ def _merge_dataset_from_folder(
531
+ env_id: str,
532
+ temp_folder: Path,
533
+ final_dataset_path: Path,
534
+ ) -> None:
535
+ """
536
+ Merge all episode files from temporary folder into final dataset.
537
+
538
+ Args:
539
+ env_id: Environment ID
540
+ temp_folder: Temporary folder containing episode files
541
+ final_dataset_path: Final output HDF5 file path
542
+ """
543
+ if not temp_folder.exists() or not temp_folder.is_dir():
544
+ print(f"Warning: Temporary folder {temp_folder} does not exist")
545
+ return
546
+
547
+ final_dataset_path.parent.mkdir(parents=True, exist_ok=True)
548
+
549
+ # Find subfolders created by RobommeRecordWrapper
550
+ # It usually creates directories ending with "_hdf5_files"
551
+ hdf5_folders = list(temp_folder.glob("*_hdf5_files"))
552
+
553
+ if not hdf5_folders:
554
+ print(f"Warning: No HDF5 folders found in {temp_folder}")
555
+ return
556
+
557
+ print(f"Merging episodes from {temp_folder} into {final_dataset_path}")
558
+
559
+ # Open final HDF5 file for append mode writing
560
+ with h5py.File(final_dataset_path, "a") as final_file:
561
+ for hdf5_folder in sorted(hdf5_folders):
562
+ # Get all h5 files in folder
563
+ h5_files = sorted(hdf5_folder.glob("*.h5"))
564
+
565
+ if not h5_files:
566
+ print(f"Warning: No h5 files found in {hdf5_folder}")
567
+ continue
568
+
569
+ print(f"Found {len(h5_files)} episode files in {hdf5_folder.name}")
570
+
571
+ # Merge each episode file
572
+ for h5_file in h5_files:
573
+ print(f" - Merging {h5_file.name}")
574
+
575
+ try:
576
+ with h5py.File(h5_file, "r") as episode_file:
577
+ file_keys = list(episode_file.keys())
578
+ if len(file_keys) == 0:
579
+ print(f" Warning: {h5_file.name} is empty, skipping...")
580
+ continue
581
+
582
+ for env_group_name, src_env_group in episode_file.items():
583
+ episode_keys = list(src_env_group.keys()) if isinstance(src_env_group, h5py.Group) else []
584
+ if len(episode_keys) == 0:
585
+ print(f" Warning: {env_group_name} in {h5_file.name} has no episodes, skipping...")
586
+ continue
587
+
588
+ # If environment group (e.g. 'PickXtimes') does not exist, copy directly
589
+ if env_group_name not in final_file:
590
+ final_file.copy(src_env_group, env_group_name)
591
+ continue
592
+
593
+ dest_env_group = final_file[env_group_name]
594
+ if not isinstance(dest_env_group, h5py.Group):
595
+ print(f" Warning: {env_group_name} is not a group, skipping...")
596
+ continue
597
+
598
+ # If environment group exists, copy episodes one by one
599
+ for episode_name in src_env_group.keys():
600
+ if episode_name in dest_env_group:
601
+ print(f" Warning: Episode {episode_name} already exists, overwriting...")
602
+ del dest_env_group[episode_name]
603
+ src_env_group.copy(episode_name, dest_env_group, name=episode_name)
604
+ except Exception as e:
605
+ print(f" Error merging {h5_file.name}: {e}")
606
+ continue
607
+
608
+ # Keep videos: wrapper writes videos to 'videos' under temp dir, move to final dir before cleanup
609
+ temp_videos_dir = temp_folder / "videos"
610
+ final_videos_dir = final_dataset_path.parent / "videos"
611
+ if temp_videos_dir.exists() and temp_videos_dir.is_dir():
612
+ final_videos_dir.mkdir(parents=True, exist_ok=True)
613
+ moved_count = 0
614
+ for video_path in sorted(temp_videos_dir.glob("*.mp4")):
615
+ target_path = final_videos_dir / video_path.name
616
+ if target_path.exists():
617
+ stem = target_path.stem
618
+ suffix = target_path.suffix
619
+ index = 1
620
+ while True:
621
+ candidate = final_videos_dir / f"{stem}_dup{index}{suffix}"
622
+ if not candidate.exists():
623
+ target_path = candidate
624
+ break
625
+ index += 1
626
+ try:
627
+ shutil.move(str(video_path), str(target_path))
628
+ moved_count += 1
629
+ except Exception as exc:
630
+ print(f"Warning: Failed to move video {video_path.name}: {exc}")
631
+ if moved_count > 0:
632
+ print(f"Moved {moved_count} videos to {final_videos_dir}")
633
+
634
+ # Clean up temporary folder after successful merge
635
+ try:
636
+ shutil.rmtree(temp_folder)
637
+ print(f"Cleaned up temporary folder: {temp_folder}")
638
+ except Exception as e:
639
+ print(f"Warning: Failed to remove temporary folder {temp_folder}: {e}")
640
+
641
+
642
+ def _save_episode_metadata(
643
+ records: List[Dict[str, Any]],
644
+ metadata_path: Path,
645
+ env_id: str,
646
+ ) -> None:
647
+ """Save seed/difficulty metadata for each episode to JSON file."""
648
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
649
+ sorted_records = sorted(records, key=lambda rec: rec.get("episode", -1))
650
+ metadata = {
651
+ "env_id": env_id,
652
+ "record_count": len(sorted_records),
653
+ "records": sorted_records,
654
+ }
655
+ try:
656
+ with metadata_path.open("w", encoding="utf-8") as metadata_file:
657
+ json.dump(metadata, metadata_file, indent=2)
658
+ print(f"Saved episode metadata to {metadata_path}")
659
+ except Exception as exc:
660
+ print(f"Warning: Failed to save episode metadata to {metadata_path}: {exc}")
661
+
662
+
663
+ def parse_args() -> argparse.Namespace:
664
+ parser = argparse.ArgumentParser(description="Robomme Dataset Generator")
665
+ parser.add_argument(
666
+ "--env",
667
+ "-e",
668
+ type=str,
669
+ nargs="+",
670
+ default=None,
671
+ help="Environment IDs to run. Provide one or more values; defaults to all built-in Robomme environments.",
672
+ )
673
+ parser.add_argument(
674
+ "--episodes",
675
+ "-n",
676
+ type=int,
677
+ default=100,
678
+ help="Number of episodes generated per environment (Default: 100)",
679
+ )
680
+ parser.add_argument(
681
+ "--save-video",
682
+ dest="save_video",
683
+ action="store_true",
684
+ default=True,
685
+ help="Enable video recording via RobommeRecordWrapper (Default: Enabled).",
686
+ )
687
+ parser.add_argument(
688
+ "--no-save-video",
689
+ dest="save_video",
690
+ action="store_false",
691
+ help="Disable video recording.",
692
+ )
693
+ parser.add_argument(
694
+ "--max-workers",
695
+ "-w",
696
+ type=int,
697
+ default=20,
698
+ help="Number of parallel workers when running multiple environments.",
699
+ )
700
+ parser.add_argument(
701
+ "--gpus",
702
+ type=str,
703
+ default="1",
704
+ help="GPU selection. Supported values: '0', '1', '0,1' (or '1,0'). Default: '0'.",
705
+ )
706
+ return parser.parse_args()
707
+
708
+
709
+ def _parse_gpu_ids(gpu_spec: str) -> List[int]:
710
+ """Parse user GPU spec string to a deduplicated GPU id list."""
711
+ valid_gpu_ids = {0, 1}
712
+ raw_tokens = [token.strip() for token in gpu_spec.split(",") if token.strip()]
713
+ if not raw_tokens:
714
+ raise ValueError("GPU spec is empty. Use one of: 0, 1, 0,1")
715
+
716
+ gpu_ids: List[int] = []
717
+ for token in raw_tokens:
718
+ try:
719
+ gpu_id = int(token)
720
+ except ValueError as exc:
721
+ raise ValueError(
722
+ f"Invalid GPU id '{token}'. Supported values are 0 and 1."
723
+ ) from exc
724
+
725
+ if gpu_id not in valid_gpu_ids:
726
+ raise ValueError(
727
+ f"Unsupported GPU id '{gpu_id}'. Supported values are 0 and 1."
728
+ )
729
+ if gpu_id not in gpu_ids:
730
+ gpu_ids.append(gpu_id)
731
+
732
+ if not gpu_ids:
733
+ raise ValueError("No valid GPU id provided. Use one of: 0, 1, 0,1")
734
+ return gpu_ids
735
+
736
+
737
+ def main() -> None:
738
+ args = parse_args()
739
+ env_inputs = args.env or DEFAULT_ENVS
740
+ env_ids: List[str] = []
741
+ # Parse environment list arguments, support comma separation
742
+ for raw_env in env_inputs:
743
+ env_ids.extend(env.strip() for env in raw_env.split(",") if env.strip())
744
+
745
+ if not env_ids:
746
+ env_ids = DEFAULT_ENVS.copy()
747
+
748
+ num_workers = max(1, args.max_workers)
749
+ gpu_spec = args.gpus
750
+ gpu_ids = _parse_gpu_ids(gpu_spec)
751
+ episode_indices = list(range(args.episodes))
752
+
753
+ for env_id in env_ids:
754
+ source_metadata_records = _load_env_metadata_records(
755
+ env_id=env_id,
756
+ metadata_root=SOURCE_METADATA_ROOT,
757
+ )
758
+
759
+ # Create shared temporary folder for all episodes
760
+ temp_folder = Path(f"/data/hongzefu/data_0226/temp_{env_id}_episodes")
761
+ final_dataset_path = Path(f"/data/hongzefu/data_0226/record_dataset_{env_id}.h5")
762
+ #final_dataset_path = Path(f"/data/hongzefu/dataset_generate/record_dataset_{env_id}.h5")
763
+
764
+ print(f"\n{'='*80}")
765
+ print(f"Environment: {env_id}")
766
+ print(f"Episodes: {args.episodes}")
767
+ print(f"Workers: {num_workers}")
768
+ if len(gpu_ids) == 1:
769
+ print(f"GPU mode: Single GPU ({gpu_ids[0]})")
770
+ else:
771
+ print(f"GPU mode: Multi GPU ({','.join(str(gpu) for gpu in gpu_ids)})")
772
+ print(f"Temporary folder: {temp_folder}")
773
+ print(f"Final dataset: {final_dataset_path}")
774
+ print(f"{'='*80}\n")
775
+
776
+ episode_records: List[Dict[str, Any]] = []
777
+
778
+ if num_workers > 1:
779
+ # 1. Split task chunks
780
+ episode_chunks = _split_episode_indices(args.episodes, num_workers)
781
+
782
+ if len(episode_chunks) <= 1:
783
+ # Single chunk, run directly
784
+ chunk = episode_chunks[0] if episode_chunks else []
785
+ episode_records = run_env_dataset(
786
+ env_id,
787
+ chunk,
788
+ temp_folder,
789
+ args.save_video,
790
+ source_metadata_records,
791
+ gpu_ids[0],
792
+ )
793
+ else:
794
+ worker_count = len(episode_chunks)
795
+ print(
796
+ f"Running {env_id} with {worker_count} workers across {args.episodes} episodes..."
797
+ )
798
+
799
+ future_to_chunk = {}
800
+ futures = []
801
+ if len(gpu_ids) == 1:
802
+ print(
803
+ f"Assigning all {len(episode_chunks)} chunks to GPU {gpu_ids[0]} ({num_workers} workers)"
804
+ )
805
+ else:
806
+ print(
807
+ f"Assigning {len(episode_chunks)} chunks across GPUs {','.join(str(gpu) for gpu in gpu_ids)}"
808
+ )
809
+
810
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
811
+ for chunk_idx, chunk in enumerate(episode_chunks):
812
+ assigned_gpu = gpu_ids[chunk_idx % len(gpu_ids)]
813
+ f = executor.submit(
814
+ run_env_dataset,
815
+ env_id,
816
+ chunk,
817
+ temp_folder,
818
+ args.save_video,
819
+ source_metadata_records,
820
+ assigned_gpu,
821
+ )
822
+ future_to_chunk[f] = (chunk, assigned_gpu)
823
+ futures.append(f)
824
+
825
+ for future in as_completed(futures):
826
+ chunk, assigned_gpu = future_to_chunk[future]
827
+ chunk_label = (chunk[0], chunk[-1]) if chunk else ("?", "?")
828
+ try:
829
+ records = future.result()
830
+ episode_records.extend(records)
831
+ print(
832
+ f"✓ Completed episodes {chunk_label[0]}-{chunk_label[1]} for {env_id} on GPU {assigned_gpu}"
833
+ )
834
+ except Exception as exc:
835
+ print(
836
+ f"✗ Environment {env_id} failed on episodes "
837
+ f"{chunk_label[0]}-{chunk_label[1]} (GPU {assigned_gpu}) with error: {exc}"
838
+ )
839
+
840
+ # 3. Merge all episode files into final dataset
841
+ print(f"\nMerging all episodes into final dataset...")
842
+ _merge_dataset_from_folder(
843
+ env_id,
844
+ temp_folder,
845
+ final_dataset_path,
846
+ )
847
+ else:
848
+ # Single worker mode
849
+ episode_records = run_env_dataset(
850
+ env_id,
851
+ episode_indices,
852
+ temp_folder,
853
+ args.save_video,
854
+ source_metadata_records,
855
+ gpu_ids[0], # gpu_id
856
+ )
857
+
858
+ # Merge episodes into final dataset
859
+ print(f"\nMerging all episodes into final dataset...")
860
+ _merge_dataset_from_folder(
861
+ env_id,
862
+ temp_folder,
863
+ final_dataset_path,
864
+ )
865
+
866
+ # 4. Save metadata
867
+ metadata_path = final_dataset_path.with_name(
868
+ f"{final_dataset_path.stem}_metadata.json"
869
+ )
870
+ _save_episode_metadata(episode_records, metadata_path, env_id)
871
+
872
+ print(f"\n✓ Finished! Final dataset saved to: {final_dataset_path}\n")
873
+
874
+ print("✓ All requested environments processed.")
875
+
876
+
877
+ if __name__ == "__main__":
878
+ main()