File size: 890 Bytes
b14c4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from sim_reveal.dataset import collect_teacher_dataset
from sim_reveal.procedural_envs import available_proxy_names


def test_dataset_hard_negative_presence():
    dataset_bundle = collect_teacher_dataset(
        proxy_names=available_proxy_names(),
        episodes_per_proxy=1,
        resolution=32,
        seed=3,
        chunk_horizon=4,
        rollout_horizon=4,
        planner_candidates=6,
    )
    negative_families = set()
    hard_negative_count = 0
    for sample in dataset_bundle["samples"]:
        hard_negative_count += int(sum(sample["candidate_is_hard_negative"]))
        negative_families.update(
            family
            for family in sample["candidate_negative_families"]
            if family not in {"teacher", "positive"}
        )
    assert hard_negative_count > 0
    assert {"premature_retrieve", "reveal_with_release"}.issubset(negative_families)