Spaces:
Sleeping
Sleeping
| import time, math | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # --- Model choice (CPU-friendly) --- | |
| MODEL_ID = "google/vit-base-patch16-224" # alternatives: "microsoft/resnet-50", "facebook/convnext-tiny-224" | |
| # Load once at startup | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| LABELS = model.config.id2label # {idx: "label"} | |
| def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor: | |
| """Softmax(logits / T) with numerical stability.""" | |
| if temperature <= 0: | |
| temperature = 1.0 | |
| scaled = logits / float(temperature) | |
| scaled = scaled - torch.max(scaled) | |
| exp = torch.exp(scaled) | |
| return exp / torch.sum(exp) | |
| def _entropy(probs: np.ndarray) -> float: | |
| """Shannon entropy in nats: -Σ p log p (ignore zeros).""" | |
| p = probs[probs > 0] | |
| return float(-(p * np.log(p)).sum()) | |
| def _kl(p: np.ndarray, q: np.ndarray) -> float: | |
| """KL divergence KL(p||q), with small epsilon for stability.""" | |
| eps = 1e-12 | |
| p = p + eps | |
| q = q + eps | |
| return float(np.sum(p * np.log(p / q))) | |
| def _jsd(p: np.ndarray, q: np.ndarray) -> float: | |
| """Jensen–Shannon divergence (symmetric, bounded).""" | |
| m = 0.5 * (p + q) | |
| return 0.5 * _kl(p, m) + 0.5 * _kl(q, m) | |
| def _make_bar(labels, probs, title="Top-K predicted classes"): | |
| """Return a matplotlib horizontal bar chart of probabilities.""" | |
| fig, ax = plt.subplots(figsize=(6, 3.2)) | |
| y = np.arange(len(labels)) | |
| ax.barh(y, probs) # default colors only (per teaching tool rules) | |
| ax.set_yticks(y, labels) | |
| ax.invert_yaxis() | |
| ax.set_xlim(0, 1) | |
| ax.set_xlabel("Probability") | |
| ax.set_title(title) | |
| fig.tight_layout() | |
| return fig | |
| def _analyze_single(img, top_k=5, temperature=1.0): | |
| """ | |
| Return: (quick_label_dict, bar_plot, table_rows, notes_markdown, full_probs_numpy) | |
| """ | |
| if img is None: | |
| return ({"<no image>": 1.0}, None, [], "Please upload an image.", None) | |
| t0 = time.perf_counter() | |
| inputs = processor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits[0] # shape [num_labels] | |
| probs = _softmax_with_temperature(logits, temperature) | |
| k = max(1, int(top_k)) | |
| k = min(k, probs.shape[0]) | |
| top_vals, top_idx = torch.topk(probs, k=k, dim=-1) | |
| top_idx = top_idx.tolist() | |
| top_vals = top_vals.tolist() | |
| labels = [LABELS[i] for i in top_idx] | |
| logits_top = [float(logits[i]) for i in top_idx] | |
| quick = {lab: float(p) for lab, p in zip(labels, top_vals)} | |
| fig = _make_bar(labels, top_vals) | |
| rows = [] | |
| for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1): | |
| rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)]) | |
| probs_np = probs.detach().cpu().numpy() | |
| H = _entropy(probs_np) | |
| top1 = float(top_vals[0]) | |
| top2 = float(top_vals[1]) if len(top_vals) > 1 else 0.0 | |
| margin = top1 - top2 | |
| cum_topk = float(sum(top_vals)) | |
| infer_ms = (time.perf_counter() - t0) * 1000.0 | |
| size = processor.size if hasattr(processor, "size") else {} | |
| target_h = size.get("height", None) | |
| target_w = size.get("width", None) | |
| size_str = f"{target_h}×{target_w}" if (target_h and target_w) else "model default" | |
| md = ( | |
| f"**Uncertainty** \n" | |
| f"- Entropy (lower→more confident): **{H:.3f} nats** \n" | |
| f"- Top-1 margin (Top-1 − Top-2): **{margin:.3f}** \n" | |
| f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n" | |
| f"**Preprocessing & Runtime** \n" | |
| f"- Processor target size: **{size_str}** \n" | |
| f"- Inference time: **{infer_ms:.1f} ms** \n" | |
| ) | |
| return quick, fig, rows, md, probs_np | |
| def _align_topk(labelsA, probsA, labelsB, probsB, K=5): | |
| """Make a unified label set of size up to K using union-of-top labels then rank by max(prob).""" | |
| dA = dict(zip(labelsA, probsA)) | |
| dB = dict(zip(labelsB, probsB)) | |
| union = set(labelsA) | set(labelsB) | |
| # rank by max(prob from A, prob from B) | |
| ranked = sorted(list(union), key=lambda x: max(dA.get(x, 0.0), dB.get(x, 0.0)), reverse=True) | |
| chosen = ranked[:K] | |
| a = [float(dA.get(l, 0.0)) for l in chosen] | |
| b = [float(dB.get(l, 0.0)) for l in chosen] | |
| return chosen, a, b | |
| def analyze_single(img, top_k=5, temperature=1.0): | |
| quick, fig, rows, md, _ = _analyze_single(img, top_k, temperature) | |
| return quick, fig, rows, md | |
| def analyze_pair(imgA, imgB, top_k=5, temperature=1.0): | |
| """ | |
| A/B analysis: | |
| - show per-image quick dict, bar chart, table, notes | |
| - show aligned Top-K delta bar and divergence metrics | |
| """ | |
| # Analyze each side | |
| qa, figa, rowsa, mda, pa = _analyze_single(imgA, top_k, temperature) | |
| qb, figb, rowsb, mdb, pb = _analyze_single(imgB, top_k, temperature) | |
| # If either missing, return as-is | |
| if pa is None or pb is None: | |
| return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, None, "Upload both images for delta metrics." | |
| # Build aligned top-K over labels | |
| # We need label sets and probs for both to compute aligned bars | |
| # Recover top-K labels directly from rows (rank, label, prob, logit) | |
| labelsA = [r[1] for r in rowsa] | |
| probsA = [r[2] for r in rowsa] | |
| labelsB = [r[1] for r in rowsb] | |
| probsB = [r[2] for r in rowsb] | |
| chosen, a, b = _align_topk(labelsA, probsA, labelsB, probsB, K=max(int(top_k), 1)) | |
| # Delta bar (A−B) | |
| deltas = [float(x - y) for x, y in zip(a, b)] | |
| fig_delta = _make_bar([f"{lbl} (Δ)" for lbl in chosen], deltas, title="Aligned Top-K Δ Probabilities (A − B)") | |
| # Distribution-level differences (full softmax vectors) | |
| # Ensure same length and normalize to prob distributions | |
| pa = pa / (pa.sum() + 1e-12) | |
| pb = pb / (pb.sum() + 1e-12) | |
| H_a = _entropy(pa) | |
| H_b = _entropy(pb) | |
| jsd = _jsd(pa, pb) | |
| # Top-1 labels for each side | |
| top1_a_idx = int(np.argmax(pa)) | |
| top1_b_idx = int(np.argmax(pb)) | |
| top1_a = LABELS[top1_a_idx] | |
| top1_b = LABELS[top1_b_idx] | |
| diff_md = ( | |
| f"**A/B Divergence** \n" | |
| f"- Jensen–Shannon divergence: **{jsd:.4f}** (0=same, higher=more different) \n" | |
| f"- Entropy A / B: **{H_a:.3f} / {H_b:.3f}** nats \n" | |
| f"- Top-1 A / B: **{top1_a} / {top1_b}** \n" | |
| f"- Aligned Top-K shown above is ranked by max(prob_A, prob_B). \n" | |
| f"_Tip:_ Try different crops/lighting or adjust **Temperature** to watch distributions change." | |
| ) | |
| return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, fig_delta, diff_md | |
| with gr.Blocks(fill_height=True, analytics_enabled=False) as demo: | |
| gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nVisualize probabilities, logits, entropy, and A/B deltas.\n\n" | |
| "_Notes:_ Predictions reflect the ImageNet‑1k label space; unusual objects or logos may be misclassified. " | |
| "Do not use for identity or sensitive inferences.") | |
| with gr.Tab("Single Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img = gr.Image(type="pil", label="Upload image (JPEG/PNG)", height=340) | |
| topk = gr.Slider(1, 10, value=5, step=1, label="Top‑K predictions") | |
| temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature") | |
| run = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Quick glance") | |
| glance = gr.Label(num_top_classes=10) | |
| gr.Markdown("### Probabilities (Top‑K)") | |
| plot = gr.Plot() | |
| gr.Markdown("### Details (Top‑K)") | |
| table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5) | |
| gr.Markdown("### Metrics & Notes") | |
| notes = gr.Markdown() | |
| run.click(analyze_single, [img, topk, temp], [glance, plot, table, notes]) | |
| with gr.Tab("A/B Compare"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| imgA = gr.Image(type="pil", label="Image A", height=300) | |
| imgB = gr.Image(type="pil", label="Image B", height=300) | |
| topkAB = gr.Slider(1, 10, value=5, step=1, label="Aligned Top‑K") | |
| tempAB = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature") | |
| runAB = gr.Button("Analyze A/B", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### A — Quick glance") | |
| glanceA = gr.Label(num_top_classes=10) | |
| gr.Markdown("### A — Probabilities (Top‑K)") | |
| plotA = gr.Plot() | |
| gr.Markdown("### A — Details (Top‑K)") | |
| tableA = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5) | |
| gr.Markdown("### A — Notes") | |
| notesA = gr.Markdown() | |
| with gr.Column(scale=1): | |
| gr.Markdown("### B — Quick glance") | |
| glanceB = gr.Label(num_top_classes=10) | |
| gr.Markdown("### B — Probabilities (Top‑K)") | |
| plotB = gr.Plot() | |
| gr.Markdown("### B — Details (Top‑K)") | |
| tableB = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5) | |
| gr.Markdown("### B — Notes") | |
| notesB = gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown("### Aligned Top‑K Δ (A − B) & Divergence") | |
| with gr.Row(): | |
| deltaPlot = gr.Plot() | |
| deltaNotes = gr.Markdown() | |
| runAB.click(analyze_pair, [imgA, imgB, topkAB, tempAB], [glanceA, plotA, tableA, notesA, glanceB, plotB, tableB, notesB, deltaPlot, deltaNotes]) | |
| if __name__ == "__main__": | |
| demo.launch() | |