Spaces:
Running on Zero
Running on Zero
| """ | |
| SigLIP2 Tagger — Gradio ONNX Inference App | |
| Runs ONNX-exported SigLIP2 tagger models with category-grouped results. | |
| Designed for HuggingFace Spaces deployment. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from PIL import Image | |
| # --------------------------------------------------------------------------- | |
| # Category colours (SushiUI) | |
| # --------------------------------------------------------------------------- | |
| CATEGORY_COLORS: Dict[str, str] = { | |
| "Quality": "#eab308", | |
| "Rating": "#f97316", | |
| "Character": "#3b82f6", | |
| "Copyright": "#a855f7", | |
| "General": "#22c55e", | |
| "Meta": "#9ca3af", | |
| "Model": "#06b6d4", | |
| } | |
| # Quality and Rating are pinned (always shown); the rest use threshold filtering. | |
| CATEGORY_ORDER = ["Character", "Copyright", "General", "Meta", "Model"] | |
| # Categories that use the "character" threshold in per-category mode. | |
| CHAR_THRESHOLD_CATS = {"Character", "Copyright"} | |
| # Categories that use the "general" threshold in per-category mode. | |
| GENERAL_THRESHOLD_CATS = {"General", "Meta", "Model"} | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "celstk/cl-SigLIP2-lora-onnx") | |
| CACHE_DIR = os.environ.get("HF_HOME", None) | |
| # MODEL_VERSIONS is populated at runtime by _fetch_available_versions(). | |
| # Keys are HF folder names (e.g. "v1_03"), values are display labels ("v1.03"). | |
| MODEL_VERSIONS: Dict[str, str] = {} | |
| # Resolved after version discovery. | |
| DEFAULT_VERSION: str = "v1_03" | |
| # Per-version threshold overrides. Unknown / future versions default to 0.5. | |
| # v1_02 used CS-ASL (bimodal distribution) — needs a much higher cut. | |
| _VERSION_DEFAULT_THRESHOLD: Dict[str, float] = { | |
| "v1_00": 0.6, | |
| "v1_01": 0.6, | |
| "v1_02": 0.9, | |
| } | |
| def _default_threshold(version: str) -> float: | |
| """Return the recommended inference threshold for *version*. | |
| Explicit entries exist for versions whose loss function produces an | |
| unusually skewed probability distribution. Everything else — including | |
| all future versions — falls back to 0.5, which is a safe middle ground | |
| that can be nudged up or down by the user. | |
| """ | |
| return _VERSION_DEFAULT_THRESHOLD.get(version, 0.5) | |
| # --------------------------------------------------------------------------- | |
| # Dynamic version discovery | |
| # --------------------------------------------------------------------------- | |
| import re as _re | |
| def _fetch_available_versions() -> Dict[str, str]: | |
| """Query the HF repo for version subdirectories (e.g. v1_03, v2_00). | |
| Returns an ordered dict keyed by folder name → display label, sorted | |
| oldest-to-newest by version tuple. Falls back to the last-known list | |
| if the API call fails so the app still loads. | |
| """ | |
| _FALLBACK: Dict[str, str] = { | |
| "v1_00": "v1.00", | |
| "v1_01": "v1.01", | |
| "v1_02": "v1.02", | |
| "v1_03": "v1.03", | |
| } | |
| token = os.environ.get("HF_TOKEN") | |
| try: | |
| from huggingface_hub import list_repo_files | |
| dirs: set = set() | |
| pat = _re.compile(r'^(v\d+_\d+)/') | |
| for f in list_repo_files(repo_id=MODEL_REPO_ID, token=token): | |
| m = pat.match(f) | |
| if m: | |
| dirs.add(m.group(1)) | |
| if not dirs: | |
| return _FALLBACK | |
| def _vkey(v: str): | |
| return tuple(int(x) for x in _re.findall(r'\d+', v)) | |
| sorted_dirs = sorted(dirs, key=_vkey) | |
| result = {v: v.replace("_", ".", 1) for v in sorted_dirs} | |
| print(f"[Tagger] Discovered versions: {list(result.keys())}") | |
| return result | |
| except Exception as exc: | |
| print(f"[Tagger] WARNING: version discovery failed ({exc}); using fallback list") | |
| return _FALLBACK | |
| _onnx_path: Optional[str] = None | |
| _session = None | |
| _processor = None | |
| _is_naflex: bool = True | |
| _idx_to_tag: Dict[int, str] = {} | |
| _tag_to_category: Dict[str, str] = {} | |
| _loaded_version: Optional[str] = None | |
| def _download_from_hub(version: str) -> Tuple[str, str, str]: | |
| import glob | |
| from huggingface_hub import snapshot_download | |
| token = os.environ.get("HF_TOKEN") | |
| local_dir = snapshot_download( | |
| repo_id=MODEL_REPO_ID, | |
| allow_patterns=[f"{version}/*"], | |
| cache_dir=CACHE_DIR, | |
| token=token, | |
| ) | |
| version_dir = os.path.join(local_dir, version) | |
| onnx_files = [f for f in glob.glob(os.path.join(version_dir, "*.onnx")) | |
| if not f.endswith(".onnx.data")] | |
| if not onnx_files: | |
| raise FileNotFoundError(f"No .onnx file found in {version_dir}") | |
| onnx_path = onnx_files[0] | |
| onnx_base = os.path.splitext(os.path.basename(onnx_path))[0] | |
| # Prefer the per-ONNX vocabulary snapshot ``<onnx_base>_vocabulary.json`` | |
| # (frozen at export time, immune to overwrites). Fall back to any | |
| # ``*vocabulary.json`` in the same directory with a warning. | |
| per_ckpt_vocab = os.path.join(version_dir, f"{onnx_base}_vocabulary.json") | |
| if os.path.isfile(per_ckpt_vocab): | |
| vocab_path = per_ckpt_vocab | |
| else: | |
| vocab_files = glob.glob(os.path.join(version_dir, "*vocabulary.json")) | |
| if not vocab_files: | |
| raise FileNotFoundError(f"No *vocabulary.json found in {version_dir}") | |
| vocab_path = vocab_files[0] | |
| print(f"[Spaces] WARNING: per-ONNX vocabulary " | |
| f"'{onnx_base}_vocabulary.json' not found; falling back to " | |
| f"'{os.path.basename(vocab_path)}'. Tag→idx alignment with " | |
| f"'{onnx_base}.onnx' cannot be verified.") | |
| meta_files = glob.glob(os.path.join(version_dir, "*_metadata.json")) | |
| return onnx_path, vocab_path, meta_files[0] if meta_files else "" | |
| def _vocab_get(vocab: dict, key: str) -> dict: | |
| """Extract *key* from a vocabulary dict with prefix-fallback. | |
| Some exported vocab files store keys under a version prefix, e.g. | |
| ``"v1_03/idx_to_tag"`` instead of plain ``"idx_to_tag"``. The prefix | |
| may change between versions, so we try the bare key first then fall back | |
| to any key whose suffix matches ``/{key}`` or equals ``key``. | |
| """ | |
| if key in vocab: | |
| return vocab[key] | |
| suffix = f"/{key}" | |
| for k, v in vocab.items(): | |
| if isinstance(k, str) and k.endswith(suffix): | |
| print(f"[Tagger] vocab: resolved '{key}' via prefixed key '{k}'") | |
| return v | |
| return {} | |
| def _load_processor_and_vocab(vocab_path: str, meta_path: str = "") -> None: | |
| global _processor, _idx_to_tag, _tag_to_category, _is_naflex | |
| from transformers import AutoProcessor | |
| DEFAULT_REPO = "google/siglip2-so400m-patch16-naflex" | |
| processor_repo = DEFAULT_REPO | |
| _is_naflex = True | |
| if meta_path and os.path.isfile(meta_path): | |
| with open(meta_path, "r", encoding="utf-8") as _fh: | |
| _meta = json.load(_fh) | |
| processor_repo = _meta.get("vision_encoder_repo", DEFAULT_REPO) | |
| _is_naflex = bool(_meta.get("is_naflex", True)) | |
| print(f"[Tagger] Loading processor from {processor_repo} (mode: {'NaFlex' if _is_naflex else 'standard'})...") | |
| try: | |
| _processor = AutoProcessor.from_pretrained(processor_repo, local_files_only=True) | |
| except Exception: | |
| _processor = AutoProcessor.from_pretrained(processor_repo) | |
| with open(vocab_path, "r", encoding="utf-8") as f: | |
| vocab = json.load(f) | |
| raw_idx = _vocab_get(vocab, "idx_to_tag") | |
| if not raw_idx: | |
| raise ValueError(f"'idx_to_tag' not found in {vocab_path} (top-level keys: {list(vocab.keys())[:8]})") | |
| _idx_to_tag = {int(k): v for k, v in raw_idx.items()} | |
| _tag_to_category = _vocab_get(vocab, "tag_to_category") | |
| print(f"[Tagger] Vocabulary loaded | {len(_idx_to_tag)} tags") | |
| def _ensure_session() -> None: | |
| global _session | |
| if _session is not None: | |
| return | |
| import onnxruntime as ort | |
| opts = ort.SessionOptions() | |
| opts.log_severity_level = 2 | |
| _session = ort.InferenceSession( | |
| _onnx_path, | |
| sess_options=opts, | |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"], | |
| ) | |
| print(f"[Tagger] ONNX session ready | provider={_session.get_providers()[0]}") | |
| def load_version(version: str) -> str: | |
| """Download (if needed) and activate the selected model version. | |
| Returns a status string shown in the UI. | |
| """ | |
| global _onnx_path, _session, _loaded_version | |
| if version == _loaded_version: | |
| label = MODEL_VERSIONS.get(version, version) | |
| return f"Already loaded: {label}" | |
| # Accept versions not yet in MODEL_VERSIONS (e.g. discovered mid-session). | |
| label = MODEL_VERSIONS.get(version, version) | |
| print(f"[Tagger] Switching to version '{label}'...") | |
| try: | |
| onnx_path, vocab_path, meta_path = _download_from_hub(version) | |
| except Exception as e: | |
| return f"Download failed: {e}" | |
| # Reset session so _ensure_session() recreates it for the new ONNX file. | |
| _session = None | |
| _onnx_path = onnx_path | |
| try: | |
| _load_processor_and_vocab(vocab_path, meta_path) | |
| except Exception as e: | |
| return f"Load failed: {e}" | |
| _loaded_version = version | |
| mode = "NaFlex" if _is_naflex else "Standard" | |
| return f"Loaded: {label} ({mode}, {len(_idx_to_tag)} tags)" | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def predict( | |
| image: Image.Image | None, | |
| threshold_mode: str, | |
| unified_threshold: float, | |
| general_threshold: float, | |
| char_threshold: float, | |
| max_num_patches: int = 256, | |
| ) -> str: | |
| """Run inference and return results HTML.""" | |
| if image is None: | |
| return "<p style='color:#888'>Upload an image to start.</p>" | |
| if _onnx_path is None: | |
| return "<p style='color:#f87171'>Model not loaded.</p>" | |
| _ensure_session() | |
| image = image.convert("RGB") | |
| if _is_naflex: | |
| inputs = _processor(images=image, return_tensors="pt", max_num_patches=max_num_patches) | |
| pv = inputs["pixel_values"].float().numpy() | |
| pam = inputs["pixel_attention_mask"].float().numpy() | |
| ss = inputs["spatial_shapes"].numpy().astype(np.int64) | |
| outputs = _session.run( | |
| ["logits"], | |
| {"pixel_values": pv, "pixel_attention_mask": pam, "spatial_shapes": ss}, | |
| ) | |
| else: | |
| inputs = _processor(images=image, return_tensors="pt") | |
| pv = inputs["pixel_values"].float().numpy() | |
| outputs = _session.run(["logits"], {"pixel_values": pv}) | |
| logits = outputs[0][0].astype(np.float64) | |
| probs = 1.0 / (1.0 + np.exp(-logits)) | |
| # Build per-tag results; skip Unknown-category tags entirely. | |
| all_items: List[Dict[str, Any]] = [] | |
| for i, prob in enumerate(probs): | |
| tag = _idx_to_tag.get(i, f"__unk_{i}__") | |
| cat = _tag_to_category.get(tag, "Unknown") | |
| if cat == "Unknown": | |
| continue | |
| all_items.append({"tag": tag, "prob": float(prob), "category": cat}) | |
| # Quality / Rating: always show top-1 | |
| quality_top = _pick_top(all_items, "Quality") | |
| rating_top = _pick_top(all_items, "Rating") | |
| # Build threshold map | |
| if threshold_mode == "Unified": | |
| thr_map: Dict[str, float] = {c: unified_threshold for c in CATEGORY_ORDER} | |
| else: | |
| thr_map = { | |
| **{c: char_threshold for c in CHAR_THRESHOLD_CATS}, | |
| **{c: general_threshold for c in GENERAL_THRESHOLD_CATS}, | |
| } | |
| filtered = [ | |
| it for it in all_items | |
| if it["category"] not in ("Quality", "Rating") | |
| and it["prob"] >= thr_map.get(it["category"], unified_threshold) | |
| ] | |
| filtered.sort(key=lambda x: x["prob"], reverse=True) | |
| # Build plaintext tags for the copy button | |
| tag_parts: List[str] = [] | |
| if quality_top: | |
| tag_parts.append(quality_top["tag"]) | |
| if rating_top: | |
| tag_parts.append(rating_top["tag"]) | |
| tag_parts.extend(it["tag"] for it in filtered) | |
| tags_text = ", ".join(tag_parts) | |
| return _build_html(quality_top, rating_top, filtered, tags_text) | |
| def _pick_top(items: List[Dict], category: str) -> Optional[Dict]: | |
| cat_items = [it for it in items if it["category"] == category] | |
| return max(cat_items, key=lambda x: x["prob"]) if cat_items else None | |
| # --------------------------------------------------------------------------- | |
| # HTML result builder | |
| # --------------------------------------------------------------------------- | |
| def _bar_html(tag: str, prob: float, color: str) -> str: | |
| pct = prob * 100 | |
| # Compact bar designed for multi-column grid cells (~300px wide each) | |
| return ( | |
| f'<div style="display:flex;align-items:center;gap:5px;margin:2px 0;min-width:0">' | |
| f'<span style="width:130px;flex-shrink:0;font-size:12px;color:#e5e7eb;' | |
| f'overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="{tag}">{tag}</span>' | |
| f'<div style="flex:1;height:12px;background:#374151;border-radius:2px;overflow:hidden;min-width:0">' | |
| f'<div style="width:{pct:.1f}%;height:100%;background:{color};border-radius:2px"></div>' | |
| f'</div>' | |
| f'<span style="width:40px;flex-shrink:0;text-align:right;font-size:11px;' | |
| f'color:#9ca3af;font-family:monospace">{pct:.1f}%</span>' | |
| f'</div>' | |
| ) | |
| def _copy_button_html(tags_text: str, n_tags: int) -> str: | |
| import html as _html | |
| escaped = _html.escape(tags_text, quote=True) | |
| return ( | |
| f'<div style="display:flex;align-items:center;gap:10px;margin-bottom:10px;' | |
| f'padding-bottom:8px;border-bottom:1px solid #374151">' | |
| f'<button ' | |
| f' data-tags="{escaped}" ' | |
| f' onclick="navigator.clipboard.writeText(this.dataset.tags)' | |
| f'.then(()=>{{this.textContent=\'Copied!\';setTimeout(()=>this.textContent=\'Copy Tags\',1500)}})' | |
| f'.catch(()=>{{}})" ' | |
| f' style="background:#3b82f6;color:#fff;border:none;padding:5px 14px;' | |
| f'border-radius:4px;cursor:pointer;font-size:13px;font-weight:600">' | |
| f'Copy Tags' | |
| f'</button>' | |
| f'<span style="font-size:12px;color:#6b7280">{n_tags} tags</span>' | |
| f'</div>' | |
| ) | |
| def _build_html( | |
| quality_top: Optional[Dict], | |
| rating_top: Optional[Dict], | |
| filtered: List[Dict], | |
| tags_text: str, | |
| ) -> str: | |
| parts: List[str] = ['<div style="font-family:sans-serif">'] | |
| # Copy button + tag count header | |
| n_tags = len(filtered) + (1 if quality_top else 0) + (1 if rating_top else 0) | |
| parts.append(_copy_button_html(tags_text, n_tags)) | |
| # Pinned: Quality + Rating (single row each, no multi-column) | |
| if quality_top or rating_top: | |
| parts.append('<div style="margin-bottom:12px;padding-bottom:8px;border-bottom:1px solid #374151">') | |
| if quality_top: | |
| c = CATEGORY_COLORS["Quality"] | |
| parts.append(f'<div style="font-size:11px;color:{c};font-weight:600;margin-bottom:2px">Quality</div>') | |
| parts.append(_bar_html(quality_top["tag"], quality_top["prob"], c)) | |
| if rating_top: | |
| c = CATEGORY_COLORS["Rating"] | |
| parts.append(f'<div style="font-size:11px;color:{c};font-weight:600;margin-top:6px;margin-bottom:2px">Rating</div>') | |
| parts.append(_bar_html(rating_top["tag"], rating_top["prob"], c)) | |
| parts.append('</div>') | |
| # Grouped categories — multi-column grid | |
| grouped: Dict[str, List[Dict]] = {} | |
| for it in filtered: | |
| grouped.setdefault(it["category"], []).append(it) | |
| for cat in CATEGORY_ORDER: | |
| items = grouped.get(cat) | |
| if not items: | |
| continue | |
| c = CATEGORY_COLORS.get(cat, "#888") | |
| parts.append( | |
| f'<details open style="margin-bottom:10px">' | |
| f'<summary style="cursor:pointer;font-size:12px;font-weight:600;' | |
| f'color:{c};margin-bottom:4px;user-select:none">' | |
| f'{cat} ({len(items)})</summary>' | |
| # auto-fill: ≥2 columns when the container is wide enough | |
| f'<div style="display:grid;grid-template-columns:repeat(auto-fill,minmax(280px,1fr));gap:0 12px">' | |
| ) | |
| for it in items: | |
| parts.append(_bar_html(it["tag"], it["prob"], c)) | |
| parts.append('</div></details>') | |
| if not filtered and not quality_top and not rating_top: | |
| parts.append('<p style="color:#888;font-size:13px">No tags above threshold.</p>') | |
| parts.append('</div>') | |
| return "\n".join(parts) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| CSS = """ | |
| body, .gradio-container { background:#111827 !important; color:#e5e7eb !important; } | |
| button.primary { background:#3b82f6 !important; } | |
| /* Prevent columns from reflowing when image loads */ | |
| .tagger-row { align-items: flex-start !important; } | |
| .tagger-left { min-width: 0; flex-shrink: 0; } | |
| .tagger-right { min-width: 0; flex: 1 1 0; overflow: hidden; } | |
| """ | |
| def create_ui() -> gr.Blocks: | |
| # Read globals that are set by _fetch_available_versions() before this call. | |
| version_choices = [(label, key) for key, label in MODEL_VERSIONS.items()] | |
| default_thr = _default_threshold(DEFAULT_VERSION) | |
| with gr.Blocks(css=CSS, title="SigLIP2 Tagger", theme=gr.themes.Base()) as demo: | |
| gr.Markdown("<h2 style='color:#e5e7eb;margin:8px 0'>SigLIP2 Tagger</h2>") | |
| with gr.Row(elem_classes="tagger-row"): | |
| # --- Left panel: image + controls --- | |
| with gr.Column(scale=1, elem_classes="tagger-left"): | |
| image_input = gr.Image(type="pil", label="Image", height=380) | |
| with gr.Row(): | |
| version_dropdown = gr.Dropdown( | |
| choices=version_choices, | |
| value=DEFAULT_VERSION, | |
| label="Model Version", | |
| scale=3, | |
| ) | |
| load_btn = gr.Button("Load", scale=1) | |
| model_status = gr.Textbox( | |
| value="", | |
| label="Status", | |
| interactive=False, | |
| lines=3, | |
| max_lines=8, | |
| ) | |
| threshold_mode = gr.Radio( | |
| ["Unified", "Per Category"], | |
| value="Unified", | |
| label="Threshold Mode", | |
| ) | |
| # Unified mode | |
| unified_slider = gr.Slider( | |
| 0.01, 0.99, value=default_thr, step=0.01, | |
| label="Threshold", | |
| visible=True, | |
| ) | |
| # Per-category mode | |
| general_slider = gr.Slider( | |
| 0.01, 0.99, value=default_thr, step=0.01, | |
| label="General / Meta Threshold", | |
| visible=False, | |
| ) | |
| char_slider = gr.Slider( | |
| 0.01, 0.99, value=default_thr, step=0.01, | |
| label="Character / Copyright Threshold", | |
| visible=False, | |
| ) | |
| predict_btn = gr.Button("Run", variant="primary") | |
| # --- Right panel: results --- | |
| with gr.Column(scale=2, elem_classes="tagger-right"): | |
| html_output = gr.HTML() | |
| # Load button — step-by-step progress via generator, then reset thresholds. | |
| def _load_and_reset(version: str): | |
| global _onnx_path, _session, _loaded_version | |
| _no_update = (gr.update(), gr.update(), gr.update()) | |
| if version == _loaded_version: | |
| label = MODEL_VERSIONS.get(version, version) | |
| yield (f"Already loaded: {label}", *_no_update) | |
| return | |
| label = MODEL_VERSIONS.get(version, version) | |
| yield (f"[1/3] Downloading {label} from Hub…", *_no_update) | |
| try: | |
| onnx_path, vocab_path, meta_path = _download_from_hub(version) | |
| except Exception as exc: | |
| yield (f"Download failed:\n{exc}", *_no_update) | |
| return | |
| yield (f"[2/3] Loading processor and vocabulary…", *_no_update) | |
| _session = None | |
| _onnx_path = onnx_path | |
| try: | |
| _load_processor_and_vocab(vocab_path, meta_path) | |
| except Exception as exc: | |
| yield (f"Load failed:\n{exc}", *_no_update) | |
| return | |
| _loaded_version = version | |
| mode = "NaFlex" if _is_naflex else "Standard" | |
| thr = _default_threshold(version) | |
| status = f"[3/3] Ready: {label} ({mode}, {len(_idx_to_tag):,} tags)" | |
| yield (status, gr.update(value=thr), gr.update(value=thr), gr.update(value=thr)) | |
| load_btn.click( | |
| fn=_load_and_reset, | |
| inputs=[version_dropdown], | |
| outputs=[model_status, unified_slider, general_slider, char_slider], | |
| ) | |
| # Toggle slider visibility on mode change | |
| def _toggle_mode(mode: str): | |
| unified = mode == "Unified" | |
| return gr.update(visible=unified), gr.update(visible=not unified), gr.update(visible=not unified) | |
| threshold_mode.change( | |
| fn=_toggle_mode, | |
| inputs=[threshold_mode], | |
| outputs=[unified_slider, general_slider, char_slider], | |
| ) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[image_input, threshold_mode, unified_slider, general_slider, char_slider], | |
| outputs=[html_output], | |
| ) | |
| demo.queue() | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| def _init_from_hub(version: str) -> None: | |
| global _onnx_path, _loaded_version | |
| print(f"[Tagger] Downloading model version '{version}' from {MODEL_REPO_ID}...") | |
| onnx_path, vocab_path, meta_path = _download_from_hub(version) | |
| _onnx_path = onnx_path | |
| _load_processor_and_vocab(vocab_path, meta_path) | |
| _loaded_version = version | |
| def _init_local(onnx_path: str, vocab_path: str) -> None: | |
| global _onnx_path | |
| _onnx_path = onnx_path | |
| meta_path = onnx_path.replace(".onnx", "_metadata.json") if onnx_path else "" | |
| _load_processor_and_vocab(vocab_path, meta_path) | |
| def _setup_versions(requested_version: str = "") -> str: | |
| """Discover available versions, populate MODEL_VERSIONS, return resolved default.""" | |
| global MODEL_VERSIONS, DEFAULT_VERSION | |
| MODEL_VERSIONS = _fetch_available_versions() | |
| # Pick the newest discovered version as default; fall back to the requested | |
| # version (or hardcoded sentinel) if discovery returned nothing. | |
| if MODEL_VERSIONS: | |
| DEFAULT_VERSION = next(reversed(MODEL_VERSIONS)) | |
| elif requested_version and requested_version in MODEL_VERSIONS: | |
| DEFAULT_VERSION = requested_version | |
| # If an explicit version was requested (CLI --version), honour it. | |
| if requested_version and requested_version in MODEL_VERSIONS: | |
| return requested_version | |
| return DEFAULT_VERSION | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="SigLIP2 Tagger Gradio App") | |
| parser.add_argument("--onnx", default="", help="Local ONNX path") | |
| parser.add_argument("--vocab", default="", help="Local vocabulary JSON path") | |
| parser.add_argument("--version", default="", help="Model version to load (default: latest)") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| args = parser.parse_args() | |
| if args.onnx and args.vocab: | |
| _setup_versions() | |
| _init_local(args.onnx, args.vocab) | |
| else: | |
| version = _setup_versions(args.version) | |
| _init_from_hub(version) | |
| demo = create_ui() | |
| demo.launch(server_port=args.port, share=args.share) | |
| else: | |
| # HuggingFace Spaces: discover versions, then auto-load the latest. | |
| _init_from_hub(_setup_versions()) | |