File size: 3,704 Bytes
9c74dfe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import torch
from eval.run_reveal_benchmark import _policy_outputs
from train.trainer import build_policy
def _options(**overrides):
base = {
"disable_planner": False,
"disable_memory": False,
"disable_task_conditioning": False,
"disable_geometry": False,
"disable_camera_pose": False,
"disable_shortlist": False,
"ignore_proposal_logits": False,
"ignore_proposal_logits_in_shortlist": False,
"ignore_proposal_logits_in_planner": False,
"disable_support_mode_conditioning": False,
"disable_world_model": False,
"disable_depth": False,
"disable_role_tokens": False,
"short_history": False,
"rollout_mode_override": None,
"use_proposal_candidates": True,
}
base.update(overrides)
return base
def test_eval_toggle_paths_work(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
model = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
model.eval()
base = _policy_outputs(model, batch, _options(), disable_planner_for_selection=False)
no_planner = _policy_outputs(model, batch, _options(disable_planner=True), disable_planner_for_selection=False)
no_memory = _policy_outputs(model, batch, _options(disable_memory=True), disable_planner_for_selection=True)
no_task = _policy_outputs(model, batch, _options(disable_task_conditioning=True), disable_planner_for_selection=True)
no_geometry = _policy_outputs(model, batch, _options(disable_geometry=True, disable_camera_pose=True), disable_planner_for_selection=True)
no_shortlist = _policy_outputs(model, batch, _options(disable_shortlist=True), disable_planner_for_selection=False)
ignore_shortlist_logits = _policy_outputs(
model,
batch,
_options(ignore_proposal_logits_in_shortlist=True),
disable_planner_for_selection=False,
)
ignore_planner_logits = _policy_outputs(
model,
batch,
_options(ignore_proposal_logits_in_planner=True),
disable_planner_for_selection=False,
)
ignore_proposal_logits = _policy_outputs(model, batch, _options(ignore_proposal_logits=True), disable_planner_for_selection=False)
assert torch.count_nonzero(no_planner["planner_scores"]) == 0
assert torch.allclose(
no_memory["memory_output"]["memory_tokens"],
torch.zeros_like(no_memory["memory_output"]["memory_tokens"]),
)
assert not torch.allclose(base["interaction_state"]["support_mode_logits"], no_task["interaction_state"]["support_mode_logits"])
assert not torch.allclose(base["scene_tokens"], no_geometry["scene_tokens"])
assert no_shortlist["planner_topk_indices"].shape[1] == base["candidate_chunks"].shape[1]
assert base["proposal_logits_used_for_shortlist"]
assert base["proposal_logits_used_for_planner"]
assert not ignore_shortlist_logits["proposal_logits_used_for_shortlist"]
assert ignore_shortlist_logits["proposal_logits_used_for_planner"]
assert ignore_planner_logits["proposal_logits_used_for_shortlist"]
assert not ignore_planner_logits["proposal_logits_used_for_planner"]
assert torch.equal(base["planner_topk_indices"], ignore_planner_logits["planner_topk_indices"])
assert not torch.allclose(base["planner_scores"], ignore_planner_logits["planner_scores"])
assert not ignore_proposal_logits["proposal_logits_used_for_shortlist"]
assert not ignore_proposal_logits["proposal_logits_used_for_planner"]
assert not torch.allclose(base["planner_scores"], ignore_proposal_logits["planner_scores"])
|