RoboMME / tests /lightweight /test_step_error_handling.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
# -*- coding: utf-8 -*-
"""
test_step_error_handling.py
============================
Lightweight test: Verify that DemonstrationWrapper.step() returns a structured error
via info["status"] = "error" upon internal exceptions, instead of propagating upwards.
Also verify that the step loops in run_example.py and dataset_replay.py
have been changed to check info["status"] rather than bare try/except.
Run (must use uv):
cd /data/hongzefu/robomme_benchmark
uv run python tests/lightweight/test_step_error_handling.py
"""
from __future__ import annotations
import ast
import sys
import types
import unittest.mock as mock
from pathlib import Path
# ---------------------------------------------------------------------------
# Repo path helpers
# ---------------------------------------------------------------------------
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from tests._shared.repo_paths import find_repo_root, ensure_src_on_path
_PROJECT_ROOT = find_repo_root(__file__)
ensure_src_on_path(__file__)
# ---------------------------------------------------------------------------
# Helpers: load DemonstrationWrapper.step source for AST inspection
# ---------------------------------------------------------------------------
def _demo_wrapper_path() -> Path:
return _PROJECT_ROOT / "src/robomme/env_record_wrapper/DemonstrationWrapper.py"
def _load_step_source() -> str:
return _demo_wrapper_path().read_text(encoding="utf-8")
def _script_path(name: str) -> Path:
return _PROJECT_ROOT / "scripts" / name
# ---------------------------------------------------------------------------
# Test 1: DemonstrationWrapper.step() catches exceptions and returns status="error"
# ---------------------------------------------------------------------------
def test_step_error_returns_status_error() -> None:
"""
Construct a minimal Mock environment to make super().step() inside _step_batch() throw an exception,
verifying that DemonstrationWrapper.step() does not propagate upwards, but returns "error" via info["status"].
"""
source = _load_step_source()
tree = ast.parse(source, filename=str(_demo_wrapper_path()))
# Find step() method, verify try/except structure exists
step_method = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "DemonstrationWrapper":
for item in node.body:
if isinstance(item, ast.FunctionDef) and item.name == "step":
step_method = item
break
break
assert step_method is not None, "DemonstrationWrapper.step method not found"
# Verify that the step method body contains try/except
has_try = any(isinstance(n, ast.Try) for n in ast.walk(step_method))
assert has_try, "DemonstrationWrapper.step() should contain a try/except block"
# Verify that status = "error" is set in the except block
has_error_status = False
for node in ast.walk(step_method):
if isinstance(node, ast.Try):
for handler in node.handlers:
for n in ast.walk(handler):
if isinstance(n, ast.Constant) and n.value == "error":
has_error_status = True
assert has_error_status, "There should be a string constant status='error' in the except block"
# Verify that the error_message key exists in the except block
has_error_message = False
for node in ast.walk(step_method):
if isinstance(node, ast.Try):
for handler in node.handlers:
for n in ast.walk(handler):
if isinstance(n, ast.Constant) and n.value == "error_message":
has_error_message = True
assert has_error_message, "There should be an 'error_message' key in the except block"
print(" ✓ DemonstrationWrapper.step() contains try/except and except returns status='error' + error_message")
# ---------------------------------------------------------------------------
# Test 2: Runtime behavior verification — Mock actual call
# ---------------------------------------------------------------------------
def test_step_error_runtime_behavior() -> None:
"""
Directly call DemonstrationWrapper.step() using a Mock object,
verifying that the return value meets the contract when _step_batch throws an exception.
"""
# Dynamically inject Mock dependencies, do not actually import ManiSkill
_inject_mock_dependencies()
# Set sys.path to point to the directory where DemonstrationWrapper is located, for from episode... import etc.
wrapper_dir = str(_PROJECT_ROOT / "src" / "robomme" / "env_record_wrapper")
if wrapper_dir not in sys.path:
sys.path.insert(0, wrapper_dir)
# Directly execute the try/except logic of step() without relying on a real class instance
# ——by constructing a fake instance where _step_batch throws an exception
class FakeDemoWrapper:
"""Minimal stub, only implements the logic used in step()."""
@staticmethod
def _step_batch(action):
raise RuntimeError("IK failed: no solution found")
@staticmethod
def _flatten_info_batch(info_batch):
return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
def step(self, action):
try:
batch = self._step_batch(action)
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
info_flat = self._flatten_info_batch(info_batch)
return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat)
except Exception as exc:
error_info = {
"status": "error",
"error_message": f"{type(exc).__name__}: {exc}",
}
return ({}, 0.0, True, False, error_info)
wrapper = FakeDemoWrapper()
obs, reward, terminated, truncated, info = wrapper.step(action=[0.0] * 8)
assert obs == {}, f"obs should be an empty dict on error, got {obs!r}"
assert reward == 0.0, f"reward should be 0.0 on error, got {reward!r}"
assert terminated is True, f"terminated should be True on error, got {terminated!r}"
assert truncated is False, f"truncated should be False on error, got {truncated!r}"
assert info.get("status") == "error", f"status should be 'error', got {info.get('status')!r}"
assert "RuntimeError" in info.get("error_message", ""), (
f"error_message should contain exception type, got {info.get('error_message')!r}"
)
assert "IK failed" in info.get("error_message", ""), (
f"error_message should contain original exception info, got {info.get('error_message')!r}"
)
print(" ✓ step() returns status='error' + correct error_message when throwing an exception")
def test_step_normal_returns_ongoing_status() -> None:
"""
Verify that step() should not return status='error' when the Mock env returns normally.
(Indirect test: status will not be tampered with as error under normal paths)
"""
import torch
class FakeDemoWrapperNormal:
"""Normal _step_batch, info["status"] = "ongoing"."""
def _step_batch(self, action):
obs_batch = {"front_rgb_list": [None]}
reward_batch = torch.tensor([0.1])
terminated_batch = torch.tensor([False])
truncated_batch = torch.tensor([False])
info_batch = {"status": ["ongoing"], "success": [False]}
return (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch)
def _flatten_info_batch(self, info_batch):
return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
def step(self, action):
try:
batch = self._step_batch(action)
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
info_flat = self._flatten_info_batch(info_batch)
return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat)
except Exception as exc:
error_info = {
"status": "error",
"error_message": f"{type(exc).__name__}: {exc}",
}
return ({}, 0.0, True, False, error_info)
wrapper = FakeDemoWrapperNormal()
obs, reward, terminated, truncated, info = wrapper.step(action=[0.0] * 8)
assert info.get("status") == "ongoing", (
f"Normal step status should be 'ongoing', got {info.get('status')!r}"
)
assert "error_message" not in info, "Normal step should not contain error_message"
print(" ✓ Normal step returns status='ongoing', no error_message")
# ---------------------------------------------------------------------------
# Test 3: AST check that scripts no longer have bare try/except wrapping env.step(action)
# ---------------------------------------------------------------------------
def test_scripts_use_status_check_not_bare_try_except() -> None:
"""
Parse run_example.py and dataset_replay.py, verify:
1. There is an info.get("status") or status == "error" check in the script
2. env.step(action) calls are no longer directly wrapped by try/except Exception
"""
scripts = ["run_example.py", "dataset_replay.py"]
for script_name in scripts:
script_path = _script_path(script_name)
source = script_path.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(script_path))
# ---- Check 1: Contains status related checks ----
has_status_check = (
'info.get("status")' in source
or "status == \"error\"" in source
or "status==" in source.replace(" ", "")
)
assert has_status_check, (
f"{script_name}: Should have info.get('status') or status==\"error\" check"
)
# ---- Check 2: env.step is not wrapped by bare try/except ----
# Exact lookup: env.step is directly called in the Try block and handler catches Exception
_assert_no_bare_step_try_except(tree, script_name)
print(f" ✓ {script_name}: Use status check, no bare try/except wrapping env.step")
def _assert_no_bare_step_try_except(tree: ast.AST, script_name: str) -> None:
"""Check that there is no 'try block contains env.step and except catches Exception' structure in the AST."""
for node in ast.walk(tree):
if not isinstance(node, ast.Try):
continue
# Check if there is an env.step(action) call in the try body
step_in_try = False
for n in ast.walk(ast.Module(body=node.body, type_ignores=[])):
if (
isinstance(n, ast.Call)
and isinstance(getattr(n, "func", None), ast.Attribute)
and n.func.attr == "step"
):
step_in_try = True
break
if not step_in_try:
continue
# Check if handler is a bare Exception catch
for handler in node.handlers:
if handler.type is None:
assert False, (
f"{script_name}: env.step is still wrapped by bare try/except: (no exception type), should be changed to status check"
)
if isinstance(handler.type, ast.Name) and handler.type.id == "Exception":
assert False, (
f"{script_name}: env.step is still wrapped by bare try/except Exception: should be changed to status check"
)
# ---------------------------------------------------------------------------
# Utility: inject mock modules so imports inside DemonstrationWrapper don't fail
# ---------------------------------------------------------------------------
def _inject_mock_dependencies() -> None:
"""Inject placeholder mock modules to prevent import DemonstrationWrapper from failing due to missing ManiSkill."""
mock_mods = [
"mani_skill",
"mani_skill.envs",
"mani_skill.envs.sapien_env",
"mani_skill.utils",
"mani_skill.utils.common",
"mani_skill.utils.gym_utils",
"mani_skill.utils.sapien_utils",
"mani_skill.utils.io_utils",
"mani_skill.utils.logging_utils",
"mani_skill.utils.structs",
"mani_skill.utils.structs.types",
"mani_skill.utils.wrappers",
"mani_skill.examples",
"mani_skill.examples.motionplanning",
"mani_skill.examples.motionplanning.panda",
"mani_skill.examples.motionplanning.panda.motionplanner",
"mani_skill.examples.motionplanning.panda.motionplanner_stick",
"mani_skill.examples.motionplanning.base_motionplanner",
"mani_skill.examples.motionplanning.base_motionplanner.utils",
"sapien",
"sapien.physx",
"gymnasium",
"h5py",
"imageio",
"colorsys",
]
for mod_name in mock_mods:
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("\n[TEST] DemonstrationWrapper step error handling")
test_step_error_returns_status_error()
print(" test1: AST structure verification passed")
test_step_error_runtime_behavior()
print(" test2: Runtime behavior verification passed")
test_step_normal_returns_ongoing_status()
print(" test3: Normal path verification passed")
print("\n[TEST] Script status check verification")
test_scripts_use_status_check_not_bare_try_except()
print("\nPASS: All step error handling tests passed")
if __name__ == "__main__":
main()