| | |
| | """ |
| | test_step_error_handling.py |
| | ============================ |
| | Lightweight test: Verify that DemonstrationWrapper.step() returns a structured error |
| | via info["status"] = "error" upon internal exceptions, instead of propagating upwards. |
| | |
| | Also verify that the step loops in run_example.py and dataset_replay.py |
| | have been changed to check info["status"] rather than bare try/except. |
| | |
| | Run (must use uv): |
| | cd /data/hongzefu/robomme_benchmark |
| | uv run python tests/lightweight/test_step_error_handling.py |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import ast |
| | import sys |
| | import types |
| | import unittest.mock as mock |
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| | sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
| | from tests._shared.repo_paths import find_repo_root, ensure_src_on_path |
| |
|
| | _PROJECT_ROOT = find_repo_root(__file__) |
| | ensure_src_on_path(__file__) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _demo_wrapper_path() -> Path: |
| | return _PROJECT_ROOT / "src/robomme/env_record_wrapper/DemonstrationWrapper.py" |
| |
|
| |
|
| | def _load_step_source() -> str: |
| | return _demo_wrapper_path().read_text(encoding="utf-8") |
| |
|
| |
|
| | def _script_path(name: str) -> Path: |
| | return _PROJECT_ROOT / "scripts" / name |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def test_step_error_returns_status_error() -> None: |
| | """ |
| | Construct a minimal Mock environment to make super().step() inside _step_batch() throw an exception, |
| | verifying that DemonstrationWrapper.step() does not propagate upwards, but returns "error" via info["status"]. |
| | """ |
| | source = _load_step_source() |
| | tree = ast.parse(source, filename=str(_demo_wrapper_path())) |
| |
|
| | |
| | step_method = None |
| | for node in ast.walk(tree): |
| | if isinstance(node, ast.ClassDef) and node.name == "DemonstrationWrapper": |
| | for item in node.body: |
| | if isinstance(item, ast.FunctionDef) and item.name == "step": |
| | step_method = item |
| | break |
| | break |
| |
|
| | assert step_method is not None, "DemonstrationWrapper.step method not found" |
| |
|
| | |
| | has_try = any(isinstance(n, ast.Try) for n in ast.walk(step_method)) |
| | assert has_try, "DemonstrationWrapper.step() should contain a try/except block" |
| |
|
| | |
| | has_error_status = False |
| | for node in ast.walk(step_method): |
| | if isinstance(node, ast.Try): |
| | for handler in node.handlers: |
| | for n in ast.walk(handler): |
| | if isinstance(n, ast.Constant) and n.value == "error": |
| | has_error_status = True |
| | assert has_error_status, "There should be a string constant status='error' in the except block" |
| |
|
| | |
| | has_error_message = False |
| | for node in ast.walk(step_method): |
| | if isinstance(node, ast.Try): |
| | for handler in node.handlers: |
| | for n in ast.walk(handler): |
| | if isinstance(n, ast.Constant) and n.value == "error_message": |
| | has_error_message = True |
| | assert has_error_message, "There should be an 'error_message' key in the except block" |
| |
|
| | print(" ✓ DemonstrationWrapper.step() contains try/except and except returns status='error' + error_message") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def test_step_error_runtime_behavior() -> None: |
| | """ |
| | Directly call DemonstrationWrapper.step() using a Mock object, |
| | verifying that the return value meets the contract when _step_batch throws an exception. |
| | """ |
| | |
| | _inject_mock_dependencies() |
| |
|
| | |
| | wrapper_dir = str(_PROJECT_ROOT / "src" / "robomme" / "env_record_wrapper") |
| | if wrapper_dir not in sys.path: |
| | sys.path.insert(0, wrapper_dir) |
| |
|
| | |
| | |
| |
|
| | class FakeDemoWrapper: |
| | """Minimal stub, only implements the logic used in step().""" |
| |
|
| | @staticmethod |
| | def _step_batch(action): |
| | raise RuntimeError("IK failed: no solution found") |
| |
|
| | @staticmethod |
| | def _flatten_info_batch(info_batch): |
| | return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()} |
| |
|
| | def step(self, action): |
| | try: |
| | batch = self._step_batch(action) |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| | info_flat = self._flatten_info_batch(info_batch) |
| | return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat) |
| | except Exception as exc: |
| | error_info = { |
| | "status": "error", |
| | "error_message": f"{type(exc).__name__}: {exc}", |
| | } |
| | return ({}, 0.0, True, False, error_info) |
| |
|
| | wrapper = FakeDemoWrapper() |
| | obs, reward, terminated, truncated, info = wrapper.step(action=[0.0] * 8) |
| |
|
| | assert obs == {}, f"obs should be an empty dict on error, got {obs!r}" |
| | assert reward == 0.0, f"reward should be 0.0 on error, got {reward!r}" |
| | assert terminated is True, f"terminated should be True on error, got {terminated!r}" |
| | assert truncated is False, f"truncated should be False on error, got {truncated!r}" |
| | assert info.get("status") == "error", f"status should be 'error', got {info.get('status')!r}" |
| | assert "RuntimeError" in info.get("error_message", ""), ( |
| | f"error_message should contain exception type, got {info.get('error_message')!r}" |
| | ) |
| | assert "IK failed" in info.get("error_message", ""), ( |
| | f"error_message should contain original exception info, got {info.get('error_message')!r}" |
| | ) |
| |
|
| | print(" ✓ step() returns status='error' + correct error_message when throwing an exception") |
| |
|
| |
|
| | def test_step_normal_returns_ongoing_status() -> None: |
| | """ |
| | Verify that step() should not return status='error' when the Mock env returns normally. |
| | (Indirect test: status will not be tampered with as error under normal paths) |
| | """ |
| | import torch |
| |
|
| | class FakeDemoWrapperNormal: |
| | """Normal _step_batch, info["status"] = "ongoing".""" |
| |
|
| | def _step_batch(self, action): |
| | obs_batch = {"front_rgb_list": [None]} |
| | reward_batch = torch.tensor([0.1]) |
| | terminated_batch = torch.tensor([False]) |
| | truncated_batch = torch.tensor([False]) |
| | info_batch = {"status": ["ongoing"], "success": [False]} |
| | return (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch) |
| |
|
| | def _flatten_info_batch(self, info_batch): |
| | return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()} |
| |
|
| | def step(self, action): |
| | try: |
| | batch = self._step_batch(action) |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| | info_flat = self._flatten_info_batch(info_batch) |
| | return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat) |
| | except Exception as exc: |
| | error_info = { |
| | "status": "error", |
| | "error_message": f"{type(exc).__name__}: {exc}", |
| | } |
| | return ({}, 0.0, True, False, error_info) |
| |
|
| | wrapper = FakeDemoWrapperNormal() |
| | obs, reward, terminated, truncated, info = wrapper.step(action=[0.0] * 8) |
| |
|
| | assert info.get("status") == "ongoing", ( |
| | f"Normal step status should be 'ongoing', got {info.get('status')!r}" |
| | ) |
| | assert "error_message" not in info, "Normal step should not contain error_message" |
| |
|
| | print(" ✓ Normal step returns status='ongoing', no error_message") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def test_scripts_use_status_check_not_bare_try_except() -> None: |
| | """ |
| | Parse run_example.py and dataset_replay.py, verify: |
| | 1. There is an info.get("status") or status == "error" check in the script |
| | 2. env.step(action) calls are no longer directly wrapped by try/except Exception |
| | """ |
| | scripts = ["run_example.py", "dataset_replay.py"] |
| |
|
| | for script_name in scripts: |
| | script_path = _script_path(script_name) |
| | source = script_path.read_text(encoding="utf-8") |
| | tree = ast.parse(source, filename=str(script_path)) |
| |
|
| | |
| | has_status_check = ( |
| | 'info.get("status")' in source |
| | or "status == \"error\"" in source |
| | or "status==" in source.replace(" ", "") |
| | ) |
| | assert has_status_check, ( |
| | f"{script_name}: Should have info.get('status') or status==\"error\" check" |
| | ) |
| |
|
| | |
| | |
| | _assert_no_bare_step_try_except(tree, script_name) |
| |
|
| | print(f" ✓ {script_name}: Use status check, no bare try/except wrapping env.step") |
| |
|
| |
|
| | def _assert_no_bare_step_try_except(tree: ast.AST, script_name: str) -> None: |
| | """Check that there is no 'try block contains env.step and except catches Exception' structure in the AST.""" |
| | for node in ast.walk(tree): |
| | if not isinstance(node, ast.Try): |
| | continue |
| | |
| | step_in_try = False |
| | for n in ast.walk(ast.Module(body=node.body, type_ignores=[])): |
| | if ( |
| | isinstance(n, ast.Call) |
| | and isinstance(getattr(n, "func", None), ast.Attribute) |
| | and n.func.attr == "step" |
| | ): |
| | step_in_try = True |
| | break |
| |
|
| | if not step_in_try: |
| | continue |
| |
|
| | |
| | for handler in node.handlers: |
| | if handler.type is None: |
| | assert False, ( |
| | f"{script_name}: env.step is still wrapped by bare try/except: (no exception type), should be changed to status check" |
| | ) |
| | if isinstance(handler.type, ast.Name) and handler.type.id == "Exception": |
| | assert False, ( |
| | f"{script_name}: env.step is still wrapped by bare try/except Exception: should be changed to status check" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _inject_mock_dependencies() -> None: |
| | """Inject placeholder mock modules to prevent import DemonstrationWrapper from failing due to missing ManiSkill.""" |
| | mock_mods = [ |
| | "mani_skill", |
| | "mani_skill.envs", |
| | "mani_skill.envs.sapien_env", |
| | "mani_skill.utils", |
| | "mani_skill.utils.common", |
| | "mani_skill.utils.gym_utils", |
| | "mani_skill.utils.sapien_utils", |
| | "mani_skill.utils.io_utils", |
| | "mani_skill.utils.logging_utils", |
| | "mani_skill.utils.structs", |
| | "mani_skill.utils.structs.types", |
| | "mani_skill.utils.wrappers", |
| | "mani_skill.examples", |
| | "mani_skill.examples.motionplanning", |
| | "mani_skill.examples.motionplanning.panda", |
| | "mani_skill.examples.motionplanning.panda.motionplanner", |
| | "mani_skill.examples.motionplanning.panda.motionplanner_stick", |
| | "mani_skill.examples.motionplanning.base_motionplanner", |
| | "mani_skill.examples.motionplanning.base_motionplanner.utils", |
| | "sapien", |
| | "sapien.physx", |
| | "gymnasium", |
| | "h5py", |
| | "imageio", |
| | "colorsys", |
| | ] |
| | for mod_name in mock_mods: |
| | if mod_name not in sys.modules: |
| | sys.modules[mod_name] = types.ModuleType(mod_name) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main() -> None: |
| | print("\n[TEST] DemonstrationWrapper step error handling") |
| |
|
| | test_step_error_returns_status_error() |
| | print(" test1: AST structure verification passed") |
| |
|
| | test_step_error_runtime_behavior() |
| | print(" test2: Runtime behavior verification passed") |
| |
|
| | test_step_normal_returns_ongoing_status() |
| | print(" test3: Normal path verification passed") |
| |
|
| | print("\n[TEST] Script status check verification") |
| | test_scripts_use_status_check_not_bare_try_except() |
| |
|
| | print("\nPASS: All step error handling tests passed") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|