"""OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger. The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle` and downloaded on first launch via `hf_hub_download`. Preprocessing (letterbox + normalize) matches the training pipeline. """ from __future__ import annotations import json import os import time from pathlib import Path import gradio as gr import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download from PIL import Image, ImageOps # `spaces` is provided by the HF Spaces runtime. Locally it may not be # installed; fall back to a no-op decorator so the app still imports. try: import spaces # type: ignore _IS_HF_SPACES = True except ImportError: # pragma: no cover - local dev path _IS_HF_SPACES = False class _SpacesShim: @staticmethod def GPU(*dargs, **dkwargs): # noqa: N802 - matches HF API def _decorator(fn): return fn # Support both @spaces.GPU and @spaces.GPU(duration=...) if dargs and callable(dargs[0]) and not dkwargs: return dargs[0] return _decorator spaces = _SpacesShim() # type: ignore MODEL_REPO = "Grio43/OppaiOracle" ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx" VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json" PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json" IMAGE_SIZE = 448 PAD_COLOR = (114, 114, 114) _MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3) _STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3) # PR_BREAKEVEN is the macro single-threshold PR break-even measured on the val set # (see V1.1_onnx/pr_thresholds.json -> macro_single_threshold.support_ge_0.pr_breakeven). # DEFAULT_THRESHOLD is set a touch below the break-even so the UI surfaces a # reasonable number of tags out of the box; users can adjust the slider. DEFAULT_THRESHOLD = 0.65 PR_BREAKEVEN = 0.7596 def _download_assets() -> tuple[Path, Path]: onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO)) vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO)) try: hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO) except Exception: pass return onnx_path, vocab_path def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]: with open(vocab_path, "r", encoding="utf-8") as f: obj = json.load(f) tag_to_index: dict[str, int] = obj["tag_to_index"] index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()} pad_idx = int(tag_to_index.get("", 0)) unk_idx = int(tag_to_index.get("", 1)) return index_to_tag, pad_idx, unk_idx def _make_session(onnx_path: Path) -> ort.InferenceSession: cpu_threads = os.cpu_count() or 2 providers = ort.get_available_providers() preferred = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in providers] # Graph optimization at ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the # bool padding_mask path that some ORT/CPU builds reject with NOT_IMPLEMENTED. # Step down to BASIC, then to disabled, before giving up. last_err: Exception | None = None for level in ( ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, ort.GraphOptimizationLevel.ORT_DISABLE_ALL, ): so = ort.SessionOptions() so.graph_optimization_level = level so.intra_op_num_threads = min(cpu_threads, 8) try: return ort.InferenceSession(str(onnx_path), sess_options=so, providers=preferred) except Exception as e: # pragma: no cover - depends on ORT build last_err = e print(f"InferenceSession failed at opt level {level}: {e!r}") assert last_err is not None raise last_err _SESSION: ort.InferenceSession | None = None _PRIMARY_INPUT: str | None = None _HAS_MASK_INPUT: bool = False def _get_session() -> ort.InferenceSession: """Lazy-init the ORT session. Must be called from inside @spaces.GPU on ZeroGPU so CUDAExecutionProvider can see the just-attached device.""" global _SESSION, _PRIMARY_INPUT, _HAS_MASK_INPUT if _SESSION is not None: return _SESSION sess = _make_session(ONNX_PATH) names = {i.name for i in sess.get_inputs()} _PRIMARY_INPUT = "pixel_values" if "pixel_values" in names else sess.get_inputs()[0].name _HAS_MASK_INPUT = "padding_mask" in names print("Providers:", sess.get_providers()) print("Inputs:", list(names)) _SESSION = sess return sess def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]: img = ImageOps.exif_transpose(img) was_composited = False if img.mode in ("RGBA", "LA") or "transparency" in img.info: was_composited = True bg = Image.new("RGB", img.size, PAD_COLOR) rgba = img.convert("RGBA") alpha = rgba.getchannel("A") bg.paste(rgba.convert("RGB"), mask=alpha) img = bg else: img = img.convert("RGB") w, h = img.size scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0) new_w = max(1, round(w * scale)) new_h = max(1, round(h * scale)) if (new_w, new_h) != (w, h): img = img.resize((new_w, new_h), Image.BILINEAR) canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8) top = (IMAGE_SIZE - new_h) // 2 left = (IMAGE_SIZE - new_w) // 2 canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8) mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool) mask[top : top + new_h, left : left + new_w] = False x = canvas.astype(np.float32) / 255.0 x = (x - _MEAN) / _STD x = x.transpose(2, 0, 1) return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited print("Downloading model and vocabulary from", MODEL_REPO) ONNX_PATH, VOCAB_PATH = _download_assets() print("ONNX model:", ONNX_PATH) print("Vocabulary:", VOCAB_PATH) INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH) print("ZeroGPU mode:", _IS_HF_SPACES, "(session built lazily on first call)") # ZeroGPU: GPU is only attached to this process *during* a @spaces.GPU call. # Build the ORT session lazily inside the decorated function so the # CUDAExecutionProvider initializes against an attached device. The # `duration` budget covers session build (~10s for ~1 GB ONNX) on the first # call plus ~1s of actual inference; subsequent calls reuse the cached session. @spaces.GPU(duration=120) def predict(image: Image.Image, threshold: float, top_k: int): if image is None: return {}, "Upload an image to get tag predictions.", "" session = _get_session() primary_input = _PRIMARY_INPUT or session.get_inputs()[0].name t0 = time.perf_counter() x, mask, was_composited = _preprocess(image) feed: dict[str, np.ndarray] = {primary_input: x} if _HAS_MASK_INPUT: feed["padding_mask"] = mask probs = session.run(None, feed)[0][0] # sigmoid already applied inside the ONNX graph elapsed_ms = int((time.perf_counter() - t0) * 1000) cap = max(int(top_k), 1) order = np.argsort(probs)[::-1] label_dict: dict[str, float] = {} rating = None rating_score = -1.0 for raw_idx in order: idx = int(raw_idx) if idx in (PAD_IDX, UNK_IDX): continue score = float(probs[idx]) if score < threshold: break name = INDEX_TO_TAG.get(idx) if name is None: continue if was_composited and name == "gray_background": continue if name.startswith("rating:"): if score > rating_score: rating = name.split(":", 1)[1] rating_score = score label_dict[name] = score if len(label_dict) >= cap: break if not label_dict: text_summary = "No tags above threshold." else: text_summary = ", ".join(label_dict.keys()) info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"] if rating is not None: info_lines.append(f"Predicted rating: {rating} (score {rating_score:.3f})") info = " • ".join(info_lines) return label_dict, text_summary, info with gr.Blocks(title="OppaiOracle — anime tagger") as demo: gr.Markdown( """# OppaiOracle V1.1 — anime tagger Multi-label ViT fine-tuned at 448×448 (a fine-tune of the from-scratch V1 320×320 model, on the same cleaned ~5.9M-image corpus, 19,294 tags). Drop in an image to see ranked tag predictions. **Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle) documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias). Predictions are best treated as a fast first pass that a human reviews — not as ground truth. """ ) with gr.Row(): with gr.Column(scale=1): image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"]) threshold = gr.Slider( minimum=0.05, maximum=0.95, value=DEFAULT_THRESHOLD, step=0.01, label="Threshold", info=f"Tags below this score are dropped. Macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.", ) top_k = gr.Slider( minimum=5, maximum=200, value=50, step=1, label="Max tags", ) run_btn = gr.Button("Tag", variant="primary") with gr.Column(scale=1): labels_out = gr.Label(label="Predictions", num_top_classes=200) tags_text = gr.Textbox(label="Comma-separated tags", lines=4) info_out = gr.Markdown() run_btn.click( predict, inputs=[image_in, threshold, top_k], outputs=[labels_out, tags_text, info_out], ) image_in.change( predict, inputs=[image_in, threshold, top_k], outputs=[labels_out, tags_text, info_out], ) gr.Markdown( """--- Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 · Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph). This Space runs the V1.1 (448×448) checkpoint. The V1 (320×320) checkpoint is also in the model repo for users who specifically want the smaller native resolution — match input resolution to the checkpoint you load. """ ) if __name__ == "__main__": demo.queue(default_concurrency_limit=2).launch()