magboola's picture
deploy model-validator-zerogpu
6cc17e0 verified
"""Sandboxed .pt validator HF ZeroGPU Space.
`.pt` files are Python pickles; loading an untrusted one inside our backend
process is a remote-code-execution risk. This Space runs a short-lived GPU
task that:
1. Downloads the weights from a signed URL supplied by the backend
2. Sanity-checks the file size
3. Loads the weights as an Ultralytics YOLO model
4. Runs a single inference on a synthetic image to confirm the model
produces a predictable structure
5. Returns {ok, class_names, metrics} or {ok: false, error}
This follows the same pattern as the other Cadayn ZeroGPU spaces
(sam3-zerogpu, parakeet-zerogpu, qwen3vl-zerogpu, etc.) -- Gradio app with
an api_* function exposed as the Space's remote API.
"""
from __future__ import annotations
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
class _spaces_stub:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
spaces = _spaces_stub()
import gc
import os
import tempfile
import time
import traceback
from pathlib import Path
from typing import Any
import gradio as gr
import httpx
import numpy as np
MAX_MODEL_BYTES = 200 * 1024 * 1024 # 200 MB
GPU_DURATION_S = 120
def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path:
"""Download a .pt file to a temp location and return the local path."""
tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
tmp_path = Path(tmp.name)
tmp.close()
with httpx.stream("GET", source_url, timeout=timeout_s, follow_redirects=True) as response:
response.raise_for_status()
with tmp_path.open("wb") as f:
for chunk in response.iter_bytes(chunk_size=1024 * 1024):
f.write(chunk)
if tmp_path.stat().st_size > MAX_MODEL_BYTES:
raise ValueError(f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit")
return tmp_path
@spaces.GPU(duration=GPU_DURATION_S)
def _load_and_probe(weights_path: str, expected_classes: list[str]) -> dict[str, Any]:
"""Load the YOLO model on the GPU worker and run one synthetic prediction."""
import torch # deferred so ZeroGPU handles CUDA import order
from ultralytics import YOLO
local_path = Path(weights_path)
try:
model = YOLO(str(local_path))
raw_names = getattr(model, "names", {}) or {}
class_names = (
[raw_names[k] for k in sorted(raw_names.keys())] if isinstance(raw_names, dict) else list(raw_names)
)
synthetic_frame = np.zeros((640, 640, 3), dtype=np.uint8)
results = model.predict(source=synthetic_frame, verbose=False)
if not results:
return {"ok": False, "error": "Model returned no results from synthetic input"}
metrics = {
"task": getattr(model, "task", "unknown"),
"num_classes": len(class_names),
"model_file_mb": round(local_path.stat().st_size / (1024 * 1024), 2),
}
warnings: list[str] = []
if expected_classes:
expected_set = {c.strip() for c in expected_classes if c.strip()}
present = set(class_names)
missing = sorted(expected_set - present)
extra = sorted(present - expected_set)
if missing:
warnings.append(f"Missing declared classes: {missing}")
if extra:
warnings.append(f"Weights include extra classes not declared: {extra}")
return {
"ok": True,
"class_names": class_names,
"metrics": metrics,
"warnings": warnings,
}
finally:
gc.collect()
try:
torch.cuda.empty_cache()
except Exception:
pass
def api_validate_weights(
weights_url: str | None = None,
expected_class_names: list[str] | None = None,
) -> dict[str, Any]:
"""API endpoint for sandboxed .pt weights validation from the backend.
Args:
weights_url: Signed URL from which the .pt file can be downloaded.
expected_class_names: Optional class list the user declared on upload;
warnings are emitted for mismatches against the weights' classes.
Returns:
{ ok: bool, class_names?: list[str], metrics?: dict, warnings?: list[str],
error?: str, elapsed_s: float }
"""
if not weights_url:
return {"ok": False, "error": "weights_url is required"}
start = time.monotonic()
local_path: Path | None = None
try:
local_path = _download_weights(weights_url)
size_bytes = local_path.stat().st_size
if size_bytes == 0:
return {"ok": False, "error": "Downloaded weights file is empty"}
if size_bytes > MAX_MODEL_BYTES:
return {
"ok": False,
"error": f"Weights file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit",
}
try:
probe = _load_and_probe(str(local_path), expected_class_names or [])
except Exception as exc:
return {
"ok": False,
"error": f"Failed to load model: {exc}",
"traceback": traceback.format_exc(),
}
probe["elapsed_s"] = round(time.monotonic() - start, 3)
return probe
except httpx.HTTPError as exc:
return {"ok": False, "error": f"Download failed: {exc}"}
except ValueError as exc:
return {"ok": False, "error": str(exc)}
finally:
if local_path is not None:
try:
os.unlink(local_path)
except FileNotFoundError:
pass
# ─── Gradio wiring ────────────────────────────────────────────────────────────
with gr.Blocks(title="Cadayn Model Validator") as demo:
gr.Markdown(
"""
# Cadayn Custom Model Validator
Validates an uploaded Ultralytics YOLO `.pt` file by loading it and
running a single synthetic-image prediction. Returns the declared class
names and basic metrics, or a structured error payload.
**API endpoint:** `/api/api_validate_weights`
"""
)
with gr.Row():
api_weights_url = gr.Textbox(label="Signed .pt URL")
api_expected_classes = gr.JSON(label="Expected class names", value=[])
api_output = gr.JSON(label="Validation result")
api_weights_url.change(
fn=api_validate_weights,
inputs=[api_weights_url, api_expected_classes],
outputs=api_output,
api_name="api_validate_weights",
)
if __name__ == "__main__":
demo.launch()