zeroshotGPU / tests /test_gpu_tasks.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
import unittest
from zsgdp.config import load_config
from zsgdp.gpu import plan_gpu_tasks
from zsgdp.routing import RouteDecision
from zsgdp.routing.budgets import Budget
from zsgdp.schema import DocumentProfile, FigureObject, PageProfile, ParsedDocument, TableObject
class GPUTaskTests(unittest.TestCase):
def test_plan_gpu_tasks_includes_route_ocr_table_and_figure(self):
config = load_config(overrides={"chunking": {"vision_guided": True}})
profile = DocumentProfile(
doc_id="d1",
source_path="sample.pdf",
file_type="pdf",
page_count=1,
extension=".pdf",
pages=[
PageProfile(page_num=1, scanned_score=0.8, digital_text_chars=0, digital_text_quality=0.0),
],
)
parsed = ParsedDocument(
doc_id="d1",
source_path="sample.pdf",
file_type="pdf",
pages=[
{
"page_num": 1,
"parser_pages": [
{"rendered_page": {"image_path": "/tmp/page.png"}},
],
}
],
)
parsed.tables.append(
TableObject(
table_id="t1",
page_nums=[1],
bbox=[(1.0, 2.0, 3.0, 4.0)],
markdown="| A | B |\n| --- | --- |\n| 1 | 2 |",
provenance={"crop_path": "/tmp/table.png"},
)
)
parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png"))
routes = [
RouteDecision(
page_id=1,
experts=["pymupdf", "vlm_figure_repair"],
reason="figure-heavy page",
budget=Budget(),
labels=["figure_heavy"],
)
]
tasks = plan_gpu_tasks(profile, parsed, config, routes)
task_types = [task["task_type"] for task in tasks]
self.assertIn("vlm_route_repair", task_types)
self.assertIn("ocr_page", task_types)
self.assertIn("table_vlm_repair", task_types)
self.assertIn("figure_description", task_types)
self.assertEqual(tasks[0]["task_type"], "vlm_route_repair")
self.assertTrue(all(task["provider"] == "huggingface_spaces" for task in tasks))
self.assertTrue(all(task["space_name"] == "zeroshotGPU" for task in tasks))
self.assertTrue(all(task["model_id"] for task in tasks))
self.assertEqual(_task_by_type(tasks, "ocr_page")["model_role"], "ocr")
self.assertEqual(_task_by_type(tasks, "table_vlm_repair")["model_role"], "table")
self.assertEqual(_task_by_type(tasks, "figure_description")["model_role"], "vlm")
self.assertEqual(_task_by_type(tasks, "figure_description")["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct")
def test_plan_gpu_tasks_respects_max_vlm_calls(self):
config = load_config(overrides={"gpu": {"max_vlm_calls_per_doc": 1}, "chunking": {"vision_guided": True}})
profile = DocumentProfile(
doc_id="d1",
source_path="sample.pdf",
file_type="pdf",
page_count=1,
extension=".pdf",
pages=[PageProfile(page_num=1, scanned_score=0.8)],
)
parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf")
parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png"))
tasks = plan_gpu_tasks(profile, parsed, config)
self.assertEqual(len(tasks), 1)
self.assertEqual(tasks[0]["task_type"], "ocr_page")
def _task_by_type(tasks, task_type):
for task in tasks:
if task["task_type"] == task_type:
return task
raise AssertionError(f"Missing task type: {task_type}")
if __name__ == "__main__":
unittest.main()