File size: 5,972 Bytes
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Dataset loader tests."""

from __future__ import annotations

import json
import tempfile
import unittest
from pathlib import Path

from zsgdp.benchmarks.datasets import (
    DatasetDocument,
    get_dataset_loader,
    iter_dataset,
    list_dataset_loaders,
    register_dataset_loader,
)


class TestDatasetRegistry(unittest.TestCase):
    def test_built_in_loaders_registered(self):
        loaders = list_dataset_loaders()
        self.assertIn("custom_folder", loaders)
        self.assertIn("omnidocbench", loaders)
        self.assertIn("doclaynet", loaders)

    def test_custom_alias_resolves_to_custom_folder(self):
        loader_default = get_dataset_loader("default")
        loader_alias = get_dataset_loader("custom")
        loader_canonical = get_dataset_loader("custom_folder")
        self.assertIs(loader_default, loader_canonical)
        self.assertIs(loader_alias, loader_canonical)

    def test_unknown_loader_raises(self):
        with self.assertRaises(KeyError):
            get_dataset_loader("not_a_real_dataset")


class TestCustomFolderLoader(unittest.TestCase):
    def test_yields_files_with_no_ground_truth(self):
        with tempfile.TemporaryDirectory() as tmp:
            root = Path(tmp)
            (root / "a.md").write_text("# A\n")
            (root / "b.md").write_text("# B\n")
            (root / "subdir").mkdir()
            (root / "subdir" / "ignored.md").write_text("# nope\n")

            documents = list(iter_dataset("custom_folder", root))

            ids = sorted(document.doc_id for document in documents)
            self.assertEqual(ids, ["a", "b"])
            for document in documents:
                self.assertIsNone(document.ground_truth)
                self.assertEqual(document.dataset_id, "custom_folder")
                self.assertTrue(document.path.exists())

    def test_missing_root_raises(self):
        with self.assertRaises(FileNotFoundError):
            list(iter_dataset("custom_folder", "/tmp/this-path-should-not-exist-zsgdp"))


class TestOmniDocBenchLoader(unittest.TestCase):
    def test_pairs_pdf_with_sibling_json(self):
        with tempfile.TemporaryDirectory() as tmp:
            root = Path(tmp)
            (root / "doc1.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n")
            (root / "doc1.json").write_text(json.dumps({"reading_order": ["e1", "e2"]}))
            (root / "doc2.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n")  # no GT

            documents = list(iter_dataset("omnidocbench", root))

        by_id = {document.doc_id: document for document in documents}
        self.assertEqual(set(by_id), {"doc1", "doc2"})

        self.assertIsNotNone(by_id["doc1"].ground_truth)
        self.assertEqual(by_id["doc1"].ground_truth["reading_order"], ["e1", "e2"])
        self.assertTrue(by_id["doc1"].metadata["has_ground_truth"])

        self.assertIsNone(by_id["doc2"].ground_truth)
        self.assertFalse(by_id["doc2"].metadata["has_ground_truth"])

    def test_no_pdfs_raises(self):
        with tempfile.TemporaryDirectory() as tmp:
            with self.assertRaises(FileNotFoundError):
                list(iter_dataset("omnidocbench", tmp))


class TestDocLayNetLoader(unittest.TestCase):
    def test_yields_one_document_per_image_with_filtered_annotations(self):
        with tempfile.TemporaryDirectory() as tmp:
            root = Path(tmp)
            (root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n")
            (root / "page2.png").write_bytes(b"\x89PNG\r\n\x1a\n")
            (root / "annotations.json").write_text(
                json.dumps(
                    {
                        "images": [
                            {"id": 1, "file_name": "page1.png", "width": 800, "height": 1100},
                            {"id": 2, "file_name": "page2.png", "width": 800, "height": 1100},
                        ],
                        "annotations": [
                            {"id": 10, "image_id": 1, "category_id": 1, "bbox": [0, 0, 100, 50]},
                            {"id": 11, "image_id": 1, "category_id": 2, "bbox": [0, 60, 100, 50]},
                            {"id": 12, "image_id": 2, "category_id": 1, "bbox": [0, 0, 100, 50]},
                        ],
                        "categories": [
                            {"id": 1, "name": "Title"},
                            {"id": 2, "name": "Text"},
                        ],
                    }
                )
            )

            documents = list(iter_dataset("doclaynet", root))

        by_id = {document.doc_id: document for document in documents}
        self.assertEqual(set(by_id), {"page1.png", "page2.png"})

        self.assertEqual(len(by_id["page1.png"].ground_truth["annotations"]), 2)
        self.assertEqual(len(by_id["page2.png"].ground_truth["annotations"]), 1)
        self.assertEqual(by_id["page1.png"].ground_truth["categories"][1]["name"], "Title")

    def test_missing_annotations_raises(self):
        with tempfile.TemporaryDirectory() as tmp:
            root = Path(tmp)
            (root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n")
            with self.assertRaises(FileNotFoundError):
                list(iter_dataset("doclaynet", root))


class TestRegisterDatasetLoader(unittest.TestCase):
    def test_register_and_use_custom_loader(self):
        marker = []

        def fake_loader(root: Path):
            marker.append(root)
            yield DatasetDocument(dataset_id="fake", doc_id="x", path=root)

        register_dataset_loader("zsgdp_test_fake", fake_loader)
        try:
            documents = list(iter_dataset("zsgdp_test_fake", Path("/tmp/whatever")))
        finally:
            from zsgdp.benchmarks.datasets import _LOADERS

            _LOADERS.pop("zsgdp_test_fake", None)

        self.assertEqual(len(documents), 1)
        self.assertEqual(documents[0].dataset_id, "fake")
        self.assertEqual(marker, [Path("/tmp/whatever")])


if __name__ == "__main__":
    unittest.main()