zeroshotGPU / zsgdp /gpu /worker.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""GPU worker extension point."""
from __future__ import annotations
from pathlib import Path
from time import perf_counter
from zsgdp.gpu.transformers_client import TransformersClient
from zsgdp.gpu.vllm_client import VLLMClient
from zsgdp.logging_config import get_logger
logger = get_logger(__name__)
class GPUWorker:
def __init__(self, config: dict):
self.config = config
def run(self, task: dict, *, dry_run: bool = True) -> dict:
readiness = _task_readiness(task)
base = {
"dry_run": dry_run,
"task_id": task.get("task_id"),
"task_type": task.get("task_type"),
"region_id": task.get("region_id"),
"provider": task.get("provider") or self.config.get("gpu", {}).get("provider", "huggingface_spaces"),
"space_name": task.get("space_name") or self.config.get("gpu", {}).get("space_name", "zeroshotGPU"),
"backend": task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers"),
"model_role": task.get("model_role"),
"model_id": task.get("model_id"),
"readiness": readiness,
}
if dry_run:
return {**base, "status": "ready_dry_run" if readiness["ready"] else "blocked_missing_inputs"}
if not readiness["ready"]:
logger.warning(
"gpu_task_blocked",
extra={
"task_id": task.get("task_id"),
"task_type": task.get("task_type"),
"missing_inputs": readiness.get("missing_inputs"),
},
)
return {**base, "status": "blocked_missing_inputs"}
started = perf_counter()
client = self._client_for_task(task)
result = client.execute_task(task)
status = result.get("status", "execution_failed")
elapsed = round(perf_counter() - started, 3)
log_method = logger.info if status == "executed" else logger.warning
log_method(
"gpu_task_executed",
extra={
"task_id": task.get("task_id"),
"task_type": task.get("task_type"),
"model_id": task.get("model_id"),
"backend": base["backend"],
"status": status,
"elapsed_seconds": elapsed,
},
)
return {
**base,
"status": status,
"output": result,
}
def run_batch(self, batch: dict, *, dry_run: bool = True) -> dict:
task_results = [self.run(task, dry_run=dry_run) for task in batch.get("tasks", [])]
ready_count = sum(1 for result in task_results if result["readiness"]["ready"])
executed_count = sum(1 for result in task_results if result["status"] == "executed")
failed_count = sum(
1
for result in task_results
if result["readiness"]["ready"] and result["status"] not in {"ready_dry_run", "executed"}
)
return {
"batch_id": batch.get("batch_id"),
"task_type": batch.get("task_type"),
"provider": batch.get("provider"),
"space_name": batch.get("space_name"),
"backend": batch.get("backend"),
"model_role": batch.get("model_role"),
"model_id": batch.get("model_id"),
"task_count": len(task_results),
"ready_count": ready_count,
"blocked_count": len(task_results) - ready_count,
"executed_count": executed_count,
"failed_count": failed_count,
"status": "dry_run_complete" if dry_run else _batch_status(task_results),
"results": task_results,
}
def _client_for_task(self, task: dict):
backend = str(task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers"))
model_config = _model_config_for_task(self.config, task)
model_id = task.get("model_id") or model_config.get("model_id")
if backend == "vllm":
gpu = self.config.get("gpu", {})
endpoint = model_config.get("endpoint") or gpu.get("vllm_endpoint") or gpu.get("endpoint")
return VLLMClient(endpoint=endpoint, model_id=model_id, api_key=gpu.get("api_key"))
return TransformersClient(model_id=model_id, model_config=model_config)
def _task_readiness(task: dict) -> dict:
missing: list[str] = []
image_path = task.get("image_path")
if task.get("task_type") in {"ocr_page", "table_vlm_repair", "figure_description", "vlm_route_repair"}:
if image_path and not Path(str(image_path)).exists():
missing.append("image_path")
if not image_path and task.get("task_type") != "vlm_route_repair":
missing.append("image_path")
if not task.get("doc_id"):
missing.append("doc_id")
if not task.get("page_nums"):
missing.append("page_nums")
return {
"ready": not missing,
"missing_inputs": missing,
"image_path_exists": bool(image_path and Path(str(image_path)).exists()),
}
def _model_config_for_task(config: dict, task: dict) -> dict:
if isinstance(task.get("model_config"), dict) and task["model_config"]:
return dict(task["model_config"])
role = task.get("model_role") or "vlm"
models = config.get("gpu", {}).get("models", {})
model_config = models.get(role, {}) if isinstance(models, dict) else {}
return dict(model_config) if isinstance(model_config, dict) else {}
def _batch_status(task_results: list[dict]) -> str:
if not task_results:
return "execute_complete"
if all(result["status"] == "executed" for result in task_results):
return "execute_complete"
if any(result["status"] == "executed" for result in task_results):
return "execute_partial"
if all(not result["readiness"]["ready"] for result in task_results):
return "blocked_missing_inputs"
return "execute_failed"