VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_planner_reocclusion_gating.py
lsnu's picture
Add files using upload-large-folder tool
31ade1f verified
import torch
from models.planner import CascadePlanner
def test_planner_reocclusion_gating_prefers_maintain_first(tiny_policy_config, tiny_state):
config = tiny_policy_config(num_candidates=2, top_k=2, field_size=4)
planner = CascadePlanner(config.planner)
for parameter in planner.residual.parameters():
parameter.data.zero_()
initial_state = tiny_state(batch_size=1, field_size=4)
candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, config.decoder.action_dim)
candidate_chunks[:, 0, :, -1] = 1.0
candidate_chunks[:, 1, :, 0] = 0.2
low = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05)
high = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.90)
clearance = torch.cat([low.expand(-1, -1, -1, 1, -1, -1), high.expand(-1, -1, -1, 1, -1, -1)], dim=3)
access = torch.full((1, 2, config.world_model.rollout_horizon, 3, 4, 4), 0.05)
access[:, 1] = 0.90
support = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), -4.0)
support[:, 1] = 4.0
persistence = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05)
persistence[:, 1] = 0.95
reocclusion = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.95)
reocclusion[:, 1] = 0.05
rollout_state = {
"target_belief_field": high.clone(),
"visibility_field": high.clone(),
"clearance_field": clearance,
"occluder_contact_field": high.clone(),
"grasp_affordance_field": high.clone(),
"support_stability_field": support,
"persistence_field": persistence,
"reocclusion_field": reocclusion,
"disturbance_field": low.clone(),
"access_field": access,
}
selected = planner.select_best(
initial_state=initial_state,
candidate_chunks=candidate_chunks,
rollout_state=rollout_state,
proposal_mode_names=[["retrieve", "maintain_gap"]],
)
assert selected["feasibility_penalty"][0, 0] > 0.0
assert selected["best_indices"].item() == 1