"""Sandboxed .pt validator HF ZeroGPU Space. `.pt` files are Python pickles; loading an untrusted one inside our backend process is a remote-code-execution risk. This Space runs a short-lived GPU task that: 1. Downloads the weights from a signed URL supplied by the backend 2. Sanity-checks the file size 3. Loads the weights as an Ultralytics YOLO model 4. Runs a single inference on a synthetic image to confirm the model produces a predictable structure 5. Returns {ok, class_names, metrics} or {ok: false, error} This follows the same pattern as the other Cadayn ZeroGPU spaces (sam3-zerogpu, parakeet-zerogpu, qwen3vl-zerogpu, etc.) -- Gradio app with an api_* function exposed as the Space's remote API. """ from __future__ import annotations try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False class _spaces_stub: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator spaces = _spaces_stub() import gc import os import tempfile import time import traceback from pathlib import Path from typing import Any import gradio as gr import httpx import numpy as np MAX_MODEL_BYTES = 200 * 1024 * 1024 # 200 MB GPU_DURATION_S = 120 def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path: """Download a .pt file to a temp location and return the local path.""" tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) tmp_path = Path(tmp.name) tmp.close() with httpx.stream("GET", source_url, timeout=timeout_s, follow_redirects=True) as response: response.raise_for_status() with tmp_path.open("wb") as f: for chunk in response.iter_bytes(chunk_size=1024 * 1024): f.write(chunk) if tmp_path.stat().st_size > MAX_MODEL_BYTES: raise ValueError(f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit") return tmp_path @spaces.GPU(duration=GPU_DURATION_S) def _load_and_probe(weights_path: str, expected_classes: list[str]) -> dict[str, Any]: """Load the YOLO model on the GPU worker and run one synthetic prediction.""" import torch # deferred so ZeroGPU handles CUDA import order from ultralytics import YOLO local_path = Path(weights_path) try: model = YOLO(str(local_path)) raw_names = getattr(model, "names", {}) or {} class_names = ( [raw_names[k] for k in sorted(raw_names.keys())] if isinstance(raw_names, dict) else list(raw_names) ) synthetic_frame = np.zeros((640, 640, 3), dtype=np.uint8) results = model.predict(source=synthetic_frame, verbose=False) if not results: return {"ok": False, "error": "Model returned no results from synthetic input"} metrics = { "task": getattr(model, "task", "unknown"), "num_classes": len(class_names), "model_file_mb": round(local_path.stat().st_size / (1024 * 1024), 2), } warnings: list[str] = [] if expected_classes: expected_set = {c.strip() for c in expected_classes if c.strip()} present = set(class_names) missing = sorted(expected_set - present) extra = sorted(present - expected_set) if missing: warnings.append(f"Missing declared classes: {missing}") if extra: warnings.append(f"Weights include extra classes not declared: {extra}") return { "ok": True, "class_names": class_names, "metrics": metrics, "warnings": warnings, } finally: gc.collect() try: torch.cuda.empty_cache() except Exception: pass def api_validate_weights( weights_url: str | None = None, expected_class_names: list[str] | None = None, ) -> dict[str, Any]: """API endpoint for sandboxed .pt weights validation from the backend. Args: weights_url: Signed URL from which the .pt file can be downloaded. expected_class_names: Optional class list the user declared on upload; warnings are emitted for mismatches against the weights' classes. Returns: { ok: bool, class_names?: list[str], metrics?: dict, warnings?: list[str], error?: str, elapsed_s: float } """ if not weights_url: return {"ok": False, "error": "weights_url is required"} start = time.monotonic() local_path: Path | None = None try: local_path = _download_weights(weights_url) size_bytes = local_path.stat().st_size if size_bytes == 0: return {"ok": False, "error": "Downloaded weights file is empty"} if size_bytes > MAX_MODEL_BYTES: return { "ok": False, "error": f"Weights file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit", } try: probe = _load_and_probe(str(local_path), expected_class_names or []) except Exception as exc: return { "ok": False, "error": f"Failed to load model: {exc}", "traceback": traceback.format_exc(), } probe["elapsed_s"] = round(time.monotonic() - start, 3) return probe except httpx.HTTPError as exc: return {"ok": False, "error": f"Download failed: {exc}"} except ValueError as exc: return {"ok": False, "error": str(exc)} finally: if local_path is not None: try: os.unlink(local_path) except FileNotFoundError: pass # ─── Gradio wiring ──────────────────────────────────────────────────────────── with gr.Blocks(title="Cadayn Model Validator") as demo: gr.Markdown( """ # Cadayn Custom Model Validator Validates an uploaded Ultralytics YOLO `.pt` file by loading it and running a single synthetic-image prediction. Returns the declared class names and basic metrics, or a structured error payload. **API endpoint:** `/api/api_validate_weights` """ ) with gr.Row(): api_weights_url = gr.Textbox(label="Signed .pt URL") api_expected_classes = gr.JSON(label="Expected class names", value=[]) api_output = gr.JSON(label="Validation result") api_weights_url.change( fn=api_validate_weights, inputs=[api_weights_url, api_expected_classes], outputs=api_output, api_name="api_validate_weights", ) if __name__ == "__main__": demo.launch()