Spaces:
Sleeping
Sleeping
File size: 6,840 Bytes
6cc17e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """Sandboxed .pt validator HF ZeroGPU Space.
`.pt` files are Python pickles; loading an untrusted one inside our backend
process is a remote-code-execution risk. This Space runs a short-lived GPU
task that:
1. Downloads the weights from a signed URL supplied by the backend
2. Sanity-checks the file size
3. Loads the weights as an Ultralytics YOLO model
4. Runs a single inference on a synthetic image to confirm the model
produces a predictable structure
5. Returns {ok, class_names, metrics} or {ok: false, error}
This follows the same pattern as the other Cadayn ZeroGPU spaces
(sam3-zerogpu, parakeet-zerogpu, qwen3vl-zerogpu, etc.) -- Gradio app with
an api_* function exposed as the Space's remote API.
"""
from __future__ import annotations
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
class _spaces_stub:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
spaces = _spaces_stub()
import gc
import os
import tempfile
import time
import traceback
from pathlib import Path
from typing import Any
import gradio as gr
import httpx
import numpy as np
MAX_MODEL_BYTES = 200 * 1024 * 1024 # 200 MB
GPU_DURATION_S = 120
def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path:
"""Download a .pt file to a temp location and return the local path."""
tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
tmp_path = Path(tmp.name)
tmp.close()
with httpx.stream("GET", source_url, timeout=timeout_s, follow_redirects=True) as response:
response.raise_for_status()
with tmp_path.open("wb") as f:
for chunk in response.iter_bytes(chunk_size=1024 * 1024):
f.write(chunk)
if tmp_path.stat().st_size > MAX_MODEL_BYTES:
raise ValueError(f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit")
return tmp_path
@spaces.GPU(duration=GPU_DURATION_S)
def _load_and_probe(weights_path: str, expected_classes: list[str]) -> dict[str, Any]:
"""Load the YOLO model on the GPU worker and run one synthetic prediction."""
import torch # deferred so ZeroGPU handles CUDA import order
from ultralytics import YOLO
local_path = Path(weights_path)
try:
model = YOLO(str(local_path))
raw_names = getattr(model, "names", {}) or {}
class_names = (
[raw_names[k] for k in sorted(raw_names.keys())] if isinstance(raw_names, dict) else list(raw_names)
)
synthetic_frame = np.zeros((640, 640, 3), dtype=np.uint8)
results = model.predict(source=synthetic_frame, verbose=False)
if not results:
return {"ok": False, "error": "Model returned no results from synthetic input"}
metrics = {
"task": getattr(model, "task", "unknown"),
"num_classes": len(class_names),
"model_file_mb": round(local_path.stat().st_size / (1024 * 1024), 2),
}
warnings: list[str] = []
if expected_classes:
expected_set = {c.strip() for c in expected_classes if c.strip()}
present = set(class_names)
missing = sorted(expected_set - present)
extra = sorted(present - expected_set)
if missing:
warnings.append(f"Missing declared classes: {missing}")
if extra:
warnings.append(f"Weights include extra classes not declared: {extra}")
return {
"ok": True,
"class_names": class_names,
"metrics": metrics,
"warnings": warnings,
}
finally:
gc.collect()
try:
torch.cuda.empty_cache()
except Exception:
pass
def api_validate_weights(
weights_url: str | None = None,
expected_class_names: list[str] | None = None,
) -> dict[str, Any]:
"""API endpoint for sandboxed .pt weights validation from the backend.
Args:
weights_url: Signed URL from which the .pt file can be downloaded.
expected_class_names: Optional class list the user declared on upload;
warnings are emitted for mismatches against the weights' classes.
Returns:
{ ok: bool, class_names?: list[str], metrics?: dict, warnings?: list[str],
error?: str, elapsed_s: float }
"""
if not weights_url:
return {"ok": False, "error": "weights_url is required"}
start = time.monotonic()
local_path: Path | None = None
try:
local_path = _download_weights(weights_url)
size_bytes = local_path.stat().st_size
if size_bytes == 0:
return {"ok": False, "error": "Downloaded weights file is empty"}
if size_bytes > MAX_MODEL_BYTES:
return {
"ok": False,
"error": f"Weights file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit",
}
try:
probe = _load_and_probe(str(local_path), expected_class_names or [])
except Exception as exc:
return {
"ok": False,
"error": f"Failed to load model: {exc}",
"traceback": traceback.format_exc(),
}
probe["elapsed_s"] = round(time.monotonic() - start, 3)
return probe
except httpx.HTTPError as exc:
return {"ok": False, "error": f"Download failed: {exc}"}
except ValueError as exc:
return {"ok": False, "error": str(exc)}
finally:
if local_path is not None:
try:
os.unlink(local_path)
except FileNotFoundError:
pass
# βββ Gradio wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Cadayn Model Validator") as demo:
gr.Markdown(
"""
# Cadayn Custom Model Validator
Validates an uploaded Ultralytics YOLO `.pt` file by loading it and
running a single synthetic-image prediction. Returns the declared class
names and basic metrics, or a structured error payload.
**API endpoint:** `/api/api_validate_weights`
"""
)
with gr.Row():
api_weights_url = gr.Textbox(label="Signed .pt URL")
api_expected_classes = gr.JSON(label="Expected class names", value=[])
api_output = gr.JSON(label="Validation result")
api_weights_url.change(
fn=api_validate_weights,
inputs=[api_weights_url, api_expected_classes],
outputs=api_output,
api_name="api_validate_weights",
)
if __name__ == "__main__":
demo.launch()
|