"""GPU worker extension point.""" from __future__ import annotations from pathlib import Path from time import perf_counter from zsgdp.gpu.transformers_client import TransformersClient from zsgdp.gpu.vllm_client import VLLMClient from zsgdp.logging_config import get_logger logger = get_logger(__name__) class GPUWorker: def __init__(self, config: dict): self.config = config def run(self, task: dict, *, dry_run: bool = True) -> dict: readiness = _task_readiness(task) base = { "dry_run": dry_run, "task_id": task.get("task_id"), "task_type": task.get("task_type"), "region_id": task.get("region_id"), "provider": task.get("provider") or self.config.get("gpu", {}).get("provider", "huggingface_spaces"), "space_name": task.get("space_name") or self.config.get("gpu", {}).get("space_name", "zeroshotGPU"), "backend": task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers"), "model_role": task.get("model_role"), "model_id": task.get("model_id"), "readiness": readiness, } if dry_run: return {**base, "status": "ready_dry_run" if readiness["ready"] else "blocked_missing_inputs"} if not readiness["ready"]: logger.warning( "gpu_task_blocked", extra={ "task_id": task.get("task_id"), "task_type": task.get("task_type"), "missing_inputs": readiness.get("missing_inputs"), }, ) return {**base, "status": "blocked_missing_inputs"} started = perf_counter() client = self._client_for_task(task) result = client.execute_task(task) status = result.get("status", "execution_failed") elapsed = round(perf_counter() - started, 3) log_method = logger.info if status == "executed" else logger.warning log_method( "gpu_task_executed", extra={ "task_id": task.get("task_id"), "task_type": task.get("task_type"), "model_id": task.get("model_id"), "backend": base["backend"], "status": status, "elapsed_seconds": elapsed, }, ) return { **base, "status": status, "output": result, } def run_batch(self, batch: dict, *, dry_run: bool = True) -> dict: task_results = [self.run(task, dry_run=dry_run) for task in batch.get("tasks", [])] ready_count = sum(1 for result in task_results if result["readiness"]["ready"]) executed_count = sum(1 for result in task_results if result["status"] == "executed") failed_count = sum( 1 for result in task_results if result["readiness"]["ready"] and result["status"] not in {"ready_dry_run", "executed"} ) return { "batch_id": batch.get("batch_id"), "task_type": batch.get("task_type"), "provider": batch.get("provider"), "space_name": batch.get("space_name"), "backend": batch.get("backend"), "model_role": batch.get("model_role"), "model_id": batch.get("model_id"), "task_count": len(task_results), "ready_count": ready_count, "blocked_count": len(task_results) - ready_count, "executed_count": executed_count, "failed_count": failed_count, "status": "dry_run_complete" if dry_run else _batch_status(task_results), "results": task_results, } def _client_for_task(self, task: dict): backend = str(task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers")) model_config = _model_config_for_task(self.config, task) model_id = task.get("model_id") or model_config.get("model_id") if backend == "vllm": gpu = self.config.get("gpu", {}) endpoint = model_config.get("endpoint") or gpu.get("vllm_endpoint") or gpu.get("endpoint") return VLLMClient(endpoint=endpoint, model_id=model_id, api_key=gpu.get("api_key")) return TransformersClient(model_id=model_id, model_config=model_config) def _task_readiness(task: dict) -> dict: missing: list[str] = [] image_path = task.get("image_path") if task.get("task_type") in {"ocr_page", "table_vlm_repair", "figure_description", "vlm_route_repair"}: if image_path and not Path(str(image_path)).exists(): missing.append("image_path") if not image_path and task.get("task_type") != "vlm_route_repair": missing.append("image_path") if not task.get("doc_id"): missing.append("doc_id") if not task.get("page_nums"): missing.append("page_nums") return { "ready": not missing, "missing_inputs": missing, "image_path_exists": bool(image_path and Path(str(image_path)).exists()), } def _model_config_for_task(config: dict, task: dict) -> dict: if isinstance(task.get("model_config"), dict) and task["model_config"]: return dict(task["model_config"]) role = task.get("model_role") or "vlm" models = config.get("gpu", {}).get("models", {}) model_config = models.get(role, {}) if isinstance(models, dict) else {} return dict(model_config) if isinstance(model_config, dict) else {} def _batch_status(task_results: list[dict]) -> str: if not task_results: return "execute_complete" if all(result["status"] == "executed" for result in task_results): return "execute_complete" if any(result["status"] == "executed" for result in task_results): return "execute_partial" if all(not result["readiness"]["ready"] for result in task_results): return "blocked_missing_inputs" return "execute_failed"