zeroshotGPU / tests /regression /test_regression.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Snapshot regression tests against fixtures in this directory.
Discovery: every <name>.expected.json under fixtures/ pairs with a sibling
<name>.input.<ext>. The runner parses the input, then asserts each tolerance
in the expected file. Tolerance keys are documented in fixtures/README.md.
Performance baselines are opt-in per fixture via a `performance` block in
the expected file. They run only when ZSGDP_REGRESSION_PERF=1 (or when the
performance block has `always_enforce: true`) so a slow CI runner does not
fail on transient noise. When enabled, the parse is run twice and the
median elapsed time is compared against the floor.
"""
from __future__ import annotations
import json
import os
import statistics
import tempfile
import time
import unittest
import unittest.mock
from pathlib import Path
from typing import Any
from zsgdp.pipeline import parse_document
FIXTURE_DIR = Path(__file__).parent / "fixtures"
def _discover_fixtures() -> list[tuple[str, Path, Path]]:
pairs: list[tuple[str, Path, Path]] = []
if not FIXTURE_DIR.exists():
return pairs
for expected in sorted(FIXTURE_DIR.glob("*.expected.json")):
name = expected.name[: -len(".expected.json")]
candidates = sorted(FIXTURE_DIR.glob(f"{name}.input.*"))
if not candidates:
continue
pairs.append((name, candidates[0], expected))
return pairs
def _check_int_or_range(actual: int, exact: Any, range_value: Any, label: str) -> str | None:
if exact is not None and int(exact) != actual:
return f"{label}: expected {exact}, got {actual}"
if isinstance(range_value, (list, tuple)) and len(range_value) == 2:
lo, hi = int(range_value[0]), int(range_value[1])
if not (lo <= actual <= hi):
return f"{label}: expected in [{lo}, {hi}], got {actual}"
return None
def _evaluate(parsed, tolerances: dict[str, Any]) -> list[str]:
failures: list[str] = []
score = float(parsed.quality_report.score)
if "quality_score_min" in tolerances and score < float(tolerances["quality_score_min"]):
failures.append(f"quality_score: {score:.3f} < {tolerances['quality_score_min']}")
if "quality_score_max" in tolerances and score > float(tolerances["quality_score_max"]):
failures.append(f"quality_score: {score:.3f} > {tolerances['quality_score_max']}")
for label, count, exact_key, range_key in (
("element_count", len(parsed.elements), "element_count", "element_count_range"),
("table_count", len(parsed.tables), "table_count", "table_count_range"),
("figure_count", len(parsed.figures), "figure_count", "figure_count_range"),
):
message = _check_int_or_range(count, tolerances.get(exact_key), tolerances.get(range_key), label)
if message:
failures.append(message)
chunk_count = len(parsed.chunks)
if "chunk_count_min" in tolerances and chunk_count < int(tolerances["chunk_count_min"]):
failures.append(f"chunk_count: {chunk_count} < {tolerances['chunk_count_min']}")
if "chunk_count_max" in tolerances and chunk_count > int(tolerances["chunk_count_max"]):
failures.append(f"chunk_count: {chunk_count} > {tolerances['chunk_count_max']}")
if "blocking_failures" in tolerances:
actual = parsed.quality_report.has_blocking_failures
expected = bool(tolerances["blocking_failures"])
if actual != expected:
failures.append(f"blocking_failures: expected {expected}, got {actual}")
md = parsed.to_markdown()
for needle in tolerances.get("must_contain_markdown", []) or []:
if str(needle) not in md:
failures.append(f"must_contain_markdown: {needle!r} not found")
for needle in tolerances.get("must_not_contain_markdown", []) or []:
if str(needle) in md:
failures.append(f"must_not_contain_markdown: {needle!r} present")
metrics = parsed.quality_report.metrics
for key in tolerances.get("must_contain_quality_metrics", []) or []:
if key not in metrics:
failures.append(f"must_contain_quality_metrics: {key!r} missing")
if "parser_disagreement_rate_max" in tolerances:
rate = float(metrics.get("parser_disagreement_rate", 0.0))
if rate > float(tolerances["parser_disagreement_rate_max"]):
failures.append(
f"parser_disagreement_rate: {rate:.3f} > {tolerances['parser_disagreement_rate_max']}"
)
if "repair_resolution_rate_min" in tolerances:
rate = float(metrics.get("repair_resolution_rate", 1.0))
if rate < float(tolerances["repair_resolution_rate_min"]):
failures.append(
f"repair_resolution_rate: {rate:.3f} < {tolerances['repair_resolution_rate_min']}"
)
return failures
def _perf_enforcement_enabled(performance: dict[str, Any]) -> bool:
if performance.get("always_enforce"):
return True
return os.environ.get("ZSGDP_REGRESSION_PERF", "").strip().lower() in {"1", "true", "yes"}
def _measure_parse(input_path: Path, *, config_path: Path | None, selected_parsers, repeats: int) -> tuple[Any, list[float]]:
"""Parse the input N times, returning (last_parsed, list_of_elapsed_seconds).
Uses a fresh temp output directory for each run so disk caching effects
are roughly equal across runs. The last parsed document is returned for
tolerance evaluation; per-run elapsed times feed the perf assertion.
"""
elapsed: list[float] = []
parsed = None
for _ in range(max(1, repeats)):
with tempfile.TemporaryDirectory() as tmp:
started = time.perf_counter()
parsed = parse_document(
input_path,
Path(tmp) / "out",
config_path=config_path if config_path else None,
selected_parsers=selected_parsers,
)
elapsed.append(time.perf_counter() - started)
return parsed, elapsed
def _evaluate_performance(parsed, performance: dict[str, Any], elapsed_seconds: list[float]) -> list[str]:
failures: list[str] = []
if not elapsed_seconds:
return failures
median_elapsed = statistics.median(elapsed_seconds)
page_count = max(len(parsed.pages), 1)
median_pages_per_second = page_count / median_elapsed if median_elapsed > 0 else float("inf")
max_elapsed = performance.get("max_elapsed_seconds")
if max_elapsed is not None and median_elapsed > float(max_elapsed):
failures.append(
f"performance.max_elapsed_seconds: median {median_elapsed:.2f}s > {max_elapsed}s "
f"(runs={len(elapsed_seconds)})"
)
min_pps = performance.get("min_pages_per_second")
if min_pps is not None and median_pages_per_second < float(min_pps):
failures.append(
f"performance.min_pages_per_second: median {median_pages_per_second:.2f} < {min_pps} "
f"(runs={len(elapsed_seconds)})"
)
return failures
class RegressionFixturesTest(unittest.TestCase):
def test_regression_fixtures_match_snapshots(self):
fixtures = _discover_fixtures()
if not fixtures:
self.skipTest("No regression fixtures present.")
all_failures: list[str] = []
for name, input_path, expected_path in fixtures:
with self.subTest(fixture=name):
expected = json.loads(expected_path.read_text(encoding="utf-8"))
tolerances = expected.get("tolerances") or {}
performance = expected.get("performance") or {}
config_rel = expected.get("config")
config_path = Path(config_rel) if config_rel else None
if config_path and not config_path.is_absolute():
config_path = Path(__file__).resolve().parents[2] / config_path
selected_parsers = expected.get("selected_parsers")
perf_enabled = bool(performance) and _perf_enforcement_enabled(performance)
repeats = int(performance.get("repeats", 2)) if perf_enabled else 1
parsed, elapsed = _measure_parse(
input_path,
config_path=config_path,
selected_parsers=selected_parsers,
repeats=repeats,
)
failures = _evaluate(parsed, tolerances)
if perf_enabled:
failures.extend(_evaluate_performance(parsed, performance, elapsed))
if failures:
all_failures.append(f"[{name}] " + "; ".join(failures))
if all_failures:
self.fail("\n".join(all_failures))
class PerformanceEvaluatorTests(unittest.TestCase):
"""Unit tests for the perf-evaluation helpers, separate from fixture discovery."""
def test_max_elapsed_floor_fires_when_too_slow(self):
from types import SimpleNamespace
parsed = SimpleNamespace(pages=[{"page_num": 1}])
failures = _evaluate_performance(parsed, {"max_elapsed_seconds": 0.1}, [0.5, 0.5])
self.assertEqual(len(failures), 1)
self.assertIn("max_elapsed_seconds", failures[0])
def test_min_pages_per_second_fires_when_too_slow(self):
from types import SimpleNamespace
parsed = SimpleNamespace(pages=[{"page_num": 1}])
# 1 page in 10s => 0.1 pps, floor 1.0 => fail.
failures = _evaluate_performance(parsed, {"min_pages_per_second": 1.0}, [10.0, 10.0])
self.assertEqual(len(failures), 1)
self.assertIn("min_pages_per_second", failures[0])
def test_passing_floors_yield_no_failures(self):
from types import SimpleNamespace
parsed = SimpleNamespace(pages=[{"page_num": 1}, {"page_num": 2}])
# 2 pages in 0.5s => 4 pps; floor 1.0 pps and max 2s.
failures = _evaluate_performance(
parsed,
{"max_elapsed_seconds": 2.0, "min_pages_per_second": 1.0},
[0.5, 0.5, 0.5],
)
self.assertEqual(failures, [])
def test_median_strips_cold_outlier(self):
from types import SimpleNamespace
parsed = SimpleNamespace(pages=[{"page_num": 1}])
# First run cold (5s), next two warm (0.1s). Median = 0.1s; floor 1s passes.
failures = _evaluate_performance(parsed, {"max_elapsed_seconds": 1.0}, [5.0, 0.1, 0.1])
self.assertEqual(failures, [])
def test_perf_enforcement_gating(self):
with unittest.mock.patch.dict("os.environ", {"ZSGDP_REGRESSION_PERF": "0"}, clear=False):
self.assertFalse(_perf_enforcement_enabled({"max_elapsed_seconds": 1.0}))
self.assertTrue(_perf_enforcement_enabled({"always_enforce": True}))
with unittest.mock.patch.dict("os.environ", {"ZSGDP_REGRESSION_PERF": "1"}, clear=False):
self.assertTrue(_perf_enforcement_enabled({"max_elapsed_seconds": 1.0}))
if __name__ == "__main__":
unittest.main()