| """OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger, CPU build. |
| |
| The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle` |
| (`V1.1_onnx/model.onnx`) and downloaded on first launch via `hf_hub_download`. |
| Preprocessing (letterbox + normalize) matches the training pipeline. This |
| Space targets CPU hardware — inference of the 448x448 ViT on the standard |
| HF CPU tier takes ~10-30s per image. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import time |
| from pathlib import Path |
|
|
| import gradio as gr |
| import numpy as np |
| import onnxruntime as ort |
| from huggingface_hub import hf_hub_download |
| from PIL import Image, ImageOps |
|
|
| MODEL_REPO = "Grio43/OppaiOracle" |
| ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx" |
| VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json" |
| PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json" |
|
|
| IMAGE_SIZE = 448 |
| PAD_COLOR = (114, 114, 114) |
| _MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3) |
| _STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3) |
|
|
| |
| PR_BREAKEVEN = 0.7596 |
| DEFAULT_THRESHOLD = 0.50 |
|
|
|
|
| def _download_assets() -> tuple[Path, Path]: |
| onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO)) |
| vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO)) |
| try: |
| hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO) |
| except Exception: |
| pass |
| return onnx_path, vocab_path |
|
|
|
|
| def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]: |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| obj = json.load(f) |
| tag_to_index: dict[str, int] = obj["tag_to_index"] |
| index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()} |
| pad_idx = int(tag_to_index.get("<PAD>", 0)) |
| unk_idx = int(tag_to_index.get("<UNK>", 1)) |
| return index_to_tag, pad_idx, unk_idx |
|
|
|
|
| def _make_session(onnx_path: Path) -> ort.InferenceSession: |
| cpu_threads = os.cpu_count() or 2 |
| providers = ["CPUExecutionProvider"] |
|
|
| |
| |
| |
| last_err: Exception | None = None |
| for level in ( |
| ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, |
| ort.GraphOptimizationLevel.ORT_DISABLE_ALL, |
| ): |
| so = ort.SessionOptions() |
| so.graph_optimization_level = level |
| so.intra_op_num_threads = min(cpu_threads, 8) |
| try: |
| return ort.InferenceSession(str(onnx_path), sess_options=so, providers=providers) |
| except Exception as e: |
| last_err = e |
| print(f"InferenceSession failed at opt level {level}: {e!r}") |
| assert last_err is not None |
| raise last_err |
|
|
|
|
| def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]: |
| img = ImageOps.exif_transpose(img) |
| was_composited = False |
| if img.mode in ("RGBA", "LA") or "transparency" in img.info: |
| was_composited = True |
| bg = Image.new("RGB", img.size, PAD_COLOR) |
| rgba = img.convert("RGBA") |
| alpha = rgba.getchannel("A") |
| bg.paste(rgba.convert("RGB"), mask=alpha) |
| img = bg |
| else: |
| img = img.convert("RGB") |
|
|
| w, h = img.size |
| scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0) |
| new_w = max(1, round(w * scale)) |
| new_h = max(1, round(h * scale)) |
| if (new_w, new_h) != (w, h): |
| img = img.resize((new_w, new_h), Image.BILINEAR) |
|
|
| canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8) |
| top = (IMAGE_SIZE - new_h) // 2 |
| left = (IMAGE_SIZE - new_w) // 2 |
| canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8) |
|
|
| mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool) |
| mask[top : top + new_h, left : left + new_w] = False |
|
|
| x = canvas.astype(np.float32) / 255.0 |
| x = (x - _MEAN) / _STD |
| x = x.transpose(2, 0, 1) |
| return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited |
|
|
|
|
| print("Downloading model and vocabulary from", MODEL_REPO) |
| ONNX_PATH, VOCAB_PATH = _download_assets() |
| print("ONNX model:", ONNX_PATH) |
| print("Vocabulary:", VOCAB_PATH) |
| INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH) |
|
|
| print("Building ORT CPU session...") |
| SESSION = _make_session(ONNX_PATH) |
| INPUT_NAMES = {i.name for i in SESSION.get_inputs()} |
| PRIMARY_INPUT = "pixel_values" if "pixel_values" in INPUT_NAMES else SESSION.get_inputs()[0].name |
| HAS_MASK_INPUT = "padding_mask" in INPUT_NAMES |
| print("Providers:", SESSION.get_providers()) |
| print("Inputs:", list(INPUT_NAMES)) |
|
|
|
|
| def predict(image: Image.Image, threshold: float, top_k: int): |
| if image is None: |
| return {}, "Upload an image to get tag predictions.", "" |
|
|
| t0 = time.perf_counter() |
| x, mask, was_composited = _preprocess(image) |
| feed: dict[str, np.ndarray] = {PRIMARY_INPUT: x} |
| if HAS_MASK_INPUT: |
| feed["padding_mask"] = mask |
|
|
| probs = SESSION.run(None, feed)[0][0] |
| elapsed_ms = int((time.perf_counter() - t0) * 1000) |
|
|
| cap = max(int(top_k), 1) |
| order = np.argsort(probs)[::-1] |
| label_dict: dict[str, float] = {} |
| rating = None |
| rating_score = -1.0 |
| for raw_idx in order: |
| idx = int(raw_idx) |
| if idx in (PAD_IDX, UNK_IDX): |
| continue |
| score = float(probs[idx]) |
| if score < threshold: |
| break |
| name = INDEX_TO_TAG.get(idx) |
| if name is None: |
| continue |
| if was_composited and name == "gray_background": |
| continue |
| if name.startswith("rating:"): |
| if score > rating_score: |
| rating = name.split(":", 1)[1] |
| rating_score = score |
| label_dict[name] = score |
| if len(label_dict) >= cap: |
| break |
|
|
| if not label_dict: |
| text_summary = "No tags above threshold." |
| else: |
| text_summary = ", ".join(label_dict.keys()) |
|
|
| info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"] |
| if rating is not None: |
| info_lines.append(f"Predicted rating: {rating} (score {rating_score:.3f})") |
| info = " • ".join(info_lines) |
|
|
| return label_dict, text_summary, info |
|
|
|
|
| with gr.Blocks(title="OppaiOracle V1.1 (CPU) — anime tagger") as demo: |
| gr.Markdown( |
| """# OppaiOracle V1.1 (CPU) — anime tagger |
| Multi-label ViT trained from scratch at 448×448 on a cleaned ~5.9M-image corpus |
| (19,294 tags). Drop in an image to see ranked tag predictions. |
| |
| **This Space runs on CPU.** Each prediction takes ~10–30 s on the HF CPU tier |
| (the 448² ViT is ~250M params). For faster turnaround, see the GPU Space at |
| [Grio43/OppaiOracle](https://huggingface.co/spaces/Grio43/OppaiOracle). |
| |
| **Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle) |
| documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias). |
| Predictions are best treated as a fast first pass that a human reviews — not as ground truth. |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"]) |
| threshold = gr.Slider( |
| minimum=0.05, |
| maximum=0.95, |
| value=DEFAULT_THRESHOLD, |
| step=0.01, |
| label="Threshold", |
| info=f"Tags below this score are dropped. V1.1 macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.", |
| ) |
| top_k = gr.Slider( |
| minimum=5, |
| maximum=200, |
| value=50, |
| step=1, |
| label="Max tags", |
| ) |
| run_btn = gr.Button("Tag", variant="primary") |
| with gr.Column(scale=1): |
| labels_out = gr.Label(label="Predictions", num_top_classes=200) |
| tags_text = gr.Textbox(label="Comma-separated tags", lines=4) |
| info_out = gr.Markdown() |
|
|
| run_btn.click( |
| predict, |
| inputs=[image_in, threshold, top_k], |
| outputs=[labels_out, tags_text, info_out], |
| ) |
|
|
| gr.Markdown( |
| """--- |
| Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 · |
| Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph). |
| |
| This Space runs the V1.1 (448×448) checkpoint on CPU. Match input resolution to the |
| checkpoint you load — feeding 320 to V1.1 (or 448 to V1) hurts accuracy because the ViT |
| position-embedding grid is fixed at load time. |
| """ |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(default_concurrency_limit=1).launch() |
|
|