File size: 10,956 Bytes
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""Snapshot regression tests against fixtures in this directory.

Discovery: every <name>.expected.json under fixtures/ pairs with a sibling
<name>.input.<ext>. The runner parses the input, then asserts each tolerance
in the expected file. Tolerance keys are documented in fixtures/README.md.

Performance baselines are opt-in per fixture via a `performance` block in
the expected file. They run only when ZSGDP_REGRESSION_PERF=1 (or when the
performance block has `always_enforce: true`) so a slow CI runner does not
fail on transient noise. When enabled, the parse is run twice and the
median elapsed time is compared against the floor.
"""

from __future__ import annotations

import json
import os
import statistics
import tempfile
import time
import unittest
import unittest.mock
from pathlib import Path
from typing import Any

from zsgdp.pipeline import parse_document

FIXTURE_DIR = Path(__file__).parent / "fixtures"


def _discover_fixtures() -> list[tuple[str, Path, Path]]:
    pairs: list[tuple[str, Path, Path]] = []
    if not FIXTURE_DIR.exists():
        return pairs
    for expected in sorted(FIXTURE_DIR.glob("*.expected.json")):
        name = expected.name[: -len(".expected.json")]
        candidates = sorted(FIXTURE_DIR.glob(f"{name}.input.*"))
        if not candidates:
            continue
        pairs.append((name, candidates[0], expected))
    return pairs


def _check_int_or_range(actual: int, exact: Any, range_value: Any, label: str) -> str | None:
    if exact is not None and int(exact) != actual:
        return f"{label}: expected {exact}, got {actual}"
    if isinstance(range_value, (list, tuple)) and len(range_value) == 2:
        lo, hi = int(range_value[0]), int(range_value[1])
        if not (lo <= actual <= hi):
            return f"{label}: expected in [{lo}, {hi}], got {actual}"
    return None


def _evaluate(parsed, tolerances: dict[str, Any]) -> list[str]:
    failures: list[str] = []
    score = float(parsed.quality_report.score)
    if "quality_score_min" in tolerances and score < float(tolerances["quality_score_min"]):
        failures.append(f"quality_score: {score:.3f} < {tolerances['quality_score_min']}")
    if "quality_score_max" in tolerances and score > float(tolerances["quality_score_max"]):
        failures.append(f"quality_score: {score:.3f} > {tolerances['quality_score_max']}")

    for label, count, exact_key, range_key in (
        ("element_count", len(parsed.elements), "element_count", "element_count_range"),
        ("table_count", len(parsed.tables), "table_count", "table_count_range"),
        ("figure_count", len(parsed.figures), "figure_count", "figure_count_range"),
    ):
        message = _check_int_or_range(count, tolerances.get(exact_key), tolerances.get(range_key), label)
        if message:
            failures.append(message)

    chunk_count = len(parsed.chunks)
    if "chunk_count_min" in tolerances and chunk_count < int(tolerances["chunk_count_min"]):
        failures.append(f"chunk_count: {chunk_count} < {tolerances['chunk_count_min']}")
    if "chunk_count_max" in tolerances and chunk_count > int(tolerances["chunk_count_max"]):
        failures.append(f"chunk_count: {chunk_count} > {tolerances['chunk_count_max']}")

    if "blocking_failures" in tolerances:
        actual = parsed.quality_report.has_blocking_failures
        expected = bool(tolerances["blocking_failures"])
        if actual != expected:
            failures.append(f"blocking_failures: expected {expected}, got {actual}")

    md = parsed.to_markdown()
    for needle in tolerances.get("must_contain_markdown", []) or []:
        if str(needle) not in md:
            failures.append(f"must_contain_markdown: {needle!r} not found")
    for needle in tolerances.get("must_not_contain_markdown", []) or []:
        if str(needle) in md:
            failures.append(f"must_not_contain_markdown: {needle!r} present")

    metrics = parsed.quality_report.metrics
    for key in tolerances.get("must_contain_quality_metrics", []) or []:
        if key not in metrics:
            failures.append(f"must_contain_quality_metrics: {key!r} missing")

    if "parser_disagreement_rate_max" in tolerances:
        rate = float(metrics.get("parser_disagreement_rate", 0.0))
        if rate > float(tolerances["parser_disagreement_rate_max"]):
            failures.append(
                f"parser_disagreement_rate: {rate:.3f} > {tolerances['parser_disagreement_rate_max']}"
            )
    if "repair_resolution_rate_min" in tolerances:
        rate = float(metrics.get("repair_resolution_rate", 1.0))
        if rate < float(tolerances["repair_resolution_rate_min"]):
            failures.append(
                f"repair_resolution_rate: {rate:.3f} < {tolerances['repair_resolution_rate_min']}"
            )

    return failures


def _perf_enforcement_enabled(performance: dict[str, Any]) -> bool:
    if performance.get("always_enforce"):
        return True
    return os.environ.get("ZSGDP_REGRESSION_PERF", "").strip().lower() in {"1", "true", "yes"}


def _measure_parse(input_path: Path, *, config_path: Path | None, selected_parsers, repeats: int) -> tuple[Any, list[float]]:
    """Parse the input N times, returning (last_parsed, list_of_elapsed_seconds).

    Uses a fresh temp output directory for each run so disk caching effects
    are roughly equal across runs. The last parsed document is returned for
    tolerance evaluation; per-run elapsed times feed the perf assertion.
    """

    elapsed: list[float] = []
    parsed = None
    for _ in range(max(1, repeats)):
        with tempfile.TemporaryDirectory() as tmp:
            started = time.perf_counter()
            parsed = parse_document(
                input_path,
                Path(tmp) / "out",
                config_path=config_path if config_path else None,
                selected_parsers=selected_parsers,
            )
            elapsed.append(time.perf_counter() - started)
    return parsed, elapsed


def _evaluate_performance(parsed, performance: dict[str, Any], elapsed_seconds: list[float]) -> list[str]:
    failures: list[str] = []
    if not elapsed_seconds:
        return failures

    median_elapsed = statistics.median(elapsed_seconds)
    page_count = max(len(parsed.pages), 1)
    median_pages_per_second = page_count / median_elapsed if median_elapsed > 0 else float("inf")

    max_elapsed = performance.get("max_elapsed_seconds")
    if max_elapsed is not None and median_elapsed > float(max_elapsed):
        failures.append(
            f"performance.max_elapsed_seconds: median {median_elapsed:.2f}s > {max_elapsed}s "
            f"(runs={len(elapsed_seconds)})"
        )

    min_pps = performance.get("min_pages_per_second")
    if min_pps is not None and median_pages_per_second < float(min_pps):
        failures.append(
            f"performance.min_pages_per_second: median {median_pages_per_second:.2f} < {min_pps} "
            f"(runs={len(elapsed_seconds)})"
        )

    return failures


class RegressionFixturesTest(unittest.TestCase):
    def test_regression_fixtures_match_snapshots(self):
        fixtures = _discover_fixtures()
        if not fixtures:
            self.skipTest("No regression fixtures present.")

        all_failures: list[str] = []
        for name, input_path, expected_path in fixtures:
            with self.subTest(fixture=name):
                expected = json.loads(expected_path.read_text(encoding="utf-8"))
                tolerances = expected.get("tolerances") or {}
                performance = expected.get("performance") or {}
                config_rel = expected.get("config")
                config_path = Path(config_rel) if config_rel else None
                if config_path and not config_path.is_absolute():
                    config_path = Path(__file__).resolve().parents[2] / config_path
                selected_parsers = expected.get("selected_parsers")

                perf_enabled = bool(performance) and _perf_enforcement_enabled(performance)
                repeats = int(performance.get("repeats", 2)) if perf_enabled else 1

                parsed, elapsed = _measure_parse(
                    input_path,
                    config_path=config_path,
                    selected_parsers=selected_parsers,
                    repeats=repeats,
                )

                failures = _evaluate(parsed, tolerances)
                if perf_enabled:
                    failures.extend(_evaluate_performance(parsed, performance, elapsed))
                if failures:
                    all_failures.append(f"[{name}] " + "; ".join(failures))

        if all_failures:
            self.fail("\n".join(all_failures))


class PerformanceEvaluatorTests(unittest.TestCase):
    """Unit tests for the perf-evaluation helpers, separate from fixture discovery."""

    def test_max_elapsed_floor_fires_when_too_slow(self):
        from types import SimpleNamespace

        parsed = SimpleNamespace(pages=[{"page_num": 1}])
        failures = _evaluate_performance(parsed, {"max_elapsed_seconds": 0.1}, [0.5, 0.5])
        self.assertEqual(len(failures), 1)
        self.assertIn("max_elapsed_seconds", failures[0])

    def test_min_pages_per_second_fires_when_too_slow(self):
        from types import SimpleNamespace

        parsed = SimpleNamespace(pages=[{"page_num": 1}])
        # 1 page in 10s => 0.1 pps, floor 1.0 => fail.
        failures = _evaluate_performance(parsed, {"min_pages_per_second": 1.0}, [10.0, 10.0])
        self.assertEqual(len(failures), 1)
        self.assertIn("min_pages_per_second", failures[0])

    def test_passing_floors_yield_no_failures(self):
        from types import SimpleNamespace

        parsed = SimpleNamespace(pages=[{"page_num": 1}, {"page_num": 2}])
        # 2 pages in 0.5s => 4 pps; floor 1.0 pps and max 2s.
        failures = _evaluate_performance(
            parsed,
            {"max_elapsed_seconds": 2.0, "min_pages_per_second": 1.0},
            [0.5, 0.5, 0.5],
        )
        self.assertEqual(failures, [])

    def test_median_strips_cold_outlier(self):
        from types import SimpleNamespace

        parsed = SimpleNamespace(pages=[{"page_num": 1}])
        # First run cold (5s), next two warm (0.1s). Median = 0.1s; floor 1s passes.
        failures = _evaluate_performance(parsed, {"max_elapsed_seconds": 1.0}, [5.0, 0.1, 0.1])
        self.assertEqual(failures, [])

    def test_perf_enforcement_gating(self):
        with unittest.mock.patch.dict("os.environ", {"ZSGDP_REGRESSION_PERF": "0"}, clear=False):
            self.assertFalse(_perf_enforcement_enabled({"max_elapsed_seconds": 1.0}))
            self.assertTrue(_perf_enforcement_enabled({"always_enforce": True}))

        with unittest.mock.patch.dict("os.environ", {"ZSGDP_REGRESSION_PERF": "1"}, clear=False):
            self.assertTrue(_perf_enforcement_enabled({"max_elapsed_seconds": 1.0}))


if __name__ == "__main__":
    unittest.main()