"""Unit tests for the evaluation harness (build-plan phase 5). Fully offline: no model calls, no dataset downloads. The comparison function is tested directly, and the score-phase computation is tested on hand-built cached entries with known, hand-computed metrics. The threshold sweep is tested for the hard-failure override and the precision/recall trade-off. """ from __future__ import annotations from pathlib import Path from typing import Any import pytest from eval.cache import read_entries, report_from_dict, write_entry from eval.metrics import ( THRESHOLDS, compute_field_metrics, confidence_histogram, smallest_threshold_meeting, sweep_thresholds, ) from eval.normalize import is_present, normalize, values_match from eval.score import build_report # --------------------------------------------------------------------------- # Comparison function (normalize / values_match) # --------------------------------------------------------------------------- def test_money_matches_cent_exact() -> None: """Monetary values match only when equal at cent precision.""" assert values_match("total", 193.0, "193.00") assert values_match("total", "1,234.56", "1234.56") assert values_match("total", 100.004, 100.0) # sub-cent noise rounds away assert not values_match("total", 100.0, 100.01) # a genuine 1-cent difference assert not values_match("total", 100.0, 105.0) def test_money_rejects_relative_tolerance_error() -> None: """Regression: a materially-wrong total within 0.5% is NOT scored correct. The eval comparator must not reuse the reconciliation relative epsilon, or a $2 error on a $500 total (and $10 on $2000) would inflate critical precision. """ assert not values_match("total", 502.0, "500.00") # within money_close's 0.5% band assert not values_match("total", 2010.0, 2000.0) # +/-$10 window at $2000 assert not values_match("tax", 9.05, 9.0) # 5-cent tax error def test_money_handles_currency_symbols_and_separators() -> None: """Currency symbols and thousands separators normalize away.""" assert normalize("total", "$1,000.00") == pytest.approx(1000.0) assert values_match("total", "RM 193.00", "193.0") def test_date_matches_day_first_format() -> None: """SROIE-style day-first dates match the ISO-cached prediction.""" # gold "15/01/2019" (D/M/Y) vs predicted cached ISO "2019-01-15". assert values_match("document_date", "2019-01-15", "15/01/2019") assert not values_match("document_date", "2019-01-16", "15/01/2019") def test_text_matches_case_and_whitespace_insensitive() -> None: """Text matches after lower-casing and whitespace collapsing.""" assert values_match("vendor_name", "OJC Marketing SDN BHD", "ojc marketing sdn bhd") assert not values_match("vendor_name", "Acme Corp", "Beta LLC") def test_absent_values_never_match() -> None: """A present prediction against absent gold (or vice versa) is not a match.""" assert not values_match("total", 100.0, None) assert not values_match("total", None, 100.0) assert not values_match("vendor_name", "", "acme") assert not is_present("vendor_name", " ") assert not is_present("total", "N/A") def test_unparseable_money_is_absent() -> None: """An unparseable monetary string normalizes to None (absent), not a crash.""" assert normalize("total", "not a number") is None assert not is_present("total", "abc") # --------------------------------------------------------------------------- # Synthetic cached entries with known metrics # --------------------------------------------------------------------------- def _validation(hard_failed: bool, *, hard_codes: list[str] | None = None) -> dict[str, Any]: """Build a minimal validation dict as ValidationReport.to_dict would.""" results = [] for code in hard_codes or []: results.append( {"code": code, "severity": "hard", "status": "fail", "message": "synthetic"} ) return { "hard_failed": hard_failed, "results": results, "hard_failures": list(hard_codes or []), "soft_failures": [], } def _entry( example_id: str, *, predicted: dict[str, Any], gold: dict[str, Any], confidence: float, hard_failed: bool = False, hard_codes: list[str] | None = None, labeled: tuple[str, ...] = ("vendor_name", "vendor_address", "document_date", "total"), ) -> dict[str, Any]: return { "id": example_id, "dataset": "synthetic", "gold": gold, "labeled_fields": list(labeled), "predicted": predicted, "confidence": confidence, "decision": "review", "modality": "image", "backend": "stub", "validation": _validation(hard_failed, hard_codes=hard_codes), "error": None, } @pytest.fixture def synthetic_entries() -> list[dict[str, Any]]: """Four entries with hand-computable per-field metrics. total field: preds present on all 4; golds present on all 4; - e1 correct, e2 correct, e3 wrong value, e4 correct -> 3/4 match. vendor_name: preds present on 3 (e4 missing); golds present on 4; - e1 correct, e2 correct, e3 correct -> 3 match. => precision 3/3 = 1.0, recall 3/4 = 0.75. """ return [ _entry( "e1", predicted={"vendor_name": "Acme", "total": 100.0}, gold={"vendor_name": "acme", "total": "100.00", "vendor_address": None, "document_date": None}, confidence=0.90, ), _entry( "e2", predicted={"vendor_name": "Beta", "total": 50.0}, gold={"vendor_name": "beta", "total": "50.00", "vendor_address": None, "document_date": None}, confidence=0.80, ), _entry( "e3", predicted={"vendor_name": "Gamma", "total": 999.0}, # total wrong gold={"vendor_name": "gamma", "total": "10.00", "vendor_address": None, "document_date": None}, confidence=0.70, ), _entry( "e4", predicted={"vendor_name": None, "total": 25.0}, # vendor missing gold={"vendor_name": "delta", "total": "25.00", "vendor_address": None, "document_date": None}, confidence=0.60, ), ] def test_field_metrics_match_hand_computed(synthetic_entries: list[dict[str, Any]]) -> None: """Per-field precision/recall/F1 equal the hand-computed values.""" metrics = {m.field: m for m in compute_field_metrics( synthetic_entries, ("vendor_name", "total"))} total = metrics["total"] assert (total.n_pred, total.n_gold, total.n_match) == (4, 4, 3) assert total.precision == pytest.approx(0.75) assert total.recall == pytest.approx(0.75) assert total.f1 == pytest.approx(0.75) vendor = metrics["vendor_name"] assert (vendor.n_pred, vendor.n_gold, vendor.n_match) == (3, 4, 3) assert vendor.precision == pytest.approx(1.0) assert vendor.recall == pytest.approx(0.75) assert vendor.f1 == pytest.approx(2 * 1.0 * 0.75 / (1.0 + 0.75)) def test_field_precision_none_when_no_prediction() -> None: """Precision is None (undefined) when nothing was predicted for a field.""" entries = [ _entry("e1", predicted={"total": None}, gold={"total": "5.00"}, confidence=0.9), ] (metric,) = compute_field_metrics(entries, ("total",)) assert metric.n_pred == 0 assert metric.precision is None assert metric.recall == pytest.approx(0.0) # --------------------------------------------------------------------------- # Threshold sweep # --------------------------------------------------------------------------- def test_sweep_accept_count_falls_as_threshold_rises( synthetic_entries: list[dict[str, Any]], ) -> None: """Higher thresholds auto-accept no more documents than lower ones.""" rows = sweep_thresholds(synthetic_entries, ("total",), THRESHOLDS) accept_counts = [row.n_accepted for row in rows] assert accept_counts == sorted(accept_counts, reverse=True) # At 0.50 every clean doc (confidence >= 0.50) is accepted; all 4 here. assert rows[0].threshold == 0.50 assert rows[0].n_accepted == 4 def test_sweep_hard_failure_never_accepted() -> None: """A hard-failed document is forced to review at every threshold.""" entries = [ _entry( "hard", predicted={"total": 100.0}, gold={"total": "100.00"}, confidence=0.99, # very confident... hard_failed=True, hard_codes=["H2"], # ...but a hard rule failed. ), ] rows = sweep_thresholds(entries, ("total",), THRESHOLDS) assert all(row.n_accepted == 0 for row in rows) # And the reconstructed report reports the hard failure. report = report_from_dict(entries[0]["validation"]) assert report.hard_failed is True def test_sweep_critical_precision_and_recall( synthetic_entries: list[dict[str, Any]], ) -> None: """At threshold 0.50 the critical (total) precision/recall match hand calc. All 4 accepted; total correct on 3/4 => precision 0.75; gold present on 4 => recall 3/4 = 0.75. """ rows = sweep_thresholds(synthetic_entries, ("total",), THRESHOLDS) row_050 = rows[0] assert row_050.crit_pred == 4 assert row_050.crit_match == 3 assert row_050.crit_precision == pytest.approx(0.75) assert row_050.crit_recall == pytest.approx(0.75) def test_smallest_threshold_meeting_target() -> None: """The lowest qualifying threshold is found when a clean high-conf doc exists.""" entries = [ _entry("ok", predicted={"total": 10.0}, gold={"total": "10.00"}, confidence=0.95), ] rows = sweep_thresholds(entries, ("total",), THRESHOLDS) target = smallest_threshold_meeting(rows, 0.98) assert target is not None # confidence 0.95 accepts for thresholds <= 0.95; precision is 1.0 (perfect). assert target.threshold == 0.50 assert target.crit_precision == pytest.approx(1.0) def test_smallest_threshold_meeting_none_when_unreachable() -> None: """Returns None when no threshold reaches the target precision.""" entries = [ _entry("bad", predicted={"total": 999.0}, gold={"total": "10.00"}, confidence=0.95), ] rows = sweep_thresholds(entries, ("total",), THRESHOLDS) assert smallest_threshold_meeting(rows, 0.98) is None def test_confidence_histogram_counts() -> None: """The histogram buckets rounded confidences.""" entries = [ _entry("a", predicted={}, gold={}, confidence=0.50), _entry("b", predicted={}, gold={}, confidence=0.50), _entry("c", predicted={}, gold={}, confidence=0.40), ] hist = confidence_histogram(entries) assert hist == {0.40: 1, 0.50: 2} # --------------------------------------------------------------------------- # End-to-end score phase over a written cache (still offline) # --------------------------------------------------------------------------- def test_build_report_from_written_cache( tmp_path: Path, synthetic_entries: list[dict[str, Any]] ) -> None: """Writing entries then building a report round-trips and computes metrics.""" for entry in synthetic_entries: write_entry(tmp_path, "synthetic", entry) assert len(read_entries(tmp_path, "synthetic")) == 4 report = build_report("synthetic", cache_base=tmp_path) assert report.n == 4 assert "total" in report.labeled_fields # SROIE-like labeling: total is the only critical field labeled here. assert report.critical_labeled == ("total",) total = next(m for m in report.field_metrics if m.field == "total") assert total.n_match == 3 def test_build_report_raises_without_cache(tmp_path: Path) -> None: """Scoring a dataset with no cache raises a clear error.""" with pytest.raises(FileNotFoundError): build_report("missing", cache_base=tmp_path)