Spaces:
Running on Zero
Running on Zero
| """GPU worker extension point.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from time import perf_counter | |
| from zsgdp.gpu.transformers_client import TransformersClient | |
| from zsgdp.gpu.vllm_client import VLLMClient | |
| from zsgdp.logging_config import get_logger | |
| logger = get_logger(__name__) | |
| class GPUWorker: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| def run(self, task: dict, *, dry_run: bool = True) -> dict: | |
| readiness = _task_readiness(task) | |
| base = { | |
| "dry_run": dry_run, | |
| "task_id": task.get("task_id"), | |
| "task_type": task.get("task_type"), | |
| "region_id": task.get("region_id"), | |
| "provider": task.get("provider") or self.config.get("gpu", {}).get("provider", "huggingface_spaces"), | |
| "space_name": task.get("space_name") or self.config.get("gpu", {}).get("space_name", "zeroshotGPU"), | |
| "backend": task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers"), | |
| "model_role": task.get("model_role"), | |
| "model_id": task.get("model_id"), | |
| "readiness": readiness, | |
| } | |
| if dry_run: | |
| return {**base, "status": "ready_dry_run" if readiness["ready"] else "blocked_missing_inputs"} | |
| if not readiness["ready"]: | |
| logger.warning( | |
| "gpu_task_blocked", | |
| extra={ | |
| "task_id": task.get("task_id"), | |
| "task_type": task.get("task_type"), | |
| "missing_inputs": readiness.get("missing_inputs"), | |
| }, | |
| ) | |
| return {**base, "status": "blocked_missing_inputs"} | |
| started = perf_counter() | |
| client = self._client_for_task(task) | |
| result = client.execute_task(task) | |
| status = result.get("status", "execution_failed") | |
| elapsed = round(perf_counter() - started, 3) | |
| log_method = logger.info if status == "executed" else logger.warning | |
| log_method( | |
| "gpu_task_executed", | |
| extra={ | |
| "task_id": task.get("task_id"), | |
| "task_type": task.get("task_type"), | |
| "model_id": task.get("model_id"), | |
| "backend": base["backend"], | |
| "status": status, | |
| "elapsed_seconds": elapsed, | |
| }, | |
| ) | |
| return { | |
| **base, | |
| "status": status, | |
| "output": result, | |
| } | |
| def run_batch(self, batch: dict, *, dry_run: bool = True) -> dict: | |
| task_results = [self.run(task, dry_run=dry_run) for task in batch.get("tasks", [])] | |
| ready_count = sum(1 for result in task_results if result["readiness"]["ready"]) | |
| executed_count = sum(1 for result in task_results if result["status"] == "executed") | |
| failed_count = sum( | |
| 1 | |
| for result in task_results | |
| if result["readiness"]["ready"] and result["status"] not in {"ready_dry_run", "executed"} | |
| ) | |
| return { | |
| "batch_id": batch.get("batch_id"), | |
| "task_type": batch.get("task_type"), | |
| "provider": batch.get("provider"), | |
| "space_name": batch.get("space_name"), | |
| "backend": batch.get("backend"), | |
| "model_role": batch.get("model_role"), | |
| "model_id": batch.get("model_id"), | |
| "task_count": len(task_results), | |
| "ready_count": ready_count, | |
| "blocked_count": len(task_results) - ready_count, | |
| "executed_count": executed_count, | |
| "failed_count": failed_count, | |
| "status": "dry_run_complete" if dry_run else _batch_status(task_results), | |
| "results": task_results, | |
| } | |
| def _client_for_task(self, task: dict): | |
| backend = str(task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers")) | |
| model_config = _model_config_for_task(self.config, task) | |
| model_id = task.get("model_id") or model_config.get("model_id") | |
| if backend == "vllm": | |
| gpu = self.config.get("gpu", {}) | |
| endpoint = model_config.get("endpoint") or gpu.get("vllm_endpoint") or gpu.get("endpoint") | |
| return VLLMClient(endpoint=endpoint, model_id=model_id, api_key=gpu.get("api_key")) | |
| return TransformersClient(model_id=model_id, model_config=model_config) | |
| def _task_readiness(task: dict) -> dict: | |
| missing: list[str] = [] | |
| image_path = task.get("image_path") | |
| if task.get("task_type") in {"ocr_page", "table_vlm_repair", "figure_description", "vlm_route_repair"}: | |
| if image_path and not Path(str(image_path)).exists(): | |
| missing.append("image_path") | |
| if not image_path and task.get("task_type") != "vlm_route_repair": | |
| missing.append("image_path") | |
| if not task.get("doc_id"): | |
| missing.append("doc_id") | |
| if not task.get("page_nums"): | |
| missing.append("page_nums") | |
| return { | |
| "ready": not missing, | |
| "missing_inputs": missing, | |
| "image_path_exists": bool(image_path and Path(str(image_path)).exists()), | |
| } | |
| def _model_config_for_task(config: dict, task: dict) -> dict: | |
| if isinstance(task.get("model_config"), dict) and task["model_config"]: | |
| return dict(task["model_config"]) | |
| role = task.get("model_role") or "vlm" | |
| models = config.get("gpu", {}).get("models", {}) | |
| model_config = models.get(role, {}) if isinstance(models, dict) else {} | |
| return dict(model_config) if isinstance(model_config, dict) else {} | |
| def _batch_status(task_results: list[dict]) -> str: | |
| if not task_results: | |
| return "execute_complete" | |
| if all(result["status"] == "executed" for result in task_results): | |
| return "execute_complete" | |
| if any(result["status"] == "executed" for result in task_results): | |
| return "execute_partial" | |
| if all(not result["readiness"]["ready"] for result in task_results): | |
| return "blocked_missing_inputs" | |
| return "execute_failed" | |