Spaces:
Running on Zero
Running on Zero
| """OppaiOracle — Gradio Space for V1.1 (448x448) anime tagger. | |
| The ONNX model is hosted in the sibling model repo `Grio43/OppaiOracle` | |
| and downloaded on first launch via `hf_hub_download`. Preprocessing | |
| (letterbox + normalize) matches the training pipeline. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import time | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image, ImageOps | |
| # `spaces` is provided by the HF Spaces runtime. Locally it may not be | |
| # installed; fall back to a no-op decorator so the app still imports. | |
| try: | |
| import spaces # type: ignore | |
| _IS_HF_SPACES = True | |
| except ImportError: # pragma: no cover - local dev path | |
| _IS_HF_SPACES = False | |
| class _SpacesShim: | |
| def GPU(*dargs, **dkwargs): # noqa: N802 - matches HF API | |
| def _decorator(fn): | |
| return fn | |
| # Support both @spaces.GPU and @spaces.GPU(duration=...) | |
| if dargs and callable(dargs[0]) and not dkwargs: | |
| return dargs[0] | |
| return _decorator | |
| spaces = _SpacesShim() # type: ignore | |
| MODEL_REPO = "Grio43/OppaiOracle" | |
| ONNX_PATH_IN_REPO = "V1.1_onnx/model.onnx" | |
| VOCAB_PATH_IN_REPO = "V1.1_onnx/vocabulary.json" | |
| PR_PATH_IN_REPO = "V1.1_onnx/pr_thresholds.json" | |
| IMAGE_SIZE = 448 | |
| PAD_COLOR = (114, 114, 114) | |
| _MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3) | |
| _STD = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(1, 1, 3) | |
| # PR_BREAKEVEN is the macro single-threshold PR break-even measured on the val set | |
| # (see V1.1_onnx/pr_thresholds.json -> macro_single_threshold.support_ge_0.pr_breakeven). | |
| # DEFAULT_THRESHOLD is set a touch below the break-even so the UI surfaces a | |
| # reasonable number of tags out of the box; users can adjust the slider. | |
| DEFAULT_THRESHOLD = 0.65 | |
| PR_BREAKEVEN = 0.7596 | |
| def _download_assets() -> tuple[Path, Path]: | |
| onnx_path = Path(hf_hub_download(MODEL_REPO, ONNX_PATH_IN_REPO)) | |
| vocab_path = Path(hf_hub_download(MODEL_REPO, VOCAB_PATH_IN_REPO)) | |
| try: | |
| hf_hub_download(MODEL_REPO, PR_PATH_IN_REPO) | |
| except Exception: | |
| pass | |
| return onnx_path, vocab_path | |
| def _load_vocab(vocab_path: Path) -> tuple[dict[int, str], int, int]: | |
| with open(vocab_path, "r", encoding="utf-8") as f: | |
| obj = json.load(f) | |
| tag_to_index: dict[str, int] = obj["tag_to_index"] | |
| index_to_tag: dict[int, str] = {int(idx): tag for tag, idx in tag_to_index.items()} | |
| pad_idx = int(tag_to_index.get("<PAD>", 0)) | |
| unk_idx = int(tag_to_index.get("<UNK>", 1)) | |
| return index_to_tag, pad_idx, unk_idx | |
| def _make_session(onnx_path: Path) -> ort.InferenceSession: | |
| cpu_threads = os.cpu_count() or 2 | |
| providers = ort.get_available_providers() | |
| preferred = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in providers] | |
| # Graph optimization at ORT_ENABLE_ALL synthesizes a MemcpyToHost node on the | |
| # bool padding_mask path that some ORT/CPU builds reject with NOT_IMPLEMENTED. | |
| # Step down to BASIC, then to disabled, before giving up. | |
| last_err: Exception | None = None | |
| for level in ( | |
| ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, | |
| ort.GraphOptimizationLevel.ORT_DISABLE_ALL, | |
| ): | |
| so = ort.SessionOptions() | |
| so.graph_optimization_level = level | |
| so.intra_op_num_threads = min(cpu_threads, 8) | |
| try: | |
| return ort.InferenceSession(str(onnx_path), sess_options=so, providers=preferred) | |
| except Exception as e: # pragma: no cover - depends on ORT build | |
| last_err = e | |
| print(f"InferenceSession failed at opt level {level}: {e!r}") | |
| assert last_err is not None | |
| raise last_err | |
| _SESSION: ort.InferenceSession | None = None | |
| _PRIMARY_INPUT: str | None = None | |
| _HAS_MASK_INPUT: bool = False | |
| def _get_session() -> ort.InferenceSession: | |
| """Lazy-init the ORT session. Must be called from inside @spaces.GPU on | |
| ZeroGPU so CUDAExecutionProvider can see the just-attached device.""" | |
| global _SESSION, _PRIMARY_INPUT, _HAS_MASK_INPUT | |
| if _SESSION is not None: | |
| return _SESSION | |
| sess = _make_session(ONNX_PATH) | |
| names = {i.name for i in sess.get_inputs()} | |
| _PRIMARY_INPUT = "pixel_values" if "pixel_values" in names else sess.get_inputs()[0].name | |
| _HAS_MASK_INPUT = "padding_mask" in names | |
| print("Providers:", sess.get_providers()) | |
| print("Inputs:", list(names)) | |
| _SESSION = sess | |
| return sess | |
| def _preprocess(img: Image.Image) -> tuple[np.ndarray, np.ndarray, bool]: | |
| img = ImageOps.exif_transpose(img) | |
| was_composited = False | |
| if img.mode in ("RGBA", "LA") or "transparency" in img.info: | |
| was_composited = True | |
| bg = Image.new("RGB", img.size, PAD_COLOR) | |
| rgba = img.convert("RGBA") | |
| alpha = rgba.getchannel("A") | |
| bg.paste(rgba.convert("RGB"), mask=alpha) | |
| img = bg | |
| else: | |
| img = img.convert("RGB") | |
| w, h = img.size | |
| scale = min(IMAGE_SIZE / w, IMAGE_SIZE / h, 1.0) | |
| new_w = max(1, round(w * scale)) | |
| new_h = max(1, round(h * scale)) | |
| if (new_w, new_h) != (w, h): | |
| img = img.resize((new_w, new_h), Image.BILINEAR) | |
| canvas = np.full((IMAGE_SIZE, IMAGE_SIZE, 3), PAD_COLOR, dtype=np.uint8) | |
| top = (IMAGE_SIZE - new_h) // 2 | |
| left = (IMAGE_SIZE - new_w) // 2 | |
| canvas[top : top + new_h, left : left + new_w] = np.asarray(img, dtype=np.uint8) | |
| mask = np.ones((IMAGE_SIZE, IMAGE_SIZE), dtype=bool) | |
| mask[top : top + new_h, left : left + new_w] = False | |
| x = canvas.astype(np.float32) / 255.0 | |
| x = (x - _MEAN) / _STD | |
| x = x.transpose(2, 0, 1) | |
| return np.expand_dims(x, 0), np.expand_dims(mask, 0), was_composited | |
| print("Downloading model and vocabulary from", MODEL_REPO) | |
| ONNX_PATH, VOCAB_PATH = _download_assets() | |
| print("ONNX model:", ONNX_PATH) | |
| print("Vocabulary:", VOCAB_PATH) | |
| INDEX_TO_TAG, PAD_IDX, UNK_IDX = _load_vocab(VOCAB_PATH) | |
| print("ZeroGPU mode:", _IS_HF_SPACES, "(session built lazily on first call)") | |
| # ZeroGPU: GPU is only attached to this process *during* a @spaces.GPU call. | |
| # Build the ORT session lazily inside the decorated function so the | |
| # CUDAExecutionProvider initializes against an attached device. The | |
| # `duration` budget covers session build (~10s for ~1 GB ONNX) on the first | |
| # call plus ~1s of actual inference; subsequent calls reuse the cached session. | |
| def predict(image: Image.Image, threshold: float, top_k: int): | |
| if image is None: | |
| return {}, "Upload an image to get tag predictions.", "" | |
| session = _get_session() | |
| primary_input = _PRIMARY_INPUT or session.get_inputs()[0].name | |
| t0 = time.perf_counter() | |
| x, mask, was_composited = _preprocess(image) | |
| feed: dict[str, np.ndarray] = {primary_input: x} | |
| if _HAS_MASK_INPUT: | |
| feed["padding_mask"] = mask | |
| probs = session.run(None, feed)[0][0] # sigmoid already applied inside the ONNX graph | |
| elapsed_ms = int((time.perf_counter() - t0) * 1000) | |
| cap = max(int(top_k), 1) | |
| order = np.argsort(probs)[::-1] | |
| label_dict: dict[str, float] = {} | |
| rating = None | |
| rating_score = -1.0 | |
| for raw_idx in order: | |
| idx = int(raw_idx) | |
| if idx in (PAD_IDX, UNK_IDX): | |
| continue | |
| score = float(probs[idx]) | |
| if score < threshold: | |
| break | |
| name = INDEX_TO_TAG.get(idx) | |
| if name is None: | |
| continue | |
| if was_composited and name == "gray_background": | |
| continue | |
| if name.startswith("rating:"): | |
| if score > rating_score: | |
| rating = name.split(":", 1)[1] | |
| rating_score = score | |
| label_dict[name] = score | |
| if len(label_dict) >= cap: | |
| break | |
| if not label_dict: | |
| text_summary = "No tags above threshold." | |
| else: | |
| text_summary = ", ".join(label_dict.keys()) | |
| info_lines = [f"Inference: {elapsed_ms} ms", f"Tags returned: {len(label_dict)}"] | |
| if rating is not None: | |
| info_lines.append(f"Predicted rating: {rating} (score {rating_score:.3f})") | |
| info = " • ".join(info_lines) | |
| return label_dict, text_summary, info | |
| with gr.Blocks(title="OppaiOracle — anime tagger") as demo: | |
| gr.Markdown( | |
| """# OppaiOracle V1.1 — anime tagger | |
| Multi-label ViT fine-tuned at 448×448 (a fine-tune of the from-scratch V1 320×320 | |
| model, on the same cleaned ~5.9M-image corpus, 19,294 tags). Drop in an image to | |
| see ranked tag predictions. | |
| **Read first:** the model card on the [model repo](https://huggingface.co/Grio43/OppaiOracle) | |
| documents known noise patterns (color tags, hair-length boundaries, neckwear, missing-tag bias). | |
| Predictions are best treated as a fast first pass that a human reviews — not as ground truth. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_in = gr.Image(type="pil", label="Image", sources=["upload", "clipboard"]) | |
| threshold = gr.Slider( | |
| minimum=0.05, | |
| maximum=0.95, | |
| value=DEFAULT_THRESHOLD, | |
| step=0.01, | |
| label="Threshold", | |
| info=f"Tags below this score are dropped. Macro PR break-even on val ≈ {PR_BREAKEVEN:.3f}.", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=5, | |
| maximum=200, | |
| value=50, | |
| step=1, | |
| label="Max tags", | |
| ) | |
| run_btn = gr.Button("Tag", variant="primary") | |
| with gr.Column(scale=1): | |
| labels_out = gr.Label(label="Predictions", num_top_classes=200) | |
| tags_text = gr.Textbox(label="Comma-separated tags", lines=4) | |
| info_out = gr.Markdown() | |
| run_btn.click( | |
| predict, | |
| inputs=[image_in, threshold, top_k], | |
| outputs=[labels_out, tags_text, info_out], | |
| ) | |
| image_in.change( | |
| predict, | |
| inputs=[image_in, threshold, top_k], | |
| outputs=[labels_out, tags_text, info_out], | |
| ) | |
| gr.Markdown( | |
| """--- | |
| Model: [Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle) · Resolution: 448×448 · | |
| Tags: 19,294 · Activation: sigmoid (already applied inside the ONNX graph). | |
| This Space runs the V1.1 (448×448) checkpoint. The V1 (320×320) checkpoint is also in the | |
| model repo for users who specifically want the smaller native resolution — match input | |
| resolution to the checkpoint you load. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=2).launch() | |