import torch from train.trainer import build_policy def test_policy_topk_cascade(tiny_policy_config, tiny_trainer_config, tiny_batch): config = tiny_policy_config(num_candidates=4, top_k=2) batch = tiny_batch(chunk_size=config.decoder.chunk_size) policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) output = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=True, ) assert output["planner_topk_indices"].shape[1] == config.planner.top_k assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k assert (output["best_candidate_indices"] < config.decoder.num_candidates).all() def test_policy_null_rollout_ablation_keeps_planner_interface( tiny_policy_config, tiny_trainer_config, tiny_batch, ): config = tiny_policy_config(num_candidates=4, top_k=2) batch = tiny_batch(chunk_size=config.decoder.chunk_size) policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) output = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=True, use_world_model=False, use_planner=True, ) rollout = output["planned_rollout"] current_state = output["interaction_state"] assert output["rollout_source"] == "identity" assert output["planner_topk_indices"].shape[1] == config.planner.top_k assert rollout["target_belief_field"].shape[1] == config.planner.top_k repeated_belief = current_state["target_belief_field"].detach().unsqueeze(1).unsqueeze(2).expand_as( rollout["target_belief_field"] ) repeated_phase = current_state["phase_logits"].detach().unsqueeze(1).unsqueeze(2).expand_as( rollout["phase_logits"] ) assert output["utility_total"].shape == (batch["images"].shape[0], config.planner.top_k) assert torch.allclose(rollout["target_belief_field"], repeated_belief) assert torch.allclose(rollout["phase_logits"], repeated_phase)