zeroshotGPU / tests /test_datasets.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Dataset loader tests."""
from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from zsgdp.benchmarks.datasets import (
DatasetDocument,
get_dataset_loader,
iter_dataset,
list_dataset_loaders,
register_dataset_loader,
)
class TestDatasetRegistry(unittest.TestCase):
def test_built_in_loaders_registered(self):
loaders = list_dataset_loaders()
self.assertIn("custom_folder", loaders)
self.assertIn("omnidocbench", loaders)
self.assertIn("doclaynet", loaders)
def test_custom_alias_resolves_to_custom_folder(self):
loader_default = get_dataset_loader("default")
loader_alias = get_dataset_loader("custom")
loader_canonical = get_dataset_loader("custom_folder")
self.assertIs(loader_default, loader_canonical)
self.assertIs(loader_alias, loader_canonical)
def test_unknown_loader_raises(self):
with self.assertRaises(KeyError):
get_dataset_loader("not_a_real_dataset")
class TestCustomFolderLoader(unittest.TestCase):
def test_yields_files_with_no_ground_truth(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "a.md").write_text("# A\n")
(root / "b.md").write_text("# B\n")
(root / "subdir").mkdir()
(root / "subdir" / "ignored.md").write_text("# nope\n")
documents = list(iter_dataset("custom_folder", root))
ids = sorted(document.doc_id for document in documents)
self.assertEqual(ids, ["a", "b"])
for document in documents:
self.assertIsNone(document.ground_truth)
self.assertEqual(document.dataset_id, "custom_folder")
self.assertTrue(document.path.exists())
def test_missing_root_raises(self):
with self.assertRaises(FileNotFoundError):
list(iter_dataset("custom_folder", "/tmp/this-path-should-not-exist-zsgdp"))
class TestOmniDocBenchLoader(unittest.TestCase):
def test_pairs_pdf_with_sibling_json(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "doc1.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n")
(root / "doc1.json").write_text(json.dumps({"reading_order": ["e1", "e2"]}))
(root / "doc2.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n") # no GT
documents = list(iter_dataset("omnidocbench", root))
by_id = {document.doc_id: document for document in documents}
self.assertEqual(set(by_id), {"doc1", "doc2"})
self.assertIsNotNone(by_id["doc1"].ground_truth)
self.assertEqual(by_id["doc1"].ground_truth["reading_order"], ["e1", "e2"])
self.assertTrue(by_id["doc1"].metadata["has_ground_truth"])
self.assertIsNone(by_id["doc2"].ground_truth)
self.assertFalse(by_id["doc2"].metadata["has_ground_truth"])
def test_no_pdfs_raises(self):
with tempfile.TemporaryDirectory() as tmp:
with self.assertRaises(FileNotFoundError):
list(iter_dataset("omnidocbench", tmp))
class TestDocLayNetLoader(unittest.TestCase):
def test_yields_one_document_per_image_with_filtered_annotations(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n")
(root / "page2.png").write_bytes(b"\x89PNG\r\n\x1a\n")
(root / "annotations.json").write_text(
json.dumps(
{
"images": [
{"id": 1, "file_name": "page1.png", "width": 800, "height": 1100},
{"id": 2, "file_name": "page2.png", "width": 800, "height": 1100},
],
"annotations": [
{"id": 10, "image_id": 1, "category_id": 1, "bbox": [0, 0, 100, 50]},
{"id": 11, "image_id": 1, "category_id": 2, "bbox": [0, 60, 100, 50]},
{"id": 12, "image_id": 2, "category_id": 1, "bbox": [0, 0, 100, 50]},
],
"categories": [
{"id": 1, "name": "Title"},
{"id": 2, "name": "Text"},
],
}
)
)
documents = list(iter_dataset("doclaynet", root))
by_id = {document.doc_id: document for document in documents}
self.assertEqual(set(by_id), {"page1.png", "page2.png"})
self.assertEqual(len(by_id["page1.png"].ground_truth["annotations"]), 2)
self.assertEqual(len(by_id["page2.png"].ground_truth["annotations"]), 1)
self.assertEqual(by_id["page1.png"].ground_truth["categories"][1]["name"], "Title")
def test_missing_annotations_raises(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n")
with self.assertRaises(FileNotFoundError):
list(iter_dataset("doclaynet", root))
class TestRegisterDatasetLoader(unittest.TestCase):
def test_register_and_use_custom_loader(self):
marker = []
def fake_loader(root: Path):
marker.append(root)
yield DatasetDocument(dataset_id="fake", doc_id="x", path=root)
register_dataset_loader("zsgdp_test_fake", fake_loader)
try:
documents = list(iter_dataset("zsgdp_test_fake", Path("/tmp/whatever")))
finally:
from zsgdp.benchmarks.datasets import _LOADERS
_LOADERS.pop("zsgdp_test_fake", None)
self.assertEqual(len(documents), 1)
self.assertEqual(documents[0].dataset_id, "fake")
self.assertEqual(marker, [Path("/tmp/whatever")])
if __name__ == "__main__":
unittest.main()