File size: 13,967 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# -*- coding: utf-8 -*-
"""
test_step_error_handling.py
============================
Lightweight test: Verify that DemonstrationWrapper.step() returns a structured error
via info["status"] = "error" upon internal exceptions, instead of propagating upwards.

Also verify that the step loops in run_example.py and dataset_replay.py
have been changed to check info["status"] rather than bare try/except.

Run (must use uv):
    cd /data/hongzefu/robomme_benchmark
    uv run python tests/lightweight/test_step_error_handling.py
"""

from __future__ import annotations

import ast
import sys
import types
import unittest.mock as mock
from pathlib import Path

# ---------------------------------------------------------------------------
# Repo path helpers
# ---------------------------------------------------------------------------
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from tests._shared.repo_paths import find_repo_root, ensure_src_on_path

_PROJECT_ROOT = find_repo_root(__file__)
ensure_src_on_path(__file__)


# ---------------------------------------------------------------------------
# Helpers: load DemonstrationWrapper.step source for AST inspection
# ---------------------------------------------------------------------------

def _demo_wrapper_path() -> Path:
    return _PROJECT_ROOT / "src/robomme/env_record_wrapper/DemonstrationWrapper.py"


def _load_step_source() -> str:
    return _demo_wrapper_path().read_text(encoding="utf-8")


def _script_path(name: str) -> Path:
    return _PROJECT_ROOT / "scripts" / name


# ---------------------------------------------------------------------------
# Test 1: DemonstrationWrapper.step() catches exceptions and returns status="error"
# ---------------------------------------------------------------------------

def test_step_error_returns_status_error() -> None:
    """
    Construct a minimal Mock environment to make super().step() inside _step_batch() throw an exception,
    verifying that DemonstrationWrapper.step() does not propagate upwards, but returns "error" via info["status"].
    """
    source = _load_step_source()
    tree = ast.parse(source, filename=str(_demo_wrapper_path()))

    # Find step() method, verify try/except structure exists
    step_method = None
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef) and node.name == "DemonstrationWrapper":
            for item in node.body:
                if isinstance(item, ast.FunctionDef) and item.name == "step":
                    step_method = item
                    break
            break

    assert step_method is not None, "DemonstrationWrapper.step method not found"

    # Verify that the step method body contains try/except
    has_try = any(isinstance(n, ast.Try) for n in ast.walk(step_method))
    assert has_try, "DemonstrationWrapper.step() should contain a try/except block"

    # Verify that status = "error" is set in the except block
    has_error_status = False
    for node in ast.walk(step_method):
        if isinstance(node, ast.Try):
            for handler in node.handlers:
                for n in ast.walk(handler):
                    if isinstance(n, ast.Constant) and n.value == "error":
                        has_error_status = True
    assert has_error_status, "There should be a string constant status='error' in the except block"

    # Verify that the error_message key exists in the except block
    has_error_message = False
    for node in ast.walk(step_method):
        if isinstance(node, ast.Try):
            for handler in node.handlers:
                for n in ast.walk(handler):
                    if isinstance(n, ast.Constant) and n.value == "error_message":
                        has_error_message = True
    assert has_error_message, "There should be an 'error_message' key in the except block"

    print("  ✓ DemonstrationWrapper.step() contains try/except and except returns status='error' + error_message")


# ---------------------------------------------------------------------------
# Test 2: Runtime behavior verification — Mock actual call
# ---------------------------------------------------------------------------

def test_step_error_runtime_behavior() -> None:
    """
    Directly call DemonstrationWrapper.step() using a Mock object,
    verifying that the return value meets the contract when _step_batch throws an exception.
    """
    # Dynamically inject Mock dependencies, do not actually import ManiSkill
    _inject_mock_dependencies()

    # Set sys.path to point to the directory where DemonstrationWrapper is located, for from episode... import etc.
    wrapper_dir = str(_PROJECT_ROOT / "src" / "robomme" / "env_record_wrapper")
    if wrapper_dir not in sys.path:
        sys.path.insert(0, wrapper_dir)

    # Directly execute the try/except logic of step() without relying on a real class instance
    # ——by constructing a fake instance where _step_batch throws an exception

    class FakeDemoWrapper:
        """Minimal stub, only implements the logic used in step()."""

        @staticmethod
        def _step_batch(action):
            raise RuntimeError("IK failed: no solution found")

        @staticmethod
        def _flatten_info_batch(info_batch):
            return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}

        def step(self, action):
            try:
                batch = self._step_batch(action)
                obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
                info_flat = self._flatten_info_batch(info_batch)
                return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat)
            except Exception as exc:
                error_info = {
                    "status": "error",
                    "error_message": f"{type(exc).__name__}: {exc}",
                }
                return ({}, 0.0, True, False, error_info)

    wrapper = FakeDemoWrapper()
    obs, reward, terminated, truncated, info = wrapper.step(action=[0.0] * 8)

    assert obs == {}, f"obs should be an empty dict on error, got {obs!r}"
    assert reward == 0.0, f"reward should be 0.0 on error, got {reward!r}"
    assert terminated is True, f"terminated should be True on error, got {terminated!r}"
    assert truncated is False, f"truncated should be False on error, got {truncated!r}"
    assert info.get("status") == "error", f"status should be 'error', got {info.get('status')!r}"
    assert "RuntimeError" in info.get("error_message", ""), (
        f"error_message should contain exception type, got {info.get('error_message')!r}"
    )
    assert "IK failed" in info.get("error_message", ""), (
        f"error_message should contain original exception info, got {info.get('error_message')!r}"
    )

    print("  ✓ step() returns status='error' + correct error_message when throwing an exception")


def test_step_normal_returns_ongoing_status() -> None:
    """
    Verify that step() should not return status='error' when the Mock env returns normally.
    (Indirect test: status will not be tampered with as error under normal paths)
    """
    import torch

    class FakeDemoWrapperNormal:
        """Normal _step_batch, info["status"] = "ongoing"."""

        def _step_batch(self, action):
            obs_batch = {"front_rgb_list": [None]}
            reward_batch = torch.tensor([0.1])
            terminated_batch = torch.tensor([False])
            truncated_batch = torch.tensor([False])
            info_batch = {"status": ["ongoing"], "success": [False]}
            return (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch)

        def _flatten_info_batch(self, info_batch):
            return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}

        def step(self, action):
            try:
                batch = self._step_batch(action)
                obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
                info_flat = self._flatten_info_batch(info_batch)
                return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat)
            except Exception as exc:
                error_info = {
                    "status": "error",
                    "error_message": f"{type(exc).__name__}: {exc}",
                }
                return ({}, 0.0, True, False, error_info)

    wrapper = FakeDemoWrapperNormal()
    obs, reward, terminated, truncated, info = wrapper.step(action=[0.0] * 8)

    assert info.get("status") == "ongoing", (
        f"Normal step status should be 'ongoing', got {info.get('status')!r}"
    )
    assert "error_message" not in info, "Normal step should not contain error_message"

    print("  ✓ Normal step returns status='ongoing', no error_message")


# ---------------------------------------------------------------------------
# Test 3: AST check that scripts no longer have bare try/except wrapping env.step(action)
# ---------------------------------------------------------------------------

def test_scripts_use_status_check_not_bare_try_except() -> None:
    """
    Parse run_example.py and dataset_replay.py, verify:
    1. There is an info.get("status") or status == "error" check in the script
    2. env.step(action) calls are no longer directly wrapped by try/except Exception
    """
    scripts = ["run_example.py", "dataset_replay.py"]

    for script_name in scripts:
        script_path = _script_path(script_name)
        source = script_path.read_text(encoding="utf-8")
        tree = ast.parse(source, filename=str(script_path))

        # ---- Check 1: Contains status related checks ----
        has_status_check = (
            'info.get("status")' in source
            or "status == \"error\"" in source
            or "status==" in source.replace(" ", "")
        )
        assert has_status_check, (
            f"{script_name}: Should have info.get('status') or status==\"error\" check"
        )

        # ---- Check 2: env.step is not wrapped by bare try/except ----
        # Exact lookup: env.step is directly called in the Try block and handler catches Exception
        _assert_no_bare_step_try_except(tree, script_name)

        print(f"  ✓ {script_name}: Use status check, no bare try/except wrapping env.step")


def _assert_no_bare_step_try_except(tree: ast.AST, script_name: str) -> None:
    """Check that there is no 'try block contains env.step and except catches Exception' structure in the AST."""
    for node in ast.walk(tree):
        if not isinstance(node, ast.Try):
            continue
        # Check if there is an env.step(action) call in the try body
        step_in_try = False
        for n in ast.walk(ast.Module(body=node.body, type_ignores=[])):
            if (
                isinstance(n, ast.Call)
                and isinstance(getattr(n, "func", None), ast.Attribute)
                and n.func.attr == "step"
            ):
                step_in_try = True
                break

        if not step_in_try:
            continue

        # Check if handler is a bare Exception catch
        for handler in node.handlers:
            if handler.type is None:
                assert False, (
                    f"{script_name}: env.step is still wrapped by bare try/except: (no exception type), should be changed to status check"
                )
            if isinstance(handler.type, ast.Name) and handler.type.id == "Exception":
                assert False, (
                    f"{script_name}: env.step is still wrapped by bare try/except Exception: should be changed to status check"
                )


# ---------------------------------------------------------------------------
# Utility: inject mock modules so imports inside DemonstrationWrapper don't fail
# ---------------------------------------------------------------------------

def _inject_mock_dependencies() -> None:
    """Inject placeholder mock modules to prevent import DemonstrationWrapper from failing due to missing ManiSkill."""
    mock_mods = [
        "mani_skill",
        "mani_skill.envs",
        "mani_skill.envs.sapien_env",
        "mani_skill.utils",
        "mani_skill.utils.common",
        "mani_skill.utils.gym_utils",
        "mani_skill.utils.sapien_utils",
        "mani_skill.utils.io_utils",
        "mani_skill.utils.logging_utils",
        "mani_skill.utils.structs",
        "mani_skill.utils.structs.types",
        "mani_skill.utils.wrappers",
        "mani_skill.examples",
        "mani_skill.examples.motionplanning",
        "mani_skill.examples.motionplanning.panda",
        "mani_skill.examples.motionplanning.panda.motionplanner",
        "mani_skill.examples.motionplanning.panda.motionplanner_stick",
        "mani_skill.examples.motionplanning.base_motionplanner",
        "mani_skill.examples.motionplanning.base_motionplanner.utils",
        "sapien",
        "sapien.physx",
        "gymnasium",
        "h5py",
        "imageio",
        "colorsys",
    ]
    for mod_name in mock_mods:
        if mod_name not in sys.modules:
            sys.modules[mod_name] = types.ModuleType(mod_name)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    print("\n[TEST] DemonstrationWrapper step error handling")

    test_step_error_returns_status_error()
    print("  test1: AST structure verification passed")

    test_step_error_runtime_behavior()
    print("  test2: Runtime behavior verification passed")

    test_step_normal_returns_ongoing_status()
    print("  test3: Normal path verification passed")

    print("\n[TEST] Script status check verification")
    test_scripts_use_status_check_not_bare_try_except()

    print("\nPASS: All step error handling tests passed")


if __name__ == "__main__":
    main()