File size: 2,218 Bytes
31ade1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from sim_reveal.dataset import RGBD_PROXY_DATASET_VERSION, collect_teacher_dataset, dataset_from_bundle


def test_dataset_v6_keys():
    bundle = collect_teacher_dataset(
        episodes_per_proxy=1,
        resolution=16,
        history_steps=2,
        planner_candidates=3,
        dataset_version=RGBD_PROXY_DATASET_VERSION,
    )
    dataset = dataset_from_bundle(bundle, resolution=16)
    item = dataset[0]
    for key in (
        "images",
        "depths",
        "depth_valid",
        "belief_map",
        "visibility_map",
        "clearance_map",
        "support_stability",
        "reocclusion_target",
        "candidate_rollout_belief_map",
    ):
        assert key in item


def test_phase_dataset_version_keeps_rgbd_path():
    bundle = collect_teacher_dataset(
        proxy_names=["bag_proxy"],
        episodes_per_proxy=1,
        resolution=16,
        history_steps=2,
        planner_candidates=3,
        dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase",
    )
    dataset = dataset_from_bundle(bundle, resolution=16)
    item = dataset[0]
    assert float(item["depth_valid"].sum()) > 0.0
    assert "phase" in item
    assert "rollout_phase" in item
    assert "candidate_rollout_phase" in item


def test_dataset_proposal_target_keys_roundtrip():
    bundle = collect_teacher_dataset(
        proxy_names=["cloth_proxy"],
        episodes_per_proxy=1,
        resolution=16,
        history_steps=2,
        planner_candidates=3,
        dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase",
        proposal_target_builder=lambda env, observation, sample: {
            "proposal_target_action_chunks": sample["candidate_action_chunks"].copy(),
            "proposal_target_retrieval_success": sample["candidate_retrieval_success"].copy(),
            "proposal_target_risk": sample["candidate_risk"].copy(),
            "proposal_target_utility": sample["candidate_utility"].copy(),
        },
    )
    dataset = dataset_from_bundle(bundle, resolution=16)
    item = dataset[0]
    assert "proposal_target_action_chunks" in item
    assert "proposal_target_retrieval_success" in item
    assert "proposal_target_risk" in item
    assert "proposal_target_utility" in item