File size: 1,668 Bytes
16405f2 e7d8e79 16405f2 e7d8e79 16405f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | from train.trainer import build_policy
def test_rgbd_forward_contract(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
output = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
compute_equivariance_probe=True,
)
assert output["action_mean"].shape[0] == batch["images"].shape[0]
assert output["depth_tokens"] is not None
assert output["geometry_tokens"] is not None
assert output["camera_tokens"] is not None
assert output["proposal_candidates"].shape[1] == config.decoder.num_candidates
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
assert "opening_quality" in output["interaction_state"]
assert "gap_width" in output["interaction_state"]
assert "hold_quality" in output["interaction_state"]
assert output["equivariance_probe_action_mean"].shape == output["equivariance_target_action_mean"].shape
|