import json import tempfile import unittest from pathlib import Path from unittest.mock import patch from zsgdp.cli import main from zsgdp.config import load_config from zsgdp.gpu.batching import batch_gpu_tasks from zsgdp.gpu.runner import dry_run_gpu_tasks, load_gpu_tasks, run_gpu_task_manifest from zsgdp.gpu.worker import GPUWorker from zsgdp.utils import write_jsonl class GPURunnerTests(unittest.TestCase): def test_batch_gpu_tasks_groups_by_task_type_and_batch_size(self): tasks = [ {"task_id": "a", "task_type": "figure_description", "priority": 1}, {"task_id": "b", "task_type": "figure_description", "priority": 2}, {"task_id": "c", "task_type": "table_vlm_repair", "priority": 3}, ] batches = batch_gpu_tasks(tasks, max_batch_size=1) self.assertEqual(len(batches), 3) self.assertEqual(batches[0]["task_count"], 1) self.assertEqual({batch["task_type"] for batch in batches}, {"figure_description", "table_vlm_repair"}) def test_worker_reports_missing_image_path(self): worker = GPUWorker(load_config()) result = worker.run( { "task_id": "gt1", "task_type": "figure_description", "doc_id": "d1", "page_nums": [1], "image_path": "/tmp/does-not-exist.png", } ) self.assertEqual(result["status"], "blocked_missing_inputs") self.assertIn("image_path", result["readiness"]["missing_inputs"]) def test_run_gpu_task_manifest_writes_report(self): with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) image_path = tmp_path / "figure.png" image_path.write_bytes(b"fake") tasks_path = tmp_path / "gpu_tasks.jsonl" report_path = tmp_path / "report.json" write_jsonl( tasks_path, [ { "task_id": "gt1", "task_type": "figure_description", "doc_id": "d1", "page_nums": [1], "image_path": str(image_path), "priority": 60, } ], ) report = run_gpu_task_manifest(tmp_path, config=load_config(), output_path=report_path) self.assertEqual(report["task_count"], 1) self.assertEqual(report["ready_count"], 1) self.assertTrue(report_path.exists()) self.assertEqual(json.loads(report_path.read_text(encoding="utf-8"))["batch_count"], 1) def test_dry_run_gpu_tasks_accepts_in_memory_tasks(self): with tempfile.TemporaryDirectory() as tmp: image_path = Path(tmp) / "figure.png" image_path.write_bytes(b"fake") report = dry_run_gpu_tasks( [ { "task_id": "gt1", "task_type": "figure_description", "doc_id": "d1", "page_nums": [1], "image_path": str(image_path), "priority": 60, } ], config=load_config(), ) self.assertEqual(report["ready_count"], 1) self.assertEqual(report["blocked_count"], 0) def test_execute_gpu_tasks_dispatches_transformers_client(self): with tempfile.TemporaryDirectory() as tmp: image_path = Path(tmp) / "figure.png" image_path.write_bytes(b"fake") task = { "task_id": "gt1", "task_type": "figure_description", "doc_id": "d1", "page_nums": [1], "image_path": str(image_path), "priority": 60, "backend": "transformers", "model_role": "vlm", "model_id": "local-test-model", } with patch("zsgdp.gpu.worker.TransformersClient") as client_class: client_class.return_value.execute_task.return_value = {"status": "executed", "text": "Figure description."} report = dry_run_gpu_tasks([task], config=load_config(), dry_run=False) self.assertFalse(report["dry_run"]) self.assertEqual(report["executed_count"], 1) self.assertEqual(report["failed_count"], 0) self.assertEqual(report["batches"][0]["status"], "execute_complete") client_class.return_value.execute_task.assert_called_once() def test_load_gpu_tasks_accepts_file_path(self): with tempfile.TemporaryDirectory() as tmp: tasks_path = Path(tmp) / "tasks.jsonl" write_jsonl(tasks_path, [{"task_id": "gt1"}]) self.assertEqual(load_gpu_tasks(tasks_path)[0]["task_id"], "gt1") def test_run_gpu_tasks_cli(self): with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) tasks_path = tmp_path / "gpu_tasks.jsonl" report_path = tmp_path / "report.json" write_jsonl( tasks_path, [ { "task_id": "gt1", "task_type": "figure_description", "doc_id": "d1", "page_nums": [1], "image_path": str(tmp_path / "missing.png"), "priority": 60, } ], ) code = main(["run-gpu-tasks", "--input", str(tasks_path), "--output", str(report_path)]) self.assertEqual(code, 0) self.assertTrue(report_path.exists()) def test_run_gpu_tasks_cli_execute(self): with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) image_path = tmp_path / "figure.png" image_path.write_bytes(b"fake") tasks_path = tmp_path / "gpu_tasks.jsonl" report_path = tmp_path / "report.json" write_jsonl( tasks_path, [ { "task_id": "gt1", "task_type": "figure_description", "doc_id": "d1", "page_nums": [1], "image_path": str(image_path), "priority": 60, "backend": "transformers", "model_role": "vlm", "model_id": "local-test-model", } ], ) with patch("zsgdp.gpu.worker.TransformersClient") as client_class: client_class.return_value.execute_task.return_value = {"status": "executed", "text": "done"} code = main(["run-gpu-tasks", "--input", str(tasks_path), "--output", str(report_path), "--execute"]) self.assertEqual(code, 0) self.assertEqual(json.loads(report_path.read_text(encoding="utf-8"))["executed_count"], 1) if __name__ == "__main__": unittest.main()