from train.trainer import build_policy def test_rgbd_forward_contract(tiny_policy_config, tiny_trainer_config, tiny_batch): config = tiny_policy_config() batch = tiny_batch(chunk_size=config.decoder.chunk_size) policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) output = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=True, compute_equivariance_probe=True, ) assert output["action_mean"].shape[0] == batch["images"].shape[0] assert output["depth_tokens"] is not None assert output["geometry_tokens"] is not None assert output["camera_tokens"] is not None assert output["proposal_candidates"].shape[1] == config.decoder.num_candidates assert output["planner_topk_indices"].shape[1] == config.planner.top_k assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k assert "opening_quality" in output["interaction_state"] assert "gap_width" in output["interaction_state"] assert "hold_quality" in output["interaction_state"] assert output["equivariance_probe_action_mean"].shape == output["equivariance_target_action_mean"].shape