File size: 1,751 Bytes
9c74dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from train.trainer import build_policy


def test_explicit_task_metadata_overrides_text(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    batch["texts"] = ["foliage canopy leaves snail"] * batch["images"].shape[0]
    batch["task_name"] = ["bag"] * batch["images"].shape[0]
    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    policy.eval()

    output = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        task_names=batch["task_name"],
        task_ids=batch["task_id"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_camera_intrinsics=batch["history_camera_intrinsics"],
        history_camera_extrinsics=batch["history_camera_extrinsics"],
        history_camera_valid_mask=batch["history_camera_valid_mask"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=False,
    )

    assert output["task_names"] == ["bag"] * batch["images"].shape[0]
    assert output["proposal_task_names"] == ["bag"] * batch["images"].shape[0]
    assert output["proposal_task_ids"].tolist() == [1] * batch["images"].shape[0]
    assert all("sweep" not in name for name in output["proposal_mode_names"][0])
    assert any("rim" in name or "mouth" in name for name in output["proposal_mode_names"][0])