Spaces:
Sleeping
Sleeping
| """Sandboxed .pt validator HF ZeroGPU Space. | |
| `.pt` files are Python pickles; loading an untrusted one inside our backend | |
| process is a remote-code-execution risk. This Space runs a short-lived GPU | |
| task that: | |
| 1. Downloads the weights from a signed URL supplied by the backend | |
| 2. Sanity-checks the file size | |
| 3. Loads the weights as an Ultralytics YOLO model | |
| 4. Runs a single inference on a synthetic image to confirm the model | |
| produces a predictable structure | |
| 5. Returns {ok, class_names, metrics} or {ok: false, error} | |
| This follows the same pattern as the other Cadayn ZeroGPU spaces | |
| (sam3-zerogpu, parakeet-zerogpu, qwen3vl-zerogpu, etc.) -- Gradio app with | |
| an api_* function exposed as the Space's remote API. | |
| """ | |
| from __future__ import annotations | |
| try: | |
| import spaces | |
| ZEROGPU_AVAILABLE = True | |
| except ImportError: | |
| ZEROGPU_AVAILABLE = False | |
| class _spaces_stub: | |
| def GPU(duration=60): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| spaces = _spaces_stub() | |
| import gc | |
| import os | |
| import tempfile | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import httpx | |
| import numpy as np | |
| MAX_MODEL_BYTES = 200 * 1024 * 1024 # 200 MB | |
| GPU_DURATION_S = 120 | |
| def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path: | |
| """Download a .pt file to a temp location and return the local path.""" | |
| tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) | |
| tmp_path = Path(tmp.name) | |
| tmp.close() | |
| with httpx.stream("GET", source_url, timeout=timeout_s, follow_redirects=True) as response: | |
| response.raise_for_status() | |
| with tmp_path.open("wb") as f: | |
| for chunk in response.iter_bytes(chunk_size=1024 * 1024): | |
| f.write(chunk) | |
| if tmp_path.stat().st_size > MAX_MODEL_BYTES: | |
| raise ValueError(f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit") | |
| return tmp_path | |
| def _load_and_probe(weights_path: str, expected_classes: list[str]) -> dict[str, Any]: | |
| """Load the YOLO model on the GPU worker and run one synthetic prediction.""" | |
| import torch # deferred so ZeroGPU handles CUDA import order | |
| from ultralytics import YOLO | |
| local_path = Path(weights_path) | |
| try: | |
| model = YOLO(str(local_path)) | |
| raw_names = getattr(model, "names", {}) or {} | |
| class_names = ( | |
| [raw_names[k] for k in sorted(raw_names.keys())] if isinstance(raw_names, dict) else list(raw_names) | |
| ) | |
| synthetic_frame = np.zeros((640, 640, 3), dtype=np.uint8) | |
| results = model.predict(source=synthetic_frame, verbose=False) | |
| if not results: | |
| return {"ok": False, "error": "Model returned no results from synthetic input"} | |
| metrics = { | |
| "task": getattr(model, "task", "unknown"), | |
| "num_classes": len(class_names), | |
| "model_file_mb": round(local_path.stat().st_size / (1024 * 1024), 2), | |
| } | |
| warnings: list[str] = [] | |
| if expected_classes: | |
| expected_set = {c.strip() for c in expected_classes if c.strip()} | |
| present = set(class_names) | |
| missing = sorted(expected_set - present) | |
| extra = sorted(present - expected_set) | |
| if missing: | |
| warnings.append(f"Missing declared classes: {missing}") | |
| if extra: | |
| warnings.append(f"Weights include extra classes not declared: {extra}") | |
| return { | |
| "ok": True, | |
| "class_names": class_names, | |
| "metrics": metrics, | |
| "warnings": warnings, | |
| } | |
| finally: | |
| gc.collect() | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| def api_validate_weights( | |
| weights_url: str | None = None, | |
| expected_class_names: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| """API endpoint for sandboxed .pt weights validation from the backend. | |
| Args: | |
| weights_url: Signed URL from which the .pt file can be downloaded. | |
| expected_class_names: Optional class list the user declared on upload; | |
| warnings are emitted for mismatches against the weights' classes. | |
| Returns: | |
| { ok: bool, class_names?: list[str], metrics?: dict, warnings?: list[str], | |
| error?: str, elapsed_s: float } | |
| """ | |
| if not weights_url: | |
| return {"ok": False, "error": "weights_url is required"} | |
| start = time.monotonic() | |
| local_path: Path | None = None | |
| try: | |
| local_path = _download_weights(weights_url) | |
| size_bytes = local_path.stat().st_size | |
| if size_bytes == 0: | |
| return {"ok": False, "error": "Downloaded weights file is empty"} | |
| if size_bytes > MAX_MODEL_BYTES: | |
| return { | |
| "ok": False, | |
| "error": f"Weights file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit", | |
| } | |
| try: | |
| probe = _load_and_probe(str(local_path), expected_class_names or []) | |
| except Exception as exc: | |
| return { | |
| "ok": False, | |
| "error": f"Failed to load model: {exc}", | |
| "traceback": traceback.format_exc(), | |
| } | |
| probe["elapsed_s"] = round(time.monotonic() - start, 3) | |
| return probe | |
| except httpx.HTTPError as exc: | |
| return {"ok": False, "error": f"Download failed: {exc}"} | |
| except ValueError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| finally: | |
| if local_path is not None: | |
| try: | |
| os.unlink(local_path) | |
| except FileNotFoundError: | |
| pass | |
| # βββ Gradio wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Cadayn Model Validator") as demo: | |
| gr.Markdown( | |
| """ | |
| # Cadayn Custom Model Validator | |
| Validates an uploaded Ultralytics YOLO `.pt` file by loading it and | |
| running a single synthetic-image prediction. Returns the declared class | |
| names and basic metrics, or a structured error payload. | |
| **API endpoint:** `/api/api_validate_weights` | |
| """ | |
| ) | |
| with gr.Row(): | |
| api_weights_url = gr.Textbox(label="Signed .pt URL") | |
| api_expected_classes = gr.JSON(label="Expected class names", value=[]) | |
| api_output = gr.JSON(label="Validation result") | |
| api_weights_url.change( | |
| fn=api_validate_weights, | |
| inputs=[api_weights_url, api_expected_classes], | |
| outputs=api_output, | |
| api_name="api_validate_weights", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |