File size: 1,165 Bytes
9c74dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from train.trainer import build_policy


def test_backbone_only_policy_accepts_rgbd_batch(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    policy = build_policy(config, tiny_trainer_config(policy_type="backbone_only"))

    output = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_camera_intrinsics=batch["history_camera_intrinsics"],
        history_camera_extrinsics=batch["history_camera_extrinsics"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
    )

    assert output["action_mean"].shape[0] == batch["images"].shape[0]
    assert output["scene_tokens"].shape[0] == batch["images"].shape[0]