File size: 2,281 Bytes
16405f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
from models.planner import CascadePlanner
def test_planner_structured_utility(tiny_policy_config, tiny_state):
config = tiny_policy_config()
planner = CascadePlanner(config.planner)
initial_state = tiny_state(batch_size=1, field_size=config.reveal_head.field_size)
candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, 14)
candidate_chunks[:, 0, :, -1] = 2.0
rollout_state = {
"target_belief_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"visibility_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"clearance_field": torch.zeros(1, 2, config.decoder.chunk_size, 2, config.reveal_head.field_size, config.reveal_head.field_size),
"occluder_contact_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"grasp_affordance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"support_stability_field": torch.ones(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"persistence_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"reocclusion_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"disturbance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
"access_field": torch.zeros(1, 2, config.decoder.chunk_size, 3, config.reveal_head.field_size, config.reveal_head.field_size),
}
rollout_state["target_belief_field"][:, 0] = 2.0
rollout_state["visibility_field"][:, 0] = 1.5
rollout_state["clearance_field"][:, 0] = 1.0
rollout_state["persistence_field"][:, 0] = 1.0
rollout_state["access_field"][:, 0] = 2.0
selected = planner.select_best(initial_state, candidate_chunks, rollout_state)
assert int(selected["best_indices"][0]) == 0
assert selected["utility_structured"].shape == (1, 2)
|