zeroshotGPU / zsgdp /gpu /transformers_client.py
Arjunvir Singh
Lower @spaces.GPU duration from 180/120 to 60 (ZeroGPU max-duration cap)
de03f34
"""Transformers backend client.
ZeroGPU note (mirrors zsgdp/benchmarks/embedding_retriever.py): the
@spaces.GPU decorator runs the wrapped function in a separate worker
process. Mutations to `self` inside the worker — including
cached_property results — do NOT propagate back to the caller. The
worker exits at the end of each call. So we keep `execute_task` in the
main process and offload the GPU-bound pipeline load + inference to the
free stateless helper `_gpu_run_pipeline(...)`. Only picklable values
cross the boundary (strings, dicts of strings/numbers); the helper
returns the extracted text string, not the raw pipeline output.
"""
from __future__ import annotations
from typing import Any
from zsgdp.gpu.worker_prompts import prompt_for_task
from zsgdp.gpu.zero_gpu import gpu as zero_gpu_slot
@zero_gpu_slot(duration=60)
def _gpu_run_pipeline(
model_id: str,
pipeline_task: str,
dtype: str | None,
device: str | None,
prompt: str,
image_path: str | None,
) -> str:
"""Load a transformers pipeline and run inference under a ZeroGPU slot.
Stateless by design: takes only picklable inputs and returns a single
string (the extracted output text). Subsequent calls re-load the model;
that's the cost of bursty ZeroGPU usage. For sustained throughput,
pin the Space to non-ZeroGPU hardware and the no-op decorator path
will let the cached_property pattern work as intended.
"""
from transformers import pipeline # type: ignore
kwargs: dict[str, Any] = {"model": model_id}
if dtype:
kwargs["torch_dtype"] = dtype
if device and device != "auto":
kwargs["device"] = device
elif device == "auto":
kwargs["device_map"] = "auto"
pipe = pipeline(pipeline_task, **kwargs)
if image_path:
output = pipe({"image": image_path, "text": prompt})
else:
output = pipe(prompt)
return _extract_text(output)
class TransformersClient:
def __init__(self, model_id: str | None, model_config: dict[str, Any] | None = None):
self.model_id = model_id
self.model_config = model_config or {}
@property
def available(self) -> bool:
if not self.model_id:
return False
try:
import transformers # noqa: F401
except Exception:
return False
return True
def execute_task(self, task: dict[str, Any]) -> dict[str, Any]:
# NOT decorated with @zero_gpu_slot — see module docstring. The GPU
# work is offloaded to the stateless _gpu_run_pipeline helper.
if not self.available:
return {
"status": "backend_unavailable",
"error": "Transformers is not installed or model_id is missing.",
}
prompt = prompt_for_task(task)
image_path = task.get("image_path")
pipeline_task = str(self.model_config.get("task", "image-text-to-text"))
try:
text = _gpu_run_pipeline(
model_id=str(self.model_id),
pipeline_task=pipeline_task,
dtype=self.model_config.get("dtype"),
device=self.model_config.get("device"),
prompt=prompt,
image_path=str(image_path) if image_path else None,
)
except Exception as exc:
return {"status": "execution_failed", "error": str(exc)}
return {
"status": "executed",
"text": text,
}
def _extract_text(output: Any) -> str:
if isinstance(output, str):
return output
if isinstance(output, list) and output:
return _extract_text(output[0])
if isinstance(output, dict):
for key in ("generated_text", "text", "summary_text", "answer"):
if output.get(key):
return str(output[key])
return str(output)