File size: 6,840 Bytes
6cc17e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""Sandboxed .pt validator HF ZeroGPU Space.

`.pt` files are Python pickles; loading an untrusted one inside our backend
process is a remote-code-execution risk. This Space runs a short-lived GPU
task that:

  1. Downloads the weights from a signed URL supplied by the backend
  2. Sanity-checks the file size
  3. Loads the weights as an Ultralytics YOLO model
  4. Runs a single inference on a synthetic image to confirm the model
     produces a predictable structure
  5. Returns {ok, class_names, metrics} or {ok: false, error}

This follows the same pattern as the other Cadayn ZeroGPU spaces
(sam3-zerogpu, parakeet-zerogpu, qwen3vl-zerogpu, etc.) -- Gradio app with
an api_* function exposed as the Space's remote API.
"""

from __future__ import annotations

try:
    import spaces

    ZEROGPU_AVAILABLE = True
except ImportError:
    ZEROGPU_AVAILABLE = False

    class _spaces_stub:
        @staticmethod
        def GPU(duration=60):
            def decorator(func):
                return func

            return decorator

    spaces = _spaces_stub()

import gc
import os
import tempfile
import time
import traceback
from pathlib import Path
from typing import Any

import gradio as gr
import httpx
import numpy as np

MAX_MODEL_BYTES = 200 * 1024 * 1024  # 200 MB
GPU_DURATION_S = 120


def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path:
    """Download a .pt file to a temp location and return the local path."""
    tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
    tmp_path = Path(tmp.name)
    tmp.close()

    with httpx.stream("GET", source_url, timeout=timeout_s, follow_redirects=True) as response:
        response.raise_for_status()
        with tmp_path.open("wb") as f:
            for chunk in response.iter_bytes(chunk_size=1024 * 1024):
                f.write(chunk)
                if tmp_path.stat().st_size > MAX_MODEL_BYTES:
                    raise ValueError(f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit")
    return tmp_path


@spaces.GPU(duration=GPU_DURATION_S)
def _load_and_probe(weights_path: str, expected_classes: list[str]) -> dict[str, Any]:
    """Load the YOLO model on the GPU worker and run one synthetic prediction."""
    import torch  # deferred so ZeroGPU handles CUDA import order
    from ultralytics import YOLO

    local_path = Path(weights_path)

    try:
        model = YOLO(str(local_path))

        raw_names = getattr(model, "names", {}) or {}
        class_names = (
            [raw_names[k] for k in sorted(raw_names.keys())] if isinstance(raw_names, dict) else list(raw_names)
        )

        synthetic_frame = np.zeros((640, 640, 3), dtype=np.uint8)
        results = model.predict(source=synthetic_frame, verbose=False)
        if not results:
            return {"ok": False, "error": "Model returned no results from synthetic input"}

        metrics = {
            "task": getattr(model, "task", "unknown"),
            "num_classes": len(class_names),
            "model_file_mb": round(local_path.stat().st_size / (1024 * 1024), 2),
        }

        warnings: list[str] = []
        if expected_classes:
            expected_set = {c.strip() for c in expected_classes if c.strip()}
            present = set(class_names)
            missing = sorted(expected_set - present)
            extra = sorted(present - expected_set)
            if missing:
                warnings.append(f"Missing declared classes: {missing}")
            if extra:
                warnings.append(f"Weights include extra classes not declared: {extra}")

        return {
            "ok": True,
            "class_names": class_names,
            "metrics": metrics,
            "warnings": warnings,
        }
    finally:
        gc.collect()
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass


def api_validate_weights(
    weights_url: str | None = None,
    expected_class_names: list[str] | None = None,
) -> dict[str, Any]:
    """API endpoint for sandboxed .pt weights validation from the backend.

    Args:
        weights_url: Signed URL from which the .pt file can be downloaded.
        expected_class_names: Optional class list the user declared on upload;
            warnings are emitted for mismatches against the weights' classes.

    Returns:
        { ok: bool, class_names?: list[str], metrics?: dict, warnings?: list[str],
          error?: str, elapsed_s: float }
    """
    if not weights_url:
        return {"ok": False, "error": "weights_url is required"}

    start = time.monotonic()
    local_path: Path | None = None

    try:
        local_path = _download_weights(weights_url)
        size_bytes = local_path.stat().st_size
        if size_bytes == 0:
            return {"ok": False, "error": "Downloaded weights file is empty"}
        if size_bytes > MAX_MODEL_BYTES:
            return {
                "ok": False,
                "error": f"Weights file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit",
            }

        try:
            probe = _load_and_probe(str(local_path), expected_class_names or [])
        except Exception as exc:
            return {
                "ok": False,
                "error": f"Failed to load model: {exc}",
                "traceback": traceback.format_exc(),
            }

        probe["elapsed_s"] = round(time.monotonic() - start, 3)
        return probe

    except httpx.HTTPError as exc:
        return {"ok": False, "error": f"Download failed: {exc}"}
    except ValueError as exc:
        return {"ok": False, "error": str(exc)}
    finally:
        if local_path is not None:
            try:
                os.unlink(local_path)
            except FileNotFoundError:
                pass


# ─── Gradio wiring ────────────────────────────────────────────────────────────

with gr.Blocks(title="Cadayn Model Validator") as demo:
    gr.Markdown(
        """
        # Cadayn Custom Model Validator

        Validates an uploaded Ultralytics YOLO `.pt` file by loading it and
        running a single synthetic-image prediction. Returns the declared class
        names and basic metrics, or a structured error payload.

        **API endpoint:** `/api/api_validate_weights`
        """
    )

    with gr.Row():
        api_weights_url = gr.Textbox(label="Signed .pt URL")
        api_expected_classes = gr.JSON(label="Expected class names", value=[])
        api_output = gr.JSON(label="Validation result")

    api_weights_url.change(
        fn=api_validate_weights,
        inputs=[api_weights_url, api_expected_classes],
        outputs=api_output,
        api_name="api_validate_weights",
    )


if __name__ == "__main__":
    demo.launch()