File size: 2,574 Bytes
b14c4b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | from __future__ import annotations
from eval.summarize_anybimanual_overlap_eval import (
_best_overlap_step,
_delta,
_last_overlap_step,
_merge_rows_by_step,
)
def _row(step: int, **values: str) -> dict[str, str]:
base = {"step": str(step)}
base.update(values)
return base
def test_merge_rows_by_step_fills_missing_values() -> None:
rows = [
_row(600, **{"eval_envs/return/coordinated_push_box": "10"}),
_row(600, **{"eval_envs/return/coordinated_lift_ball": "4"}),
_row(600, **{"eval_envs/return/dual_push_buttons": "20"}),
]
merged = _merge_rows_by_step(rows)
assert len(merged) == 1
assert merged[0]["eval_envs/return/coordinated_push_box"] == "10"
assert merged[0]["eval_envs/return/coordinated_lift_ball"] == "4"
assert merged[0]["eval_envs/return/dual_push_buttons"] == "20"
def test_overlap_summary_picks_last_local_and_best_public() -> None:
local_rows = _merge_rows_by_step(
[
_row(
200,
**{
"eval_envs/return/coordinated_push_box": "0",
"eval_envs/return/coordinated_lift_ball": "0",
"eval_envs/return/dual_push_buttons": "0",
},
),
_row(
1000,
**{
"eval_envs/return/coordinated_push_box": "15",
"eval_envs/return/coordinated_lift_ball": "8",
"eval_envs/return/dual_push_buttons": "20",
},
),
]
)
public_rows = _merge_rows_by_step(
[
_row(
50000,
**{
"eval_envs/return/coordinated_push_box": "18",
"eval_envs/return/coordinated_lift_ball": "6",
"eval_envs/return/dual_push_buttons": "20",
},
),
_row(
60000,
**{
"eval_envs/return/coordinated_push_box": "20",
"eval_envs/return/coordinated_lift_ball": "8",
"eval_envs/return/dual_push_buttons": "24",
},
),
]
)
local_last = _last_overlap_step(local_rows, 25)
public_best = _best_overlap_step(public_rows, 25)
assert local_last["step"] == 1000
assert public_best["step"] == 60000
assert public_best["mean_success"] > local_last["mean_success"]
delta = _delta(local_last, public_best)
assert delta["mean_success_delta"] < 0.0
|