"""Dataset loader tests.""" from __future__ import annotations import json import tempfile import unittest from pathlib import Path from zsgdp.benchmarks.datasets import ( DatasetDocument, get_dataset_loader, iter_dataset, list_dataset_loaders, register_dataset_loader, ) class TestDatasetRegistry(unittest.TestCase): def test_built_in_loaders_registered(self): loaders = list_dataset_loaders() self.assertIn("custom_folder", loaders) self.assertIn("omnidocbench", loaders) self.assertIn("doclaynet", loaders) def test_custom_alias_resolves_to_custom_folder(self): loader_default = get_dataset_loader("default") loader_alias = get_dataset_loader("custom") loader_canonical = get_dataset_loader("custom_folder") self.assertIs(loader_default, loader_canonical) self.assertIs(loader_alias, loader_canonical) def test_unknown_loader_raises(self): with self.assertRaises(KeyError): get_dataset_loader("not_a_real_dataset") class TestCustomFolderLoader(unittest.TestCase): def test_yields_files_with_no_ground_truth(self): with tempfile.TemporaryDirectory() as tmp: root = Path(tmp) (root / "a.md").write_text("# A\n") (root / "b.md").write_text("# B\n") (root / "subdir").mkdir() (root / "subdir" / "ignored.md").write_text("# nope\n") documents = list(iter_dataset("custom_folder", root)) ids = sorted(document.doc_id for document in documents) self.assertEqual(ids, ["a", "b"]) for document in documents: self.assertIsNone(document.ground_truth) self.assertEqual(document.dataset_id, "custom_folder") self.assertTrue(document.path.exists()) def test_missing_root_raises(self): with self.assertRaises(FileNotFoundError): list(iter_dataset("custom_folder", "/tmp/this-path-should-not-exist-zsgdp")) class TestOmniDocBenchLoader(unittest.TestCase): def test_pairs_pdf_with_sibling_json(self): with tempfile.TemporaryDirectory() as tmp: root = Path(tmp) (root / "doc1.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n") (root / "doc1.json").write_text(json.dumps({"reading_order": ["e1", "e2"]})) (root / "doc2.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n") # no GT documents = list(iter_dataset("omnidocbench", root)) by_id = {document.doc_id: document for document in documents} self.assertEqual(set(by_id), {"doc1", "doc2"}) self.assertIsNotNone(by_id["doc1"].ground_truth) self.assertEqual(by_id["doc1"].ground_truth["reading_order"], ["e1", "e2"]) self.assertTrue(by_id["doc1"].metadata["has_ground_truth"]) self.assertIsNone(by_id["doc2"].ground_truth) self.assertFalse(by_id["doc2"].metadata["has_ground_truth"]) def test_no_pdfs_raises(self): with tempfile.TemporaryDirectory() as tmp: with self.assertRaises(FileNotFoundError): list(iter_dataset("omnidocbench", tmp)) class TestDocLayNetLoader(unittest.TestCase): def test_yields_one_document_per_image_with_filtered_annotations(self): with tempfile.TemporaryDirectory() as tmp: root = Path(tmp) (root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n") (root / "page2.png").write_bytes(b"\x89PNG\r\n\x1a\n") (root / "annotations.json").write_text( json.dumps( { "images": [ {"id": 1, "file_name": "page1.png", "width": 800, "height": 1100}, {"id": 2, "file_name": "page2.png", "width": 800, "height": 1100}, ], "annotations": [ {"id": 10, "image_id": 1, "category_id": 1, "bbox": [0, 0, 100, 50]}, {"id": 11, "image_id": 1, "category_id": 2, "bbox": [0, 60, 100, 50]}, {"id": 12, "image_id": 2, "category_id": 1, "bbox": [0, 0, 100, 50]}, ], "categories": [ {"id": 1, "name": "Title"}, {"id": 2, "name": "Text"}, ], } ) ) documents = list(iter_dataset("doclaynet", root)) by_id = {document.doc_id: document for document in documents} self.assertEqual(set(by_id), {"page1.png", "page2.png"}) self.assertEqual(len(by_id["page1.png"].ground_truth["annotations"]), 2) self.assertEqual(len(by_id["page2.png"].ground_truth["annotations"]), 1) self.assertEqual(by_id["page1.png"].ground_truth["categories"][1]["name"], "Title") def test_missing_annotations_raises(self): with tempfile.TemporaryDirectory() as tmp: root = Path(tmp) (root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n") with self.assertRaises(FileNotFoundError): list(iter_dataset("doclaynet", root)) class TestRegisterDatasetLoader(unittest.TestCase): def test_register_and_use_custom_loader(self): marker = [] def fake_loader(root: Path): marker.append(root) yield DatasetDocument(dataset_id="fake", doc_id="x", path=root) register_dataset_loader("zsgdp_test_fake", fake_loader) try: documents = list(iter_dataset("zsgdp_test_fake", Path("/tmp/whatever"))) finally: from zsgdp.benchmarks.datasets import _LOADERS _LOADERS.pop("zsgdp_test_fake", None) self.assertEqual(len(documents), 1) self.assertEqual(documents[0].dataset_id, "fake") self.assertEqual(marker, [Path("/tmp/whatever")]) if __name__ == "__main__": unittest.main()