import time, math import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from transformers import AutoImageProcessor, AutoModelForImageClassification # --- Model choice (CPU-friendly) --- MODEL_ID = "google/vit-base-patch16-224" # alternatives: "microsoft/resnet-50", "facebook/convnext-tiny-224" # Load once at startup processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForImageClassification.from_pretrained(MODEL_ID) model.eval() LABELS = model.config.id2label # {idx: "label"} def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor: """Softmax(logits / T) with numerical stability.""" if temperature <= 0: temperature = 1.0 scaled = logits / float(temperature) scaled = scaled - torch.max(scaled) exp = torch.exp(scaled) return exp / torch.sum(exp) def _entropy(probs: np.ndarray) -> float: """Shannon entropy in nats: -Σ p log p (ignore zeros).""" p = probs[probs > 0] return float(-(p * np.log(p)).sum()) def _kl(p: np.ndarray, q: np.ndarray) -> float: """KL divergence KL(p||q), with small epsilon for stability.""" eps = 1e-12 p = p + eps q = q + eps return float(np.sum(p * np.log(p / q))) def _jsd(p: np.ndarray, q: np.ndarray) -> float: """Jensen–Shannon divergence (symmetric, bounded).""" m = 0.5 * (p + q) return 0.5 * _kl(p, m) + 0.5 * _kl(q, m) def _make_bar(labels, probs, title="Top-K predicted classes"): """Return a matplotlib horizontal bar chart of probabilities.""" fig, ax = plt.subplots(figsize=(6, 3.2)) y = np.arange(len(labels)) ax.barh(y, probs) # default colors only (per teaching tool rules) ax.set_yticks(y, labels) ax.invert_yaxis() ax.set_xlim(0, 1) ax.set_xlabel("Probability") ax.set_title(title) fig.tight_layout() return fig def _analyze_single(img, top_k=5, temperature=1.0): """ Return: (quick_label_dict, bar_plot, table_rows, notes_markdown, full_probs_numpy) """ if img is None: return ({"": 1.0}, None, [], "Please upload an image.", None) t0 = time.perf_counter() inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0] # shape [num_labels] probs = _softmax_with_temperature(logits, temperature) k = max(1, int(top_k)) k = min(k, probs.shape[0]) top_vals, top_idx = torch.topk(probs, k=k, dim=-1) top_idx = top_idx.tolist() top_vals = top_vals.tolist() labels = [LABELS[i] for i in top_idx] logits_top = [float(logits[i]) for i in top_idx] quick = {lab: float(p) for lab, p in zip(labels, top_vals)} fig = _make_bar(labels, top_vals) rows = [] for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1): rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)]) probs_np = probs.detach().cpu().numpy() H = _entropy(probs_np) top1 = float(top_vals[0]) top2 = float(top_vals[1]) if len(top_vals) > 1 else 0.0 margin = top1 - top2 cum_topk = float(sum(top_vals)) infer_ms = (time.perf_counter() - t0) * 1000.0 size = processor.size if hasattr(processor, "size") else {} target_h = size.get("height", None) target_w = size.get("width", None) size_str = f"{target_h}×{target_w}" if (target_h and target_w) else "model default" md = ( f"**Uncertainty** \n" f"- Entropy (lower→more confident): **{H:.3f} nats** \n" f"- Top-1 margin (Top-1 − Top-2): **{margin:.3f}** \n" f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n" f"**Preprocessing & Runtime** \n" f"- Processor target size: **{size_str}** \n" f"- Inference time: **{infer_ms:.1f} ms** \n" ) return quick, fig, rows, md, probs_np def _align_topk(labelsA, probsA, labelsB, probsB, K=5): """Make a unified label set of size up to K using union-of-top labels then rank by max(prob).""" dA = dict(zip(labelsA, probsA)) dB = dict(zip(labelsB, probsB)) union = set(labelsA) | set(labelsB) # rank by max(prob from A, prob from B) ranked = sorted(list(union), key=lambda x: max(dA.get(x, 0.0), dB.get(x, 0.0)), reverse=True) chosen = ranked[:K] a = [float(dA.get(l, 0.0)) for l in chosen] b = [float(dB.get(l, 0.0)) for l in chosen] return chosen, a, b def analyze_single(img, top_k=5, temperature=1.0): quick, fig, rows, md, _ = _analyze_single(img, top_k, temperature) return quick, fig, rows, md def analyze_pair(imgA, imgB, top_k=5, temperature=1.0): """ A/B analysis: - show per-image quick dict, bar chart, table, notes - show aligned Top-K delta bar and divergence metrics """ # Analyze each side qa, figa, rowsa, mda, pa = _analyze_single(imgA, top_k, temperature) qb, figb, rowsb, mdb, pb = _analyze_single(imgB, top_k, temperature) # If either missing, return as-is if pa is None or pb is None: return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, None, "Upload both images for delta metrics." # Build aligned top-K over labels # We need label sets and probs for both to compute aligned bars # Recover top-K labels directly from rows (rank, label, prob, logit) labelsA = [r[1] for r in rowsa] probsA = [r[2] for r in rowsa] labelsB = [r[1] for r in rowsb] probsB = [r[2] for r in rowsb] chosen, a, b = _align_topk(labelsA, probsA, labelsB, probsB, K=max(int(top_k), 1)) # Delta bar (A−B) deltas = [float(x - y) for x, y in zip(a, b)] fig_delta = _make_bar([f"{lbl} (Δ)" for lbl in chosen], deltas, title="Aligned Top-K Δ Probabilities (A − B)") # Distribution-level differences (full softmax vectors) # Ensure same length and normalize to prob distributions pa = pa / (pa.sum() + 1e-12) pb = pb / (pb.sum() + 1e-12) H_a = _entropy(pa) H_b = _entropy(pb) jsd = _jsd(pa, pb) # Top-1 labels for each side top1_a_idx = int(np.argmax(pa)) top1_b_idx = int(np.argmax(pb)) top1_a = LABELS[top1_a_idx] top1_b = LABELS[top1_b_idx] diff_md = ( f"**A/B Divergence** \n" f"- Jensen–Shannon divergence: **{jsd:.4f}** (0=same, higher=more different) \n" f"- Entropy A / B: **{H_a:.3f} / {H_b:.3f}** nats \n" f"- Top-1 A / B: **{top1_a} / {top1_b}** \n" f"- Aligned Top-K shown above is ranked by max(prob_A, prob_B). \n" f"_Tip:_ Try different crops/lighting or adjust **Temperature** to watch distributions change." ) return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, fig_delta, diff_md with gr.Blocks(fill_height=True, analytics_enabled=False) as demo: gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nVisualize probabilities, logits, entropy, and A/B deltas.\n\n" "_Notes:_ Predictions reflect the ImageNet‑1k label space; unusual objects or logos may be misclassified. " "Do not use for identity or sensitive inferences.") with gr.Tab("Single Image"): with gr.Row(): with gr.Column(scale=1): img = gr.Image(type="pil", label="Upload image (JPEG/PNG)", height=340) topk = gr.Slider(1, 10, value=5, step=1, label="Top‑K predictions") temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature") run = gr.Button("Analyze", variant="primary") with gr.Column(scale=1): gr.Markdown("### Quick glance") glance = gr.Label(num_top_classes=10) gr.Markdown("### Probabilities (Top‑K)") plot = gr.Plot() gr.Markdown("### Details (Top‑K)") table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5) gr.Markdown("### Metrics & Notes") notes = gr.Markdown() run.click(analyze_single, [img, topk, temp], [glance, plot, table, notes]) with gr.Tab("A/B Compare"): with gr.Row(): with gr.Column(scale=1): imgA = gr.Image(type="pil", label="Image A", height=300) imgB = gr.Image(type="pil", label="Image B", height=300) topkAB = gr.Slider(1, 10, value=5, step=1, label="Aligned Top‑K") tempAB = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature") runAB = gr.Button("Analyze A/B", variant="primary") with gr.Column(scale=1): gr.Markdown("### A — Quick glance") glanceA = gr.Label(num_top_classes=10) gr.Markdown("### A — Probabilities (Top‑K)") plotA = gr.Plot() gr.Markdown("### A — Details (Top‑K)") tableA = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5) gr.Markdown("### A — Notes") notesA = gr.Markdown() with gr.Column(scale=1): gr.Markdown("### B — Quick glance") glanceB = gr.Label(num_top_classes=10) gr.Markdown("### B — Probabilities (Top‑K)") plotB = gr.Plot() gr.Markdown("### B — Details (Top‑K)") tableB = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5) gr.Markdown("### B — Notes") notesB = gr.Markdown() with gr.Row(): gr.Markdown("### Aligned Top‑K Δ (A − B) & Divergence") with gr.Row(): deltaPlot = gr.Plot() deltaNotes = gr.Markdown() runAB.click(analyze_pair, [imgA, imgB, topkAB, tempAB], [glanceA, plotA, tableA, notesA, glanceB, plotB, tableB, notesB, deltaPlot, deltaNotes]) if __name__ == "__main__": demo.launch()