File size: 8,767 Bytes
a1baa27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | """OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger, CPU build.
The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle`
(`V1.1_onnx/model.onnx`) and downloaded on first launch via `hf_hub_download`.
Preprocessing (letterbox + normalize) matches the training pipeline. This
Space targets CPU hardware — inference of the 448x448 ViT on the standard
HF CPU tier takes ~10-30s per image.
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
import gradio as gr
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image, ImageOps
MODEL_REPO = "Grio43/OppaiOracle"
ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx"
VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json"
PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json"
IMAGE_SIZE = 448
PAD_COLOR = (114, 114, 114)
_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
_STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3)
# V1.1 macro PR break-even on val (pr_thresholds.json -> macro_single_threshold.support_ge_0).
PR_BREAKEVEN = 0.7596
DEFAULT_THRESHOLD = 0.50
def _download_assets() -> tuple[Path, Path]:
onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO))
vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO))
try:
hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO)
except Exception:
pass
return onnx_path, vocab_path
def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]:
with open(vocab_path, "r", encoding="utf-8") as f:
obj = json.load(f)
tag_to_index: dict[str, int] = obj["tag_to_index"]
index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()}
pad_idx = int(tag_to_index.get("<PAD>", 0))
unk_idx = int(tag_to_index.get("<UNK>", 1))
return index_to_tag, pad_idx, unk_idx
def _make_session(onnx_path: Path) -> ort.InferenceSession:
cpu_threads = os.cpu_count() or 2
providers = ["CPUExecutionProvider"]
# On HF Spaces CPU, ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the
# bool padding_mask path that the CPU build rejects with NOT_IMPLEMENTED.
# Step down to BASIC, then to disabled, before giving up.
last_err: Exception | None = None
for level in (
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
):
so = ort.SessionOptions()
so.graph_optimization_level = level
so.intra_op_num_threads = min(cpu_threads, 8)
try:
return ort.InferenceSession(str(onnx_path), sess_options=so, providers=providers)
except Exception as e: # pragma: no cover - depends on ORT build
last_err = e
print(f"InferenceSession failed at opt level {level}: {e!r}")
assert last_err is not None
raise last_err
def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]:
img = ImageOps.exif_transpose(img)
was_composited = False
if img.mode in ("RGBA", "LA") or "transparency" in img.info:
was_composited = True
bg = Image.new("RGB", img.size, PAD_COLOR)
rgba = img.convert("RGBA")
alpha = rgba.getchannel("A")
bg.paste(rgba.convert("RGB"), mask=alpha)
img = bg
else:
img = img.convert("RGB")
w, h = img.size
scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0)
new_w = max(1, round(w * scale))
new_h = max(1, round(h * scale))
if (new_w, new_h) != (w, h):
img = img.resize((new_w, new_h), Image.BILINEAR)
canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8)
top = (IMAGE_SIZE - new_h) // 2
left = (IMAGE_SIZE - new_w) // 2
canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8)
mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool)
mask[top : top + new_h, left : left + new_w] = False
x = canvas.astype(np.float32) / 255.0
x = (x - _MEAN) / _STD
x = x.transpose(2, 0, 1)
return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited
print("Downloading model and vocabulary from", MODEL_REPO)
ONNX_PATH, VOCAB_PATH = _download_assets()
print("ONNX model:", ONNX_PATH)
print("Vocabulary:", VOCAB_PATH)
INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH)
print("Building ORT CPU session...")
SESSION = _make_session(ONNX_PATH)
INPUT_NAMES = {i.name for i in SESSION.get_inputs()}
PRIMARY_INPUT = "pixel_values" if "pixel_values" in INPUT_NAMES else SESSION.get_inputs()[0].name
HAS_MASK_INPUT = "padding_mask" in INPUT_NAMES
print("Providers:", SESSION.get_providers())
print("Inputs:", list(INPUT_NAMES))
def predict(image: Image.Image, threshold: float, top_k: int):
if image is None:
return {}, "Upload an image to get tag predictions.", ""
t0 = time.perf_counter()
x, mask, was_composited = _preprocess(image)
feed: dict[str, np.ndarray] = {PRIMARY_INPUT: x}
if HAS_MASK_INPUT:
feed["padding_mask"] = mask
probs = SESSION.run(None, feed)[0][0] # already sigmoid in the V1.1 graph
elapsed_ms = int((time.perf_counter() - t0) * 1000)
cap = max(int(top_k), 1)
order = np.argsort(probs)[::-1]
label_dict: dict[str, float] = {}
rating = None
rating_score = -1.0
for raw_idx in order:
idx = int(raw_idx)
if idx in (PAD_IDX, UNK_IDX):
continue
score = float(probs[idx])
if score < threshold:
break
name = INDEX_TO_TAG.get(idx)
if name is None:
continue
if was_composited and name == "gray_background":
continue
if name.startswith("rating:"):
if score > rating_score:
rating = name.split(":", 1)[1]
rating_score = score
label_dict[name] = score
if len(label_dict) >= cap:
break
if not label_dict:
text_summary = "No tags above threshold."
else:
text_summary = ", ".join(label_dict.keys())
info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"]
if rating is not None:
info_lines.append(f"Predicted rating: {rating} (score {rating_score:.3f})")
info = " • ".join(info_lines)
return label_dict, text_summary, info
with gr.Blocks(title="OppaiOracle V1.1 (CPU) — anime tagger") as demo:
gr.Markdown(
"""# OppaiOracle V1.1 (CPU) — anime tagger
Multi-label ViT trained from scratch at 448×448 on a cleaned ~5.9M-image corpus
(19,294 tags). Drop in an image to see ranked tag predictions.
**This Space runs on CPU.** Each prediction takes ~10–30 s on the HF CPU tier
(the 448² ViT is ~250M params). For faster turnaround, see the GPU Space at
[Grio43/OppaiOracle](https://huggingface.co/spaces/Grio43/OppaiOracle).
**Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle)
documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias).
Predictions are best treated as a fast first pass that a human reviews — not as ground truth.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"])
threshold = gr.Slider(
minimum=0.05,
maximum=0.95,
value=DEFAULT_THRESHOLD,
step=0.01,
label="Threshold",
info=f"Tags below this score are dropped. V1.1 macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.",
)
top_k = gr.Slider(
minimum=5,
maximum=200,
value=50,
step=1,
label="Max tags",
)
run_btn = gr.Button("Tag", variant="primary")
with gr.Column(scale=1):
labels_out = gr.Label(label="Predictions", num_top_classes=200)
tags_text = gr.Textbox(label="Comma-separated tags", lines=4)
info_out = gr.Markdown()
run_btn.click(
predict,
inputs=[image_in, threshold, top_k],
outputs=[labels_out, tags_text, info_out],
)
gr.Markdown(
"""---
Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 ·
Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph).
This Space runs the V1.1 (448×448) checkpoint on CPU. Match input resolution to the
checkpoint you load — feeding 320 to V1.1 (or 448 to V1) hurts accuracy because the ViT
position-embedding grid is fixed at load time.
"""
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch()
|