| import torch | |
| from models.planner import CascadePlanner | |
| def test_planner_reocclusion_gating_prefers_maintain_first(tiny_policy_config, tiny_state): | |
| config = tiny_policy_config(num_candidates=2, top_k=2, field_size=4) | |
| planner = CascadePlanner(config.planner) | |
| for parameter in planner.residual.parameters(): | |
| parameter.data.zero_() | |
| initial_state = tiny_state(batch_size=1, field_size=4) | |
| candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, config.decoder.action_dim) | |
| candidate_chunks[:, 0, :, -1] = 1.0 | |
| candidate_chunks[:, 1, :, 0] = 0.2 | |
| low = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05) | |
| high = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.90) | |
| clearance = torch.cat([low.expand(-1, -1, -1, 1, -1, -1), high.expand(-1, -1, -1, 1, -1, -1)], dim=3) | |
| access = torch.full((1, 2, config.world_model.rollout_horizon, 3, 4, 4), 0.05) | |
| access[:, 1] = 0.90 | |
| support = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), -4.0) | |
| support[:, 1] = 4.0 | |
| persistence = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05) | |
| persistence[:, 1] = 0.95 | |
| reocclusion = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.95) | |
| reocclusion[:, 1] = 0.05 | |
| rollout_state = { | |
| "target_belief_field": high.clone(), | |
| "visibility_field": high.clone(), | |
| "clearance_field": clearance, | |
| "occluder_contact_field": high.clone(), | |
| "grasp_affordance_field": high.clone(), | |
| "support_stability_field": support, | |
| "persistence_field": persistence, | |
| "reocclusion_field": reocclusion, | |
| "disturbance_field": low.clone(), | |
| "access_field": access, | |
| } | |
| selected = planner.select_best( | |
| initial_state=initial_state, | |
| candidate_chunks=candidate_chunks, | |
| rollout_state=rollout_state, | |
| proposal_mode_names=[["retrieve", "maintain_gap"]], | |
| ) | |
| assert selected["feasibility_penalty"][0, 0] > 0.0 | |
| assert selected["best_indices"].item() == 1 | |