File size: 2,077 Bytes
31ade1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch

from models.planner import CascadePlanner


def test_planner_reocclusion_gating_prefers_maintain_first(tiny_policy_config, tiny_state):
    config = tiny_policy_config(num_candidates=2, top_k=2, field_size=4)
    planner = CascadePlanner(config.planner)
    for parameter in planner.residual.parameters():
        parameter.data.zero_()

    initial_state = tiny_state(batch_size=1, field_size=4)
    candidate_chunks = torch.zeros(1, 2, config.decoder.chunk_size, config.decoder.action_dim)
    candidate_chunks[:, 0, :, -1] = 1.0
    candidate_chunks[:, 1, :, 0] = 0.2

    low = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05)
    high = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.90)
    clearance = torch.cat([low.expand(-1, -1, -1, 1, -1, -1), high.expand(-1, -1, -1, 1, -1, -1)], dim=3)
    access = torch.full((1, 2, config.world_model.rollout_horizon, 3, 4, 4), 0.05)
    access[:, 1] = 0.90
    support = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), -4.0)
    support[:, 1] = 4.0
    persistence = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.05)
    persistence[:, 1] = 0.95
    reocclusion = torch.full((1, 2, config.world_model.rollout_horizon, 1, 4, 4), 0.95)
    reocclusion[:, 1] = 0.05

    rollout_state = {
        "target_belief_field": high.clone(),
        "visibility_field": high.clone(),
        "clearance_field": clearance,
        "occluder_contact_field": high.clone(),
        "grasp_affordance_field": high.clone(),
        "support_stability_field": support,
        "persistence_field": persistence,
        "reocclusion_field": reocclusion,
        "disturbance_field": low.clone(),
        "access_field": access,
    }

    selected = planner.select_best(
        initial_state=initial_state,
        candidate_chunks=candidate_chunks,
        rollout_state=rollout_state,
        proposal_mode_names=[["retrieve", "maintain_gap"]],
    )
    assert selected["feasibility_penalty"][0, 0] > 0.0
    assert selected["best_indices"].item() == 1