cl_tagger_v2 / app.py
cella110n's picture
Upload 3 files
5b24cc3 verified
"""
SigLIP2 Tagger — Gradio ONNX Inference App
Runs ONNX-exported SigLIP2 tagger models with category-grouped results.
Designed for HuggingFace Spaces deployment.
"""
from __future__ import annotations
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
import spaces
from PIL import Image
# ---------------------------------------------------------------------------
# Category colours (SushiUI)
# ---------------------------------------------------------------------------
CATEGORY_COLORS: Dict[str, str] = {
"Quality": "#eab308",
"Rating": "#f97316",
"Character": "#3b82f6",
"Copyright": "#a855f7",
"General": "#22c55e",
"Meta": "#9ca3af",
"Model": "#06b6d4",
}
# Quality and Rating are pinned (always shown); the rest use threshold filtering.
CATEGORY_ORDER = ["Character", "Copyright", "General", "Meta", "Model"]
# Categories that use the "character" threshold in per-category mode.
CHAR_THRESHOLD_CATS = {"Character", "Copyright"}
# Categories that use the "general" threshold in per-category mode.
GENERAL_THRESHOLD_CATS = {"General", "Meta", "Model"}
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "celstk/cl-SigLIP2-lora-onnx")
CACHE_DIR = os.environ.get("HF_HOME", None)
# MODEL_VERSIONS is populated at runtime by _fetch_available_versions().
# Keys are HF folder names (e.g. "v1_03"), values are display labels ("v1.03").
MODEL_VERSIONS: Dict[str, str] = {}
# Resolved after version discovery.
DEFAULT_VERSION: str = "v1_03"
# Per-version threshold overrides. Unknown / future versions default to 0.5.
# v1_02 used CS-ASL (bimodal distribution) — needs a much higher cut.
_VERSION_DEFAULT_THRESHOLD: Dict[str, float] = {
"v1_00": 0.6,
"v1_01": 0.6,
"v1_02": 0.9,
}
def _default_threshold(version: str) -> float:
"""Return the recommended inference threshold for *version*.
Explicit entries exist for versions whose loss function produces an
unusually skewed probability distribution. Everything else — including
all future versions — falls back to 0.5, which is a safe middle ground
that can be nudged up or down by the user.
"""
return _VERSION_DEFAULT_THRESHOLD.get(version, 0.5)
# ---------------------------------------------------------------------------
# Dynamic version discovery
# ---------------------------------------------------------------------------
import re as _re
def _fetch_available_versions() -> Dict[str, str]:
"""Query the HF repo for version subdirectories (e.g. v1_03, v2_00).
Returns an ordered dict keyed by folder name → display label, sorted
oldest-to-newest by version tuple. Falls back to the last-known list
if the API call fails so the app still loads.
"""
_FALLBACK: Dict[str, str] = {
"v1_00": "v1.00",
"v1_01": "v1.01",
"v1_02": "v1.02",
"v1_03": "v1.03",
}
token = os.environ.get("HF_TOKEN")
try:
from huggingface_hub import list_repo_files
dirs: set = set()
pat = _re.compile(r'^(v\d+_\d+)/')
for f in list_repo_files(repo_id=MODEL_REPO_ID, token=token):
m = pat.match(f)
if m:
dirs.add(m.group(1))
if not dirs:
return _FALLBACK
def _vkey(v: str):
return tuple(int(x) for x in _re.findall(r'\d+', v))
sorted_dirs = sorted(dirs, key=_vkey)
result = {v: v.replace("_", ".", 1) for v in sorted_dirs}
print(f"[Tagger] Discovered versions: {list(result.keys())}")
return result
except Exception as exc:
print(f"[Tagger] WARNING: version discovery failed ({exc}); using fallback list")
return _FALLBACK
_onnx_path: Optional[str] = None
_session = None
_processor = None
_is_naflex: bool = True
_idx_to_tag: Dict[int, str] = {}
_tag_to_category: Dict[str, str] = {}
_loaded_version: Optional[str] = None
def _download_from_hub(version: str) -> Tuple[str, str, str]:
import glob
from huggingface_hub import snapshot_download
token = os.environ.get("HF_TOKEN")
local_dir = snapshot_download(
repo_id=MODEL_REPO_ID,
allow_patterns=[f"{version}/*"],
cache_dir=CACHE_DIR,
token=token,
)
version_dir = os.path.join(local_dir, version)
onnx_files = [f for f in glob.glob(os.path.join(version_dir, "*.onnx"))
if not f.endswith(".onnx.data")]
if not onnx_files:
raise FileNotFoundError(f"No .onnx file found in {version_dir}")
onnx_path = onnx_files[0]
onnx_base = os.path.splitext(os.path.basename(onnx_path))[0]
# Prefer the per-ONNX vocabulary snapshot ``<onnx_base>_vocabulary.json``
# (frozen at export time, immune to overwrites). Fall back to any
# ``*vocabulary.json`` in the same directory with a warning.
per_ckpt_vocab = os.path.join(version_dir, f"{onnx_base}_vocabulary.json")
if os.path.isfile(per_ckpt_vocab):
vocab_path = per_ckpt_vocab
else:
vocab_files = glob.glob(os.path.join(version_dir, "*vocabulary.json"))
if not vocab_files:
raise FileNotFoundError(f"No *vocabulary.json found in {version_dir}")
vocab_path = vocab_files[0]
print(f"[Spaces] WARNING: per-ONNX vocabulary "
f"'{onnx_base}_vocabulary.json' not found; falling back to "
f"'{os.path.basename(vocab_path)}'. Tag→idx alignment with "
f"'{onnx_base}.onnx' cannot be verified.")
meta_files = glob.glob(os.path.join(version_dir, "*_metadata.json"))
return onnx_path, vocab_path, meta_files[0] if meta_files else ""
def _vocab_get(vocab: dict, key: str) -> dict:
"""Extract *key* from a vocabulary dict with prefix-fallback.
Some exported vocab files store keys under a version prefix, e.g.
``"v1_03/idx_to_tag"`` instead of plain ``"idx_to_tag"``. The prefix
may change between versions, so we try the bare key first then fall back
to any key whose suffix matches ``/{key}`` or equals ``key``.
"""
if key in vocab:
return vocab[key]
suffix = f"/{key}"
for k, v in vocab.items():
if isinstance(k, str) and k.endswith(suffix):
print(f"[Tagger] vocab: resolved '{key}' via prefixed key '{k}'")
return v
return {}
def _load_processor_and_vocab(vocab_path: str, meta_path: str = "") -> None:
global _processor, _idx_to_tag, _tag_to_category, _is_naflex
from transformers import AutoProcessor
DEFAULT_REPO = "google/siglip2-so400m-patch16-naflex"
processor_repo = DEFAULT_REPO
_is_naflex = True
if meta_path and os.path.isfile(meta_path):
with open(meta_path, "r", encoding="utf-8") as _fh:
_meta = json.load(_fh)
processor_repo = _meta.get("vision_encoder_repo", DEFAULT_REPO)
_is_naflex = bool(_meta.get("is_naflex", True))
print(f"[Tagger] Loading processor from {processor_repo} (mode: {'NaFlex' if _is_naflex else 'standard'})...")
try:
_processor = AutoProcessor.from_pretrained(processor_repo, local_files_only=True)
except Exception:
_processor = AutoProcessor.from_pretrained(processor_repo)
with open(vocab_path, "r", encoding="utf-8") as f:
vocab = json.load(f)
raw_idx = _vocab_get(vocab, "idx_to_tag")
if not raw_idx:
raise ValueError(f"'idx_to_tag' not found in {vocab_path} (top-level keys: {list(vocab.keys())[:8]})")
_idx_to_tag = {int(k): v for k, v in raw_idx.items()}
_tag_to_category = _vocab_get(vocab, "tag_to_category")
print(f"[Tagger] Vocabulary loaded | {len(_idx_to_tag)} tags")
def _ensure_session() -> None:
global _session
if _session is not None:
return
import onnxruntime as ort
opts = ort.SessionOptions()
opts.log_severity_level = 2
_session = ort.InferenceSession(
_onnx_path,
sess_options=opts,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
print(f"[Tagger] ONNX session ready | provider={_session.get_providers()[0]}")
def load_version(version: str) -> str:
"""Download (if needed) and activate the selected model version.
Returns a status string shown in the UI.
"""
global _onnx_path, _session, _loaded_version
if version == _loaded_version:
label = MODEL_VERSIONS.get(version, version)
return f"Already loaded: {label}"
# Accept versions not yet in MODEL_VERSIONS (e.g. discovered mid-session).
label = MODEL_VERSIONS.get(version, version)
print(f"[Tagger] Switching to version '{label}'...")
try:
onnx_path, vocab_path, meta_path = _download_from_hub(version)
except Exception as e:
return f"Download failed: {e}"
# Reset session so _ensure_session() recreates it for the new ONNX file.
_session = None
_onnx_path = onnx_path
try:
_load_processor_and_vocab(vocab_path, meta_path)
except Exception as e:
return f"Load failed: {e}"
_loaded_version = version
mode = "NaFlex" if _is_naflex else "Standard"
return f"Loaded: {label} ({mode}, {len(_idx_to_tag)} tags)"
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
@spaces.GPU
def predict(
image: Image.Image | None,
threshold_mode: str,
unified_threshold: float,
general_threshold: float,
char_threshold: float,
max_num_patches: int = 256,
) -> str:
"""Run inference and return results HTML."""
if image is None:
return "<p style='color:#888'>Upload an image to start.</p>"
if _onnx_path is None:
return "<p style='color:#f87171'>Model not loaded.</p>"
_ensure_session()
image = image.convert("RGB")
if _is_naflex:
inputs = _processor(images=image, return_tensors="pt", max_num_patches=max_num_patches)
pv = inputs["pixel_values"].float().numpy()
pam = inputs["pixel_attention_mask"].float().numpy()
ss = inputs["spatial_shapes"].numpy().astype(np.int64)
outputs = _session.run(
["logits"],
{"pixel_values": pv, "pixel_attention_mask": pam, "spatial_shapes": ss},
)
else:
inputs = _processor(images=image, return_tensors="pt")
pv = inputs["pixel_values"].float().numpy()
outputs = _session.run(["logits"], {"pixel_values": pv})
logits = outputs[0][0].astype(np.float64)
probs = 1.0 / (1.0 + np.exp(-logits))
# Build per-tag results; skip Unknown-category tags entirely.
all_items: List[Dict[str, Any]] = []
for i, prob in enumerate(probs):
tag = _idx_to_tag.get(i, f"__unk_{i}__")
cat = _tag_to_category.get(tag, "Unknown")
if cat == "Unknown":
continue
all_items.append({"tag": tag, "prob": float(prob), "category": cat})
# Quality / Rating: always show top-1
quality_top = _pick_top(all_items, "Quality")
rating_top = _pick_top(all_items, "Rating")
# Build threshold map
if threshold_mode == "Unified":
thr_map: Dict[str, float] = {c: unified_threshold for c in CATEGORY_ORDER}
else:
thr_map = {
**{c: char_threshold for c in CHAR_THRESHOLD_CATS},
**{c: general_threshold for c in GENERAL_THRESHOLD_CATS},
}
filtered = [
it for it in all_items
if it["category"] not in ("Quality", "Rating")
and it["prob"] >= thr_map.get(it["category"], unified_threshold)
]
filtered.sort(key=lambda x: x["prob"], reverse=True)
# Build plaintext tags for the copy button
tag_parts: List[str] = []
if quality_top:
tag_parts.append(quality_top["tag"])
if rating_top:
tag_parts.append(rating_top["tag"])
tag_parts.extend(it["tag"] for it in filtered)
tags_text = ", ".join(tag_parts)
return _build_html(quality_top, rating_top, filtered, tags_text)
def _pick_top(items: List[Dict], category: str) -> Optional[Dict]:
cat_items = [it for it in items if it["category"] == category]
return max(cat_items, key=lambda x: x["prob"]) if cat_items else None
# ---------------------------------------------------------------------------
# HTML result builder
# ---------------------------------------------------------------------------
def _bar_html(tag: str, prob: float, color: str) -> str:
pct = prob * 100
# Compact bar designed for multi-column grid cells (~300px wide each)
return (
f'<div style="display:flex;align-items:center;gap:5px;margin:2px 0;min-width:0">'
f'<span style="width:130px;flex-shrink:0;font-size:12px;color:#e5e7eb;'
f'overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="{tag}">{tag}</span>'
f'<div style="flex:1;height:12px;background:#374151;border-radius:2px;overflow:hidden;min-width:0">'
f'<div style="width:{pct:.1f}%;height:100%;background:{color};border-radius:2px"></div>'
f'</div>'
f'<span style="width:40px;flex-shrink:0;text-align:right;font-size:11px;'
f'color:#9ca3af;font-family:monospace">{pct:.1f}%</span>'
f'</div>'
)
def _copy_button_html(tags_text: str, n_tags: int) -> str:
import html as _html
escaped = _html.escape(tags_text, quote=True)
return (
f'<div style="display:flex;align-items:center;gap:10px;margin-bottom:10px;'
f'padding-bottom:8px;border-bottom:1px solid #374151">'
f'<button '
f' data-tags="{escaped}" '
f' onclick="navigator.clipboard.writeText(this.dataset.tags)'
f'.then(()=>{{this.textContent=\'Copied!\';setTimeout(()=>this.textContent=\'Copy Tags\',1500)}})'
f'.catch(()=>{{}})" '
f' style="background:#3b82f6;color:#fff;border:none;padding:5px 14px;'
f'border-radius:4px;cursor:pointer;font-size:13px;font-weight:600">'
f'Copy Tags'
f'</button>'
f'<span style="font-size:12px;color:#6b7280">{n_tags} tags</span>'
f'</div>'
)
def _build_html(
quality_top: Optional[Dict],
rating_top: Optional[Dict],
filtered: List[Dict],
tags_text: str,
) -> str:
parts: List[str] = ['<div style="font-family:sans-serif">']
# Copy button + tag count header
n_tags = len(filtered) + (1 if quality_top else 0) + (1 if rating_top else 0)
parts.append(_copy_button_html(tags_text, n_tags))
# Pinned: Quality + Rating (single row each, no multi-column)
if quality_top or rating_top:
parts.append('<div style="margin-bottom:12px;padding-bottom:8px;border-bottom:1px solid #374151">')
if quality_top:
c = CATEGORY_COLORS["Quality"]
parts.append(f'<div style="font-size:11px;color:{c};font-weight:600;margin-bottom:2px">Quality</div>')
parts.append(_bar_html(quality_top["tag"], quality_top["prob"], c))
if rating_top:
c = CATEGORY_COLORS["Rating"]
parts.append(f'<div style="font-size:11px;color:{c};font-weight:600;margin-top:6px;margin-bottom:2px">Rating</div>')
parts.append(_bar_html(rating_top["tag"], rating_top["prob"], c))
parts.append('</div>')
# Grouped categories — multi-column grid
grouped: Dict[str, List[Dict]] = {}
for it in filtered:
grouped.setdefault(it["category"], []).append(it)
for cat in CATEGORY_ORDER:
items = grouped.get(cat)
if not items:
continue
c = CATEGORY_COLORS.get(cat, "#888")
parts.append(
f'<details open style="margin-bottom:10px">'
f'<summary style="cursor:pointer;font-size:12px;font-weight:600;'
f'color:{c};margin-bottom:4px;user-select:none">'
f'{cat} ({len(items)})</summary>'
# auto-fill: ≥2 columns when the container is wide enough
f'<div style="display:grid;grid-template-columns:repeat(auto-fill,minmax(280px,1fr));gap:0 12px">'
)
for it in items:
parts.append(_bar_html(it["tag"], it["prob"], c))
parts.append('</div></details>')
if not filtered and not quality_top and not rating_top:
parts.append('<p style="color:#888;font-size:13px">No tags above threshold.</p>')
parts.append('</div>')
return "\n".join(parts)
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
CSS = """
body, .gradio-container { background:#111827 !important; color:#e5e7eb !important; }
button.primary { background:#3b82f6 !important; }
/* Prevent columns from reflowing when image loads */
.tagger-row { align-items: flex-start !important; }
.tagger-left { min-width: 0; flex-shrink: 0; }
.tagger-right { min-width: 0; flex: 1 1 0; overflow: hidden; }
"""
def create_ui() -> gr.Blocks:
# Read globals that are set by _fetch_available_versions() before this call.
version_choices = [(label, key) for key, label in MODEL_VERSIONS.items()]
default_thr = _default_threshold(DEFAULT_VERSION)
with gr.Blocks(css=CSS, title="SigLIP2 Tagger", theme=gr.themes.Base()) as demo:
gr.Markdown("<h2 style='color:#e5e7eb;margin:8px 0'>SigLIP2 Tagger</h2>")
with gr.Row(elem_classes="tagger-row"):
# --- Left panel: image + controls ---
with gr.Column(scale=1, elem_classes="tagger-left"):
image_input = gr.Image(type="pil", label="Image", height=380)
with gr.Row():
version_dropdown = gr.Dropdown(
choices=version_choices,
value=DEFAULT_VERSION,
label="Model Version",
scale=3,
)
load_btn = gr.Button("Load", scale=1)
model_status = gr.Textbox(
value="",
label="Status",
interactive=False,
lines=3,
max_lines=8,
)
threshold_mode = gr.Radio(
["Unified", "Per Category"],
value="Unified",
label="Threshold Mode",
)
# Unified mode
unified_slider = gr.Slider(
0.01, 0.99, value=default_thr, step=0.01,
label="Threshold",
visible=True,
)
# Per-category mode
general_slider = gr.Slider(
0.01, 0.99, value=default_thr, step=0.01,
label="General / Meta Threshold",
visible=False,
)
char_slider = gr.Slider(
0.01, 0.99, value=default_thr, step=0.01,
label="Character / Copyright Threshold",
visible=False,
)
predict_btn = gr.Button("Run", variant="primary")
# --- Right panel: results ---
with gr.Column(scale=2, elem_classes="tagger-right"):
html_output = gr.HTML()
# Load button — step-by-step progress via generator, then reset thresholds.
def _load_and_reset(version: str):
global _onnx_path, _session, _loaded_version
_no_update = (gr.update(), gr.update(), gr.update())
if version == _loaded_version:
label = MODEL_VERSIONS.get(version, version)
yield (f"Already loaded: {label}", *_no_update)
return
label = MODEL_VERSIONS.get(version, version)
yield (f"[1/3] Downloading {label} from Hub…", *_no_update)
try:
onnx_path, vocab_path, meta_path = _download_from_hub(version)
except Exception as exc:
yield (f"Download failed:\n{exc}", *_no_update)
return
yield (f"[2/3] Loading processor and vocabulary…", *_no_update)
_session = None
_onnx_path = onnx_path
try:
_load_processor_and_vocab(vocab_path, meta_path)
except Exception as exc:
yield (f"Load failed:\n{exc}", *_no_update)
return
_loaded_version = version
mode = "NaFlex" if _is_naflex else "Standard"
thr = _default_threshold(version)
status = f"[3/3] Ready: {label} ({mode}, {len(_idx_to_tag):,} tags)"
yield (status, gr.update(value=thr), gr.update(value=thr), gr.update(value=thr))
load_btn.click(
fn=_load_and_reset,
inputs=[version_dropdown],
outputs=[model_status, unified_slider, general_slider, char_slider],
)
# Toggle slider visibility on mode change
def _toggle_mode(mode: str):
unified = mode == "Unified"
return gr.update(visible=unified), gr.update(visible=not unified), gr.update(visible=not unified)
threshold_mode.change(
fn=_toggle_mode,
inputs=[threshold_mode],
outputs=[unified_slider, general_slider, char_slider],
)
predict_btn.click(
fn=predict,
inputs=[image_input, threshold_mode, unified_slider, general_slider, char_slider],
outputs=[html_output],
)
demo.queue()
return demo
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def _init_from_hub(version: str) -> None:
global _onnx_path, _loaded_version
print(f"[Tagger] Downloading model version '{version}' from {MODEL_REPO_ID}...")
onnx_path, vocab_path, meta_path = _download_from_hub(version)
_onnx_path = onnx_path
_load_processor_and_vocab(vocab_path, meta_path)
_loaded_version = version
def _init_local(onnx_path: str, vocab_path: str) -> None:
global _onnx_path
_onnx_path = onnx_path
meta_path = onnx_path.replace(".onnx", "_metadata.json") if onnx_path else ""
_load_processor_and_vocab(vocab_path, meta_path)
def _setup_versions(requested_version: str = "") -> str:
"""Discover available versions, populate MODEL_VERSIONS, return resolved default."""
global MODEL_VERSIONS, DEFAULT_VERSION
MODEL_VERSIONS = _fetch_available_versions()
# Pick the newest discovered version as default; fall back to the requested
# version (or hardcoded sentinel) if discovery returned nothing.
if MODEL_VERSIONS:
DEFAULT_VERSION = next(reversed(MODEL_VERSIONS))
elif requested_version and requested_version in MODEL_VERSIONS:
DEFAULT_VERSION = requested_version
# If an explicit version was requested (CLI --version), honour it.
if requested_version and requested_version in MODEL_VERSIONS:
return requested_version
return DEFAULT_VERSION
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="SigLIP2 Tagger Gradio App")
parser.add_argument("--onnx", default="", help="Local ONNX path")
parser.add_argument("--vocab", default="", help="Local vocabulary JSON path")
parser.add_argument("--version", default="", help="Model version to load (default: latest)")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
if args.onnx and args.vocab:
_setup_versions()
_init_local(args.onnx, args.vocab)
else:
version = _setup_versions(args.version)
_init_from_hub(version)
demo = create_ui()
demo.launch(server_port=args.port, share=args.share)
else:
# HuggingFace Spaces: discover versions, then auto-load the latest.
_init_from_hub(_setup_versions())