File size: 3,888 Bytes
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import unittest

from zsgdp.config import load_config
from zsgdp.gpu import plan_gpu_tasks
from zsgdp.routing import RouteDecision
from zsgdp.routing.budgets import Budget
from zsgdp.schema import DocumentProfile, FigureObject, PageProfile, ParsedDocument, TableObject


class GPUTaskTests(unittest.TestCase):
    def test_plan_gpu_tasks_includes_route_ocr_table_and_figure(self):
        config = load_config(overrides={"chunking": {"vision_guided": True}})
        profile = DocumentProfile(
            doc_id="d1",
            source_path="sample.pdf",
            file_type="pdf",
            page_count=1,
            extension=".pdf",
            pages=[
                PageProfile(page_num=1, scanned_score=0.8, digital_text_chars=0, digital_text_quality=0.0),
            ],
        )
        parsed = ParsedDocument(
            doc_id="d1",
            source_path="sample.pdf",
            file_type="pdf",
            pages=[
                {
                    "page_num": 1,
                    "parser_pages": [
                        {"rendered_page": {"image_path": "/tmp/page.png"}},
                    ],
                }
            ],
        )
        parsed.tables.append(
            TableObject(
                table_id="t1",
                page_nums=[1],
                bbox=[(1.0, 2.0, 3.0, 4.0)],
                markdown="| A | B |\n| --- | --- |\n| 1 | 2 |",
                provenance={"crop_path": "/tmp/table.png"},
            )
        )
        parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png"))
        routes = [
            RouteDecision(
                page_id=1,
                experts=["pymupdf", "vlm_figure_repair"],
                reason="figure-heavy page",
                budget=Budget(),
                labels=["figure_heavy"],
            )
        ]

        tasks = plan_gpu_tasks(profile, parsed, config, routes)

        task_types = [task["task_type"] for task in tasks]
        self.assertIn("vlm_route_repair", task_types)
        self.assertIn("ocr_page", task_types)
        self.assertIn("table_vlm_repair", task_types)
        self.assertIn("figure_description", task_types)
        self.assertEqual(tasks[0]["task_type"], "vlm_route_repair")
        self.assertTrue(all(task["provider"] == "huggingface_spaces" for task in tasks))
        self.assertTrue(all(task["space_name"] == "zeroshotGPU" for task in tasks))
        self.assertTrue(all(task["model_id"] for task in tasks))
        self.assertEqual(_task_by_type(tasks, "ocr_page")["model_role"], "ocr")
        self.assertEqual(_task_by_type(tasks, "table_vlm_repair")["model_role"], "table")
        self.assertEqual(_task_by_type(tasks, "figure_description")["model_role"], "vlm")
        self.assertEqual(_task_by_type(tasks, "figure_description")["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct")

    def test_plan_gpu_tasks_respects_max_vlm_calls(self):
        config = load_config(overrides={"gpu": {"max_vlm_calls_per_doc": 1}, "chunking": {"vision_guided": True}})
        profile = DocumentProfile(
            doc_id="d1",
            source_path="sample.pdf",
            file_type="pdf",
            page_count=1,
            extension=".pdf",
            pages=[PageProfile(page_num=1, scanned_score=0.8)],
        )
        parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf")
        parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png"))

        tasks = plan_gpu_tasks(profile, parsed, config)

        self.assertEqual(len(tasks), 1)
        self.assertEqual(tasks[0]["task_type"], "ocr_page")


def _task_by_type(tasks, task_type):
    for task in tasks:
        if task["task_type"] == task_type:
            return task
    raise AssertionError(f"Missing task type: {task_type}")


if __name__ == "__main__":
    unittest.main()