Spaces:
Running on Zero
Running on Zero
| import json | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from unittest.mock import patch | |
| from zsgdp.cli import main | |
| from zsgdp.config import load_config | |
| from zsgdp.gpu.batching import batch_gpu_tasks | |
| from zsgdp.gpu.runner import dry_run_gpu_tasks, load_gpu_tasks, run_gpu_task_manifest | |
| from zsgdp.gpu.worker import GPUWorker | |
| from zsgdp.utils import write_jsonl | |
| class GPURunnerTests(unittest.TestCase): | |
| def test_batch_gpu_tasks_groups_by_task_type_and_batch_size(self): | |
| tasks = [ | |
| {"task_id": "a", "task_type": "figure_description", "priority": 1}, | |
| {"task_id": "b", "task_type": "figure_description", "priority": 2}, | |
| {"task_id": "c", "task_type": "table_vlm_repair", "priority": 3}, | |
| ] | |
| batches = batch_gpu_tasks(tasks, max_batch_size=1) | |
| self.assertEqual(len(batches), 3) | |
| self.assertEqual(batches[0]["task_count"], 1) | |
| self.assertEqual({batch["task_type"] for batch in batches}, {"figure_description", "table_vlm_repair"}) | |
| def test_worker_reports_missing_image_path(self): | |
| worker = GPUWorker(load_config()) | |
| result = worker.run( | |
| { | |
| "task_id": "gt1", | |
| "task_type": "figure_description", | |
| "doc_id": "d1", | |
| "page_nums": [1], | |
| "image_path": "/tmp/does-not-exist.png", | |
| } | |
| ) | |
| self.assertEqual(result["status"], "blocked_missing_inputs") | |
| self.assertIn("image_path", result["readiness"]["missing_inputs"]) | |
| def test_run_gpu_task_manifest_writes_report(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp_path = Path(tmp) | |
| image_path = tmp_path / "figure.png" | |
| image_path.write_bytes(b"fake") | |
| tasks_path = tmp_path / "gpu_tasks.jsonl" | |
| report_path = tmp_path / "report.json" | |
| write_jsonl( | |
| tasks_path, | |
| [ | |
| { | |
| "task_id": "gt1", | |
| "task_type": "figure_description", | |
| "doc_id": "d1", | |
| "page_nums": [1], | |
| "image_path": str(image_path), | |
| "priority": 60, | |
| } | |
| ], | |
| ) | |
| report = run_gpu_task_manifest(tmp_path, config=load_config(), output_path=report_path) | |
| self.assertEqual(report["task_count"], 1) | |
| self.assertEqual(report["ready_count"], 1) | |
| self.assertTrue(report_path.exists()) | |
| self.assertEqual(json.loads(report_path.read_text(encoding="utf-8"))["batch_count"], 1) | |
| def test_dry_run_gpu_tasks_accepts_in_memory_tasks(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| image_path = Path(tmp) / "figure.png" | |
| image_path.write_bytes(b"fake") | |
| report = dry_run_gpu_tasks( | |
| [ | |
| { | |
| "task_id": "gt1", | |
| "task_type": "figure_description", | |
| "doc_id": "d1", | |
| "page_nums": [1], | |
| "image_path": str(image_path), | |
| "priority": 60, | |
| } | |
| ], | |
| config=load_config(), | |
| ) | |
| self.assertEqual(report["ready_count"], 1) | |
| self.assertEqual(report["blocked_count"], 0) | |
| def test_execute_gpu_tasks_dispatches_transformers_client(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| image_path = Path(tmp) / "figure.png" | |
| image_path.write_bytes(b"fake") | |
| task = { | |
| "task_id": "gt1", | |
| "task_type": "figure_description", | |
| "doc_id": "d1", | |
| "page_nums": [1], | |
| "image_path": str(image_path), | |
| "priority": 60, | |
| "backend": "transformers", | |
| "model_role": "vlm", | |
| "model_id": "local-test-model", | |
| } | |
| with patch("zsgdp.gpu.worker.TransformersClient") as client_class: | |
| client_class.return_value.execute_task.return_value = {"status": "executed", "text": "Figure description."} | |
| report = dry_run_gpu_tasks([task], config=load_config(), dry_run=False) | |
| self.assertFalse(report["dry_run"]) | |
| self.assertEqual(report["executed_count"], 1) | |
| self.assertEqual(report["failed_count"], 0) | |
| self.assertEqual(report["batches"][0]["status"], "execute_complete") | |
| client_class.return_value.execute_task.assert_called_once() | |
| def test_load_gpu_tasks_accepts_file_path(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tasks_path = Path(tmp) / "tasks.jsonl" | |
| write_jsonl(tasks_path, [{"task_id": "gt1"}]) | |
| self.assertEqual(load_gpu_tasks(tasks_path)[0]["task_id"], "gt1") | |
| def test_run_gpu_tasks_cli(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp_path = Path(tmp) | |
| tasks_path = tmp_path / "gpu_tasks.jsonl" | |
| report_path = tmp_path / "report.json" | |
| write_jsonl( | |
| tasks_path, | |
| [ | |
| { | |
| "task_id": "gt1", | |
| "task_type": "figure_description", | |
| "doc_id": "d1", | |
| "page_nums": [1], | |
| "image_path": str(tmp_path / "missing.png"), | |
| "priority": 60, | |
| } | |
| ], | |
| ) | |
| code = main(["run-gpu-tasks", "--input", str(tasks_path), "--output", str(report_path)]) | |
| self.assertEqual(code, 0) | |
| self.assertTrue(report_path.exists()) | |
| def test_run_gpu_tasks_cli_execute(self): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp_path = Path(tmp) | |
| image_path = tmp_path / "figure.png" | |
| image_path.write_bytes(b"fake") | |
| tasks_path = tmp_path / "gpu_tasks.jsonl" | |
| report_path = tmp_path / "report.json" | |
| write_jsonl( | |
| tasks_path, | |
| [ | |
| { | |
| "task_id": "gt1", | |
| "task_type": "figure_description", | |
| "doc_id": "d1", | |
| "page_nums": [1], | |
| "image_path": str(image_path), | |
| "priority": 60, | |
| "backend": "transformers", | |
| "model_role": "vlm", | |
| "model_id": "local-test-model", | |
| } | |
| ], | |
| ) | |
| with patch("zsgdp.gpu.worker.TransformersClient") as client_class: | |
| client_class.return_value.execute_task.return_value = {"status": "executed", "text": "done"} | |
| code = main(["run-gpu-tasks", "--input", str(tasks_path), "--output", str(report_path), "--execute"]) | |
| self.assertEqual(code, 0) | |
| self.assertEqual(json.loads(report_path.read_text(encoding="utf-8"))["executed_count"], 1) | |
| if __name__ == "__main__": | |
| unittest.main() | |