Spaces:
Sleeping
Sleeping
| """Tests for eval/runner.py — pure/testable functions only (no ChromaDB).""" | |
| import json | |
| import pytest | |
| from pathlib import Path | |
| from mediastorm.eval.runner import _avg, _build_run_data, save_run, load_previous_run, load_all_runs | |
| # --------------------------------------------------------------------------- | |
| # Fixtures | |
| # --------------------------------------------------------------------------- | |
| def _make_row(category: str, **overrides) -> dict: | |
| base = { | |
| "query": "test query", | |
| "category": category, | |
| "precision_at_1": 1.0, | |
| "recall_at_5": 0.8, | |
| "mrr": 0.9, | |
| "ndcg_at_5": 0.85, | |
| "retrieved": ["uid_a", "uid_b"], | |
| "expected": ["uid_a"], | |
| "missed": [], | |
| "duration": 0.1, | |
| } | |
| base.update(overrides) | |
| return base | |
| def _make_edge_row(**overrides) -> dict: | |
| base = { | |
| "query": "edge query", | |
| "category": "edge_no_match", | |
| "success": True, | |
| "num_returned": 0, | |
| "duration": 0.05, | |
| } | |
| base.update(overrides) | |
| return base | |
| def _make_eval_result(details: list[dict]) -> dict: | |
| return { | |
| "details": details, | |
| "semantic_precision_at_1": 0.8, | |
| "semantic_recall_at_5": 0.7, | |
| "semantic_mrr": 0.75, | |
| "semantic_ndcg_at_5": 0.72, | |
| "filter_precision_at_1": 0.6, | |
| "filter_recall_at_5": 0.65, | |
| "edge_pass_rate": 1.0, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # _avg | |
| # --------------------------------------------------------------------------- | |
| class TestAvg: | |
| def test_averages_key_values(self): | |
| rows = [{"score": 0.5}, {"score": 1.0}] | |
| assert _avg(rows, "score") == pytest.approx(0.75) | |
| def test_skips_rows_missing_key(self): | |
| rows = [{"score": 1.0}, {"other": 0.0}] | |
| assert _avg(rows, "score") == pytest.approx(1.0) | |
| def test_empty_list_returns_zero(self): | |
| assert _avg([], "score") == 0.0 | |
| def test_all_rows_missing_key_returns_zero(self): | |
| rows = [{"other": 1.0}, {"other": 2.0}] | |
| assert _avg(rows, "score") == 0.0 | |
| def test_single_row(self): | |
| rows = [{"val": 0.42}] | |
| assert _avg(rows, "val") == pytest.approx(0.42) | |
| # --------------------------------------------------------------------------- | |
| # _build_run_data | |
| # --------------------------------------------------------------------------- | |
| class TestBuildRunData: | |
| def test_timestamp_is_iso_string(self): | |
| result = _build_run_data(_make_eval_result([_make_row("geographic")])) | |
| ts = result["timestamp"] | |
| assert isinstance(ts, str) | |
| # Should parse back without error | |
| from datetime import datetime | |
| datetime.fromisoformat(ts) | |
| def test_aggregates_keys(self): | |
| result = _build_run_data(_make_eval_result([_make_row("thematic")])) | |
| agg = result["aggregates"] | |
| assert set(agg.keys()) == { | |
| "semantic_p1", "semantic_r5", "semantic_mrr", "semantic_ndcg5", | |
| "filter_p1", "filter_r5", "edge_pass_rate", | |
| } | |
| def test_aggregates_values_passthrough(self): | |
| eval_result = _make_eval_result([_make_row("geographic")]) | |
| result = _build_run_data(eval_result) | |
| agg = result["aggregates"] | |
| assert agg["semantic_p1"] == pytest.approx(0.8) | |
| assert agg["filter_r5"] == pytest.approx(0.65) | |
| assert agg["edge_pass_rate"] == pytest.approx(1.0) | |
| def test_category_summary_for_normal_category(self): | |
| rows = [ | |
| _make_row("geographic", precision_at_1=1.0, recall_at_5=0.6, mrr=0.8, ndcg_at_5=0.7), | |
| _make_row("geographic", precision_at_1=0.0, recall_at_5=0.4, mrr=0.5, ndcg_at_5=0.3), | |
| ] | |
| result = _build_run_data(_make_eval_result(rows)) | |
| cat = result["categories"]["geographic"] | |
| assert cat["count"] == 2 | |
| assert cat["p1"] == pytest.approx(0.5) | |
| assert cat["r5"] == pytest.approx(0.5) | |
| assert cat["mrr"] == pytest.approx(0.65) | |
| assert cat["ndcg5"] == pytest.approx(0.5) | |
| def test_category_summary_for_edge_no_match(self): | |
| rows = [ | |
| _make_edge_row(success=True), | |
| _make_edge_row(success=False), | |
| _make_edge_row(success=True), | |
| ] | |
| result = _build_run_data(_make_eval_result(rows)) | |
| edge = result["categories"]["edge_no_match"] | |
| assert edge["passed"] == 2 | |
| assert edge["total"] == 3 | |
| def test_multiple_categories_separated(self): | |
| rows = [ | |
| _make_row("geographic"), | |
| _make_row("thematic"), | |
| _make_row("geographic"), | |
| ] | |
| result = _build_run_data(_make_eval_result(rows)) | |
| assert result["categories"]["geographic"]["count"] == 2 | |
| assert result["categories"]["thematic"]["count"] == 1 | |
| def test_queries_list_length_matches_details(self): | |
| rows = [_make_row("geographic"), _make_row("thematic"), _make_edge_row()] | |
| result = _build_run_data(_make_eval_result(rows)) | |
| assert len(result["queries"]) == 3 | |
| def test_normal_query_entry_has_expected_keys(self): | |
| result = _build_run_data(_make_eval_result([_make_row("geographic")])) | |
| q = result["queries"][0] | |
| assert "query" in q | |
| assert "p1" in q | |
| assert "r5" in q | |
| assert "mrr" in q | |
| assert "ndcg5" in q | |
| assert "retrieved_ids" in q | |
| assert "expected_ids" in q | |
| assert "missed" in q | |
| assert "duration" in q | |
| # Edge-only fields should not be present | |
| assert "success" not in q | |
| assert "num_returned" not in q | |
| def test_edge_query_entry_has_expected_keys(self): | |
| result = _build_run_data(_make_eval_result([_make_edge_row()])) | |
| q = result["queries"][0] | |
| assert "success" in q | |
| assert "num_returned" in q | |
| # Metric fields should not be present | |
| assert "p1" not in q | |
| assert "retrieved_ids" not in q | |
| def test_returns_dict_with_required_top_level_keys(self): | |
| result = _build_run_data(_make_eval_result([_make_row("geographic")])) | |
| assert set(result.keys()) >= {"timestamp", "aggregates", "categories", "queries"} | |
| # --------------------------------------------------------------------------- | |
| # save_run / load_previous_run / load_all_runs | |
| # --------------------------------------------------------------------------- | |
| class TestRunPersistence: | |
| def _sample_run(self, timestamp: str = "2026-01-15T10:30:00") -> dict: | |
| return { | |
| "timestamp": timestamp, | |
| "aggregates": {"semantic_p1": 0.8}, | |
| "categories": {}, | |
| "queries": [], | |
| } | |
| def test_save_run_creates_json_file(self, tmp_path): | |
| run = self._sample_run() | |
| path = save_run(run, runs_dir=tmp_path) | |
| assert path.exists() | |
| assert path.suffix == ".json" | |
| def test_save_run_filename_matches_timestamp(self, tmp_path): | |
| run = self._sample_run("2026-01-15T10:30:00") | |
| path = save_run(run, runs_dir=tmp_path) | |
| assert path.name == "2026-01-15_10-30-00.json" | |
| def test_save_run_content_is_valid_json(self, tmp_path): | |
| run = self._sample_run() | |
| path = save_run(run, runs_dir=tmp_path) | |
| loaded = json.loads(path.read_text()) | |
| assert loaded["aggregates"]["semantic_p1"] == pytest.approx(0.8) | |
| def test_save_run_creates_parent_dirs(self, tmp_path): | |
| nested = tmp_path / "deep" / "nested" | |
| run = self._sample_run() | |
| save_run(run, runs_dir=nested) | |
| assert nested.exists() | |
| def test_load_previous_run_returns_none_when_no_dir(self, tmp_path): | |
| missing = tmp_path / "nonexistent" | |
| assert load_previous_run(runs_dir=missing) is None | |
| def test_load_previous_run_returns_none_when_empty_dir(self, tmp_path): | |
| assert load_previous_run(runs_dir=tmp_path) is None | |
| def test_load_previous_run_returns_most_recent(self, tmp_path): | |
| run_a = self._sample_run("2026-01-15T10:00:00") | |
| run_b = self._sample_run("2026-01-15T11:00:00") | |
| save_run(run_a, runs_dir=tmp_path) | |
| save_run(run_b, runs_dir=tmp_path) | |
| loaded = load_previous_run(runs_dir=tmp_path) | |
| assert loaded["timestamp"] == "2026-01-15T11:00:00" | |
| def test_load_all_runs_returns_empty_when_no_dir(self, tmp_path): | |
| missing = tmp_path / "nonexistent" | |
| assert load_all_runs(runs_dir=missing) == [] | |
| def test_load_all_runs_returns_empty_when_empty_dir(self, tmp_path): | |
| assert load_all_runs(runs_dir=tmp_path) == [] | |
| def test_load_all_runs_returns_all_in_order(self, tmp_path): | |
| timestamps = [ | |
| "2026-01-10T09:00:00", | |
| "2026-01-15T11:00:00", | |
| "2026-01-12T14:00:00", | |
| ] | |
| for ts in timestamps: | |
| save_run(self._sample_run(ts), runs_dir=tmp_path) | |
| runs = load_all_runs(runs_dir=tmp_path) | |
| assert len(runs) == 3 | |
| result_ts = [r["timestamp"] for r in runs] | |
| assert result_ts == sorted(result_ts) | |
| def test_save_run_returns_path_object(self, tmp_path): | |
| run = self._sample_run() | |
| path = save_run(run, runs_dir=tmp_path) | |
| assert isinstance(path, Path) | |