| from sim_reveal.dataset import RGBD_PROXY_DATASET_VERSION, collect_teacher_dataset, dataset_from_bundle | |
| def test_dataset_v6_keys(): | |
| bundle = collect_teacher_dataset( | |
| episodes_per_proxy=1, | |
| resolution=16, | |
| history_steps=2, | |
| planner_candidates=3, | |
| dataset_version=RGBD_PROXY_DATASET_VERSION, | |
| ) | |
| dataset = dataset_from_bundle(bundle, resolution=16) | |
| item = dataset[0] | |
| for key in ( | |
| "images", | |
| "depths", | |
| "depth_valid", | |
| "belief_map", | |
| "visibility_map", | |
| "clearance_map", | |
| "support_stability", | |
| "reocclusion_target", | |
| "candidate_rollout_belief_map", | |
| ): | |
| assert key in item | |
| def test_phase_dataset_version_keeps_rgbd_path(): | |
| bundle = collect_teacher_dataset( | |
| proxy_names=["bag_proxy"], | |
| episodes_per_proxy=1, | |
| resolution=16, | |
| history_steps=2, | |
| planner_candidates=3, | |
| dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase", | |
| ) | |
| dataset = dataset_from_bundle(bundle, resolution=16) | |
| item = dataset[0] | |
| assert float(item["depth_valid"].sum()) > 0.0 | |
| assert "phase" in item | |
| assert "rollout_phase" in item | |
| assert "candidate_rollout_phase" in item | |