Grio43's picture
web_interface: fix run.bat fallback + UTF-8 console handling
8741801 verified
raw
history blame
25 kB
"""
Oppai ONNX tagger — single-file launcher.
First run: creates `.venv/`, installs requirements, re-execs inside the venv,
then starts a local Gradio web UI on http://127.0.0.1:7860 .
Subsequent runs: skip install (marker file) and start the UI immediately.
For most users:
Just run `run.bat` (or `py app.py`). The launcher will show a numbered
menu so you can pick an existing model or download one — no flags
required. Press Enter to accept the highlighted default at any prompt.
Advanced flags:
py app.py --reinstall # force re-install of requirements
py app.py --model-dir <folder> # skip the menu, load a specific folder
Models live in folders next to this script. Any folder containing
`model.onnx`, `selected_tags.csv`, and `preprocessing.json` is treated as a
model and will appear in the launcher menu and the UI's model picker. You
can also download variants from https://huggingface.co/Grio43/OppaiOracle
directly from the menu or from the UI.
"""
from __future__ import annotations
import os
import subprocess
import sys
import venv
from pathlib import Path
ROOT = Path(__file__).resolve().parent
VENV_DIR = ROOT / ".venv"
MARKER = VENV_DIR / ".bootstrapped"
# Default folder if it exists; otherwise the first auto-discovered folder
# next to this script is used. Override with --model-dir or the UI picker.
DEFAULT_MODEL_DIR = ROOT / "V1.1_onnx"
# Variants published on HuggingFace that are usable with this ONNX runtime.
# First entry is the recommended default in interactive prompts.
HF_REPO_ID = "Grio43/OppaiOracle"
HF_VARIANTS = ["V1.1_onnx", "V1_onnx"]
HF_VARIANT_DESC = {
"V1.1_onnx": "448×448, higher accuracy",
"V1_onnx": "320×320, smaller and faster",
}
REQUIREMENTS = [
"onnxruntime>=1.20",
"pillow>=10.0",
"numpy>=1.26,<3",
"gradio>=4.44",
"huggingface_hub>=0.24",
]
# ---------------------------------------------------------------------------
# Bootstrap
# ---------------------------------------------------------------------------
def _venv_python() -> Path:
if os.name == "nt":
return VENV_DIR / "Scripts" / "python.exe"
return VENV_DIR / "bin" / "python"
def _in_target_venv() -> bool:
# Belt-and-suspenders: compare both sys.executable and sys.prefix against
# the target venv. Windows Store Python uses reparse points that can make
# Path.resolve() on sys.executable return a path that differs from the
# venv's python.exe even when running inside it; sys.prefix is more
# reliable for that case. Either match counts as "in venv".
try:
target_py = _venv_python().resolve()
except OSError:
target_py = None
try:
target_dir = VENV_DIR.resolve()
except OSError:
target_dir = None
try:
if target_py is not None and Path(sys.executable).resolve() == target_py:
return True
except OSError:
pass
try:
if target_dir is not None and Path(sys.prefix).resolve() == target_dir:
return True
except OSError:
pass
return False
def _bootstrap(force_reinstall: bool) -> None:
if not VENV_DIR.exists():
print(f"[bootstrap] Creating virtualenv at {VENV_DIR} ...")
venv.EnvBuilder(with_pip=True, clear=False, upgrade_deps=False).create(VENV_DIR)
py = _venv_python()
needs_install = force_reinstall or not MARKER.exists()
if needs_install:
print("[bootstrap] Upgrading pip ...")
subprocess.check_call([str(py), "-m", "pip", "install", "--upgrade", "pip"])
print(f"[bootstrap] Installing: {', '.join(REQUIREMENTS)}")
subprocess.check_call([str(py), "-m", "pip", "install", *REQUIREMENTS])
MARKER.write_text("ok\n", encoding="utf-8")
else:
print("[bootstrap] Requirements already installed (delete .venv/.bootstrapped to redo).")
args = [a for a in sys.argv[1:] if a != "--reinstall"]
print("[bootstrap] Re-launching inside venv ...\n")
sys.exit(subprocess.call([str(py), str(Path(__file__).resolve()), *args]))
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
REQUIRED_FILES = ("model.onnx", "selected_tags.csv", "preprocessing.json")
def _discover_model_dirs() -> list[Path]:
"""Return every subdirectory of ROOT that looks like a usable model folder."""
out: list[Path] = []
if not ROOT.exists():
return out
for sub in sorted(ROOT.iterdir(), key=lambda p: p.name.lower()):
if not sub.is_dir():
continue
if all((sub / f).exists() for f in REQUIRED_FILES):
out.append(sub)
return out
def _variant_rank(name: str) -> int:
try:
return HF_VARIANTS.index(name)
except ValueError:
return len(HF_VARIANTS)
def _is_tty() -> bool:
try:
return sys.stdin.isatty() and sys.stdout.isatty()
except (AttributeError, OSError):
return False
def _prompt_choice(prompt: str, options: list[tuple[str, str]], default_idx: int = 0) -> str | None:
"""Show a numbered terminal menu. Returns the chosen option's value, or None on EOF.
options: list of (display_text, value).
"""
if not options:
return None
if not _is_tty():
return options[default_idx][1]
print()
print(prompt)
for i, (display, _) in enumerate(options, 1):
marker = " <- press Enter for this" if i - 1 == default_idx else ""
print(f" {i}) {display}{marker}")
while True:
try:
raw = input(f"Choice [1-{len(options)}, default {default_idx + 1}]: ").strip()
except EOFError:
return options[default_idx][1]
if not raw:
return options[default_idx][1]
try:
idx = int(raw) - 1
except ValueError:
print(f" Please enter a number 1-{len(options)}.")
continue
if 0 <= idx < len(options):
return options[idx][1]
print(f" Out of range. Pick 1-{len(options)}.")
def _download_variant(variant: str) -> Path | None:
"""Download a HuggingFace variant into ROOT/<variant>. Returns the folder on success."""
try:
from huggingface_hub import snapshot_download
except ImportError:
print("[app] huggingface_hub is not installed — re-run with --reinstall.")
return None
print(f"[app] Downloading '{variant}' from huggingface.co/{HF_REPO_ID} ...")
try:
snapshot_download(
repo_id=HF_REPO_ID,
allow_patterns=[f"{variant}/*"],
local_dir=str(ROOT),
)
except Exception as e: # noqa: BLE001
print(f"[app] Download failed: {e}")
return None
target = ROOT / variant
missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
if missing:
print(f"[app] Download finished but {target} is missing: {', '.join(missing)}")
return None
return target
def _interactive_pick_model() -> Path | None:
"""Show a friendly menu so non-technical users can pick or download a model.
Returns the chosen model directory, or None if the user wants to start the
UI without loading anything (they can pick from the web UI then).
"""
discovered = _discover_model_dirs()
discovered.sort(key=lambda p: (_variant_rank(p.name), p.name.lower()))
discovered_names = {p.name for p in discovered}
options: list[tuple[str, str]] = []
actions: list[tuple[str, str]] = [] # parallel list of (action, payload)
for p in discovered:
desc = HF_VARIANT_DESC.get(p.name, "model folder")
options.append((f"Use {p.name} ({desc})", str(p)))
actions.append(("load", str(p)))
for v in HF_VARIANTS:
if v in discovered_names:
continue
desc = HF_VARIANT_DESC.get(v, "")
suffix = f" ({desc})" if desc else ""
options.append((f"Download {v} from HuggingFace{suffix}", v))
actions.append(("download", v))
options.append(("Open the web UI without loading anything (pick later from the page)", "skip"))
actions.append(("skip", ""))
if not _is_tty() and discovered:
return discovered[0]
if not _is_tty():
return None
print()
print("=" * 50)
print(" Oppai ONNX Tagger")
print("=" * 50)
if discovered:
print(f"Found {len(discovered)} model folder(s) next to app.py.")
else:
print("No model folders found yet next to app.py.")
print(f"Pick a variant to download from huggingface.co/{HF_REPO_ID}.")
chosen = _prompt_choice("What would you like to do?", options, default_idx=0)
if chosen is None:
return None
idx = next(i for i, (_, v) in enumerate(options) if v == chosen)
action, payload = actions[idx]
if action == "load":
return Path(payload)
if action == "download":
return _download_variant(payload)
return None # skip
def _resolve_initial_model(cli_dir: str | None) -> Path | None:
if cli_dir:
p = Path(cli_dir).expanduser().resolve()
if not p.is_dir():
print(f"[app] --model-dir not a directory: {p}")
return None
missing = [f for f in REQUIRED_FILES if not (p / f).exists()]
if missing:
print(f"[app] --model-dir is missing required files: {', '.join(missing)}")
return None
return p
return _interactive_pick_model()
def _run_app() -> None:
import argparse
import csv
import json
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--model-dir", type=str, default=None)
cli_args, _ = parser.parse_known_args()
cat_names = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "meta"}
inv_cat_names = {v: k for k, v in cat_names.items()}
# Mutable holder so the UI can swap models without restarting the process.
state: dict = {
"session": None,
"tag_names": [],
"categories": [],
"skip_mask": None,
"image_size": 0,
"pad_color": (0, 0, 0),
"mean": None,
"std": None,
"breakeven_threshold": None,
"model_dir": None,
"providers": [],
}
def _ort_providers() -> list[str]:
available = ort.get_available_providers()
if "DmlExecutionProvider" in available:
return ["DmlExecutionProvider", "CPUExecutionProvider"]
if "CUDAExecutionProvider" in available:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
def load_model(model_dir: Path) -> str:
model_dir = Path(model_dir).expanduser().resolve()
if not model_dir.is_dir():
raise FileNotFoundError(f"not a directory: {model_dir}")
missing = [f for f in REQUIRED_FILES if not (model_dir / f).exists()]
if missing:
raise FileNotFoundError(
f"{model_dir} is missing required files: {', '.join(missing)}"
)
tag_names: list[str] = []
categories: list[int] = []
with (model_dir / "selected_tags.csv").open(encoding="utf-8") as f:
for row in csv.DictReader(f):
tag_names.append(row["name"])
categories.append(int(row["category"]))
n_tags = len(tag_names)
skip_mask = np.zeros(n_tags, dtype=bool)
for i, name in enumerate(tag_names):
if name in ("<PAD>", "<UNK>"):
skip_mask[i] = True
with (model_dir / "preprocessing.json").open(encoding="utf-8") as f:
preproc = json.load(f)
image_size = int(preproc["image_size"])
pad_color = tuple(int(c) for c in preproc["pad_color_rgb"])
mean = np.array(preproc["normalize_mean"], dtype=np.float32).reshape(3, 1, 1)
std = np.array(preproc["normalize_std"], dtype=np.float32).reshape(3, 1, 1)
# Calibrated breakeven (precision = recall) lives in pr_thresholds.json.
# It is tuned for whole-eval-set precision and is far too strict for
# interactive single-image tagging, so we surface it only as a hint.
breakeven_threshold = None
thr_path = model_dir / "pr_thresholds.json"
if thr_path.exists():
try:
with thr_path.open(encoding="utf-8") as f:
thr_data = json.load(f)
breakeven_threshold = float(thr_data["micro"]["pr_breakeven"]["threshold"])
except (OSError, KeyError, ValueError, json.JSONDecodeError):
pass
providers = _ort_providers()
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
print(f"[app] Loading {model_dir / 'model.onnx'} ({image_size}×{image_size}) ...")
print(f"[app] Providers: {providers}")
session = ort.InferenceSession(
str(model_dir / "model.onnx"), sess_options=sess_opts, providers=providers
)
state.update(
session=session,
tag_names=tag_names,
categories=categories,
skip_mask=skip_mask,
image_size=image_size,
pad_color=pad_color,
mean=mean,
std=std,
breakeven_threshold=breakeven_threshold,
model_dir=model_dir,
providers=providers,
)
return _status_md()
def _status_md() -> str:
if state["session"] is None:
return (
"**No model loaded.** Drop an ONNX model folder next to "
"`app.py`, or use the **Download from HuggingFace** section below."
)
try:
display = state["model_dir"].relative_to(ROOT)
except ValueError:
display = state["model_dir"]
parts = [
f"**Loaded:** `{display}`",
f"{state['image_size']}×{state['image_size']}",
f"{len(state['tag_names'])} tags",
f"providers: {', '.join(state['providers'])}",
]
if state["breakeven_threshold"] is not None:
parts.append(f"P=R breakeven: {state['breakeven_threshold']:.3f}")
return " — ".join(parts)
def _dropdown_choices() -> list[tuple[str, str]]:
out = []
for p in _discover_model_dirs():
try:
label = str(p.relative_to(ROOT))
except ValueError:
label = p.name
out.append((label, str(p)))
return out
def _current_value() -> str | None:
return str(state["model_dir"]) if state["model_dir"] else None
# Initial load (CLI override > default folder > first discovered)
initial = _resolve_initial_model(cli_args.model_dir)
if initial is not None:
try:
load_model(initial)
except Exception as e: # noqa: BLE001
print(f"[app] Initial model load failed: {e!r}")
else:
print("[app] No model folder found yet — pick or download one in the UI.")
def letterbox(img: Image.Image):
img = img.convert("RGB")
w, h = img.size
size = state["image_size"]
scale = min(size / w, size / h)
nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
resized = img.resize((nw, nh), Image.BICUBIC)
canvas = Image.new("RGB", (size, size), state["pad_color"])
x0 = (size - nw) // 2
y0 = (size - nh) // 2
canvas.paste(resized, (x0, y0))
mask = np.ones((size, size), dtype=bool) # True = padded
mask[y0:y0 + nh, x0:x0 + nw] = False
return canvas, mask
def preprocess(img: Image.Image):
canvas, mask = letterbox(img)
arr = np.asarray(canvas, dtype=np.float32) / 255.0
arr = arr.transpose(2, 0, 1) # CHW
arr = (arr - state["mean"]) / state["std"]
return arr.astype(np.float32), mask
def predict(image, threshold: float, max_tags, category_filter):
if state["session"] is None:
return "", "*no model loaded — pick or download one above*"
if image is None:
return "", "*upload an image to start*"
try:
max_tags_i = int(max_tags) if max_tags is not None else 0
if max_tags_i <= 0:
return "", "*no tags above threshold*"
# An empty list means "no categories selected" -> show nothing.
# `None` (event before component initialized) means "no filter".
if category_filter is None:
keep_cats = None
else:
keep_cats = {inv_cat_names[c] for c in category_filter if c in inv_cat_names}
if not keep_cats:
return "", "*no tags above threshold*"
pixel_values, padding_mask = preprocess(image)
outputs = state["session"].run(
["probabilities"],
{
"pixel_values": pixel_values[None, ...],
"padding_mask": padding_mask[None, ...],
},
)
probs = outputs[0][0].astype(np.float32)
probs[state["skip_mask"]] = -1.0 # never surface PAD/UNK
order = np.argsort(-probs)
results = []
tag_names = state["tag_names"]
categories = state["categories"]
for idx in order:
p = float(probs[idx])
if p < threshold:
break
cat = categories[idx]
if keep_cats is not None and cat not in keep_cats:
continue
results.append((tag_names[idx], p, cat))
if len(results) >= max_tags_i:
break
if not results:
return "", "*no tags above threshold*"
comma = ", ".join(name.replace("_", " ") for name, _, _ in results)
lines = ["| # | Tag | Confidence | Category |", "|---|---|---|---|"]
for i, (name, p, cat) in enumerate(results, 1):
lines.append(f"| {i} | `{name}` | {p:.3f} | {cat_names.get(cat, str(cat))} |")
return comma, "\n".join(lines)
except Exception as e: # noqa: BLE001 — keep Gradio toast away
print(f"[app] predict() error: {e!r}")
return "", f"*error during inference: {e}*"
# --- UI callbacks ------------------------------------------------------
def on_refresh():
choices = _dropdown_choices()
return gr.update(choices=choices, value=_current_value()), _status_md()
def on_load(dropdown_value: str | None, custom_path: str):
target = (custom_path or "").strip() or dropdown_value
if not target:
return gr.update(), _status_md(), "Pick a model folder or paste a path first."
try:
load_model(Path(target))
except Exception as e: # noqa: BLE001
return gr.update(), _status_md(), f"Load failed: {e}"
choices = _dropdown_choices()
return (
gr.update(choices=choices, value=_current_value()),
_status_md(),
f"Loaded `{Path(target).name}`.",
)
def on_download(variant: str, progress=gr.Progress(track_tqdm=True)):
if not variant:
return gr.update(), _status_md(), "Pick a variant first."
try:
from huggingface_hub import snapshot_download
except ImportError:
return (
gr.update(),
_status_md(),
"huggingface_hub is not installed — re-run `app.py --reinstall`.",
)
progress(0, desc=f"Downloading {variant} from {HF_REPO_ID} ...")
try:
snapshot_download(
repo_id=HF_REPO_ID,
allow_patterns=[f"{variant}/*"],
local_dir=str(ROOT),
)
except Exception as e: # noqa: BLE001
return gr.update(), _status_md(), f"Download failed: {e}"
target = ROOT / variant
msg = f"Downloaded `{variant}`."
if all((target / f).exists() for f in REQUIRED_FILES):
try:
load_model(target)
msg += f" Loaded `{variant}`."
except Exception as e: # noqa: BLE001
msg += f" Load failed: {e}"
choices = _dropdown_choices()
return gr.update(choices=choices, value=_current_value()), _status_md(), msg
# --- UI layout ---------------------------------------------------------
with gr.Blocks(title="Oppai ONNX Tagger") as demo:
gr.Markdown(
"# Oppai ONNX Tagger\n"
"Upload an image and tweak the threshold / max tags. "
"Pick a model below or download one from "
"[Grio43/OppaiOracle](https://huggingface.co/Grio43/OppaiOracle)."
)
with gr.Accordion("Model", open=True):
with gr.Row():
model_dd = gr.Dropdown(
choices=_dropdown_choices(),
value=_current_value(),
label="Detected model folders (next to app.py)",
interactive=True,
scale=3,
)
refresh_btn = gr.Button("Refresh", scale=1)
with gr.Row():
custom_path = gr.Textbox(
label="…or paste a custom model folder path (overrides dropdown)",
placeholder=r"e.g. C:\models\my_onnx_folder",
scale=4,
)
load_btn = gr.Button("Load", variant="primary", scale=1)
with gr.Row():
hf_dd = gr.Dropdown(
choices=HF_VARIANTS,
value=HF_VARIANTS[0],
label=f"Download a variant from {HF_REPO_ID}",
scale=3,
)
download_btn = gr.Button("Download", scale=1)
status_md = gr.Markdown(_status_md())
action_msg = gr.Markdown("")
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(type="pil", label="Image", height=448)
threshold = gr.Slider(
0.0, 1.0,
value=0.35,
step=0.005,
label="Threshold (interactive default 0.35; calibrated breakeven shown above)",
)
max_tags = gr.Slider(1, 200, value=50, step=1, label="Max tags")
cats = gr.CheckboxGroup(
choices=list(cat_names.values()),
value=list(cat_names.values()),
label="Categories to include",
)
btn = gr.Button("Tag image", variant="primary")
with gr.Column(scale=1):
tags_out = gr.Textbox(
label="Tags (comma-separated, underscores → spaces)",
lines=5,
)
table_out = gr.Markdown(label="Per-tag detail")
refresh_btn.click(on_refresh, outputs=[model_dd, status_md])
load_btn.click(on_load, inputs=[model_dd, custom_path], outputs=[model_dd, status_md, action_msg])
download_btn.click(on_download, inputs=[hf_dd], outputs=[model_dd, status_md, action_msg])
ev_inputs = [inp, threshold, max_tags, cats]
ev_outputs = [tags_out, table_out]
btn.click(predict, ev_inputs, ev_outputs)
inp.change(predict, ev_inputs, ev_outputs)
threshold.release(predict, ev_inputs, ev_outputs)
max_tags.release(predict, ev_inputs, ev_outputs)
cats.change(predict, ev_inputs, ev_outputs)
# CPU inference is ~1-3s per image; cap concurrency so spammed slider
# changes queue serially instead of fighting for the same model session.
demo.queue(default_concurrency_limit=1).launch(
server_name="127.0.0.1", server_port=7860, inbrowser=True
)
# ---------------------------------------------------------------------------
# Entrypoint
# ---------------------------------------------------------------------------
def main() -> None:
# On Windows, the console codepage is often cp1252/cp932/etc., not UTF-8.
# Our messages contain em-dashes and × — `errors="replace"` keeps them from
# crashing the bootstrap with UnicodeEncodeError on those consoles.
for stream in (sys.stdout, sys.stderr):
try:
stream.reconfigure(errors="replace")
except (AttributeError, OSError):
pass
force = "--reinstall" in sys.argv[1:]
if not _in_target_venv():
_bootstrap(force_reinstall=force)
return # _bootstrap re-execs and exits
_run_app()
if __name__ == "__main__":
main()