import torch from models.planner import CascadePlanner def test_planner_reocclusion_gating_prefers_maintain_first(tiny_policy_config, tiny_state): config = tiny_policy_config(num_candidates=2, top_k=2, field_size=4) planner = CascadePlanner(config.planner) for parameter in planner.residual.parameters(): parameter.data.zero_() initial_state = tiny_state(batch_size=1, field_size=4) candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, config.decoder.action_dim) candidate_chunks[:, 0, :, -1] = 1.0 candidate_chunks[:, 1, :, 0] = 0.2 low = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05) high = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.90) clearance = torch.cat([low.expand(-1, -1, -1, 1, -1, -1), high.expand(-1, -1, -1, 1, -1, -1)], dim=3) access = torch.full((1, 2, config.world_model.rollout_horizon, 3, 4, 4), 0.05) access[:, 1] = 0.90 support = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), -4.0) support[:, 1] = 4.0 persistence = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05) persistence[:, 1] = 0.95 reocclusion = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.95) reocclusion[:, 1] = 0.05 rollout_state = { "target_belief_field": high.clone(), "visibility_field": high.clone(), "clearance_field": clearance, "occluder_contact_field": high.clone(), "grasp_affordance_field": high.clone(), "support_stability_field": support, "persistence_field": persistence, "reocclusion_field": reocclusion, "disturbance_field": low.clone(), "access_field": access, } selected = planner.select_best( initial_state=initial_state, candidate_chunks=candidate_chunks, rollout_state=rollout_state, proposal_mode_names=[["retrieve", "maintain_gap"]], ) assert selected["feasibility_penalty"][0, 0] > 0.0 assert selected["best_indices"].item() == 1