File size: 2,281 Bytes
16405f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

from models.planner import CascadePlanner


def test_planner_structured_utility(tiny_policy_config, tiny_state):
    config = tiny_policy_config()
    planner = CascadePlanner(config.planner)
    initial_state = tiny_state(batch_size=1, field_size=config.reveal_head.field_size)
    candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, 14)
    candidate_chunks[:, 0, :, -1] = 2.0
    rollout_state = {
        "target_belief_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "visibility_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "clearance_field": torch.zeros(1, 2, config.decoder.chunk_size, 2, config.reveal_head.field_size, config.reveal_head.field_size),
        "occluder_contact_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "grasp_affordance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "support_stability_field": torch.ones(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "persistence_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "reocclusion_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "disturbance_field": torch.zeros(1, 2, config.decoder.chunk_size, 1, config.reveal_head.field_size, config.reveal_head.field_size),
        "access_field": torch.zeros(1, 2, config.decoder.chunk_size, 3, config.reveal_head.field_size, config.reveal_head.field_size),
    }
    rollout_state["target_belief_field"][:, 0] = 2.0
    rollout_state["visibility_field"][:, 0] = 1.5
    rollout_state["clearance_field"][:, 0] = 1.0
    rollout_state["persistence_field"][:, 0] = 1.0
    rollout_state["access_field"][:, 0] = 2.0
    selected = planner.select_best(initial_state, candidate_chunks, rollout_state)
    assert int(selected["best_indices"][0]) == 0
    assert selected["utility_structured"].shape == (1, 2)