File size: 10,568 Bytes
93bb3dd
c5c3754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571166c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5c3754
93bb3dd
 
 
c5c3754
93bb3dd
c5c3754
 
 
 
93bb3dd
 
 
 
 
 
c5c3754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1a1d98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5c3754
 
571166c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5c3754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571166c
c5c3754
 
571166c
 
 
 
 
 
c5c3754
 
 
 
571166c
 
 
c5c3754
 
571166c
 
c5c3754
 
93bb3dd
c5c3754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93bb3dd
 
 
 
c5c3754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1163a34
c5c3754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93bb3dd
c5c3754
 
93bb3dd
 
 
c5c3754
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger.

The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle`
and downloaded on first launch via `hf_hub_download`. Preprocessing
(letterbox + normalize) matches the training pipeline.
"""
from __future__ import annotations

import json
import os
import time
from pathlib import Path

import gradio as gr
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image, ImageOps

# `spaces` is provided by the HF Spaces runtime. Locally it may not be
# installed; fall back to a no-op decorator so the app still imports.
try:
    import spaces  # type: ignore
    _IS_HF_SPACES = True
except ImportError:  # pragma: no cover - local dev path
    _IS_HF_SPACES = False

    class _SpacesShim:
        @staticmethod
        def GPU(*dargs, **dkwargs):  # noqa: N802 - matches HF API
            def _decorator(fn):
                return fn

            # Support both @spaces.GPU and @spaces.GPU(duration=...)
            if dargs and callable(dargs[0]) and not dkwargs:
                return dargs[0]
            return _decorator

    spaces = _SpacesShim()  # type: ignore

MODEL_REPO = "Grio43/OppaiOracle"
ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx"
VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json"
PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json"

IMAGE_SIZE = 448
PAD_COLOR = (114, 114, 114)
_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
_STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)

# PR_BREAKEVEN is the macro single-threshold PR break-even measured on the val set
# (see V1.1_onnx/pr_thresholds.json -> macro_single_threshold.support_ge_0.pr_breakeven).
# DEFAULT_THRESHOLD is set a touch below the break-even so the UI surfaces a
# reasonable number of tags out of the box; users can adjust the slider.
DEFAULT_THRESHOLD = 0.65
PR_BREAKEVEN = 0.7596


def _download_assets() -> tuple[Path, Path]:
    onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO))
    vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO))
    try:
        hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO)
    except Exception:
        pass
    return onnx_path, vocab_path


def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]:
    with open(vocab_path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    tag_to_index: dict[str, int] = obj["tag_to_index"]
    index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()}
    pad_idx = int(tag_to_index.get("<PAD>", 0))
    unk_idx = int(tag_to_index.get("<UNK>", 1))
    return index_to_tag, pad_idx, unk_idx


def _make_session(onnx_path: Path) -> ort.InferenceSession:
    cpu_threads = os.cpu_count() or 2
    providers = ort.get_available_providers()
    preferred = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in providers]

    # Graph optimization at ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the
    # bool padding_mask path that some ORT/CPU builds reject with NOT_IMPLEMENTED.
    # Step down to BASIC, then to disabled, before giving up.
    last_err: Exception | None = None
    for level in (
        ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
        ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
    ):
        so = ort.SessionOptions()
        so.graph_optimization_level = level
        so.intra_op_num_threads = min(cpu_threads, 8)
        try:
            return ort.InferenceSession(str(onnx_path), sess_options=so, providers=preferred)
        except Exception as e:  # pragma: no cover - depends on ORT build
            last_err = e
            print(f"InferenceSession failed at opt level {level}: {e!r}")
    assert last_err is not None
    raise last_err


_SESSION: ort.InferenceSession | None = None
_PRIMARY_INPUT: str | None = None
_HAS_MASK_INPUT: bool = False


def _get_session() -> ort.InferenceSession:
    """Lazy-init the ORT session. Must be called from inside @spaces.GPU on
    ZeroGPU so CUDAExecutionProvider can see the just-attached device."""
    global _SESSION, _PRIMARY_INPUT, _HAS_MASK_INPUT
    if _SESSION is not None:
        return _SESSION
    sess = _make_session(ONNX_PATH)
    names = {i.name for i in sess.get_inputs()}
    _PRIMARY_INPUT = "pixel_values" if "pixel_values" in names else sess.get_inputs()[0].name
    _HAS_MASK_INPUT = "padding_mask" in names
    print("Providers:", sess.get_providers())
    print("Inputs:", list(names))
    _SESSION = sess
    return sess


def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]:
    img = ImageOps.exif_transpose(img)
    was_composited = False
    if img.mode in ("RGBA", "LA") or "transparency" in img.info:
        was_composited = True
        bg = Image.new("RGB", img.size, PAD_COLOR)
        rgba = img.convert("RGBA")
        alpha = rgba.getchannel("A")
        bg.paste(rgba.convert("RGB"), mask=alpha)
        img = bg
    else:
        img = img.convert("RGB")

    w, h = img.size
    scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0)
    new_w = max(1, round(w * scale))
    new_h = max(1, round(h * scale))
    if (new_w, new_h) != (w, h):
        img = img.resize((new_w, new_h), Image.BILINEAR)

    canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8)
    top = (IMAGE_SIZE - new_h) // 2
    left = (IMAGE_SIZE - new_w) // 2
    canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8)

    mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool)
    mask[top : top + new_h, left : left + new_w] = False

    x = canvas.astype(np.float32) / 255.0
    x = (x - _MEAN) / _STD
    x = x.transpose(2, 0, 1)
    return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited


print("Downloading model and vocabulary from", MODEL_REPO)
ONNX_PATH, VOCAB_PATH = _download_assets()
print("ONNX model:", ONNX_PATH)
print("Vocabulary:", VOCAB_PATH)
INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH)
print("ZeroGPU mode:", _IS_HF_SPACES, "(session built lazily on first call)")


# ZeroGPU: GPU is only attached to this process *during* a @spaces.GPU call.
# Build the ORT session lazily inside the decorated function so the
# CUDAExecutionProvider initializes against an attached device. The
# `duration` budget covers session build (~10s for ~1 GB ONNX) on the first
# call plus ~1s of actual inference; subsequent calls reuse the cached session.
@spaces.GPU(duration=120)
def predict(image: Image.Image, threshold: float, top_k: int):
    if image is None:
        return {}, "Upload an image to get tag predictions.", ""

    session = _get_session()
    primary_input = _PRIMARY_INPUT or session.get_inputs()[0].name

    t0 = time.perf_counter()
    x, mask, was_composited = _preprocess(image)
    feed: dict[str, np.ndarray] = {primary_input: x}
    if _HAS_MASK_INPUT:
        feed["padding_mask"] = mask

    probs = session.run(None, feed)[0][0]  # sigmoid already applied inside the ONNX graph
    elapsed_ms = int((time.perf_counter() - t0) * 1000)

    cap = max(int(top_k), 1)
    order = np.argsort(probs)[::-1]
    label_dict: dict[str, float] = {}
    rating = None
    rating_score = -1.0
    for raw_idx in order:
        idx = int(raw_idx)
        if idx in (PAD_IDX, UNK_IDX):
            continue
        score = float(probs[idx])
        if score < threshold:
            break
        name = INDEX_TO_TAG.get(idx)
        if name is None:
            continue
        if was_composited and name == "gray_background":
            continue
        if name.startswith("rating:"):
            if score > rating_score:
                rating = name.split(":", 1)[1]
                rating_score = score
        label_dict[name] = score
        if len(label_dict) >= cap:
            break

    if not label_dict:
        text_summary = "No tags above threshold."
    else:
        text_summary = ", ".join(label_dict.keys())

    info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"]
    if rating is not None:
        info_lines.append(f"Predicted rating: {rating}  (score {rating_score:.3f})")
    info = "  •  ".join(info_lines)

    return label_dict, text_summary, info


with gr.Blocks(title="OppaiOracle — anime tagger") as demo:
    gr.Markdown(
        """# OppaiOracle V1.1 — anime tagger
Multi-label ViT fine-tuned at 448×448 (a fine-tune of the from-scratch V1 320×320
model, on the same cleaned ~5.9M-image corpus, 19,294 tags). Drop in an image to
see ranked tag predictions.

**Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle)
documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias).
Predictions are best treated as a fast first pass that a human reviews — not as ground truth.
"""
    )
    with gr.Row():
        with gr.Column(scale=1):
            image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"])
            threshold = gr.Slider(
                minimum=0.05,
                maximum=0.95,
                value=DEFAULT_THRESHOLD,
                step=0.01,
                label="Threshold",
                info=f"Tags below this score are dropped. Macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.",
            )
            top_k = gr.Slider(
                minimum=5,
                maximum=200,
                value=50,
                step=1,
                label="Max tags",
            )
            run_btn = gr.Button("Tag", variant="primary")
        with gr.Column(scale=1):
            labels_out = gr.Label(label="Predictions", num_top_classes=200)
            tags_text = gr.Textbox(label="Comma-separated tags", lines=4)
            info_out = gr.Markdown()

    run_btn.click(
        predict,
        inputs=[image_in, threshold, top_k],
        outputs=[labels_out, tags_text, info_out],
    )
    image_in.change(
        predict,
        inputs=[image_in, threshold, top_k],
        outputs=[labels_out, tags_text, info_out],
    )

    gr.Markdown(
        """---
Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 ·
Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph).

This Space runs the V1.1 (448×448) checkpoint. The V1 (320×320) checkpoint is also in the
model repo for users who specifically want the smaller native resolution — match input
resolution to the checkpoint you load.
"""
    )


if __name__ == "__main__":
    demo.queue(default_concurrency_limit=2).launch()