| |
| """ |
| test_obs_config.py |
| =================== |
| Integration test: verify that make_env_for_episode include_* flags |
| correctly control which obs/info fields are present in reset() and step() output. |
| |
| Tests: |
| 1. Default (all True): all 8 optional fields present in obs/info |
| 2. All disabled (all False): none of the 8 optional fields present |
| 3. Selective: only front_depth enabled, others False -> only front_depth present |
| 4. Always-present fields unaffected by any flag combination |
| |
| Run with: |
| cd /data/hongzefu/robomme_benchmark |
| uv run python -m pytest tests/dataset/test_obs_config.py -v -s |
| """ |
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import pytest |
|
|
| from tests._shared.repo_paths import find_repo_root |
|
|
| pytestmark = pytest.mark.dataset |
|
|
| _PROJECT_ROOT = find_repo_root(__file__) |
| sys.path.insert(0, str(_PROJECT_ROOT / "src")) |
|
|
| from robomme.robomme_env import * |
| from robomme.robomme_env.utils import * |
| from robomme.env_record_wrapper import BenchmarkEnvBuilder, EpisodeDatasetResolver |
|
|
| |
| |
| |
| TEST_ENV_ID = "VideoUnmaskSwap" |
| TEST_EPISODE = 0 |
| MAX_STEPS_ENV = 1000 |
|
|
| |
| OBS_OPTIONAL_FIELDS = [ |
| "maniskill_obs", |
| "front_depth_list", |
| "wrist_depth_list", |
| "front_camera_extrinsic_list", |
| "wrist_camera_extrinsic_list", |
| ] |
| INFO_OPTIONAL_FIELDS = [ |
| "available_multi_choices", |
| "front_camera_intrinsic", |
| "wrist_camera_intrinsic", |
| ] |
|
|
| |
| OBS_ALWAYS_FIELDS = [ |
| "front_rgb_list", |
| "wrist_rgb_list", |
| "joint_state_list", |
| "eef_state_list", |
| "gripper_state_list", |
| ] |
| INFO_ALWAYS_FIELDS = [ |
| "simple_subgoal_online", |
| "grounded_subgoal_online", |
| "task_goal", |
| ] |
|
|
| EXPECTED_FRONT_CAMERA_HW = (256, 256) |
|
|
| |
| |
| |
|
|
| def _make_env( |
| include_maniskill_obs=True, |
| include_front_depth=True, |
| include_wrist_depth=True, |
| include_front_camera_extrinsic=True, |
| include_wrist_camera_extrinsic=True, |
| include_available_multi_choices=True, |
| include_front_camera_intrinsic=True, |
| include_wrist_camera_intrinsic=True, |
| ): |
| builder = BenchmarkEnvBuilder( |
| env_id=TEST_ENV_ID, |
| dataset="train", |
| action_space="joint_angle", |
| gui_render=False, |
| ) |
| return builder.make_env_for_episode( |
| TEST_EPISODE, |
| max_steps=MAX_STEPS_ENV, |
| include_maniskill_obs=include_maniskill_obs, |
| include_front_depth=include_front_depth, |
| include_wrist_depth=include_wrist_depth, |
| include_front_camera_extrinsic=include_front_camera_extrinsic, |
| include_wrist_camera_extrinsic=include_wrist_camera_extrinsic, |
| include_available_multi_choices=include_available_multi_choices, |
| include_front_camera_intrinsic=include_front_camera_intrinsic, |
| include_wrist_camera_intrinsic=include_wrist_camera_intrinsic, |
| ) |
|
|
|
|
| def _get_first_step_action(): |
| """Return a simple no-op joint action for testing.""" |
| return np.zeros(8, dtype=np.float64) |
|
|
|
|
| def _check_always_present(obs, info, tag): |
| """Assert always-present fields are in obs and info.""" |
| for field in OBS_ALWAYS_FIELDS: |
| assert field in obs, f"[{tag}] always-present obs field '{field}' is missing" |
| lst = obs[field] |
| assert isinstance(lst, list) and len(lst) > 0, ( |
| f"[{tag}] obs['{field}'] should be non-empty list, got {type(lst)}" |
| ) |
| for field in INFO_ALWAYS_FIELDS: |
| assert field in info, f"[{tag}] always-present info field '{field}' is missing" |
|
|
|
|
| def _check_optional_present(obs, info, tag): |
| """Assert all 8 optional fields are present.""" |
| for field in OBS_OPTIONAL_FIELDS: |
| assert field in obs, f"[{tag}] optional obs field '{field}' should be present but missing" |
| for field in INFO_OPTIONAL_FIELDS: |
| assert field in info, f"[{tag}] optional info field '{field}' should be present but missing" |
|
|
|
|
| def _check_optional_absent(obs, info, tag): |
| """Assert all 8 optional fields are absent.""" |
| for field in OBS_OPTIONAL_FIELDS: |
| assert field not in obs, f"[{tag}] optional obs field '{field}' should be absent but is present" |
| for field in INFO_OPTIONAL_FIELDS: |
| assert field not in info, f"[{tag}] optional info field '{field}' should be absent but is present" |
|
|
|
|
| def _check_front_camera_shapes(obs, tag): |
| """Assert wrapped front camera outputs stay at the env-configured base resolution.""" |
| front_rgb = obs["front_rgb_list"][-1] |
| assert isinstance(front_rgb, np.ndarray), ( |
| f"[{tag}] front_rgb_list item should be ndarray, got {type(front_rgb)}" |
| ) |
| assert front_rgb.shape[:2] == EXPECTED_FRONT_CAMERA_HW, ( |
| f"[{tag}] front_rgb_list shape={front_rgb.shape[:2]}, expected {EXPECTED_FRONT_CAMERA_HW}" |
| ) |
|
|
| if "front_depth_list" in obs: |
| front_depth = obs["front_depth_list"][-1] |
| assert isinstance(front_depth, np.ndarray), ( |
| f"[{tag}] front_depth_list item should be ndarray, got {type(front_depth)}" |
| ) |
| assert front_depth.shape[:2] == EXPECTED_FRONT_CAMERA_HW, ( |
| f"[{tag}] front_depth_list shape={front_depth.shape[:2]}, expected {EXPECTED_FRONT_CAMERA_HW}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def test_all_included(video_unmaskswap_train_ep0_dataset): |
| """Default: all flags True -> all 8 optional fields present.""" |
| print("\n[TEST 1] All flags True (default behavior)") |
| env = _make_env() |
| resolver = EpisodeDatasetResolver( |
| env_id=TEST_ENV_ID, |
| episode=TEST_EPISODE, |
| dataset_directory=str(video_unmaskswap_train_ep0_dataset.resolver_dataset_dir), |
| ) |
| try: |
| obs, info = env.reset() |
| _check_always_present(obs, info, "reset/all-included") |
| _check_optional_present(obs, info, "reset/all-included") |
| _check_front_camera_shapes(obs, "reset/all-included") |
| print(" RESET: all optional fields present β") |
|
|
| action = resolver.get_step("joint_angle", 0) |
| if action is not None: |
| obs, reward, terminated, truncated, info = env.step(action) |
| _check_always_present(obs, info, "step/all-included") |
| _check_optional_present(obs, info, "step/all-included") |
| _check_front_camera_shapes(obs, "step/all-included") |
| print(" STEP: all optional fields present β") |
|
|
| |
| _check_optional_dtypes(obs, info, "all-included") |
| finally: |
| env.close() |
| print(" [TEST 1] PASS") |
|
|
|
|
| def _check_optional_dtypes(obs, info, tag): |
| """Spot-check dtypes of optional fields when present.""" |
| if "front_depth_list" in obs: |
| item = obs["front_depth_list"][-1] |
| assert isinstance(item, np.ndarray) and item.dtype == np.int16, ( |
| f"[{tag}] front_depth_list dtype={item.dtype}, expected int16" |
| ) |
| if "wrist_depth_list" in obs: |
| item = obs["wrist_depth_list"][-1] |
| assert isinstance(item, np.ndarray) and item.dtype == np.int16, ( |
| f"[{tag}] wrist_depth_list dtype={item.dtype}, expected int16" |
| ) |
| if "front_camera_extrinsic_list" in obs: |
| item = obs["front_camera_extrinsic_list"][-1] |
| assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 4), ( |
| f"[{tag}] front_camera_extrinsic_list shape={item.shape} dtype={item.dtype}" |
| ) |
| if "wrist_camera_extrinsic_list" in obs: |
| item = obs["wrist_camera_extrinsic_list"][-1] |
| assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 4), ( |
| f"[{tag}] wrist_camera_extrinsic_list shape={item.shape} dtype={item.dtype}" |
| ) |
| if "front_camera_intrinsic" in info: |
| item = info["front_camera_intrinsic"] |
| assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 3), ( |
| f"[{tag}] front_camera_intrinsic shape={item.shape} dtype={item.dtype}" |
| ) |
| if "wrist_camera_intrinsic" in info: |
| item = info["wrist_camera_intrinsic"] |
| assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 3), ( |
| f"[{tag}] wrist_camera_intrinsic shape={item.shape} dtype={item.dtype}" |
| ) |
| if "available_multi_choices" in info: |
| choices = info["available_multi_choices"] |
| assert isinstance(choices, list), ( |
| f"[{tag}] available_multi_choices expected list, got {type(choices)}" |
| ) |
|
|
|
|
| def test_all_excluded(video_unmaskswap_train_ep0_dataset): |
| """All flags False -> none of the 8 optional fields present; always-present fields still there.""" |
| print("\n[TEST 2] All flags False") |
| env = _make_env( |
| include_maniskill_obs=False, |
| include_front_depth=False, |
| include_wrist_depth=False, |
| include_front_camera_extrinsic=False, |
| include_wrist_camera_extrinsic=False, |
| include_available_multi_choices=False, |
| include_front_camera_intrinsic=False, |
| include_wrist_camera_intrinsic=False, |
| ) |
| resolver = EpisodeDatasetResolver( |
| env_id=TEST_ENV_ID, |
| episode=TEST_EPISODE, |
| dataset_directory=str(video_unmaskswap_train_ep0_dataset.resolver_dataset_dir), |
| ) |
| try: |
| obs, info = env.reset() |
| _check_always_present(obs, info, "reset/all-excluded") |
| _check_optional_absent(obs, info, "reset/all-excluded") |
| _check_front_camera_shapes(obs, "reset/all-excluded") |
| print(" RESET: all optional fields absent, always-present fields ok β") |
|
|
| action = resolver.get_step("joint_angle", 0) |
| if action is not None: |
| obs, reward, terminated, truncated, info = env.step(action) |
| _check_always_present(obs, info, "step/all-excluded") |
| _check_optional_absent(obs, info, "step/all-excluded") |
| _check_front_camera_shapes(obs, "step/all-excluded") |
| print(" STEP: all optional fields absent, always-present fields ok β") |
| finally: |
| env.close() |
| print(" [TEST 2] PASS") |
|
|
|
|
| def test_selective_front_depth_only(video_unmaskswap_train_ep0_dataset): |
| """Only front_depth enabled; others disabled.""" |
| print("\n[TEST 3] Only include_front_depth=True, others False") |
| env = _make_env( |
| include_maniskill_obs=False, |
| include_front_depth=True, |
| include_wrist_depth=False, |
| include_front_camera_extrinsic=False, |
| include_wrist_camera_extrinsic=False, |
| include_available_multi_choices=False, |
| include_front_camera_intrinsic=False, |
| include_wrist_camera_intrinsic=False, |
| ) |
| resolver = EpisodeDatasetResolver( |
| env_id=TEST_ENV_ID, |
| episode=TEST_EPISODE, |
| dataset_directory=str(video_unmaskswap_train_ep0_dataset.resolver_dataset_dir), |
| ) |
| try: |
| obs, info = env.reset() |
| _check_always_present(obs, info, "reset/selective") |
| _check_front_camera_shapes(obs, "reset/selective") |
| |
| assert "front_depth_list" in obs, "front_depth_list should be present" |
| item = obs["front_depth_list"][-1] |
| assert isinstance(item, np.ndarray) and item.dtype == np.int16, ( |
| f"front_depth_list dtype={item.dtype}, expected int16" |
| ) |
| |
| for field in ["maniskill_obs", "wrist_depth_list", "front_camera_extrinsic_list", "wrist_camera_extrinsic_list"]: |
| assert field not in obs, f"obs['{field}'] should be absent" |
| for field in INFO_OPTIONAL_FIELDS: |
| assert field not in info, f"info['{field}'] should be absent" |
| print(" RESET: front_depth present, others absent β") |
|
|
| action = resolver.get_step("joint_angle", 0) |
| if action is not None: |
| obs, reward, terminated, truncated, info = env.step(action) |
| _check_always_present(obs, info, "step/selective") |
| _check_front_camera_shapes(obs, "step/selective") |
| assert "front_depth_list" in obs, "front_depth_list should be present in step" |
| for field in ["maniskill_obs", "wrist_depth_list", "front_camera_extrinsic_list", "wrist_camera_extrinsic_list"]: |
| assert field not in obs, f"obs['{field}'] should be absent in step" |
| for field in INFO_OPTIONAL_FIELDS: |
| assert field not in info, f"info['{field}'] should be absent in step" |
| print(" STEP: front_depth present, others absent β") |
| finally: |
| env.close() |
| print(" [TEST 3] PASS") |
|
|
|
|
| def test_always_present_unaffected(): |
| """Always-present fields appear regardless of which flags are set.""" |
| print("\n[TEST 4] Always-present fields unaffected by flag combinations") |
| for flags in [ |
| dict(include_maniskill_obs=True, include_front_depth=True, include_wrist_depth=True, |
| include_front_camera_extrinsic=True, include_wrist_camera_extrinsic=True, |
| include_available_multi_choices=True, include_front_camera_intrinsic=True, |
| include_wrist_camera_intrinsic=True), |
| dict(include_maniskill_obs=False, include_front_depth=False, include_wrist_depth=False, |
| include_front_camera_extrinsic=False, include_wrist_camera_extrinsic=False, |
| include_available_multi_choices=False, include_front_camera_intrinsic=False, |
| include_wrist_camera_intrinsic=False), |
| ]: |
| flag_desc = "all-true" if flags["include_maniskill_obs"] else "all-false" |
| env = _make_env(**flags) |
| try: |
| obs, info = env.reset() |
| _check_always_present(obs, info, f"reset/{flag_desc}") |
| _check_front_camera_shapes(obs, f"reset/{flag_desc}") |
| print(f" RESET [{flag_desc}]: always-present fields ok β") |
| finally: |
| env.close() |
| print(" [TEST 4] PASS") |
|
|
|
|
| |
| |
| |
|
|
| TESTS = [ |
| ("all_included", test_all_included), |
| ("all_excluded", test_all_excluded), |
| ("selective_front_depth_only", test_selective_front_depth_only), |
| ("always_present_unaffected", test_always_present_unaffected), |
| ] |
|
|
|
|
| def main(): |
| print("test_obs_config main() now relies on pytest fixture-generated dataset.") |
| print("Run with: uv run python -m pytest tests/dataset/test_obs_config.py -v -s") |
| sys.exit(2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|