""" SigLIP2 Tagger — Gradio ONNX Inference App Runs ONNX-exported SigLIP2 tagger models with category-grouped results. Designed for HuggingFace Spaces deployment. """ from __future__ import annotations import json import os from typing import Any, Dict, List, Optional, Tuple import gradio as gr import numpy as np import spaces from PIL import Image # --------------------------------------------------------------------------- # Category colours (SushiUI) # --------------------------------------------------------------------------- CATEGORY_COLORS: Dict[str, str] = { "Quality": "#eab308", "Rating": "#f97316", "Character": "#3b82f6", "Copyright": "#a855f7", "General": "#22c55e", "Meta": "#9ca3af", "Model": "#06b6d4", } # Quality and Rating are pinned (always shown); the rest use threshold filtering. CATEGORY_ORDER = ["Character", "Copyright", "General", "Meta", "Model"] # Categories that use the "character" threshold in per-category mode. CHAR_THRESHOLD_CATS = {"Character", "Copyright"} # Categories that use the "general" threshold in per-category mode. GENERAL_THRESHOLD_CATS = {"General", "Meta", "Model"} # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "celstk/cl-SigLIP2-lora-onnx") CACHE_DIR = os.environ.get("HF_HOME", None) # MODEL_VERSIONS is populated at runtime by _fetch_available_versions(). # Keys are HF folder names (e.g. "v1_03"), values are display labels ("v1.03"). MODEL_VERSIONS: Dict[str, str] = {} # Resolved after version discovery. DEFAULT_VERSION: str = "v1_03" # Per-version threshold overrides. Unknown / future versions default to 0.5. # v1_02 used CS-ASL (bimodal distribution) — needs a much higher cut. _VERSION_DEFAULT_THRESHOLD: Dict[str, float] = { "v1_00": 0.6, "v1_01": 0.6, "v1_02": 0.9, } def _default_threshold(version: str) -> float: """Return the recommended inference threshold for *version*. Explicit entries exist for versions whose loss function produces an unusually skewed probability distribution. Everything else — including all future versions — falls back to 0.5, which is a safe middle ground that can be nudged up or down by the user. """ return _VERSION_DEFAULT_THRESHOLD.get(version, 0.5) # --------------------------------------------------------------------------- # Dynamic version discovery # --------------------------------------------------------------------------- import re as _re def _fetch_available_versions() -> Dict[str, str]: """Query the HF repo for version subdirectories (e.g. v1_03, v2_00). Returns an ordered dict keyed by folder name → display label, sorted oldest-to-newest by version tuple. Falls back to the last-known list if the API call fails so the app still loads. """ _FALLBACK: Dict[str, str] = { "v1_00": "v1.00", "v1_01": "v1.01", "v1_02": "v1.02", "v1_03": "v1.03", } token = os.environ.get("HF_TOKEN") try: from huggingface_hub import list_repo_files dirs: set = set() pat = _re.compile(r'^(v\d+_\d+)/') for f in list_repo_files(repo_id=MODEL_REPO_ID, token=token): m = pat.match(f) if m: dirs.add(m.group(1)) if not dirs: return _FALLBACK def _vkey(v: str): return tuple(int(x) for x in _re.findall(r'\d+', v)) sorted_dirs = sorted(dirs, key=_vkey) result = {v: v.replace("_", ".", 1) for v in sorted_dirs} print(f"[Tagger] Discovered versions: {list(result.keys())}") return result except Exception as exc: print(f"[Tagger] WARNING: version discovery failed ({exc}); using fallback list") return _FALLBACK _onnx_path: Optional[str] = None _session = None _processor = None _is_naflex: bool = True _idx_to_tag: Dict[int, str] = {} _tag_to_category: Dict[str, str] = {} _loaded_version: Optional[str] = None def _download_from_hub(version: str) -> Tuple[str, str, str]: import glob from huggingface_hub import snapshot_download token = os.environ.get("HF_TOKEN") local_dir = snapshot_download( repo_id=MODEL_REPO_ID, allow_patterns=[f"{version}/*"], cache_dir=CACHE_DIR, token=token, ) version_dir = os.path.join(local_dir, version) onnx_files = [f for f in glob.glob(os.path.join(version_dir, "*.onnx")) if not f.endswith(".onnx.data")] if not onnx_files: raise FileNotFoundError(f"No .onnx file found in {version_dir}") onnx_path = onnx_files[0] onnx_base = os.path.splitext(os.path.basename(onnx_path))[0] # Prefer the per-ONNX vocabulary snapshot ``_vocabulary.json`` # (frozen at export time, immune to overwrites). Fall back to any # ``*vocabulary.json`` in the same directory with a warning. per_ckpt_vocab = os.path.join(version_dir, f"{onnx_base}_vocabulary.json") if os.path.isfile(per_ckpt_vocab): vocab_path = per_ckpt_vocab else: vocab_files = glob.glob(os.path.join(version_dir, "*vocabulary.json")) if not vocab_files: raise FileNotFoundError(f"No *vocabulary.json found in {version_dir}") vocab_path = vocab_files[0] print(f"[Spaces] WARNING: per-ONNX vocabulary " f"'{onnx_base}_vocabulary.json' not found; falling back to " f"'{os.path.basename(vocab_path)}'. Tag→idx alignment with " f"'{onnx_base}.onnx' cannot be verified.") meta_files = glob.glob(os.path.join(version_dir, "*_metadata.json")) return onnx_path, vocab_path, meta_files[0] if meta_files else "" def _vocab_get(vocab: dict, key: str) -> dict: """Extract *key* from a vocabulary dict with prefix-fallback. Some exported vocab files store keys under a version prefix, e.g. ``"v1_03/idx_to_tag"`` instead of plain ``"idx_to_tag"``. The prefix may change between versions, so we try the bare key first then fall back to any key whose suffix matches ``/{key}`` or equals ``key``. """ if key in vocab: return vocab[key] suffix = f"/{key}" for k, v in vocab.items(): if isinstance(k, str) and k.endswith(suffix): print(f"[Tagger] vocab: resolved '{key}' via prefixed key '{k}'") return v return {} def _load_processor_and_vocab(vocab_path: str, meta_path: str = "") -> None: global _processor, _idx_to_tag, _tag_to_category, _is_naflex from transformers import AutoProcessor DEFAULT_REPO = "google/siglip2-so400m-patch16-naflex" processor_repo = DEFAULT_REPO _is_naflex = True if meta_path and os.path.isfile(meta_path): with open(meta_path, "r", encoding="utf-8") as _fh: _meta = json.load(_fh) processor_repo = _meta.get("vision_encoder_repo", DEFAULT_REPO) _is_naflex = bool(_meta.get("is_naflex", True)) print(f"[Tagger] Loading processor from {processor_repo} (mode: {'NaFlex' if _is_naflex else 'standard'})...") try: _processor = AutoProcessor.from_pretrained(processor_repo, local_files_only=True) except Exception: _processor = AutoProcessor.from_pretrained(processor_repo) with open(vocab_path, "r", encoding="utf-8") as f: vocab = json.load(f) raw_idx = _vocab_get(vocab, "idx_to_tag") if not raw_idx: raise ValueError(f"'idx_to_tag' not found in {vocab_path} (top-level keys: {list(vocab.keys())[:8]})") _idx_to_tag = {int(k): v for k, v in raw_idx.items()} _tag_to_category = _vocab_get(vocab, "tag_to_category") print(f"[Tagger] Vocabulary loaded | {len(_idx_to_tag)} tags") def _ensure_session() -> None: global _session if _session is not None: return import onnxruntime as ort opts = ort.SessionOptions() opts.log_severity_level = 2 _session = ort.InferenceSession( _onnx_path, sess_options=opts, providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) print(f"[Tagger] ONNX session ready | provider={_session.get_providers()[0]}") def load_version(version: str) -> str: """Download (if needed) and activate the selected model version. Returns a status string shown in the UI. """ global _onnx_path, _session, _loaded_version if version == _loaded_version: label = MODEL_VERSIONS.get(version, version) return f"Already loaded: {label}" # Accept versions not yet in MODEL_VERSIONS (e.g. discovered mid-session). label = MODEL_VERSIONS.get(version, version) print(f"[Tagger] Switching to version '{label}'...") try: onnx_path, vocab_path, meta_path = _download_from_hub(version) except Exception as e: return f"Download failed: {e}" # Reset session so _ensure_session() recreates it for the new ONNX file. _session = None _onnx_path = onnx_path try: _load_processor_and_vocab(vocab_path, meta_path) except Exception as e: return f"Load failed: {e}" _loaded_version = version mode = "NaFlex" if _is_naflex else "Standard" return f"Loaded: {label} ({mode}, {len(_idx_to_tag)} tags)" # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- @spaces.GPU def predict( image: Image.Image | None, threshold_mode: str, unified_threshold: float, general_threshold: float, char_threshold: float, max_num_patches: int = 256, ) -> str: """Run inference and return results HTML.""" if image is None: return "

Upload an image to start.

" if _onnx_path is None: return "

Model not loaded.

" _ensure_session() image = image.convert("RGB") if _is_naflex: inputs = _processor(images=image, return_tensors="pt", max_num_patches=max_num_patches) pv = inputs["pixel_values"].float().numpy() pam = inputs["pixel_attention_mask"].float().numpy() ss = inputs["spatial_shapes"].numpy().astype(np.int64) outputs = _session.run( ["logits"], {"pixel_values": pv, "pixel_attention_mask": pam, "spatial_shapes": ss}, ) else: inputs = _processor(images=image, return_tensors="pt") pv = inputs["pixel_values"].float().numpy() outputs = _session.run(["logits"], {"pixel_values": pv}) logits = outputs[0][0].astype(np.float64) probs = 1.0 / (1.0 + np.exp(-logits)) # Build per-tag results; skip Unknown-category tags entirely. all_items: List[Dict[str, Any]] = [] for i, prob in enumerate(probs): tag = _idx_to_tag.get(i, f"__unk_{i}__") cat = _tag_to_category.get(tag, "Unknown") if cat == "Unknown": continue all_items.append({"tag": tag, "prob": float(prob), "category": cat}) # Quality / Rating: always show top-1 quality_top = _pick_top(all_items, "Quality") rating_top = _pick_top(all_items, "Rating") # Build threshold map if threshold_mode == "Unified": thr_map: Dict[str, float] = {c: unified_threshold for c in CATEGORY_ORDER} else: thr_map = { **{c: char_threshold for c in CHAR_THRESHOLD_CATS}, **{c: general_threshold for c in GENERAL_THRESHOLD_CATS}, } filtered = [ it for it in all_items if it["category"] not in ("Quality", "Rating") and it["prob"] >= thr_map.get(it["category"], unified_threshold) ] filtered.sort(key=lambda x: x["prob"], reverse=True) # Build plaintext tags for the copy button tag_parts: List[str] = [] if quality_top: tag_parts.append(quality_top["tag"]) if rating_top: tag_parts.append(rating_top["tag"]) tag_parts.extend(it["tag"] for it in filtered) tags_text = ", ".join(tag_parts) return _build_html(quality_top, rating_top, filtered, tags_text) def _pick_top(items: List[Dict], category: str) -> Optional[Dict]: cat_items = [it for it in items if it["category"] == category] return max(cat_items, key=lambda x: x["prob"]) if cat_items else None # --------------------------------------------------------------------------- # HTML result builder # --------------------------------------------------------------------------- def _bar_html(tag: str, prob: float, color: str) -> str: pct = prob * 100 # Compact bar designed for multi-column grid cells (~300px wide each) return ( f'
' f'{tag}' f'
' f'
' f'
' f'{pct:.1f}%' f'
' ) def _copy_button_html(tags_text: str, n_tags: int) -> str: import html as _html escaped = _html.escape(tags_text, quote=True) return ( f'
' f'' f'{n_tags} tags' f'
' ) def _build_html( quality_top: Optional[Dict], rating_top: Optional[Dict], filtered: List[Dict], tags_text: str, ) -> str: parts: List[str] = ['
'] # Copy button + tag count header n_tags = len(filtered) + (1 if quality_top else 0) + (1 if rating_top else 0) parts.append(_copy_button_html(tags_text, n_tags)) # Pinned: Quality + Rating (single row each, no multi-column) if quality_top or rating_top: parts.append('
') if quality_top: c = CATEGORY_COLORS["Quality"] parts.append(f'
Quality
') parts.append(_bar_html(quality_top["tag"], quality_top["prob"], c)) if rating_top: c = CATEGORY_COLORS["Rating"] parts.append(f'
Rating
') parts.append(_bar_html(rating_top["tag"], rating_top["prob"], c)) parts.append('
') # Grouped categories — multi-column grid grouped: Dict[str, List[Dict]] = {} for it in filtered: grouped.setdefault(it["category"], []).append(it) for cat in CATEGORY_ORDER: items = grouped.get(cat) if not items: continue c = CATEGORY_COLORS.get(cat, "#888") parts.append( f'
' f'' f'{cat} ({len(items)})' # auto-fill: ≥2 columns when the container is wide enough f'
' ) for it in items: parts.append(_bar_html(it["tag"], it["prob"], c)) parts.append('
') if not filtered and not quality_top and not rating_top: parts.append('

No tags above threshold.

') parts.append('
') return "\n".join(parts) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- CSS = """ body, .gradio-container { background:#111827 !important; color:#e5e7eb !important; } button.primary { background:#3b82f6 !important; } /* Prevent columns from reflowing when image loads */ .tagger-row { align-items: flex-start !important; } .tagger-left { min-width: 0; flex-shrink: 0; } .tagger-right { min-width: 0; flex: 1 1 0; overflow: hidden; } """ def create_ui() -> gr.Blocks: # Read globals that are set by _fetch_available_versions() before this call. version_choices = [(label, key) for key, label in MODEL_VERSIONS.items()] default_thr = _default_threshold(DEFAULT_VERSION) with gr.Blocks(css=CSS, title="SigLIP2 Tagger", theme=gr.themes.Base()) as demo: gr.Markdown("

SigLIP2 Tagger

") with gr.Row(elem_classes="tagger-row"): # --- Left panel: image + controls --- with gr.Column(scale=1, elem_classes="tagger-left"): image_input = gr.Image(type="pil", label="Image", height=380) with gr.Row(): version_dropdown = gr.Dropdown( choices=version_choices, value=DEFAULT_VERSION, label="Model Version", scale=3, ) load_btn = gr.Button("Load", scale=1) model_status = gr.Textbox( value="", label="Status", interactive=False, lines=3, max_lines=8, ) threshold_mode = gr.Radio( ["Unified", "Per Category"], value="Unified", label="Threshold Mode", ) # Unified mode unified_slider = gr.Slider( 0.01, 0.99, value=default_thr, step=0.01, label="Threshold", visible=True, ) # Per-category mode general_slider = gr.Slider( 0.01, 0.99, value=default_thr, step=0.01, label="General / Meta Threshold", visible=False, ) char_slider = gr.Slider( 0.01, 0.99, value=default_thr, step=0.01, label="Character / Copyright Threshold", visible=False, ) predict_btn = gr.Button("Run", variant="primary") # --- Right panel: results --- with gr.Column(scale=2, elem_classes="tagger-right"): html_output = gr.HTML() # Load button — step-by-step progress via generator, then reset thresholds. def _load_and_reset(version: str): global _onnx_path, _session, _loaded_version _no_update = (gr.update(), gr.update(), gr.update()) if version == _loaded_version: label = MODEL_VERSIONS.get(version, version) yield (f"Already loaded: {label}", *_no_update) return label = MODEL_VERSIONS.get(version, version) yield (f"[1/3] Downloading {label} from Hub…", *_no_update) try: onnx_path, vocab_path, meta_path = _download_from_hub(version) except Exception as exc: yield (f"Download failed:\n{exc}", *_no_update) return yield (f"[2/3] Loading processor and vocabulary…", *_no_update) _session = None _onnx_path = onnx_path try: _load_processor_and_vocab(vocab_path, meta_path) except Exception as exc: yield (f"Load failed:\n{exc}", *_no_update) return _loaded_version = version mode = "NaFlex" if _is_naflex else "Standard" thr = _default_threshold(version) status = f"[3/3] Ready: {label} ({mode}, {len(_idx_to_tag):,} tags)" yield (status, gr.update(value=thr), gr.update(value=thr), gr.update(value=thr)) load_btn.click( fn=_load_and_reset, inputs=[version_dropdown], outputs=[model_status, unified_slider, general_slider, char_slider], ) # Toggle slider visibility on mode change def _toggle_mode(mode: str): unified = mode == "Unified" return gr.update(visible=unified), gr.update(visible=not unified), gr.update(visible=not unified) threshold_mode.change( fn=_toggle_mode, inputs=[threshold_mode], outputs=[unified_slider, general_slider, char_slider], ) predict_btn.click( fn=predict, inputs=[image_input, threshold_mode, unified_slider, general_slider, char_slider], outputs=[html_output], ) demo.queue() return demo # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def _init_from_hub(version: str) -> None: global _onnx_path, _loaded_version print(f"[Tagger] Downloading model version '{version}' from {MODEL_REPO_ID}...") onnx_path, vocab_path, meta_path = _download_from_hub(version) _onnx_path = onnx_path _load_processor_and_vocab(vocab_path, meta_path) _loaded_version = version def _init_local(onnx_path: str, vocab_path: str) -> None: global _onnx_path _onnx_path = onnx_path meta_path = onnx_path.replace(".onnx", "_metadata.json") if onnx_path else "" _load_processor_and_vocab(vocab_path, meta_path) def _setup_versions(requested_version: str = "") -> str: """Discover available versions, populate MODEL_VERSIONS, return resolved default.""" global MODEL_VERSIONS, DEFAULT_VERSION MODEL_VERSIONS = _fetch_available_versions() # Pick the newest discovered version as default; fall back to the requested # version (or hardcoded sentinel) if discovery returned nothing. if MODEL_VERSIONS: DEFAULT_VERSION = next(reversed(MODEL_VERSIONS)) elif requested_version and requested_version in MODEL_VERSIONS: DEFAULT_VERSION = requested_version # If an explicit version was requested (CLI --version), honour it. if requested_version and requested_version in MODEL_VERSIONS: return requested_version return DEFAULT_VERSION if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="SigLIP2 Tagger Gradio App") parser.add_argument("--onnx", default="", help="Local ONNX path") parser.add_argument("--vocab", default="", help="Local vocabulary JSON path") parser.add_argument("--version", default="", help="Model version to load (default: latest)") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", action="store_true") args = parser.parse_args() if args.onnx and args.vocab: _setup_versions() _init_local(args.onnx, args.vocab) else: version = _setup_versions(args.version) _init_from_hub(version) demo = create_ui() demo.launch(server_port=args.port, share=args.share) else: # HuggingFace Spaces: discover versions, then auto-load the latest. _init_from_hub(_setup_versions())