File size: 1,563 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from models.action_decoder import proposal_mode_vocab
from train.trainer import build_policy


def test_candidate_macro_coverage(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config(num_candidates=7)
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    policy.eval()

    task_to_text = {
        "foliage": ["create a gap in the foliage and retrieve the target"],
        "bag": ["open the bag mouth and retrieve the target object"],
        "cloth": ["lift the top layer enough to retrieve the hidden object"],
    }
    for task_name, texts in task_to_text.items():
        output = policy(
            images=batch["images"][:1],
            depths=batch["depths"][:1],
            depth_valid=batch["depth_valid"][:1],
            camera_intrinsics=batch["camera_intrinsics"][:1],
            camera_extrinsics=batch["camera_extrinsics"][:1],
            proprio=batch["proprio"][:1],
            texts=texts,
            history_images=batch["history_images"][:1],
            history_depths=batch["history_depths"][:1],
            history_depth_valid=batch["history_depth_valid"][:1],
            history_proprio=batch["history_proprio"][:1],
            history_actions=batch["history_actions"][:1],
            plan=False,
        )
        expected = set(proposal_mode_vocab(task_name, config.decoder.num_proposal_modes))
        observed = set(output["proposal_mode_names"][0])
        assert expected.issubset(observed)