File size: 1,357 Bytes
b14c4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from train.trainer import build_policy


def test_rgb_backward_compat(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)

    interaction_policy = build_policy(config, tiny_trainer_config(policy_type="interaction_state"))
    interaction_output = interaction_policy(
        images=batch["images"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=True,
    )
    assert interaction_output["action_mean"].shape[-1] == 14
    assert interaction_output["candidate_chunks"].shape[1] == config.decoder.num_candidates

    elastic_policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    elastic_output = elastic_policy(
        images=batch["images"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=True,
        use_depth=False,
    )
    assert elastic_output["action_mean"].shape[-1] == 14
    assert elastic_output["planned_chunk"].shape == elastic_output["action_mean"].shape