File size: 2,574 Bytes
b14c4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

from eval.summarize_anybimanual_overlap_eval import (
    _best_overlap_step,
    _delta,
    _last_overlap_step,
    _merge_rows_by_step,
)


def _row(step: int, **values: str) -> dict[str, str]:
    base = {"step": str(step)}
    base.update(values)
    return base


def test_merge_rows_by_step_fills_missing_values() -> None:
    rows = [
        _row(600, **{"eval_envs/return/coordinated_push_box": "10"}),
        _row(600, **{"eval_envs/return/coordinated_lift_ball": "4"}),
        _row(600, **{"eval_envs/return/dual_push_buttons": "20"}),
    ]
    merged = _merge_rows_by_step(rows)
    assert len(merged) == 1
    assert merged[0]["eval_envs/return/coordinated_push_box"] == "10"
    assert merged[0]["eval_envs/return/coordinated_lift_ball"] == "4"
    assert merged[0]["eval_envs/return/dual_push_buttons"] == "20"


def test_overlap_summary_picks_last_local_and_best_public() -> None:
    local_rows = _merge_rows_by_step(
        [
            _row(
                200,
                **{
                    "eval_envs/return/coordinated_push_box": "0",
                    "eval_envs/return/coordinated_lift_ball": "0",
                    "eval_envs/return/dual_push_buttons": "0",
                },
            ),
            _row(
                1000,
                **{
                    "eval_envs/return/coordinated_push_box": "15",
                    "eval_envs/return/coordinated_lift_ball": "8",
                    "eval_envs/return/dual_push_buttons": "20",
                },
            ),
        ]
    )
    public_rows = _merge_rows_by_step(
        [
            _row(
                50000,
                **{
                    "eval_envs/return/coordinated_push_box": "18",
                    "eval_envs/return/coordinated_lift_ball": "6",
                    "eval_envs/return/dual_push_buttons": "20",
                },
            ),
            _row(
                60000,
                **{
                    "eval_envs/return/coordinated_push_box": "20",
                    "eval_envs/return/coordinated_lift_ball": "8",
                    "eval_envs/return/dual_push_buttons": "24",
                },
            ),
        ]
    )
    local_last = _last_overlap_step(local_rows, 25)
    public_best = _best_overlap_step(public_rows, 25)
    assert local_last["step"] == 1000
    assert public_best["step"] == 60000
    assert public_best["mean_success"] > local_last["mean_success"]
    delta = _delta(local_last, public_best)
    assert delta["mean_success_delta"] < 0.0