import torch from eval.run_reveal_benchmark import _policy_outputs from train.trainer import build_policy def _options(**overrides): base = { "disable_planner": False, "disable_memory": False, "disable_task_conditioning": False, "disable_geometry": False, "disable_camera_pose": False, "disable_shortlist": False, "ignore_proposal_logits": False, "ignore_proposal_logits_in_shortlist": False, "ignore_proposal_logits_in_planner": False, "disable_support_mode_conditioning": False, "disable_world_model": False, "disable_depth": False, "disable_role_tokens": False, "short_history": False, "rollout_mode_override": None, "use_proposal_candidates": True, } base.update(overrides) return base def test_eval_toggle_paths_work(tiny_policy_config, tiny_trainer_config, tiny_batch): config = tiny_policy_config() batch = tiny_batch(chunk_size=config.decoder.chunk_size) model = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) model.eval() base = _policy_outputs(model, batch, _options(), disable_planner_for_selection=False) no_planner = _policy_outputs(model, batch, _options(disable_planner=True), disable_planner_for_selection=False) no_memory = _policy_outputs(model, batch, _options(disable_memory=True), disable_planner_for_selection=True) no_task = _policy_outputs(model, batch, _options(disable_task_conditioning=True), disable_planner_for_selection=True) no_geometry = _policy_outputs(model, batch, _options(disable_geometry=True, disable_camera_pose=True), disable_planner_for_selection=True) no_shortlist = _policy_outputs(model, batch, _options(disable_shortlist=True), disable_planner_for_selection=False) ignore_shortlist_logits = _policy_outputs( model, batch, _options(ignore_proposal_logits_in_shortlist=True), disable_planner_for_selection=False, ) ignore_planner_logits = _policy_outputs( model, batch, _options(ignore_proposal_logits_in_planner=True), disable_planner_for_selection=False, ) ignore_proposal_logits = _policy_outputs(model, batch, _options(ignore_proposal_logits=True), disable_planner_for_selection=False) assert torch.count_nonzero(no_planner["planner_scores"]) == 0 assert torch.allclose( no_memory["memory_output"]["memory_tokens"], torch.zeros_like(no_memory["memory_output"]["memory_tokens"]), ) assert not torch.allclose(base["interaction_state"]["support_mode_logits"], no_task["interaction_state"]["support_mode_logits"]) assert not torch.allclose(base["scene_tokens"], no_geometry["scene_tokens"]) assert no_shortlist["planner_topk_indices"].shape[1] == base["candidate_chunks"].shape[1] assert base["proposal_logits_used_for_shortlist"] assert base["proposal_logits_used_for_planner"] assert not ignore_shortlist_logits["proposal_logits_used_for_shortlist"] assert ignore_shortlist_logits["proposal_logits_used_for_planner"] assert ignore_planner_logits["proposal_logits_used_for_shortlist"] assert not ignore_planner_logits["proposal_logits_used_for_planner"] assert torch.equal(base["planner_topk_indices"], ignore_planner_logits["planner_topk_indices"]) assert not torch.allclose(base["planner_scores"], ignore_planner_logits["planner_scores"]) assert not ignore_proposal_logits["proposal_logits_used_for_shortlist"] assert not ignore_proposal_logits["proposal_logits_used_for_planner"] assert not torch.allclose(base["planner_scores"], ignore_proposal_logits["planner_scores"])