import unittest from zsgdp.config import load_config from zsgdp.gpu import plan_gpu_tasks from zsgdp.routing import RouteDecision from zsgdp.routing.budgets import Budget from zsgdp.schema import DocumentProfile, FigureObject, PageProfile, ParsedDocument, TableObject class GPUTaskTests(unittest.TestCase): def test_plan_gpu_tasks_includes_route_ocr_table_and_figure(self): config = load_config(overrides={"chunking": {"vision_guided": True}}) profile = DocumentProfile( doc_id="d1", source_path="sample.pdf", file_type="pdf", page_count=1, extension=".pdf", pages=[ PageProfile(page_num=1, scanned_score=0.8, digital_text_chars=0, digital_text_quality=0.0), ], ) parsed = ParsedDocument( doc_id="d1", source_path="sample.pdf", file_type="pdf", pages=[ { "page_num": 1, "parser_pages": [ {"rendered_page": {"image_path": "/tmp/page.png"}}, ], } ], ) parsed.tables.append( TableObject( table_id="t1", page_nums=[1], bbox=[(1.0, 2.0, 3.0, 4.0)], markdown="| A | B |\n| --- | --- |\n| 1 | 2 |", provenance={"crop_path": "/tmp/table.png"}, ) ) parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png")) routes = [ RouteDecision( page_id=1, experts=["pymupdf", "vlm_figure_repair"], reason="figure-heavy page", budget=Budget(), labels=["figure_heavy"], ) ] tasks = plan_gpu_tasks(profile, parsed, config, routes) task_types = [task["task_type"] for task in tasks] self.assertIn("vlm_route_repair", task_types) self.assertIn("ocr_page", task_types) self.assertIn("table_vlm_repair", task_types) self.assertIn("figure_description", task_types) self.assertEqual(tasks[0]["task_type"], "vlm_route_repair") self.assertTrue(all(task["provider"] == "huggingface_spaces" for task in tasks)) self.assertTrue(all(task["space_name"] == "zeroshotGPU" for task in tasks)) self.assertTrue(all(task["model_id"] for task in tasks)) self.assertEqual(_task_by_type(tasks, "ocr_page")["model_role"], "ocr") self.assertEqual(_task_by_type(tasks, "table_vlm_repair")["model_role"], "table") self.assertEqual(_task_by_type(tasks, "figure_description")["model_role"], "vlm") self.assertEqual(_task_by_type(tasks, "figure_description")["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct") def test_plan_gpu_tasks_respects_max_vlm_calls(self): config = load_config(overrides={"gpu": {"max_vlm_calls_per_doc": 1}, "chunking": {"vision_guided": True}}) profile = DocumentProfile( doc_id="d1", source_path="sample.pdf", file_type="pdf", page_count=1, extension=".pdf", pages=[PageProfile(page_num=1, scanned_score=0.8)], ) parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf") parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png")) tasks = plan_gpu_tasks(profile, parsed, config) self.assertEqual(len(tasks), 1) self.assertEqual(tasks[0]["task_type"], "ocr_page") def _task_by_type(tasks, task_type): for task in tasks: if task["task_type"] == task_type: return task raise AssertionError(f"Missing task type: {task_type}") if __name__ == "__main__": unittest.main()