RoboMME / tests /dataset /test_obs_numpy.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
# -*- coding: utf-8 -*-
"""
test_obs_numpy.py
===================
Integration test: Directly call the real environment + unified temporary dataset generated at test runtime,
test the native type conversions inside DemonstrationWrapper._augment_obs_and_info
and whether the output types and shapes of obs/info fields are correct under the four ActionSpaces.
Covered ActionSpaces:
joint_angle / ee_pose / waypoint / multi_choice
Asserts content:
1. Returned dtype complies with specifications (e.g. uint8, int16, float32, float64, etc.)
2. Non-Tensor field types in info meet expectations
Run (must use uv):
cd /data/hongzefu/robomme_benchmark
uv run python -m pytest tests/dataset/test_obs_numpy.py -v -s
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Literal, Optional
import numpy as np
import pytest
from tests._shared.repo_paths import find_repo_root
# Ensure src path can be found
pytestmark = pytest.mark.dataset
_PROJECT_ROOT = find_repo_root(__file__)
sys.path.insert(0, str(_PROJECT_ROOT / "src"))
from robomme.robomme_env import * # noqa: F401,F403 Register all custom environments
from robomme.robomme_env.utils import * # noqa: F401,F403
from robomme.env_record_wrapper import BenchmarkEnvBuilder, EpisodeDatasetResolver
# ──────────────────────────────────────────────────────────────────────────────
# Configuration
# ──────────────────────────────────────────────────────────────────────────────
TEST_ENV_ID = "VideoUnmaskSwap"
TEST_EPISODE = 0
MAX_STEPS_PER_ACTION_SPACE = 3 # Max steps to verify per ActionSpace
MAX_STEPS_ENV = 1000
ActionSpaceType = Literal["joint_angle", "ee_pose", "waypoint", "multi_choice"]
# ──────────────────────────────────────────────────────────────────────────────
# Assertion helpers
# ──────────────────────────────────────────────────────────────────────────────
def _assert_ndarray(val: Any, dtype: np.dtype, tag: str) -> None:
assert isinstance(val, np.ndarray), (
f"[{tag}] expected ndarray, got {type(val).__name__}"
)
assert val.dtype == dtype, (
f"[{tag}] expected dtype={dtype}, got {val.dtype}"
)
def _assert_ndarray_loose(val: Any, tag: str) -> None:
"""Only assert it is an ndarray, do not check specific dtype."""
assert isinstance(val, np.ndarray), (
f"[{tag}] expected ndarray, got {type(val).__name__}"
)
# ──────────────────────────────────────────────────────────────────────────────
# Core assertion: native output type is correct
# ──────────────────────────────────────────────────────────────────────────────
def assert_obs(obs: dict, tag: str) -> None:
"""Assert obs output dtype is correct and shape matches expectation."""
n = len(obs.get("front_rgb_list", []))
assert n > 0, f"[{tag}] obs front_rgb_list is empty"
for i in range(n):
pfx = f"{tag}[{i}]"
# ── RGB β†’ uint8 ───────────────────────────────────────────────────
for key, dtype in (("front_rgb_list", np.uint8), ("wrist_rgb_list", np.uint8)):
_assert_ndarray(obs[key][i], dtype, f"{pfx} {key}")
# ── Depth β†’ int16 ─────────────────────────────────────────────────
for key, dtype in (("front_depth_list", np.int16), ("wrist_depth_list", np.int16)):
_assert_ndarray(obs[key][i], dtype, f"{pfx} {key}")
# ── eef_state_list β†’ float64, shape (6,) ─────────────────────────
eef_state = obs["eef_state_list"][i]
_assert_ndarray(eef_state, np.float64, f"{pfx} eef_state_list")
assert eef_state.shape == (6,), (
f"[{pfx} eef_state_list] expected shape (6,), got {eef_state.shape}"
)
# ── joint_state_list β†’ ndarray (shape unchanged) ───────────────────────
_assert_ndarray_loose(obs["joint_state_list"][i], f"{pfx} joint_state_list")
# ── gripper_state_list β†’ ndarray (shape unchanged) ─────────────────────
_assert_ndarray_loose(obs["gripper_state_list"][i], f"{pfx} gripper_state_list")
# ── camera extrinsics β†’ float32, shape (3,4) ───────────────────────
for key in ("front_camera_extrinsic_list", "wrist_camera_extrinsic_list"):
_assert_ndarray(obs[key][i], np.float32, f"{pfx} {key}")
assert obs[key][i].shape == (3, 4), (
f"[{pfx} {key}] expected (3, 4), got {obs[key][i].shape}"
)
def assert_info(info: dict, tag: str) -> None:
"""Assert the dtypes of info output fields are correct."""
for key in ("front_camera_intrinsic", "wrist_camera_intrinsic"):
assert key in info, f"[{tag}] info missing key '{key}'"
_assert_ndarray(info[key], np.float32, f"{tag} info['{key}']")
assert info[key].shape == (3, 3), (
f"[{tag} info['{key}']] expected (3, 3), got {info[key].shape}"
)
# Non-Tensor field types unchanged
task_goal = info.get("task_goal")
assert isinstance(task_goal, (str, list, type(None))), (
f"[{tag}] info['task_goal'] unexpected type {type(task_goal)}"
)
status = info.get("status")
assert isinstance(status, (str, type(None))), (
f"[{tag}] info['status'] unexpected type {type(status)}"
)
# ──────────────────────────────────────────────────────────────────────────────
# Full episode test for a single ActionSpace
# ──────────────────────────────────────────────────────────────────────────────
def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict]:
"""Oracle command parsing consistent with dataset_replayβ€”printType.py."""
if not isinstance(choice_action, dict):
return None
choice = choice_action.get("choice")
if not isinstance(choice, str) or not choice.strip():
return None
if "point" not in choice_action:
return None
return {"choice": choice_action.get("choice"), "point": choice_action.get("point")}
def run_one_action_space(action_space: ActionSpaceType, dataset_root: str | Path) -> None:
print(f"\n{'='*60}")
print(f"[TEST] ActionSpace = {action_space}")
print(f"{'='*60}")
# multi_choice uses OraclePlannerDemonstrationWrapper,
# BenchmarkEnvBuilder directly uses unified action_space naming.
env_builder = BenchmarkEnvBuilder(
env_id=TEST_ENV_ID,
dataset="train",
action_space=action_space,
gui_render=False,
)
env = env_builder.make_env_for_episode(
TEST_EPISODE,
max_steps=MAX_STEPS_ENV,
include_maniskill_obs=True,
include_front_depth=True,
include_wrist_depth=True,
include_front_camera_extrinsic=True,
include_wrist_camera_extrinsic=True,
include_available_multi_choices=True,
include_front_camera_intrinsic=True,
include_wrist_camera_intrinsic=True,
)
dataset_resolver = EpisodeDatasetResolver(
env_id=TEST_ENV_ID,
episode=TEST_EPISODE,
dataset_directory=str(dataset_root),
)
# ── RESET ──────────────────────────────────────────────────────────────
obs, info = env.reset()
reset_tag = f"{TEST_ENV_ID} ep{TEST_EPISODE} RESET [{action_space}]"
assert_obs(obs, reset_tag)
assert_info(info, reset_tag)
print(f" RESET assertion passed (obs list len={len(obs['front_rgb_list'])}, dtype βœ“)")
# ── STEP LOOP ──────────────────────────────────────────────────────────
step = 0
while step < MAX_STEPS_PER_ACTION_SPACE:
replay_key = action_space
action = dataset_resolver.get_step(replay_key, step)
if action_space == "multi_choice":
action = _parse_oracle_command(action)
if action is None:
print(f" step {step}: action=None (dataset ended), breaking out")
break
obs, reward, terminated, truncated, info = env.step(action)
step_tag = f"{TEST_ENV_ID} ep{TEST_EPISODE} STEP{step} [{action_space}]"
assert_obs(obs, step_tag)
assert_info(info, step_tag)
print(f" STEP {step} assertion passed (obs list len={len(obs['front_rgb_list'])}, dtype βœ“)")
terminated_flag = bool(terminated.item())
truncated_flag = bool(truncated.item())
step += 1
if terminated_flag or truncated_flag:
print(f" terminated={terminated_flag} truncated={truncated_flag}, exiting early")
break
env.close()
print(f" [{action_space}] βœ“ All assertions passed (total {step} steps)")
# ──────────────────────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────────────────────
ACTION_SPACES: list[ActionSpaceType] = [
"joint_angle",
"ee_pose",
"waypoint",
"multi_choice",
]
@pytest.mark.parametrize("action_space", ACTION_SPACES)
def test_obs_numpy_action_space(action_space: ActionSpaceType, video_unmaskswap_train_ep0_dataset) -> None:
run_one_action_space(action_space, video_unmaskswap_train_ep0_dataset.resolver_dataset_dir)
def main() -> None:
print("test_obs_numpy main() now relies on pytest fixture-generated dataset.")
print("Run with: uv run python -m pytest tests/dataset/test_obs_numpy.py -v -s")
sys.exit(2)
if __name__ == "__main__":
main()