| import torch | |
| from models.planner import CascadePlanner | |
| def test_planner_structured_utility(tiny_policy_config, tiny_state): | |
| config = tiny_policy_config() | |
| planner = CascadePlanner(config.planner) | |
| initial_state = tiny_state(batch_size=1, field_size=config.reveal_head.field_size) | |
| candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, 14) | |
| candidate_chunks[:, 0, :, -1] = 2.0 | |
| rollout_state = { | |
| "target_belief_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "visibility_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "clearance_field": torch.zeros(1, 2, config.decoder.chunk_size, 2, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "occluder_contact_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "grasp_affordance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "support_stability_field": torch.ones(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "persistence_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "reocclusion_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "disturbance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), | |
| "access_field": torch.zeros(1, 2, config.decoder.chunk_size, 3, config.reveal_head.field_size, config.reveal_head.field_size), | |
| } | |
| rollout_state["target_belief_field"][:, 0] = 2.0 | |
| rollout_state["visibility_field"][:, 0] = 1.5 | |
| rollout_state["clearance_field"][:, 0] = 1.0 | |
| rollout_state["persistence_field"][:, 0] = 1.0 | |
| rollout_state["access_field"][:, 0] = 2.0 | |
| selected = planner.select_best(initial_state, candidate_chunks, rollout_state) | |
| assert int(selected["best_indices"][0]) == 0 | |
| assert selected["utility_structured"].shape == (1, 2) | |