RoboMME / tests /dataset /test_obs_config.py
HongzeFu's picture
change to 256
467d2ce
# -*- coding: utf-8 -*-
"""
test_obs_config.py
===================
Integration test: verify that make_env_for_episode include_* flags
correctly control which obs/info fields are present in reset() and step() output.
Tests:
1. Default (all True): all 8 optional fields present in obs/info
2. All disabled (all False): none of the 8 optional fields present
3. Selective: only front_depth enabled, others False -> only front_depth present
4. Always-present fields unaffected by any flag combination
Run with:
cd /data/hongzefu/robomme_benchmark
uv run python -m pytest tests/dataset/test_obs_config.py -v -s
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any
import numpy as np
import pytest
from tests._shared.repo_paths import find_repo_root
pytestmark = pytest.mark.dataset
_PROJECT_ROOT = find_repo_root(__file__)
sys.path.insert(0, str(_PROJECT_ROOT / "src"))
from robomme.robomme_env import * # noqa: F401,F403
from robomme.robomme_env.utils import * # noqa: F401,F403
from robomme.env_record_wrapper import BenchmarkEnvBuilder, EpisodeDatasetResolver
# ──────────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────────
TEST_ENV_ID = "VideoUnmaskSwap"
TEST_EPISODE = 0
MAX_STEPS_ENV = 1000
# The 8 optional obs fields and where they live
OBS_OPTIONAL_FIELDS = [
"maniskill_obs",
"front_depth_list",
"wrist_depth_list",
"front_camera_extrinsic_list",
"wrist_camera_extrinsic_list",
]
INFO_OPTIONAL_FIELDS = [
"available_multi_choices",
"front_camera_intrinsic",
"wrist_camera_intrinsic",
]
# Fields that must ALWAYS be present regardless of flags
OBS_ALWAYS_FIELDS = [
"front_rgb_list",
"wrist_rgb_list",
"joint_state_list",
"eef_state_list",
"gripper_state_list",
]
INFO_ALWAYS_FIELDS = [
"simple_subgoal_online",
"grounded_subgoal_online",
"task_goal",
]
EXPECTED_FRONT_CAMERA_HW = (256, 256)
# ──────────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────────
def _make_env(
include_maniskill_obs=True,
include_front_depth=True,
include_wrist_depth=True,
include_front_camera_extrinsic=True,
include_wrist_camera_extrinsic=True,
include_available_multi_choices=True,
include_front_camera_intrinsic=True,
include_wrist_camera_intrinsic=True,
):
builder = BenchmarkEnvBuilder(
env_id=TEST_ENV_ID,
dataset="train",
action_space="joint_angle",
gui_render=False,
)
return builder.make_env_for_episode(
TEST_EPISODE,
max_steps=MAX_STEPS_ENV,
include_maniskill_obs=include_maniskill_obs,
include_front_depth=include_front_depth,
include_wrist_depth=include_wrist_depth,
include_front_camera_extrinsic=include_front_camera_extrinsic,
include_wrist_camera_extrinsic=include_wrist_camera_extrinsic,
include_available_multi_choices=include_available_multi_choices,
include_front_camera_intrinsic=include_front_camera_intrinsic,
include_wrist_camera_intrinsic=include_wrist_camera_intrinsic,
)
def _get_first_step_action():
"""Return a simple no-op joint action for testing."""
return np.zeros(8, dtype=np.float64)
def _check_always_present(obs, info, tag):
"""Assert always-present fields are in obs and info."""
for field in OBS_ALWAYS_FIELDS:
assert field in obs, f"[{tag}] always-present obs field '{field}' is missing"
lst = obs[field]
assert isinstance(lst, list) and len(lst) > 0, (
f"[{tag}] obs['{field}'] should be non-empty list, got {type(lst)}"
)
for field in INFO_ALWAYS_FIELDS:
assert field in info, f"[{tag}] always-present info field '{field}' is missing"
def _check_optional_present(obs, info, tag):
"""Assert all 8 optional fields are present."""
for field in OBS_OPTIONAL_FIELDS:
assert field in obs, f"[{tag}] optional obs field '{field}' should be present but missing"
for field in INFO_OPTIONAL_FIELDS:
assert field in info, f"[{tag}] optional info field '{field}' should be present but missing"
def _check_optional_absent(obs, info, tag):
"""Assert all 8 optional fields are absent."""
for field in OBS_OPTIONAL_FIELDS:
assert field not in obs, f"[{tag}] optional obs field '{field}' should be absent but is present"
for field in INFO_OPTIONAL_FIELDS:
assert field not in info, f"[{tag}] optional info field '{field}' should be absent but is present"
def _check_front_camera_shapes(obs, tag):
"""Assert wrapped front camera outputs stay at the env-configured base resolution."""
front_rgb = obs["front_rgb_list"][-1]
assert isinstance(front_rgb, np.ndarray), (
f"[{tag}] front_rgb_list item should be ndarray, got {type(front_rgb)}"
)
assert front_rgb.shape[:2] == EXPECTED_FRONT_CAMERA_HW, (
f"[{tag}] front_rgb_list shape={front_rgb.shape[:2]}, expected {EXPECTED_FRONT_CAMERA_HW}"
)
if "front_depth_list" in obs:
front_depth = obs["front_depth_list"][-1]
assert isinstance(front_depth, np.ndarray), (
f"[{tag}] front_depth_list item should be ndarray, got {type(front_depth)}"
)
assert front_depth.shape[:2] == EXPECTED_FRONT_CAMERA_HW, (
f"[{tag}] front_depth_list shape={front_depth.shape[:2]}, expected {EXPECTED_FRONT_CAMERA_HW}"
)
# ──────────────────────────────────────────────────────────────────────────────
# Test cases
# ──────────────────────────────────────────────────────────────────────────────
def test_all_included(video_unmaskswap_train_ep0_dataset):
"""Default: all flags True -> all 8 optional fields present."""
print("\n[TEST 1] All flags True (default behavior)")
env = _make_env() # all True by default
resolver = EpisodeDatasetResolver(
env_id=TEST_ENV_ID,
episode=TEST_EPISODE,
dataset_directory=str(video_unmaskswap_train_ep0_dataset.resolver_dataset_dir),
)
try:
obs, info = env.reset()
_check_always_present(obs, info, "reset/all-included")
_check_optional_present(obs, info, "reset/all-included")
_check_front_camera_shapes(obs, "reset/all-included")
print(" RESET: all optional fields present βœ“")
action = resolver.get_step("joint_angle", 0)
if action is not None:
obs, reward, terminated, truncated, info = env.step(action)
_check_always_present(obs, info, "step/all-included")
_check_optional_present(obs, info, "step/all-included")
_check_front_camera_shapes(obs, "step/all-included")
print(" STEP: all optional fields present βœ“")
# Spot-check dtypes of optional fields from last obs/info
_check_optional_dtypes(obs, info, "all-included")
finally:
env.close()
print(" [TEST 1] PASS")
def _check_optional_dtypes(obs, info, tag):
"""Spot-check dtypes of optional fields when present."""
if "front_depth_list" in obs:
item = obs["front_depth_list"][-1]
assert isinstance(item, np.ndarray) and item.dtype == np.int16, (
f"[{tag}] front_depth_list dtype={item.dtype}, expected int16"
)
if "wrist_depth_list" in obs:
item = obs["wrist_depth_list"][-1]
assert isinstance(item, np.ndarray) and item.dtype == np.int16, (
f"[{tag}] wrist_depth_list dtype={item.dtype}, expected int16"
)
if "front_camera_extrinsic_list" in obs:
item = obs["front_camera_extrinsic_list"][-1]
assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 4), (
f"[{tag}] front_camera_extrinsic_list shape={item.shape} dtype={item.dtype}"
)
if "wrist_camera_extrinsic_list" in obs:
item = obs["wrist_camera_extrinsic_list"][-1]
assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 4), (
f"[{tag}] wrist_camera_extrinsic_list shape={item.shape} dtype={item.dtype}"
)
if "front_camera_intrinsic" in info:
item = info["front_camera_intrinsic"]
assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 3), (
f"[{tag}] front_camera_intrinsic shape={item.shape} dtype={item.dtype}"
)
if "wrist_camera_intrinsic" in info:
item = info["wrist_camera_intrinsic"]
assert isinstance(item, np.ndarray) and item.dtype == np.float32 and item.shape == (3, 3), (
f"[{tag}] wrist_camera_intrinsic shape={item.shape} dtype={item.dtype}"
)
if "available_multi_choices" in info:
choices = info["available_multi_choices"]
assert isinstance(choices, list), (
f"[{tag}] available_multi_choices expected list, got {type(choices)}"
)
def test_all_excluded(video_unmaskswap_train_ep0_dataset):
"""All flags False -> none of the 8 optional fields present; always-present fields still there."""
print("\n[TEST 2] All flags False")
env = _make_env(
include_maniskill_obs=False,
include_front_depth=False,
include_wrist_depth=False,
include_front_camera_extrinsic=False,
include_wrist_camera_extrinsic=False,
include_available_multi_choices=False,
include_front_camera_intrinsic=False,
include_wrist_camera_intrinsic=False,
)
resolver = EpisodeDatasetResolver(
env_id=TEST_ENV_ID,
episode=TEST_EPISODE,
dataset_directory=str(video_unmaskswap_train_ep0_dataset.resolver_dataset_dir),
)
try:
obs, info = env.reset()
_check_always_present(obs, info, "reset/all-excluded")
_check_optional_absent(obs, info, "reset/all-excluded")
_check_front_camera_shapes(obs, "reset/all-excluded")
print(" RESET: all optional fields absent, always-present fields ok βœ“")
action = resolver.get_step("joint_angle", 0)
if action is not None:
obs, reward, terminated, truncated, info = env.step(action)
_check_always_present(obs, info, "step/all-excluded")
_check_optional_absent(obs, info, "step/all-excluded")
_check_front_camera_shapes(obs, "step/all-excluded")
print(" STEP: all optional fields absent, always-present fields ok βœ“")
finally:
env.close()
print(" [TEST 2] PASS")
def test_selective_front_depth_only(video_unmaskswap_train_ep0_dataset):
"""Only front_depth enabled; others disabled."""
print("\n[TEST 3] Only include_front_depth=True, others False")
env = _make_env(
include_maniskill_obs=False,
include_front_depth=True,
include_wrist_depth=False,
include_front_camera_extrinsic=False,
include_wrist_camera_extrinsic=False,
include_available_multi_choices=False,
include_front_camera_intrinsic=False,
include_wrist_camera_intrinsic=False,
)
resolver = EpisodeDatasetResolver(
env_id=TEST_ENV_ID,
episode=TEST_EPISODE,
dataset_directory=str(video_unmaskswap_train_ep0_dataset.resolver_dataset_dir),
)
try:
obs, info = env.reset()
_check_always_present(obs, info, "reset/selective")
_check_front_camera_shapes(obs, "reset/selective")
# front_depth should be present
assert "front_depth_list" in obs, "front_depth_list should be present"
item = obs["front_depth_list"][-1]
assert isinstance(item, np.ndarray) and item.dtype == np.int16, (
f"front_depth_list dtype={item.dtype}, expected int16"
)
# all others should be absent
for field in ["maniskill_obs", "wrist_depth_list", "front_camera_extrinsic_list", "wrist_camera_extrinsic_list"]:
assert field not in obs, f"obs['{field}'] should be absent"
for field in INFO_OPTIONAL_FIELDS:
assert field not in info, f"info['{field}'] should be absent"
print(" RESET: front_depth present, others absent βœ“")
action = resolver.get_step("joint_angle", 0)
if action is not None:
obs, reward, terminated, truncated, info = env.step(action)
_check_always_present(obs, info, "step/selective")
_check_front_camera_shapes(obs, "step/selective")
assert "front_depth_list" in obs, "front_depth_list should be present in step"
for field in ["maniskill_obs", "wrist_depth_list", "front_camera_extrinsic_list", "wrist_camera_extrinsic_list"]:
assert field not in obs, f"obs['{field}'] should be absent in step"
for field in INFO_OPTIONAL_FIELDS:
assert field not in info, f"info['{field}'] should be absent in step"
print(" STEP: front_depth present, others absent βœ“")
finally:
env.close()
print(" [TEST 3] PASS")
def test_always_present_unaffected():
"""Always-present fields appear regardless of which flags are set."""
print("\n[TEST 4] Always-present fields unaffected by flag combinations")
for flags in [
dict(include_maniskill_obs=True, include_front_depth=True, include_wrist_depth=True,
include_front_camera_extrinsic=True, include_wrist_camera_extrinsic=True,
include_available_multi_choices=True, include_front_camera_intrinsic=True,
include_wrist_camera_intrinsic=True),
dict(include_maniskill_obs=False, include_front_depth=False, include_wrist_depth=False,
include_front_camera_extrinsic=False, include_wrist_camera_extrinsic=False,
include_available_multi_choices=False, include_front_camera_intrinsic=False,
include_wrist_camera_intrinsic=False),
]:
flag_desc = "all-true" if flags["include_maniskill_obs"] else "all-false"
env = _make_env(**flags)
try:
obs, info = env.reset()
_check_always_present(obs, info, f"reset/{flag_desc}")
_check_front_camera_shapes(obs, f"reset/{flag_desc}")
print(f" RESET [{flag_desc}]: always-present fields ok βœ“")
finally:
env.close()
print(" [TEST 4] PASS")
# ──────────────────────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────────────────────
TESTS = [
("all_included", test_all_included),
("all_excluded", test_all_excluded),
("selective_front_depth_only", test_selective_front_depth_only),
("always_present_unaffected", test_always_present_unaffected),
]
def main():
print("test_obs_config main() now relies on pytest fixture-generated dataset.")
print("Run with: uv run python -m pytest tests/dataset/test_obs_config.py -v -s")
sys.exit(2)
if __name__ == "__main__":
main()