import torch from models.planner import CascadePlanner def test_planner_structured_utility(tiny_policy_config, tiny_state): config = tiny_policy_config() planner = CascadePlanner(config.planner) initial_state = tiny_state(batch_size=1, field_size=config.reveal_head.field_size) candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, 14) candidate_chunks[:, 0, :, -1] = 2.0 rollout_state = { "target_belief_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "visibility_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "clearance_field": torch.zeros(1, 2, config.decoder.chunk_size, 2, config.reveal_head.field_size, config.reveal_head.field_size), "occluder_contact_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "grasp_affordance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "support_stability_field": torch.ones(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "persistence_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "reocclusion_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "disturbance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size), "access_field": torch.zeros(1, 2, config.decoder.chunk_size, 3, config.reveal_head.field_size, config.reveal_head.field_size), } rollout_state["target_belief_field"][:, 0] = 2.0 rollout_state["visibility_field"][:, 0] = 1.5 rollout_state["clearance_field"][:, 0] = 1.0 rollout_state["persistence_field"][:, 0] = 1.0 rollout_state["access_field"][:, 0] = 2.0 selected = planner.select_best(initial_state, candidate_chunks, rollout_state) assert int(selected["best_indices"][0]) == 0 assert selected["utility_structured"].shape == (1, 2)