from sim_reveal.dataset import RGBD_PROXY_DATASET_VERSION, collect_teacher_dataset, dataset_from_bundle def test_dataset_v6_keys(): bundle = collect_teacher_dataset( episodes_per_proxy=1, resolution=16, history_steps=2, planner_candidates=3, dataset_version=RGBD_PROXY_DATASET_VERSION, ) dataset = dataset_from_bundle(bundle, resolution=16) item = dataset[0] for key in ( "images", "depths", "depth_valid", "belief_map", "visibility_map", "clearance_map", "support_stability", "reocclusion_target", "candidate_rollout_belief_map", ): assert key in item def test_phase_dataset_version_keeps_rgbd_path(): bundle = collect_teacher_dataset( proxy_names=["bag_proxy"], episodes_per_proxy=1, resolution=16, history_steps=2, planner_candidates=3, dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase", ) dataset = dataset_from_bundle(bundle, resolution=16) item = dataset[0] assert float(item["depth_valid"].sum()) > 0.0 assert "phase" in item assert "rollout_phase" in item assert "candidate_rollout_phase" in item