zeroshotGPU / tests /test_ablation_runner.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Tests for parser-contribution metrics and the ablation runner."""
from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from zsgdp.benchmarks.ablation_runner import ABLATION_METRIC_KEYS, run_parser_ablations
from zsgdp.benchmarks.parser_quality import run_parser_benchmark
class TestParserContribution(unittest.TestCase):
def test_contribution_counts_appear_in_summary(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text("# Doc\n\nA paragraph.\n", encoding="utf-8")
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
doc = summary["documents"][0]
self.assertIn("parser_contribution_counts", doc)
self.assertIn("parser_contribution_fractions", doc)
self.assertGreater(sum(doc["parser_contribution_counts"].values()), 0)
# The sum of fractions should be ~1.0 across parsers.
total_fraction = sum(doc["parser_contribution_fractions"].values())
self.assertAlmostEqual(total_fraction, 1.0, places=6)
top_summary = summary["parser_contribution_summary"]
self.assertGreater(top_summary["total"], 0)
self.assertEqual(set(top_summary["counts"]), set(top_summary["fractions"]))
def test_text_parser_dominates_markdown_doc(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text("# Doc\n\nPara one.\n\nPara two.\n", encoding="utf-8")
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
top_counts = summary["parser_contribution_summary"]["counts"]
self.assertIn("text", top_counts)
text_count = top_counts["text"]
other_count = sum(value for parser, value in top_counts.items() if parser != "text")
self.assertGreaterEqual(text_count, other_count)
class TestRunParserAblations(unittest.TestCase):
def test_two_arms_plus_merged(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text("# Doc\n\nPara one.\n\nPara two.\n", encoding="utf-8")
out = tmp / "out"
comparison = run_parser_ablations(
src,
out,
parsers=["text", "pymupdf"],
dataset_name="custom_folder",
)
self.assertEqual(comparison["arm_count"], 3)
arms = sorted(row["arm"] for row in comparison["rows"])
self.assertEqual(arms, ["merged", "pymupdf", "text"])
self.assertTrue((out / "arm_text").exists())
self.assertTrue((out / "arm_pymupdf").exists())
self.assertTrue((out / "arm_merged").exists())
self.assertTrue((out / "ablation_comparison.csv").exists())
self.assertTrue((out / "ablation_summary.json").exists())
# Each arm record carries the canonical metric keys (subset of those present).
for row in comparison["rows"]:
self.assertIn("mean_quality_score", row)
def test_no_merged_when_disabled(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text("# Doc\n\nPara.\n", encoding="utf-8")
comparison = run_parser_ablations(
src,
tmp / "out",
parsers=["text", "pymupdf"],
dataset_name="custom_folder",
include_merged=False,
)
self.assertEqual(comparison["arm_count"], 2)
self.assertNotIn("merged", {row["arm"] for row in comparison["rows"]})
def test_single_parser_ablation_skips_merged_arm(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text("# Doc\n\nPara.\n", encoding="utf-8")
comparison = run_parser_ablations(
src,
tmp / "out",
parsers=["text"],
dataset_name="custom_folder",
)
# Single parser + include_merged defaults true, but len(parsers) == 1
# so merged would be redundant and is skipped.
self.assertEqual(comparison["arm_count"], 1)
self.assertEqual(comparison["rows"][0]["arm"], "text")
def test_empty_parsers_raises(self):
with self.assertRaises(ValueError):
run_parser_ablations(".", "./out", parsers=[])
def test_metric_keys_constant_matches_summary_shape(self):
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
src = tmp / "in"
src.mkdir()
(src / "doc.md").write_text("# Doc\n\nPara.\n", encoding="utf-8")
summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder")
for key in ABLATION_METRIC_KEYS:
self.assertIn(key, summary, f"benchmark summary missing key {key}")
if __name__ == "__main__":
unittest.main()