File size: 8,767 Bytes
a1baa27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger, CPU build.

The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle`
(`V1.1_onnx/model.onnx`) and downloaded on first launch via `hf_hub_download`.
Preprocessing (letterbox + normalize) matches the training pipeline. This
Space targets CPU hardware — inference of the 448x448 ViT on the standard
HF CPU tier takes ~10-30s per image.
"""
from __future__ import annotations

import json
import os
import time
from pathlib import Path

import gradio as gr
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image, ImageOps

MODEL_REPO = "Grio43/OppaiOracle"
ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx"
VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json"
PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json"

IMAGE_SIZE = 448
PAD_COLOR = (114, 114, 114)
_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
_STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)

# V1.1 macro PR break-even on val (pr_thresholds.json -> macro_single_threshold.support_ge_0).
PR_BREAKEVEN = 0.7596
DEFAULT_THRESHOLD = 0.50


def _download_assets() -> tuple[Path, Path]:
    onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO))
    vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO))
    try:
        hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO)
    except Exception:
        pass
    return onnx_path, vocab_path


def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]:
    with open(vocab_path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    tag_to_index: dict[str, int] = obj["tag_to_index"]
    index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()}
    pad_idx = int(tag_to_index.get("<PAD>", 0))
    unk_idx = int(tag_to_index.get("<UNK>", 1))
    return index_to_tag, pad_idx, unk_idx


def _make_session(onnx_path: Path) -> ort.InferenceSession:
    cpu_threads = os.cpu_count() or 2
    providers = ["CPUExecutionProvider"]

    # On HF Spaces CPU, ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the
    # bool padding_mask path that the CPU build rejects with NOT_IMPLEMENTED.
    # Step down to BASIC, then to disabled, before giving up.
    last_err: Exception | None = None
    for level in (
        ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
        ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
    ):
        so = ort.SessionOptions()
        so.graph_optimization_level = level
        so.intra_op_num_threads = min(cpu_threads, 8)
        try:
            return ort.InferenceSession(str(onnx_path), sess_options=so, providers=providers)
        except Exception as e:  # pragma: no cover - depends on ORT build
            last_err = e
            print(f"InferenceSession failed at opt level {level}: {e!r}")
    assert last_err is not None
    raise last_err


def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]:
    img = ImageOps.exif_transpose(img)
    was_composited = False
    if img.mode in ("RGBA", "LA") or "transparency" in img.info:
        was_composited = True
        bg = Image.new("RGB", img.size, PAD_COLOR)
        rgba = img.convert("RGBA")
        alpha = rgba.getchannel("A")
        bg.paste(rgba.convert("RGB"), mask=alpha)
        img = bg
    else:
        img = img.convert("RGB")

    w, h = img.size
    scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0)
    new_w = max(1, round(w * scale))
    new_h = max(1, round(h * scale))
    if (new_w, new_h) != (w, h):
        img = img.resize((new_w, new_h), Image.BILINEAR)

    canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8)
    top = (IMAGE_SIZE - new_h) // 2
    left = (IMAGE_SIZE - new_w) // 2
    canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8)

    mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool)
    mask[top : top + new_h, left : left + new_w] = False

    x = canvas.astype(np.float32) / 255.0
    x = (x - _MEAN) / _STD
    x = x.transpose(2, 0, 1)
    return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited


print("Downloading model and vocabulary from", MODEL_REPO)
ONNX_PATH, VOCAB_PATH = _download_assets()
print("ONNX model:", ONNX_PATH)
print("Vocabulary:", VOCAB_PATH)
INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH)

print("Building ORT CPU session...")
SESSION = _make_session(ONNX_PATH)
INPUT_NAMES = {i.name for i in SESSION.get_inputs()}
PRIMARY_INPUT = "pixel_values" if "pixel_values" in INPUT_NAMES else SESSION.get_inputs()[0].name
HAS_MASK_INPUT = "padding_mask" in INPUT_NAMES
print("Providers:", SESSION.get_providers())
print("Inputs:", list(INPUT_NAMES))


def predict(image: Image.Image, threshold: float, top_k: int):
    if image is None:
        return {}, "Upload an image to get tag predictions.", ""

    t0 = time.perf_counter()
    x, mask, was_composited = _preprocess(image)
    feed: dict[str, np.ndarray] = {PRIMARY_INPUT: x}
    if HAS_MASK_INPUT:
        feed["padding_mask"] = mask

    probs = SESSION.run(None, feed)[0][0]  # already sigmoid in the V1.1 graph
    elapsed_ms = int((time.perf_counter() - t0) * 1000)

    cap = max(int(top_k), 1)
    order = np.argsort(probs)[::-1]
    label_dict: dict[str, float] = {}
    rating = None
    rating_score = -1.0
    for raw_idx in order:
        idx = int(raw_idx)
        if idx in (PAD_IDX, UNK_IDX):
            continue
        score = float(probs[idx])
        if score < threshold:
            break
        name = INDEX_TO_TAG.get(idx)
        if name is None:
            continue
        if was_composited and name == "gray_background":
            continue
        if name.startswith("rating:"):
            if score > rating_score:
                rating = name.split(":", 1)[1]
                rating_score = score
        label_dict[name] = score
        if len(label_dict) >= cap:
            break

    if not label_dict:
        text_summary = "No tags above threshold."
    else:
        text_summary = ", ".join(label_dict.keys())

    info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"]
    if rating is not None:
        info_lines.append(f"Predicted rating: {rating}  (score {rating_score:.3f})")
    info = "  •  ".join(info_lines)

    return label_dict, text_summary, info


with gr.Blocks(title="OppaiOracle V1.1 (CPU) — anime tagger") as demo:
    gr.Markdown(
        """# OppaiOracle V1.1 (CPU) — anime tagger
Multi-label ViT trained from scratch at 448×448 on a cleaned ~5.9M-image corpus
(19,294 tags). Drop in an image to see ranked tag predictions.

**This Space runs on CPU.** Each prediction takes ~10–30 s on the HF CPU tier
(the 448² ViT is ~250M params). For faster turnaround, see the GPU Space at
[Grio43/OppaiOracle](https://huggingface.co/spaces/Grio43/OppaiOracle).

**Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle)
documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias).
Predictions are best treated as a fast first pass that a human reviews — not as ground truth.
"""
    )
    with gr.Row():
        with gr.Column(scale=1):
            image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"])
            threshold = gr.Slider(
                minimum=0.05,
                maximum=0.95,
                value=DEFAULT_THRESHOLD,
                step=0.01,
                label="Threshold",
                info=f"Tags below this score are dropped. V1.1 macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.",
            )
            top_k = gr.Slider(
                minimum=5,
                maximum=200,
                value=50,
                step=1,
                label="Max tags",
            )
            run_btn = gr.Button("Tag", variant="primary")
        with gr.Column(scale=1):
            labels_out = gr.Label(label="Predictions", num_top_classes=200)
            tags_text = gr.Textbox(label="Comma-separated tags", lines=4)
            info_out = gr.Markdown()

    run_btn.click(
        predict,
        inputs=[image_in, threshold, top_k],
        outputs=[labels_out, tags_text, info_out],
    )

    gr.Markdown(
        """---
Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 ·
Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph).

This Space runs the V1.1 (448×448) checkpoint on CPU. Match input resolution to the
checkpoint you load — feeding 320 to V1.1 (or 448 to V1) hurts accuracy because the ViT
position-embedding grid is fixed at load time.
"""
    )


if __name__ == "__main__":
    demo.queue(default_concurrency_limit=1).launch()