File size: 2,954 Bytes
31ade1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch

from train.trainer import build_policy


def test_policy_topk_cascade(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config(num_candidates=4, top_k=2)
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    output = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=True,
    )
    assert output["planner_topk_indices"].shape[1] == config.planner.top_k
    assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
    assert (output["best_candidate_indices"] < config.decoder.num_candidates).all()


def test_policy_null_rollout_ablation_keeps_planner_interface(
    tiny_policy_config,
    tiny_trainer_config,
    tiny_batch,
):
    config = tiny_policy_config(num_candidates=4, top_k=2)
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    output = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=True,
        use_world_model=False,
        use_planner=True,
    )
    rollout = output["planned_rollout"]
    current_state = output["interaction_state"]
    assert output["rollout_source"] == "identity"
    assert output["planner_topk_indices"].shape[1] == config.planner.top_k
    assert rollout["target_belief_field"].shape[1] == config.planner.top_k
    repeated_belief = current_state["target_belief_field"].detach().unsqueeze(1).unsqueeze(2).expand_as(
        rollout["target_belief_field"]
    )
    repeated_phase = current_state["phase_logits"].detach().unsqueeze(1).unsqueeze(2).expand_as(
        rollout["phase_logits"]
    )
    assert output["utility_total"].shape == (batch["images"].shape[0], config.planner.top_k)
    assert torch.allclose(rollout["target_belief_field"], repeated_belief)
    assert torch.allclose(rollout["phase_logits"], repeated_phase)