OppaiOracle / app.py
Grio43's picture
Switch Space to V1.1 ONNX (448x448)
93bb3dd
"""OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger.
The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle`
and downloaded on first launch via `hf_hub_download`. Preprocessing
(letterbox + normalize) matches the training pipeline.
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
import gradio as gr
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image, ImageOps
# `spaces` is provided by the HF Spaces runtime. Locally it may not be
# installed; fall back to a no-op decorator so the app still imports.
try:
import spaces # type: ignore
_IS_HF_SPACES = True
except ImportError: # pragma: no cover - local dev path
_IS_HF_SPACES = False
class _SpacesShim:
@staticmethod
def GPU(*dargs, **dkwargs): # noqa: N802 - matches HF API
def _decorator(fn):
return fn
# Support both @spaces.GPU and @spaces.GPU(duration=...)
if dargs and callable(dargs[0]) and not dkwargs:
return dargs[0]
return _decorator
spaces = _SpacesShim() # type: ignore
MODEL_REPO = "Grio43/OppaiOracle"
ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx"
VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json"
PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json"
IMAGE_SIZE = 448
PAD_COLOR = (114, 114, 114)
_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
_STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
# PR_BREAKEVEN is the macro single-threshold PR break-even measured on the val set
# (see V1.1_onnx/pr_thresholds.json -> macro_single_threshold.support_ge_0.pr_breakeven).
# DEFAULT_THRESHOLD is set a touch below the break-even so the UI surfaces a
# reasonable number of tags out of the box; users can adjust the slider.
DEFAULT_THRESHOLD = 0.65
PR_BREAKEVEN = 0.7596
def _download_assets() -> tuple[Path, Path]:
onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO))
vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO))
try:
hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO)
except Exception:
pass
return onnx_path, vocab_path
def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]:
with open(vocab_path, "r", encoding="utf-8") as f:
obj = json.load(f)
tag_to_index: dict[str, int] = obj["tag_to_index"]
index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()}
pad_idx = int(tag_to_index.get("<PAD>", 0))
unk_idx = int(tag_to_index.get("<UNK>", 1))
return index_to_tag, pad_idx, unk_idx
def _make_session(onnx_path: Path) -> ort.InferenceSession:
cpu_threads = os.cpu_count() or 2
providers = ort.get_available_providers()
preferred = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in providers]
# Graph optimization at ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the
# bool padding_mask path that some ORT/CPU builds reject with NOT_IMPLEMENTED.
# Step down to BASIC, then to disabled, before giving up.
last_err: Exception | None = None
for level in (
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
):
so = ort.SessionOptions()
so.graph_optimization_level = level
so.intra_op_num_threads = min(cpu_threads, 8)
try:
return ort.InferenceSession(str(onnx_path), sess_options=so, providers=preferred)
except Exception as e: # pragma: no cover - depends on ORT build
last_err = e
print(f"InferenceSession failed at opt level {level}: {e!r}")
assert last_err is not None
raise last_err
_SESSION: ort.InferenceSession | None = None
_PRIMARY_INPUT: str | None = None
_HAS_MASK_INPUT: bool = False
def _get_session() -> ort.InferenceSession:
"""Lazy-init the ORT session. Must be called from inside @spaces.GPU on
ZeroGPU so CUDAExecutionProvider can see the just-attached device."""
global _SESSION, _PRIMARY_INPUT, _HAS_MASK_INPUT
if _SESSION is not None:
return _SESSION
sess = _make_session(ONNX_PATH)
names = {i.name for i in sess.get_inputs()}
_PRIMARY_INPUT = "pixel_values" if "pixel_values" in names else sess.get_inputs()[0].name
_HAS_MASK_INPUT = "padding_mask" in names
print("Providers:", sess.get_providers())
print("Inputs:", list(names))
_SESSION = sess
return sess
def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]:
img = ImageOps.exif_transpose(img)
was_composited = False
if img.mode in ("RGBA", "LA") or "transparency" in img.info:
was_composited = True
bg = Image.new("RGB", img.size, PAD_COLOR)
rgba = img.convert("RGBA")
alpha = rgba.getchannel("A")
bg.paste(rgba.convert("RGB"), mask=alpha)
img = bg
else:
img = img.convert("RGB")
w, h = img.size
scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0)
new_w = max(1, round(w * scale))
new_h = max(1, round(h * scale))
if (new_w, new_h) != (w, h):
img = img.resize((new_w, new_h), Image.BILINEAR)
canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8)
top = (IMAGE_SIZE - new_h) // 2
left = (IMAGE_SIZE - new_w) // 2
canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8)
mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool)
mask[top : top + new_h, left : left + new_w] = False
x = canvas.astype(np.float32) / 255.0
x = (x - _MEAN) / _STD
x = x.transpose(2, 0, 1)
return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited
print("Downloading model and vocabulary from", MODEL_REPO)
ONNX_PATH, VOCAB_PATH = _download_assets()
print("ONNX model:", ONNX_PATH)
print("Vocabulary:", VOCAB_PATH)
INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH)
print("ZeroGPU mode:", _IS_HF_SPACES, "(session built lazily on first call)")
# ZeroGPU: GPU is only attached to this process *during* a @spaces.GPU call.
# Build the ORT session lazily inside the decorated function so the
# CUDAExecutionProvider initializes against an attached device. The
# `duration` budget covers session build (~10s for ~1 GB ONNX) on the first
# call plus ~1s of actual inference; subsequent calls reuse the cached session.
@spaces.GPU(duration=120)
def predict(image: Image.Image, threshold: float, top_k: int):
if image is None:
return {}, "Upload an image to get tag predictions.", ""
session = _get_session()
primary_input = _PRIMARY_INPUT or session.get_inputs()[0].name
t0 = time.perf_counter()
x, mask, was_composited = _preprocess(image)
feed: dict[str, np.ndarray] = {primary_input: x}
if _HAS_MASK_INPUT:
feed["padding_mask"] = mask
probs = session.run(None, feed)[0][0] # sigmoid already applied inside the ONNX graph
elapsed_ms = int((time.perf_counter() - t0) * 1000)
cap = max(int(top_k), 1)
order = np.argsort(probs)[::-1]
label_dict: dict[str, float] = {}
rating = None
rating_score = -1.0
for raw_idx in order:
idx = int(raw_idx)
if idx in (PAD_IDX, UNK_IDX):
continue
score = float(probs[idx])
if score < threshold:
break
name = INDEX_TO_TAG.get(idx)
if name is None:
continue
if was_composited and name == "gray_background":
continue
if name.startswith("rating:"):
if score > rating_score:
rating = name.split(":", 1)[1]
rating_score = score
label_dict[name] = score
if len(label_dict) >= cap:
break
if not label_dict:
text_summary = "No tags above threshold."
else:
text_summary = ", ".join(label_dict.keys())
info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"]
if rating is not None:
info_lines.append(f"Predicted rating: {rating} (score {rating_score:.3f})")
info = " • ".join(info_lines)
return label_dict, text_summary, info
with gr.Blocks(title="OppaiOracle — anime tagger") as demo:
gr.Markdown(
"""# OppaiOracle V1.1 — anime tagger
Multi-label ViT fine-tuned at 448×448 (a fine-tune of the from-scratch V1 320×320
model, on the same cleaned ~5.9M-image corpus, 19,294 tags). Drop in an image to
see ranked tag predictions.
**Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle)
documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias).
Predictions are best treated as a fast first pass that a human reviews — not as ground truth.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"])
threshold = gr.Slider(
minimum=0.05,
maximum=0.95,
value=DEFAULT_THRESHOLD,
step=0.01,
label="Threshold",
info=f"Tags below this score are dropped. Macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.",
)
top_k = gr.Slider(
minimum=5,
maximum=200,
value=50,
step=1,
label="Max tags",
)
run_btn = gr.Button("Tag", variant="primary")
with gr.Column(scale=1):
labels_out = gr.Label(label="Predictions", num_top_classes=200)
tags_text = gr.Textbox(label="Comma-separated tags", lines=4)
info_out = gr.Markdown()
run_btn.click(
predict,
inputs=[image_in, threshold, top_k],
outputs=[labels_out, tags_text, info_out],
)
image_in.change(
predict,
inputs=[image_in, threshold, top_k],
outputs=[labels_out, tags_text, info_out],
)
gr.Markdown(
"""---
Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 ·
Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph).
This Space runs the V1.1 (448×448) checkpoint. The V1 (320×320) checkpoint is also in the
model repo for users who specifically want the smaller native resolution — match input
resolution to the checkpoint you load.
"""
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=2).launch()