"""Tests for eval/runner.py — pure/testable functions only (no ChromaDB).""" import json import pytest from pathlib import Path from mediastorm.eval.runner import _avg, _build_run_data, save_run, load_previous_run, load_all_runs # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _make_row(category: str, **overrides) -> dict: base = { "query": "test query", "category": category, "precision_at_1": 1.0, "recall_at_5": 0.8, "mrr": 0.9, "ndcg_at_5": 0.85, "retrieved": ["uid_a", "uid_b"], "expected": ["uid_a"], "missed": [], "duration": 0.1, } base.update(overrides) return base def _make_edge_row(**overrides) -> dict: base = { "query": "edge query", "category": "edge_no_match", "success": True, "num_returned": 0, "duration": 0.05, } base.update(overrides) return base def _make_eval_result(details: list[dict]) -> dict: return { "details": details, "semantic_precision_at_1": 0.8, "semantic_recall_at_5": 0.7, "semantic_mrr": 0.75, "semantic_ndcg_at_5": 0.72, "filter_precision_at_1": 0.6, "filter_recall_at_5": 0.65, "edge_pass_rate": 1.0, } # --------------------------------------------------------------------------- # _avg # --------------------------------------------------------------------------- class TestAvg: def test_averages_key_values(self): rows = [{"score": 0.5}, {"score": 1.0}] assert _avg(rows, "score") == pytest.approx(0.75) def test_skips_rows_missing_key(self): rows = [{"score": 1.0}, {"other": 0.0}] assert _avg(rows, "score") == pytest.approx(1.0) def test_empty_list_returns_zero(self): assert _avg([], "score") == 0.0 def test_all_rows_missing_key_returns_zero(self): rows = [{"other": 1.0}, {"other": 2.0}] assert _avg(rows, "score") == 0.0 def test_single_row(self): rows = [{"val": 0.42}] assert _avg(rows, "val") == pytest.approx(0.42) # --------------------------------------------------------------------------- # _build_run_data # --------------------------------------------------------------------------- class TestBuildRunData: def test_timestamp_is_iso_string(self): result = _build_run_data(_make_eval_result([_make_row("geographic")])) ts = result["timestamp"] assert isinstance(ts, str) # Should parse back without error from datetime import datetime datetime.fromisoformat(ts) def test_aggregates_keys(self): result = _build_run_data(_make_eval_result([_make_row("thematic")])) agg = result["aggregates"] assert set(agg.keys()) == { "semantic_p1", "semantic_r5", "semantic_mrr", "semantic_ndcg5", "filter_p1", "filter_r5", "edge_pass_rate", } def test_aggregates_values_passthrough(self): eval_result = _make_eval_result([_make_row("geographic")]) result = _build_run_data(eval_result) agg = result["aggregates"] assert agg["semantic_p1"] == pytest.approx(0.8) assert agg["filter_r5"] == pytest.approx(0.65) assert agg["edge_pass_rate"] == pytest.approx(1.0) def test_category_summary_for_normal_category(self): rows = [ _make_row("geographic", precision_at_1=1.0, recall_at_5=0.6, mrr=0.8, ndcg_at_5=0.7), _make_row("geographic", precision_at_1=0.0, recall_at_5=0.4, mrr=0.5, ndcg_at_5=0.3), ] result = _build_run_data(_make_eval_result(rows)) cat = result["categories"]["geographic"] assert cat["count"] == 2 assert cat["p1"] == pytest.approx(0.5) assert cat["r5"] == pytest.approx(0.5) assert cat["mrr"] == pytest.approx(0.65) assert cat["ndcg5"] == pytest.approx(0.5) def test_category_summary_for_edge_no_match(self): rows = [ _make_edge_row(success=True), _make_edge_row(success=False), _make_edge_row(success=True), ] result = _build_run_data(_make_eval_result(rows)) edge = result["categories"]["edge_no_match"] assert edge["passed"] == 2 assert edge["total"] == 3 def test_multiple_categories_separated(self): rows = [ _make_row("geographic"), _make_row("thematic"), _make_row("geographic"), ] result = _build_run_data(_make_eval_result(rows)) assert result["categories"]["geographic"]["count"] == 2 assert result["categories"]["thematic"]["count"] == 1 def test_queries_list_length_matches_details(self): rows = [_make_row("geographic"), _make_row("thematic"), _make_edge_row()] result = _build_run_data(_make_eval_result(rows)) assert len(result["queries"]) == 3 def test_normal_query_entry_has_expected_keys(self): result = _build_run_data(_make_eval_result([_make_row("geographic")])) q = result["queries"][0] assert "query" in q assert "p1" in q assert "r5" in q assert "mrr" in q assert "ndcg5" in q assert "retrieved_ids" in q assert "expected_ids" in q assert "missed" in q assert "duration" in q # Edge-only fields should not be present assert "success" not in q assert "num_returned" not in q def test_edge_query_entry_has_expected_keys(self): result = _build_run_data(_make_eval_result([_make_edge_row()])) q = result["queries"][0] assert "success" in q assert "num_returned" in q # Metric fields should not be present assert "p1" not in q assert "retrieved_ids" not in q def test_returns_dict_with_required_top_level_keys(self): result = _build_run_data(_make_eval_result([_make_row("geographic")])) assert set(result.keys()) >= {"timestamp", "aggregates", "categories", "queries"} # --------------------------------------------------------------------------- # save_run / load_previous_run / load_all_runs # --------------------------------------------------------------------------- class TestRunPersistence: def _sample_run(self, timestamp: str = "2026-01-15T10:30:00") -> dict: return { "timestamp": timestamp, "aggregates": {"semantic_p1": 0.8}, "categories": {}, "queries": [], } def test_save_run_creates_json_file(self, tmp_path): run = self._sample_run() path = save_run(run, runs_dir=tmp_path) assert path.exists() assert path.suffix == ".json" def test_save_run_filename_matches_timestamp(self, tmp_path): run = self._sample_run("2026-01-15T10:30:00") path = save_run(run, runs_dir=tmp_path) assert path.name == "2026-01-15_10-30-00.json" def test_save_run_content_is_valid_json(self, tmp_path): run = self._sample_run() path = save_run(run, runs_dir=tmp_path) loaded = json.loads(path.read_text()) assert loaded["aggregates"]["semantic_p1"] == pytest.approx(0.8) def test_save_run_creates_parent_dirs(self, tmp_path): nested = tmp_path / "deep" / "nested" run = self._sample_run() save_run(run, runs_dir=nested) assert nested.exists() def test_load_previous_run_returns_none_when_no_dir(self, tmp_path): missing = tmp_path / "nonexistent" assert load_previous_run(runs_dir=missing) is None def test_load_previous_run_returns_none_when_empty_dir(self, tmp_path): assert load_previous_run(runs_dir=tmp_path) is None def test_load_previous_run_returns_most_recent(self, tmp_path): run_a = self._sample_run("2026-01-15T10:00:00") run_b = self._sample_run("2026-01-15T11:00:00") save_run(run_a, runs_dir=tmp_path) save_run(run_b, runs_dir=tmp_path) loaded = load_previous_run(runs_dir=tmp_path) assert loaded["timestamp"] == "2026-01-15T11:00:00" def test_load_all_runs_returns_empty_when_no_dir(self, tmp_path): missing = tmp_path / "nonexistent" assert load_all_runs(runs_dir=missing) == [] def test_load_all_runs_returns_empty_when_empty_dir(self, tmp_path): assert load_all_runs(runs_dir=tmp_path) == [] def test_load_all_runs_returns_all_in_order(self, tmp_path): timestamps = [ "2026-01-10T09:00:00", "2026-01-15T11:00:00", "2026-01-12T14:00:00", ] for ts in timestamps: save_run(self._sample_run(ts), runs_dir=tmp_path) runs = load_all_runs(runs_dir=tmp_path) assert len(runs) == 3 result_ts = [r["timestamp"] for r in runs] assert result_ts == sorted(result_ts) def test_save_run_returns_path_object(self, tmp_path): run = self._sample_run() path = save_run(run, runs_dir=tmp_path) assert isinstance(path, Path)