zeroshotGPU / tests /test_layout_f1.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Tests for layout F1 metric and ground-truth adapters."""
from __future__ import annotations
import unittest
from zsgdp.benchmarks.ground_truth import (
canonical_category,
doclaynet_layout_truths,
omnidocbench_layout_truths,
parsed_layout_predictions,
)
from zsgdp.schema import Element, FigureObject, ParsedDocument, QualityReport, TableObject
from zsgdp.verify.layout_f1 import compute_layout_f1
def _item(bbox, category="paragraph", page_num=1):
return {"bbox": bbox, "category": category, "page_num": page_num}
class TestComputeLayoutF1(unittest.TestCase):
def test_perfect_match_yields_f1_1(self):
predictions = [_item((0, 0, 100, 50)), _item((0, 60, 100, 110), "table")]
truths = [_item((0, 0, 100, 50)), _item((0, 60, 100, 110), "table")]
result = compute_layout_f1(predictions, truths)
self.assertEqual(result["class_aware"]["f1"], 1.0)
self.assertEqual(result["class_agnostic"]["f1"], 1.0)
self.assertEqual(result["class_aware"]["tp"], 2)
def test_zero_match_yields_f1_0(self):
predictions = [_item((0, 0, 50, 50))]
truths = [_item((1000, 1000, 1100, 1100))]
result = compute_layout_f1(predictions, truths)
self.assertEqual(result["class_aware"]["f1"], 0.0)
self.assertEqual(result["class_aware"]["fp"], 1)
self.assertEqual(result["class_aware"]["fn"], 1)
def test_iou_below_threshold_misses(self):
# 50% overlap by area in one axis only -> IoU < 0.5
predictions = [_item((0, 0, 100, 100))]
truths = [_item((60, 0, 160, 100))]
result = compute_layout_f1(predictions, truths, iou_threshold=0.5)
self.assertEqual(result["class_aware"]["tp"], 0)
def test_class_aware_vs_agnostic(self):
# Same bbox, different category -> agnostic matches, aware doesn't.
predictions = [_item((0, 0, 100, 100), "paragraph")]
truths = [_item((0, 0, 100, 100), "title")]
result = compute_layout_f1(predictions, truths)
self.assertEqual(result["class_aware"]["tp"], 0)
self.assertEqual(result["class_agnostic"]["tp"], 1)
def test_per_category_breakdown(self):
predictions = [_item((0, 0, 100, 100), "title"), _item((0, 200, 100, 300), "table")]
truths = [_item((0, 0, 100, 100), "title")]
result = compute_layout_f1(predictions, truths)
per_category = result["per_category"]
self.assertEqual(per_category["title"]["tp"], 1)
self.assertEqual(per_category["table"]["fp"], 1)
def test_empty_inputs_are_vacuously_correct(self):
self.assertEqual(compute_layout_f1([], [])["class_aware"]["f1"], 1.0)
def test_predictions_only_yields_zero(self):
result = compute_layout_f1([_item((0, 0, 10, 10))], [])
self.assertEqual(result["class_aware"]["fp"], 1)
self.assertEqual(result["class_aware"]["f1"], 0.0)
def test_page_num_must_match(self):
predictions = [_item((0, 0, 100, 100), "table", page_num=1)]
truths = [_item((0, 0, 100, 100), "table", page_num=2)]
result = compute_layout_f1(predictions, truths)
self.assertEqual(result["class_aware"]["tp"], 0)
class TestDocLayNetAdapter(unittest.TestCase):
def test_xywh_converted_and_categories_normalized(self):
ground_truth = {
"image": {"id": 5, "file_name": "p.png", "page_no": 5},
"annotations": [
{"image_id": 5, "category_id": 1, "bbox": [10, 20, 50, 60]},
{"image_id": 5, "category_id": 2, "bbox": [100, 0, 40, 30]},
],
"categories": {1: {"name": "Title"}, 2: {"name": "Section-header"}},
}
truths = doclaynet_layout_truths(ground_truth)
self.assertEqual(len(truths), 2)
self.assertEqual(truths[0]["bbox"], (10.0, 20.0, 60.0, 80.0))
self.assertEqual(truths[0]["category"], "title")
self.assertEqual(truths[0]["page_num"], 5)
self.assertEqual(truths[1]["category"], "heading")
def test_invalid_annotations_dropped(self):
ground_truth = {
"image": {"id": 1, "file_name": "p.png"},
"annotations": [
{"image_id": 1, "category_id": 1, "bbox": [0, 0, 0, 0]},
{"image_id": 1, "category_id": 1},
],
"categories": {1: {"name": "Text"}},
}
self.assertEqual(doclaynet_layout_truths(ground_truth), [])
class TestOmniDocBenchAdapter(unittest.TestCase):
def test_picks_layout_dets_first(self):
ground_truth = {
"layout_dets": [
{"bbox": [0, 0, 100, 50], "category": "title", "page_num": 1},
{"bbox": [0, 100, 100, 150], "category": "Table", "page": 1},
]
}
truths = omnidocbench_layout_truths(ground_truth)
self.assertEqual(len(truths), 2)
self.assertEqual(truths[0]["category"], "title")
self.assertEqual(truths[1]["category"], "table")
def test_pages_nested_records(self):
ground_truth = {
"pages": [
{"page_num": 1, "elements": [{"bbox": [0, 0, 10, 10], "category": "paragraph"}]},
{"page_num": 2, "elements": [{"bbox": [0, 0, 10, 10], "category": "table"}]},
]
}
truths = omnidocbench_layout_truths(ground_truth)
self.assertEqual(len(truths), 2)
self.assertEqual(truths[0]["page_num"], 1)
self.assertEqual(truths[1]["page_num"], 2)
def test_unknown_shape_returns_empty(self):
self.assertEqual(omnidocbench_layout_truths({"weird": "shape"}), [])
self.assertEqual(omnidocbench_layout_truths(None), [])
class TestParsedPredictions(unittest.TestCase):
def test_extracts_bboxes_from_elements_tables_figures(self):
parsed = ParsedDocument(
doc_id="d1",
source_path="/tmp/d1.pdf",
file_type="pdf",
elements=[
Element(
element_id="e1",
doc_id="d1",
page_num=1,
type="title",
text="Title",
bbox=(0.0, 0.0, 100.0, 30.0),
),
Element(
element_id="e2",
doc_id="d1",
page_num=1,
type="paragraph",
text="No bbox",
),
],
tables=[
TableObject(
table_id="t1",
page_nums=[1],
bbox=[(0.0, 100.0, 200.0, 200.0)],
)
],
figures=[
FigureObject(
figure_id="f1",
page_num=2,
bbox=(50.0, 50.0, 150.0, 250.0),
)
],
quality_report=QualityReport(),
)
predictions = parsed_layout_predictions(parsed)
categories = sorted(prediction["category"] for prediction in predictions)
self.assertEqual(categories, ["figure", "table", "title"])
self.assertEqual(len(predictions), 3)
class TestCanonicalCategory(unittest.TestCase):
def test_canonical_mapping(self):
self.assertEqual(canonical_category("Picture"), "figure")
self.assertEqual(canonical_category("Section-header"), "heading")
self.assertEqual(canonical_category("Page-footer"), "footer")
self.assertEqual(canonical_category("formula"), "formula")
self.assertEqual(canonical_category("Mystery"), "mystery")
if __name__ == "__main__":
unittest.main()