Spaces:
Sleeping
Sleeping
| """ | |
| VisionQuery β Zero-Shot Image Understanding with SigLIP | |
| Built with Taipy GUI | Deployed on Hugging Face Spaces | |
| """ | |
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| from taipy.gui import Gui, notify, State | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL (loaded lazily on first inference) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _processor = None | |
| _model = None | |
| def _load_siglip(): | |
| global _processor, _model | |
| if _model is None: | |
| from transformers import AutoProcessor, AutoModel | |
| model_id = "google/siglip-base-patch16-224" | |
| _processor = AutoProcessor.from_pretrained(model_id) | |
| _model = AutoModel.from_pretrained(model_id) | |
| _model.eval() | |
| return _processor, _model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _empty_chart(msg="Upload an image and click Analyze to see results"): | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| x=0.5, y=0.5, xref="paper", yref="paper", | |
| text=msg, showarrow=False, | |
| font=dict(size=14, color="#94a3b8"), | |
| ) | |
| fig.update_layout( | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| plot_bgcolor="rgba(0,0,0,0)", | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| height=300, | |
| margin=dict(l=10, r=10, t=10, b=10), | |
| ) | |
| return fig | |
| def _make_bar_chart(labels, scores): | |
| n = len(labels) | |
| alphas = [max(0.20, s / 100) for s in scores] | |
| colors = [f"rgba(99,102,241,{a:.2f})" for a in alphas] | |
| fig = go.Figure(go.Bar( | |
| x=scores, | |
| y=labels, | |
| orientation="h", | |
| marker=dict(color=colors, line=dict(width=0)), | |
| text=[f" {s:.1f}%" for s in scores], | |
| textposition="outside", | |
| )) | |
| fig.update_layout( | |
| title=dict( | |
| text="SigLIP Similarity Scores", | |
| font=dict(size=18, color="#312e81"), | |
| x=0.02, | |
| ), | |
| xaxis=dict( | |
| title="Score (%)", | |
| range=[0, min(100, max(scores) * 1.35)], | |
| gridcolor="#e2e8f0", | |
| ), | |
| yaxis=dict(autorange="reversed", gridcolor="#e2e8f0"), | |
| height=max(320, n * 52 + 100), | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| plot_bgcolor="rgba(248,250,252,1)", | |
| font=dict(color="#1e293b", size=13), | |
| margin=dict(l=10, r=100, t=60, b=40), | |
| hoverlabel=dict(bgcolor="#312e81", font_color="white"), | |
| ) | |
| return fig | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GLOBAL STATE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| uploaded_image = None # bound to file_selector | |
| display_image = None # bound to <image> | |
| text_input = ( | |
| "a cat, a dog, a car, a person walking, " | |
| "a sunset, a building, a flower, an animal" | |
| ) | |
| chart_data = pd.DataFrame({"Label": [], "Score": []}) | |
| chart_empty = True | |
| score_df = pd.DataFrame(columns=["Rank", "Label", "Score (%)"]) | |
| status_msg = "Upload an image and click **Analyze** to begin." | |
| top_label = "" | |
| top_score = 0.0 | |
| has_results = False | |
| is_analyzing = False | |
| model_status = "β³ Model loads on first inference (~15-30 s)" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CALLBACKS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_file_upload(state: State, var_name: str, value): | |
| if state.uploaded_image: | |
| state.display_image = state.uploaded_image | |
| state.status_msg = "β Image ready β click **Analyze** to run SigLIP." | |
| state.has_results = False | |
| state.chart_fig = _empty_chart("Image loaded. Click Analyze!") | |
| state.score_df = pd.DataFrame(columns=["Rank", "Label", "Score (%)"]) | |
| state.top_label = "" | |
| notify(state, "success", "Image uploaded successfully!") | |
| def analyze(state: State): | |
| if not state.display_image: | |
| notify(state, "warning", "Please upload an image first.") | |
| return | |
| label_list = [l.strip() for l in state.text_input.split(",") if l.strip()] | |
| if not label_list: | |
| notify(state, "warning", "Enter at least one comma-separated label.") | |
| return | |
| state.is_analyzing = True | |
| state.status_msg = "π Loading SigLIP model & running inferenceβ¦" | |
| try: | |
| proc, mdl = _load_siglip() | |
| state.model_status = "β google/siglip-base-patch16-224 β ready" | |
| img = Image.open(state.display_image).convert("RGB") | |
| with torch.no_grad(): | |
| inputs = proc( | |
| text=label_list, | |
| images=img, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| logits = mdl(**inputs).logits_per_image # shape: (1, N) | |
| probs = torch.sigmoid(logits).squeeze() # shape: (N,) | |
| if probs.dim() == 0: | |
| probs = probs.unsqueeze(0) | |
| probs = probs.tolist() | |
| pairs = sorted(zip(label_list, probs), key=lambda x: x[1], reverse=True) | |
| labels = [p[0] for p in pairs] | |
| scores = [round(p[1] * 100, 2) for p in pairs] | |
| state.top_label = labels[0] | |
| state.top_score = scores[0] | |
| state.chart_data = pd.DataFrame({"Label": labels, "Score (%)": scores}) | |
| state.chart_empty = False | |
| state.score_df = pd.DataFrame({ | |
| "Rank": list(range(1, len(labels) + 1)), | |
| "Label": labels, | |
| "Score (%)": [f"{s:.2f}" for s in scores], # β string, never blank | |
| }) | |
| state.has_results = True | |
| state.status_msg = f"β Top match: **{labels[0]}** ({scores[0]:.1f}%)" | |
| notify(state, "success", "Analysis complete!") | |
| except Exception as exc: | |
| state.status_msg = f"β Error: {exc}" | |
| notify(state, "error", str(exc)) | |
| finally: | |
| state.is_analyzing = False | |
| def reset(state: State): | |
| state.uploaded_image = None | |
| state.display_image = None | |
| state.chart_data = pd.DataFrame({"Label": [], "Score (%)": []}) | |
| state.chart_empty = True | |
| state.score_df = pd.DataFrame(columns=["Rank", "Label", "Score (%)"]) | |
| state.top_label = "" | |
| state.top_score = 0.0 | |
| state.has_results = False | |
| state.status_msg = "Upload a new image and click Analyze." | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PAGE β DEMO | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo_md = """ | |
| <|part|class_name=page-header| | |
| # π VisionQuery | |
| ### Zero-Shot Image Classification powered by Google SigLIP + Taipy | |
| |> | |
| <|layout|columns=5 7|gap=2.5rem|class_name=main-layout| | |
| <|part|class_name=panel card| | |
| #### Step 1 β Upload Image | |
| <|{uploaded_image}|file_selector|label=π Choose Imageβ¦|extensions=.jpg,.jpeg,.png,.webp|drop_message=Drop image here|on_action=on_file_upload|class_name=upload-btn|> | |
| <|{display_image}|image|width=100%|class_name=preview-img|> | |
| --- | |
| #### Step 2 β Enter Text Labels | |
| *Comma-separated concepts to test against the image:* | |
| <|{text_input}|input|multiline|rows=5|class_name=fullwidth label-input|> | |
| <|π Analyze Image|button|on_action=analyze|active={not is_analyzing}|class_name=plain analyze-btn|> | |
| <| βΊ Reset|button|on_action=reset|class_name=reset-btn|> | |
| --- | |
| <|{status_msg}|text|class_name=status-text|> | |
| <|{model_status}|text|class_name=model-tag|> | |
| |> | |
| <|part|class_name=panel card| | |
| #### Results | |
| <|part|render={has_results}|class_name=winner-card| | |
| <|layout|columns=1 1|gap=1rem| | |
| <|part| | |
| π **Best Match** | |
| <|{top_label}|text|class_name=winner-label|> | |
| |> | |
| <|part| | |
| π **Confidence** | |
| <|{top_score:.1f}|text|class_name=winner-score|>% | |
| |> | |
| |> | |
| |> | |
| <|{chart_data}|chart|type=bar|x=Score (%)|y=Label|orientation=h|title=SigLIP Similarity Scores|height=350px|> | |
| <|part|render={has_results}|class_name=score-table| | |
| **Detailed Scores:** | |
| <|{score_df}|table|width=100%|page_size=10|> | |
| |> | |
| |> | |
| |> | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PAGE β ABOUT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| about_md = """ | |
| <|part|class_name=page-header| | |
| # π§ About VisionQuery | |
| ### Problem Β· Solution Β· Technology Stack | |
| |> | |
| <|layout|columns=1 1|gap=2rem| | |
| <|part|class_name=card problem-card| | |
| ## π΄ The Problem | |
| Traditional image classification requires: | |
| - **Thousands of labeled images** per category | |
| - **Expensive GPU training** pipelines | |
| - **Re-training** whenever you add a new category | |
| - **Domain expertise** to build & maintain | |
| This makes vision AI **slow, costly, and inflexible** for real-world deployment. | |
| |> | |
| <|part|class_name=card solution-card| | |
| ## π’ The Solution | |
| **VisionQuery AI** uses **SigLIP** by Google DeepMind for **zero-shot classification**: | |
| - Describe what you're looking for in **plain English** | |
| - No training data required β ever | |
| - Add **unlimited new categories** instantly | |
| - Works in **100+ languages** (multilingual SigLIP) | |
| |> | |
| |> | |
| --- | |
| ### π οΈ Tech Stack | |
| **Model Layer** | |
| π€ `google/siglip-base-patch16-224` | |
| PyTorch + Transformers | |
| **GUI Layer** | |
| Taipy β Python-native reactive GUI | |
| Plotly interactive charts | |
| **Deployment** | |
| Hugging Face Spaces (Docker) | |
| |> | |
| |> | |
| --- | |
| ## π Citation | |
| > Zhai, X. et al. (2023). *Sigmoid Loss for Language Image Pre-Training.* | |
| > Google DeepMind. arXiv:2303.15343 | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RUN | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| pages = { | |
| "/": demo_md, | |
| "About": about_md, | |
| } | |
| gui = Gui(pages=pages, css_file="style.css") | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| gui.run( | |
| host="0.0.0.0", | |
| port=port, | |
| title="VisionQuery AI β SigLIP", | |
| favicon="π", | |
| use_reloader=False, | |
| dark_mode=False, | |
| ) | |