| | |
| | """ |
| | test_obs_numpy.py |
| | =================== |
| | Integration test: Directly call the real environment + unified temporary dataset generated at test runtime, |
| | test the native type conversions inside DemonstrationWrapper._augment_obs_and_info |
| | and whether the output types and shapes of obs/info fields are correct under the four ActionSpaces. |
| | |
| | Covered ActionSpaces: |
| | joint_angle / ee_pose / waypoint / multi_choice |
| | |
| | Asserts content: |
| | 1. Returned dtype complies with specifications (e.g. uint8, int16, float32, float64, etc.) |
| | 2. Non-Tensor field types in info meet expectations |
| | |
| | Run (must use uv): |
| | cd /data/hongzefu/robomme_benchmark |
| | uv run python -m pytest tests/dataset/test_obs_numpy.py -v -s |
| | """ |
| | from __future__ import annotations |
| |
|
| | import sys |
| | from pathlib import Path |
| | from typing import Any, Literal, Optional |
| |
|
| | import numpy as np |
| | import pytest |
| |
|
| | from tests._shared.repo_paths import find_repo_root |
| |
|
| | |
| | pytestmark = pytest.mark.dataset |
| |
|
| | _PROJECT_ROOT = find_repo_root(__file__) |
| | sys.path.insert(0, str(_PROJECT_ROOT / "src")) |
| |
|
| | from robomme.robomme_env import * |
| | from robomme.robomme_env.utils import * |
| | from robomme.env_record_wrapper import BenchmarkEnvBuilder, EpisodeDatasetResolver |
| |
|
| | |
| | |
| | |
| | TEST_ENV_ID = "VideoUnmaskSwap" |
| | TEST_EPISODE = 0 |
| | MAX_STEPS_PER_ACTION_SPACE = 3 |
| | MAX_STEPS_ENV = 1000 |
| |
|
| | ActionSpaceType = Literal["joint_angle", "ee_pose", "waypoint", "multi_choice"] |
| |
|
| | |
| | |
| | |
| |
|
| | def _assert_ndarray(val: Any, dtype: np.dtype, tag: str) -> None: |
| | assert isinstance(val, np.ndarray), ( |
| | f"[{tag}] expected ndarray, got {type(val).__name__}" |
| | ) |
| | assert val.dtype == dtype, ( |
| | f"[{tag}] expected dtype={dtype}, got {val.dtype}" |
| | ) |
| |
|
| |
|
| | def _assert_ndarray_loose(val: Any, tag: str) -> None: |
| | """Only assert it is an ndarray, do not check specific dtype.""" |
| | assert isinstance(val, np.ndarray), ( |
| | f"[{tag}] expected ndarray, got {type(val).__name__}" |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def assert_obs(obs: dict, tag: str) -> None: |
| | """Assert obs output dtype is correct and shape matches expectation.""" |
| | n = len(obs.get("front_rgb_list", [])) |
| | assert n > 0, f"[{tag}] obs front_rgb_list is empty" |
| |
|
| | for i in range(n): |
| | pfx = f"{tag}[{i}]" |
| |
|
| | |
| | for key, dtype in (("front_rgb_list", np.uint8), ("wrist_rgb_list", np.uint8)): |
| | _assert_ndarray(obs[key][i], dtype, f"{pfx} {key}") |
| |
|
| | |
| | for key, dtype in (("front_depth_list", np.int16), ("wrist_depth_list", np.int16)): |
| | _assert_ndarray(obs[key][i], dtype, f"{pfx} {key}") |
| |
|
| | |
| | eef_state = obs["eef_state_list"][i] |
| | _assert_ndarray(eef_state, np.float64, f"{pfx} eef_state_list") |
| | assert eef_state.shape == (6,), ( |
| | f"[{pfx} eef_state_list] expected shape (6,), got {eef_state.shape}" |
| | ) |
| |
|
| | |
| | _assert_ndarray_loose(obs["joint_state_list"][i], f"{pfx} joint_state_list") |
| |
|
| | |
| | _assert_ndarray_loose(obs["gripper_state_list"][i], f"{pfx} gripper_state_list") |
| |
|
| | |
| | for key in ("front_camera_extrinsic_list", "wrist_camera_extrinsic_list"): |
| | _assert_ndarray(obs[key][i], np.float32, f"{pfx} {key}") |
| | assert obs[key][i].shape == (3, 4), ( |
| | f"[{pfx} {key}] expected (3, 4), got {obs[key][i].shape}" |
| | ) |
| |
|
| |
|
| | def assert_info(info: dict, tag: str) -> None: |
| | """Assert the dtypes of info output fields are correct.""" |
| | for key in ("front_camera_intrinsic", "wrist_camera_intrinsic"): |
| | assert key in info, f"[{tag}] info missing key '{key}'" |
| | _assert_ndarray(info[key], np.float32, f"{tag} info['{key}']") |
| | assert info[key].shape == (3, 3), ( |
| | f"[{tag} info['{key}']] expected (3, 3), got {info[key].shape}" |
| | ) |
| |
|
| | |
| | task_goal = info.get("task_goal") |
| | assert isinstance(task_goal, (str, list, type(None))), ( |
| | f"[{tag}] info['task_goal'] unexpected type {type(task_goal)}" |
| | ) |
| | status = info.get("status") |
| | assert isinstance(status, (str, type(None))), ( |
| | f"[{tag}] info['status'] unexpected type {type(status)}" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict]: |
| | """Oracle command parsing consistent with dataset_replayβprintType.py.""" |
| | if not isinstance(choice_action, dict): |
| | return None |
| | choice = choice_action.get("choice") |
| | if not isinstance(choice, str) or not choice.strip(): |
| | return None |
| | if "point" not in choice_action: |
| | return None |
| | return {"choice": choice_action.get("choice"), "point": choice_action.get("point")} |
| |
|
| |
|
| | def run_one_action_space(action_space: ActionSpaceType, dataset_root: str | Path) -> None: |
| | print(f"\n{'='*60}") |
| | print(f"[TEST] ActionSpace = {action_space}") |
| | print(f"{'='*60}") |
| |
|
| | |
| | |
| |
|
| | env_builder = BenchmarkEnvBuilder( |
| | env_id=TEST_ENV_ID, |
| | dataset="train", |
| | action_space=action_space, |
| | gui_render=False, |
| | ) |
| | env = env_builder.make_env_for_episode( |
| | TEST_EPISODE, |
| | max_steps=MAX_STEPS_ENV, |
| | include_maniskill_obs=True, |
| | include_front_depth=True, |
| | include_wrist_depth=True, |
| | include_front_camera_extrinsic=True, |
| | include_wrist_camera_extrinsic=True, |
| | include_available_multi_choices=True, |
| | include_front_camera_intrinsic=True, |
| | include_wrist_camera_intrinsic=True, |
| | ) |
| |
|
| | dataset_resolver = EpisodeDatasetResolver( |
| | env_id=TEST_ENV_ID, |
| | episode=TEST_EPISODE, |
| | dataset_directory=str(dataset_root), |
| | ) |
| |
|
| | |
| | obs, info = env.reset() |
| |
|
| | reset_tag = f"{TEST_ENV_ID} ep{TEST_EPISODE} RESET [{action_space}]" |
| | assert_obs(obs, reset_tag) |
| | assert_info(info, reset_tag) |
| | print(f" RESET assertion passed (obs list len={len(obs['front_rgb_list'])}, dtype β)") |
| |
|
| | |
| | step = 0 |
| | while step < MAX_STEPS_PER_ACTION_SPACE: |
| | replay_key = action_space |
| | action = dataset_resolver.get_step(replay_key, step) |
| | if action_space == "multi_choice": |
| | action = _parse_oracle_command(action) |
| | if action is None: |
| | print(f" step {step}: action=None (dataset ended), breaking out") |
| | break |
| |
|
| | obs, reward, terminated, truncated, info = env.step(action) |
| |
|
| | step_tag = f"{TEST_ENV_ID} ep{TEST_EPISODE} STEP{step} [{action_space}]" |
| | assert_obs(obs, step_tag) |
| | assert_info(info, step_tag) |
| | print(f" STEP {step} assertion passed (obs list len={len(obs['front_rgb_list'])}, dtype β)") |
| |
|
| | terminated_flag = bool(terminated.item()) |
| | truncated_flag = bool(truncated.item()) |
| | step += 1 |
| | if terminated_flag or truncated_flag: |
| | print(f" terminated={terminated_flag} truncated={truncated_flag}, exiting early") |
| | break |
| |
|
| | env.close() |
| | print(f" [{action_space}] β All assertions passed (total {step} steps)") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | ACTION_SPACES: list[ActionSpaceType] = [ |
| | "joint_angle", |
| | "ee_pose", |
| | "waypoint", |
| | "multi_choice", |
| | ] |
| |
|
| |
|
| | @pytest.mark.parametrize("action_space", ACTION_SPACES) |
| | def test_obs_numpy_action_space(action_space: ActionSpaceType, video_unmaskswap_train_ep0_dataset) -> None: |
| | run_one_action_space(action_space, video_unmaskswap_train_ep0_dataset.resolver_dataset_dir) |
| |
|
| |
|
| | def main() -> None: |
| | print("test_obs_numpy main() now relies on pytest fixture-generated dataset.") |
| | print("Run with: uv run python -m pytest tests/dataset/test_obs_numpy.py -v -s") |
| | sys.exit(2) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|