Spaces:
Running on Zero
Running on Zero
| import unittest | |
| from zsgdp.config import load_config | |
| from zsgdp.gpu import plan_gpu_tasks | |
| from zsgdp.routing import RouteDecision | |
| from zsgdp.routing.budgets import Budget | |
| from zsgdp.schema import DocumentProfile, FigureObject, PageProfile, ParsedDocument, TableObject | |
| class GPUTaskTests(unittest.TestCase): | |
| def test_plan_gpu_tasks_includes_route_ocr_table_and_figure(self): | |
| config = load_config(overrides={"chunking": {"vision_guided": True}}) | |
| profile = DocumentProfile( | |
| doc_id="d1", | |
| source_path="sample.pdf", | |
| file_type="pdf", | |
| page_count=1, | |
| extension=".pdf", | |
| pages=[ | |
| PageProfile(page_num=1, scanned_score=0.8, digital_text_chars=0, digital_text_quality=0.0), | |
| ], | |
| ) | |
| parsed = ParsedDocument( | |
| doc_id="d1", | |
| source_path="sample.pdf", | |
| file_type="pdf", | |
| pages=[ | |
| { | |
| "page_num": 1, | |
| "parser_pages": [ | |
| {"rendered_page": {"image_path": "/tmp/page.png"}}, | |
| ], | |
| } | |
| ], | |
| ) | |
| parsed.tables.append( | |
| TableObject( | |
| table_id="t1", | |
| page_nums=[1], | |
| bbox=[(1.0, 2.0, 3.0, 4.0)], | |
| markdown="| A | B |\n| --- | --- |\n| 1 | 2 |", | |
| provenance={"crop_path": "/tmp/table.png"}, | |
| ) | |
| ) | |
| parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png")) | |
| routes = [ | |
| RouteDecision( | |
| page_id=1, | |
| experts=["pymupdf", "vlm_figure_repair"], | |
| reason="figure-heavy page", | |
| budget=Budget(), | |
| labels=["figure_heavy"], | |
| ) | |
| ] | |
| tasks = plan_gpu_tasks(profile, parsed, config, routes) | |
| task_types = [task["task_type"] for task in tasks] | |
| self.assertIn("vlm_route_repair", task_types) | |
| self.assertIn("ocr_page", task_types) | |
| self.assertIn("table_vlm_repair", task_types) | |
| self.assertIn("figure_description", task_types) | |
| self.assertEqual(tasks[0]["task_type"], "vlm_route_repair") | |
| self.assertTrue(all(task["provider"] == "huggingface_spaces" for task in tasks)) | |
| self.assertTrue(all(task["space_name"] == "zeroshotGPU" for task in tasks)) | |
| self.assertTrue(all(task["model_id"] for task in tasks)) | |
| self.assertEqual(_task_by_type(tasks, "ocr_page")["model_role"], "ocr") | |
| self.assertEqual(_task_by_type(tasks, "table_vlm_repair")["model_role"], "table") | |
| self.assertEqual(_task_by_type(tasks, "figure_description")["model_role"], "vlm") | |
| self.assertEqual(_task_by_type(tasks, "figure_description")["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct") | |
| def test_plan_gpu_tasks_respects_max_vlm_calls(self): | |
| config = load_config(overrides={"gpu": {"max_vlm_calls_per_doc": 1}, "chunking": {"vision_guided": True}}) | |
| profile = DocumentProfile( | |
| doc_id="d1", | |
| source_path="sample.pdf", | |
| file_type="pdf", | |
| page_count=1, | |
| extension=".pdf", | |
| pages=[PageProfile(page_num=1, scanned_score=0.8)], | |
| ) | |
| parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf") | |
| parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png")) | |
| tasks = plan_gpu_tasks(profile, parsed, config) | |
| self.assertEqual(len(tasks), 1) | |
| self.assertEqual(tasks[0]["task_type"], "ocr_page") | |
| def _task_by_type(tasks, task_type): | |
| for task in tasks: | |
| if task["task_type"] == task_type: | |
| return task | |
| raise AssertionError(f"Missing task type: {task_type}") | |
| if __name__ == "__main__": | |
| unittest.main() | |