File size: 1,668 Bytes
16405f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7d8e79
 
 
16405f2
 
 
e7d8e79
 
 
16405f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from train.trainer import build_policy


def test_rgbd_forward_contract(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    output = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=True,
        compute_equivariance_probe=True,
    )
    assert output["action_mean"].shape[0] == batch["images"].shape[0]
    assert output["depth_tokens"] is not None
    assert output["geometry_tokens"] is not None
    assert output["camera_tokens"] is not None
    assert output["proposal_candidates"].shape[1] == config.decoder.num_candidates
    assert output["planner_topk_indices"].shape[1] == config.planner.top_k
    assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
    assert "opening_quality" in output["interaction_state"]
    assert "gap_width" in output["interaction_state"]
    assert "hold_quality" in output["interaction_state"]
    assert output["equivariance_probe_action_mean"].shape == output["equivariance_target_action_mean"].shape