| import torch |
|
|
| from eval.run_reveal_benchmark import _policy_outputs |
| from train.trainer import build_policy |
|
|
|
|
| def _options(**overrides): |
| base = { |
| "disable_planner": False, |
| "disable_memory": False, |
| "disable_task_conditioning": False, |
| "disable_geometry": False, |
| "disable_camera_pose": False, |
| "disable_shortlist": False, |
| "ignore_proposal_logits": False, |
| "ignore_proposal_logits_in_shortlist": False, |
| "ignore_proposal_logits_in_planner": False, |
| "disable_support_mode_conditioning": False, |
| "disable_world_model": False, |
| "disable_depth": False, |
| "disable_role_tokens": False, |
| "short_history": False, |
| "rollout_mode_override": None, |
| "use_proposal_candidates": True, |
| } |
| base.update(overrides) |
| return base |
|
|
|
|
| def test_eval_toggle_paths_work(tiny_policy_config, tiny_trainer_config, tiny_batch): |
| config = tiny_policy_config() |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) |
| model = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) |
| model.eval() |
|
|
| base = _policy_outputs(model, batch, _options(), disable_planner_for_selection=False) |
| no_planner = _policy_outputs(model, batch, _options(disable_planner=True), disable_planner_for_selection=False) |
| no_memory = _policy_outputs(model, batch, _options(disable_memory=True), disable_planner_for_selection=True) |
| no_task = _policy_outputs(model, batch, _options(disable_task_conditioning=True), disable_planner_for_selection=True) |
| no_geometry = _policy_outputs(model, batch, _options(disable_geometry=True, disable_camera_pose=True), disable_planner_for_selection=True) |
| no_shortlist = _policy_outputs(model, batch, _options(disable_shortlist=True), disable_planner_for_selection=False) |
| ignore_shortlist_logits = _policy_outputs( |
| model, |
| batch, |
| _options(ignore_proposal_logits_in_shortlist=True), |
| disable_planner_for_selection=False, |
| ) |
| ignore_planner_logits = _policy_outputs( |
| model, |
| batch, |
| _options(ignore_proposal_logits_in_planner=True), |
| disable_planner_for_selection=False, |
| ) |
| ignore_proposal_logits = _policy_outputs(model, batch, _options(ignore_proposal_logits=True), disable_planner_for_selection=False) |
|
|
| assert torch.count_nonzero(no_planner["planner_scores"]) == 0 |
| assert torch.allclose( |
| no_memory["memory_output"]["memory_tokens"], |
| torch.zeros_like(no_memory["memory_output"]["memory_tokens"]), |
| ) |
| assert not torch.allclose(base["interaction_state"]["support_mode_logits"], no_task["interaction_state"]["support_mode_logits"]) |
| assert not torch.allclose(base["scene_tokens"], no_geometry["scene_tokens"]) |
| assert no_shortlist["planner_topk_indices"].shape[1] == base["candidate_chunks"].shape[1] |
| assert base["proposal_logits_used_for_shortlist"] |
| assert base["proposal_logits_used_for_planner"] |
| assert not ignore_shortlist_logits["proposal_logits_used_for_shortlist"] |
| assert ignore_shortlist_logits["proposal_logits_used_for_planner"] |
| assert ignore_planner_logits["proposal_logits_used_for_shortlist"] |
| assert not ignore_planner_logits["proposal_logits_used_for_planner"] |
| assert torch.equal(base["planner_topk_indices"], ignore_planner_logits["planner_topk_indices"]) |
| assert not torch.allclose(base["planner_scores"], ignore_planner_logits["planner_scores"]) |
| assert not ignore_proposal_logits["proposal_logits_used_for_shortlist"] |
| assert not ignore_proposal_logits["proposal_logits_used_for_planner"] |
| assert not torch.allclose(base["planner_scores"], ignore_proposal_logits["planner_scores"]) |
|
|