Spaces:
Running on Zero
Running on Zero
File size: 10,568 Bytes
93bb3dd c5c3754 571166c c5c3754 93bb3dd c5c3754 93bb3dd c5c3754 93bb3dd c5c3754 a1a1d98 c5c3754 571166c c5c3754 571166c c5c3754 571166c c5c3754 571166c c5c3754 571166c c5c3754 93bb3dd c5c3754 93bb3dd c5c3754 1163a34 c5c3754 93bb3dd c5c3754 93bb3dd c5c3754 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 | """OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger.
The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle`
and downloaded on first launch via `hf_hub_download`. Preprocessing
(letterbox + normalize) matches the training pipeline.
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
import gradio as gr
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image, ImageOps
# `spaces` is provided by the HF Spaces runtime. Locally it may not be
# installed; fall back to a no-op decorator so the app still imports.
try:
import spaces # type: ignore
_IS_HF_SPACES = True
except ImportError: # pragma: no cover - local dev path
_IS_HF_SPACES = False
class _SpacesShim:
@staticmethod
def GPU(*dargs, **dkwargs): # noqa: N802 - matches HF API
def _decorator(fn):
return fn
# Support both @spaces.GPU and @spaces.GPU(duration=...)
if dargs and callable(dargs[0]) and not dkwargs:
return dargs[0]
return _decorator
spaces = _SpacesShim() # type: ignore
MODEL_REPO = "Grio43/OppaiOracle"
ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx"
VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json"
PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json"
IMAGE_SIZE = 448
PAD_COLOR = (114, 114, 114)
_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
_STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
# PR_BREAKEVEN is the macro single-threshold PR break-even measured on the val set
# (see V1.1_onnx/pr_thresholds.json -> macro_single_threshold.support_ge_0.pr_breakeven).
# DEFAULT_THRESHOLD is set a touch below the break-even so the UI surfaces a
# reasonable number of tags out of the box; users can adjust the slider.
DEFAULT_THRESHOLD = 0.65
PR_BREAKEVEN = 0.7596
def _download_assets() -> tuple[Path, Path]:
onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO))
vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO))
try:
hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO)
except Exception:
pass
return onnx_path, vocab_path
def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]:
with open(vocab_path, "r", encoding="utf-8") as f:
obj = json.load(f)
tag_to_index: dict[str, int] = obj["tag_to_index"]
index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()}
pad_idx = int(tag_to_index.get("<PAD>", 0))
unk_idx = int(tag_to_index.get("<UNK>", 1))
return index_to_tag, pad_idx, unk_idx
def _make_session(onnx_path: Path) -> ort.InferenceSession:
cpu_threads = os.cpu_count() or 2
providers = ort.get_available_providers()
preferred = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in providers]
# Graph optimization at ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the
# bool padding_mask path that some ORT/CPU builds reject with NOT_IMPLEMENTED.
# Step down to BASIC, then to disabled, before giving up.
last_err: Exception | None = None
for level in (
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
):
so = ort.SessionOptions()
so.graph_optimization_level = level
so.intra_op_num_threads = min(cpu_threads, 8)
try:
return ort.InferenceSession(str(onnx_path), sess_options=so, providers=preferred)
except Exception as e: # pragma: no cover - depends on ORT build
last_err = e
print(f"InferenceSession failed at opt level {level}: {e!r}")
assert last_err is not None
raise last_err
_SESSION: ort.InferenceSession | None = None
_PRIMARY_INPUT: str | None = None
_HAS_MASK_INPUT: bool = False
def _get_session() -> ort.InferenceSession:
"""Lazy-init the ORT session. Must be called from inside @spaces.GPU on
ZeroGPU so CUDAExecutionProvider can see the just-attached device."""
global _SESSION, _PRIMARY_INPUT, _HAS_MASK_INPUT
if _SESSION is not None:
return _SESSION
sess = _make_session(ONNX_PATH)
names = {i.name for i in sess.get_inputs()}
_PRIMARY_INPUT = "pixel_values" if "pixel_values" in names else sess.get_inputs()[0].name
_HAS_MASK_INPUT = "padding_mask" in names
print("Providers:", sess.get_providers())
print("Inputs:", list(names))
_SESSION = sess
return sess
def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]:
img = ImageOps.exif_transpose(img)
was_composited = False
if img.mode in ("RGBA", "LA") or "transparency" in img.info:
was_composited = True
bg = Image.new("RGB", img.size, PAD_COLOR)
rgba = img.convert("RGBA")
alpha = rgba.getchannel("A")
bg.paste(rgba.convert("RGB"), mask=alpha)
img = bg
else:
img = img.convert("RGB")
w, h = img.size
scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0)
new_w = max(1, round(w * scale))
new_h = max(1, round(h * scale))
if (new_w, new_h) != (w, h):
img = img.resize((new_w, new_h), Image.BILINEAR)
canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8)
top = (IMAGE_SIZE - new_h) // 2
left = (IMAGE_SIZE - new_w) // 2
canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8)
mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool)
mask[top : top + new_h, left : left + new_w] = False
x = canvas.astype(np.float32) / 255.0
x = (x - _MEAN) / _STD
x = x.transpose(2, 0, 1)
return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited
print("Downloading model and vocabulary from", MODEL_REPO)
ONNX_PATH, VOCAB_PATH = _download_assets()
print("ONNX model:", ONNX_PATH)
print("Vocabulary:", VOCAB_PATH)
INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH)
print("ZeroGPU mode:", _IS_HF_SPACES, "(session built lazily on first call)")
# ZeroGPU: GPU is only attached to this process *during* a @spaces.GPU call.
# Build the ORT session lazily inside the decorated function so the
# CUDAExecutionProvider initializes against an attached device. The
# `duration` budget covers session build (~10s for ~1 GB ONNX) on the first
# call plus ~1s of actual inference; subsequent calls reuse the cached session.
@spaces.GPU(duration=120)
def predict(image: Image.Image, threshold: float, top_k: int):
if image is None:
return {}, "Upload an image to get tag predictions.", ""
session = _get_session()
primary_input = _PRIMARY_INPUT or session.get_inputs()[0].name
t0 = time.perf_counter()
x, mask, was_composited = _preprocess(image)
feed: dict[str, np.ndarray] = {primary_input: x}
if _HAS_MASK_INPUT:
feed["padding_mask"] = mask
probs = session.run(None, feed)[0][0] # sigmoid already applied inside the ONNX graph
elapsed_ms = int((time.perf_counter() - t0) * 1000)
cap = max(int(top_k), 1)
order = np.argsort(probs)[::-1]
label_dict: dict[str, float] = {}
rating = None
rating_score = -1.0
for raw_idx in order:
idx = int(raw_idx)
if idx in (PAD_IDX, UNK_IDX):
continue
score = float(probs[idx])
if score < threshold:
break
name = INDEX_TO_TAG.get(idx)
if name is None:
continue
if was_composited and name == "gray_background":
continue
if name.startswith("rating:"):
if score > rating_score:
rating = name.split(":", 1)[1]
rating_score = score
label_dict[name] = score
if len(label_dict) >= cap:
break
if not label_dict:
text_summary = "No tags above threshold."
else:
text_summary = ", ".join(label_dict.keys())
info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"]
if rating is not None:
info_lines.append(f"Predicted rating: {rating} (score {rating_score:.3f})")
info = " • ".join(info_lines)
return label_dict, text_summary, info
with gr.Blocks(title="OppaiOracle — anime tagger") as demo:
gr.Markdown(
"""# OppaiOracle V1.1 — anime tagger
Multi-label ViT fine-tuned at 448×448 (a fine-tune of the from-scratch V1 320×320
model, on the same cleaned ~5.9M-image corpus, 19,294 tags). Drop in an image to
see ranked tag predictions.
**Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle)
documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias).
Predictions are best treated as a fast first pass that a human reviews — not as ground truth.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"])
threshold = gr.Slider(
minimum=0.05,
maximum=0.95,
value=DEFAULT_THRESHOLD,
step=0.01,
label="Threshold",
info=f"Tags below this score are dropped. Macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.",
)
top_k = gr.Slider(
minimum=5,
maximum=200,
value=50,
step=1,
label="Max tags",
)
run_btn = gr.Button("Tag", variant="primary")
with gr.Column(scale=1):
labels_out = gr.Label(label="Predictions", num_top_classes=200)
tags_text = gr.Textbox(label="Comma-separated tags", lines=4)
info_out = gr.Markdown()
run_btn.click(
predict,
inputs=[image_in, threshold, top_k],
outputs=[labels_out, tags_text, info_out],
)
image_in.change(
predict,
inputs=[image_in, threshold, top_k],
outputs=[labels_out, tags_text, info_out],
)
gr.Markdown(
"""---
Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 ·
Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph).
This Space runs the V1.1 (448×448) checkpoint. The V1 (320×320) checkpoint is also in the
model repo for users who specifically want the smaller native resolution — match input
resolution to the checkpoint you load.
"""
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=2).launch()
|