File size: 3,704 Bytes
9c74dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch

from eval.run_reveal_benchmark import _policy_outputs
from train.trainer import build_policy


def _options(**overrides):
    base = {
        "disable_planner": False,
        "disable_memory": False,
        "disable_task_conditioning": False,
        "disable_geometry": False,
        "disable_camera_pose": False,
        "disable_shortlist": False,
        "ignore_proposal_logits": False,
        "ignore_proposal_logits_in_shortlist": False,
        "ignore_proposal_logits_in_planner": False,
        "disable_support_mode_conditioning": False,
        "disable_world_model": False,
        "disable_depth": False,
        "disable_role_tokens": False,
        "short_history": False,
        "rollout_mode_override": None,
        "use_proposal_candidates": True,
    }
    base.update(overrides)
    return base


def test_eval_toggle_paths_work(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    model = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    model.eval()

    base = _policy_outputs(model, batch, _options(), disable_planner_for_selection=False)
    no_planner = _policy_outputs(model, batch, _options(disable_planner=True), disable_planner_for_selection=False)
    no_memory = _policy_outputs(model, batch, _options(disable_memory=True), disable_planner_for_selection=True)
    no_task = _policy_outputs(model, batch, _options(disable_task_conditioning=True), disable_planner_for_selection=True)
    no_geometry = _policy_outputs(model, batch, _options(disable_geometry=True, disable_camera_pose=True), disable_planner_for_selection=True)
    no_shortlist = _policy_outputs(model, batch, _options(disable_shortlist=True), disable_planner_for_selection=False)
    ignore_shortlist_logits = _policy_outputs(
        model,
        batch,
        _options(ignore_proposal_logits_in_shortlist=True),
        disable_planner_for_selection=False,
    )
    ignore_planner_logits = _policy_outputs(
        model,
        batch,
        _options(ignore_proposal_logits_in_planner=True),
        disable_planner_for_selection=False,
    )
    ignore_proposal_logits = _policy_outputs(model, batch, _options(ignore_proposal_logits=True), disable_planner_for_selection=False)

    assert torch.count_nonzero(no_planner["planner_scores"]) == 0
    assert torch.allclose(
        no_memory["memory_output"]["memory_tokens"],
        torch.zeros_like(no_memory["memory_output"]["memory_tokens"]),
    )
    assert not torch.allclose(base["interaction_state"]["support_mode_logits"], no_task["interaction_state"]["support_mode_logits"])
    assert not torch.allclose(base["scene_tokens"], no_geometry["scene_tokens"])
    assert no_shortlist["planner_topk_indices"].shape[1] == base["candidate_chunks"].shape[1]
    assert base["proposal_logits_used_for_shortlist"]
    assert base["proposal_logits_used_for_planner"]
    assert not ignore_shortlist_logits["proposal_logits_used_for_shortlist"]
    assert ignore_shortlist_logits["proposal_logits_used_for_planner"]
    assert ignore_planner_logits["proposal_logits_used_for_shortlist"]
    assert not ignore_planner_logits["proposal_logits_used_for_planner"]
    assert torch.equal(base["planner_topk_indices"], ignore_planner_logits["planner_topk_indices"])
    assert not torch.allclose(base["planner_scores"], ignore_planner_logits["planner_scores"])
    assert not ignore_proposal_logits["proposal_logits_used_for_shortlist"]
    assert not ignore_proposal_logits["proposal_logits_used_for_planner"]
    assert not torch.allclose(base["planner_scores"], ignore_proposal_logits["planner_scores"])