mediastorm / tests /test_eval_runner.py
remdms's picture
test: add unit tests for eval/runner.py
f4eb869
"""Tests for eval/runner.py — pure/testable functions only (no ChromaDB)."""
import json
import pytest
from pathlib import Path
from mediastorm.eval.runner import _avg, _build_run_data, save_run, load_previous_run, load_all_runs
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_row(category: str, **overrides) -> dict:
base = {
"query": "test query",
"category": category,
"precision_at_1": 1.0,
"recall_at_5": 0.8,
"mrr": 0.9,
"ndcg_at_5": 0.85,
"retrieved": ["uid_a", "uid_b"],
"expected": ["uid_a"],
"missed": [],
"duration": 0.1,
}
base.update(overrides)
return base
def _make_edge_row(**overrides) -> dict:
base = {
"query": "edge query",
"category": "edge_no_match",
"success": True,
"num_returned": 0,
"duration": 0.05,
}
base.update(overrides)
return base
def _make_eval_result(details: list[dict]) -> dict:
return {
"details": details,
"semantic_precision_at_1": 0.8,
"semantic_recall_at_5": 0.7,
"semantic_mrr": 0.75,
"semantic_ndcg_at_5": 0.72,
"filter_precision_at_1": 0.6,
"filter_recall_at_5": 0.65,
"edge_pass_rate": 1.0,
}
# ---------------------------------------------------------------------------
# _avg
# ---------------------------------------------------------------------------
class TestAvg:
def test_averages_key_values(self):
rows = [{"score": 0.5}, {"score": 1.0}]
assert _avg(rows, "score") == pytest.approx(0.75)
def test_skips_rows_missing_key(self):
rows = [{"score": 1.0}, {"other": 0.0}]
assert _avg(rows, "score") == pytest.approx(1.0)
def test_empty_list_returns_zero(self):
assert _avg([], "score") == 0.0
def test_all_rows_missing_key_returns_zero(self):
rows = [{"other": 1.0}, {"other": 2.0}]
assert _avg(rows, "score") == 0.0
def test_single_row(self):
rows = [{"val": 0.42}]
assert _avg(rows, "val") == pytest.approx(0.42)
# ---------------------------------------------------------------------------
# _build_run_data
# ---------------------------------------------------------------------------
class TestBuildRunData:
def test_timestamp_is_iso_string(self):
result = _build_run_data(_make_eval_result([_make_row("geographic")]))
ts = result["timestamp"]
assert isinstance(ts, str)
# Should parse back without error
from datetime import datetime
datetime.fromisoformat(ts)
def test_aggregates_keys(self):
result = _build_run_data(_make_eval_result([_make_row("thematic")]))
agg = result["aggregates"]
assert set(agg.keys()) == {
"semantic_p1", "semantic_r5", "semantic_mrr", "semantic_ndcg5",
"filter_p1", "filter_r5", "edge_pass_rate",
}
def test_aggregates_values_passthrough(self):
eval_result = _make_eval_result([_make_row("geographic")])
result = _build_run_data(eval_result)
agg = result["aggregates"]
assert agg["semantic_p1"] == pytest.approx(0.8)
assert agg["filter_r5"] == pytest.approx(0.65)
assert agg["edge_pass_rate"] == pytest.approx(1.0)
def test_category_summary_for_normal_category(self):
rows = [
_make_row("geographic", precision_at_1=1.0, recall_at_5=0.6, mrr=0.8, ndcg_at_5=0.7),
_make_row("geographic", precision_at_1=0.0, recall_at_5=0.4, mrr=0.5, ndcg_at_5=0.3),
]
result = _build_run_data(_make_eval_result(rows))
cat = result["categories"]["geographic"]
assert cat["count"] == 2
assert cat["p1"] == pytest.approx(0.5)
assert cat["r5"] == pytest.approx(0.5)
assert cat["mrr"] == pytest.approx(0.65)
assert cat["ndcg5"] == pytest.approx(0.5)
def test_category_summary_for_edge_no_match(self):
rows = [
_make_edge_row(success=True),
_make_edge_row(success=False),
_make_edge_row(success=True),
]
result = _build_run_data(_make_eval_result(rows))
edge = result["categories"]["edge_no_match"]
assert edge["passed"] == 2
assert edge["total"] == 3
def test_multiple_categories_separated(self):
rows = [
_make_row("geographic"),
_make_row("thematic"),
_make_row("geographic"),
]
result = _build_run_data(_make_eval_result(rows))
assert result["categories"]["geographic"]["count"] == 2
assert result["categories"]["thematic"]["count"] == 1
def test_queries_list_length_matches_details(self):
rows = [_make_row("geographic"), _make_row("thematic"), _make_edge_row()]
result = _build_run_data(_make_eval_result(rows))
assert len(result["queries"]) == 3
def test_normal_query_entry_has_expected_keys(self):
result = _build_run_data(_make_eval_result([_make_row("geographic")]))
q = result["queries"][0]
assert "query" in q
assert "p1" in q
assert "r5" in q
assert "mrr" in q
assert "ndcg5" in q
assert "retrieved_ids" in q
assert "expected_ids" in q
assert "missed" in q
assert "duration" in q
# Edge-only fields should not be present
assert "success" not in q
assert "num_returned" not in q
def test_edge_query_entry_has_expected_keys(self):
result = _build_run_data(_make_eval_result([_make_edge_row()]))
q = result["queries"][0]
assert "success" in q
assert "num_returned" in q
# Metric fields should not be present
assert "p1" not in q
assert "retrieved_ids" not in q
def test_returns_dict_with_required_top_level_keys(self):
result = _build_run_data(_make_eval_result([_make_row("geographic")]))
assert set(result.keys()) >= {"timestamp", "aggregates", "categories", "queries"}
# ---------------------------------------------------------------------------
# save_run / load_previous_run / load_all_runs
# ---------------------------------------------------------------------------
class TestRunPersistence:
def _sample_run(self, timestamp: str = "2026-01-15T10:30:00") -> dict:
return {
"timestamp": timestamp,
"aggregates": {"semantic_p1": 0.8},
"categories": {},
"queries": [],
}
def test_save_run_creates_json_file(self, tmp_path):
run = self._sample_run()
path = save_run(run, runs_dir=tmp_path)
assert path.exists()
assert path.suffix == ".json"
def test_save_run_filename_matches_timestamp(self, tmp_path):
run = self._sample_run("2026-01-15T10:30:00")
path = save_run(run, runs_dir=tmp_path)
assert path.name == "2026-01-15_10-30-00.json"
def test_save_run_content_is_valid_json(self, tmp_path):
run = self._sample_run()
path = save_run(run, runs_dir=tmp_path)
loaded = json.loads(path.read_text())
assert loaded["aggregates"]["semantic_p1"] == pytest.approx(0.8)
def test_save_run_creates_parent_dirs(self, tmp_path):
nested = tmp_path / "deep" / "nested"
run = self._sample_run()
save_run(run, runs_dir=nested)
assert nested.exists()
def test_load_previous_run_returns_none_when_no_dir(self, tmp_path):
missing = tmp_path / "nonexistent"
assert load_previous_run(runs_dir=missing) is None
def test_load_previous_run_returns_none_when_empty_dir(self, tmp_path):
assert load_previous_run(runs_dir=tmp_path) is None
def test_load_previous_run_returns_most_recent(self, tmp_path):
run_a = self._sample_run("2026-01-15T10:00:00")
run_b = self._sample_run("2026-01-15T11:00:00")
save_run(run_a, runs_dir=tmp_path)
save_run(run_b, runs_dir=tmp_path)
loaded = load_previous_run(runs_dir=tmp_path)
assert loaded["timestamp"] == "2026-01-15T11:00:00"
def test_load_all_runs_returns_empty_when_no_dir(self, tmp_path):
missing = tmp_path / "nonexistent"
assert load_all_runs(runs_dir=missing) == []
def test_load_all_runs_returns_empty_when_empty_dir(self, tmp_path):
assert load_all_runs(runs_dir=tmp_path) == []
def test_load_all_runs_returns_all_in_order(self, tmp_path):
timestamps = [
"2026-01-10T09:00:00",
"2026-01-15T11:00:00",
"2026-01-12T14:00:00",
]
for ts in timestamps:
save_run(self._sample_run(ts), runs_dir=tmp_path)
runs = load_all_runs(runs_dir=tmp_path)
assert len(runs) == 3
result_ts = [r["timestamp"] for r in runs]
assert result_ts == sorted(result_ts)
def test_save_run_returns_path_object(self, tmp_path):
run = self._sample_run()
path = save_run(run, runs_dir=tmp_path)
assert isinstance(path, Path)