DrDavis's picture
Update app.py
bd4c84f verified
raw
history blame
10.1 kB
import time, math
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoImageProcessor, AutoModelForImageClassification
# --- Model choice (CPU-friendly) ---
MODEL_ID = "google/vit-base-patch16-224" # alternatives: "microsoft/resnet-50", "facebook/convnext-tiny-224"
# Load once at startup
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
LABELS = model.config.id2label # {idx: "label"}
def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""Softmax(logits / T) with numerical stability."""
if temperature <= 0:
temperature = 1.0
scaled = logits / float(temperature)
scaled = scaled - torch.max(scaled)
exp = torch.exp(scaled)
return exp / torch.sum(exp)
def _entropy(probs: np.ndarray) -> float:
"""Shannon entropy in nats: -Σ p log p (ignore zeros)."""
p = probs[probs > 0]
return float(-(p * np.log(p)).sum())
def _kl(p: np.ndarray, q: np.ndarray) -> float:
"""KL divergence KL(p||q), with small epsilon for stability."""
eps = 1e-12
p = p + eps
q = q + eps
return float(np.sum(p * np.log(p / q)))
def _jsd(p: np.ndarray, q: np.ndarray) -> float:
"""Jensen–Shannon divergence (symmetric, bounded)."""
m = 0.5 * (p + q)
return 0.5 * _kl(p, m) + 0.5 * _kl(q, m)
def _make_bar(labels, probs, title="Top-K predicted classes"):
"""Return a matplotlib horizontal bar chart of probabilities."""
fig, ax = plt.subplots(figsize=(6, 3.2))
y = np.arange(len(labels))
ax.barh(y, probs) # default colors only (per teaching tool rules)
ax.set_yticks(y, labels)
ax.invert_yaxis()
ax.set_xlim(0, 1)
ax.set_xlabel("Probability")
ax.set_title(title)
fig.tight_layout()
return fig
def _analyze_single(img, top_k=5, temperature=1.0):
"""
Return: (quick_label_dict, bar_plot, table_rows, notes_markdown, full_probs_numpy)
"""
if img is None:
return ({"<no image>": 1.0}, None, [], "Please upload an image.", None)
t0 = time.perf_counter()
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0] # shape [num_labels]
probs = _softmax_with_temperature(logits, temperature)
k = max(1, int(top_k))
k = min(k, probs.shape[0])
top_vals, top_idx = torch.topk(probs, k=k, dim=-1)
top_idx = top_idx.tolist()
top_vals = top_vals.tolist()
labels = [LABELS[i] for i in top_idx]
logits_top = [float(logits[i]) for i in top_idx]
quick = {lab: float(p) for lab, p in zip(labels, top_vals)}
fig = _make_bar(labels, top_vals)
rows = []
for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1):
rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)])
probs_np = probs.detach().cpu().numpy()
H = _entropy(probs_np)
top1 = float(top_vals[0])
top2 = float(top_vals[1]) if len(top_vals) > 1 else 0.0
margin = top1 - top2
cum_topk = float(sum(top_vals))
infer_ms = (time.perf_counter() - t0) * 1000.0
size = processor.size if hasattr(processor, "size") else {}
target_h = size.get("height", None)
target_w = size.get("width", None)
size_str = f"{target_h}×{target_w}" if (target_h and target_w) else "model default"
md = (
f"**Uncertainty** \n"
f"- Entropy (lower→more confident): **{H:.3f} nats** \n"
f"- Top-1 margin (Top-1 − Top-2): **{margin:.3f}** \n"
f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n"
f"**Preprocessing & Runtime** \n"
f"- Processor target size: **{size_str}** \n"
f"- Inference time: **{infer_ms:.1f} ms** \n"
)
return quick, fig, rows, md, probs_np
def _align_topk(labelsA, probsA, labelsB, probsB, K=5):
"""Make a unified label set of size up to K using union-of-top labels then rank by max(prob)."""
dA = dict(zip(labelsA, probsA))
dB = dict(zip(labelsB, probsB))
union = set(labelsA) | set(labelsB)
# rank by max(prob from A, prob from B)
ranked = sorted(list(union), key=lambda x: max(dA.get(x, 0.0), dB.get(x, 0.0)), reverse=True)
chosen = ranked[:K]
a = [float(dA.get(l, 0.0)) for l in chosen]
b = [float(dB.get(l, 0.0)) for l in chosen]
return chosen, a, b
def analyze_single(img, top_k=5, temperature=1.0):
quick, fig, rows, md, _ = _analyze_single(img, top_k, temperature)
return quick, fig, rows, md
def analyze_pair(imgA, imgB, top_k=5, temperature=1.0):
"""
A/B analysis:
- show per-image quick dict, bar chart, table, notes
- show aligned Top-K delta bar and divergence metrics
"""
# Analyze each side
qa, figa, rowsa, mda, pa = _analyze_single(imgA, top_k, temperature)
qb, figb, rowsb, mdb, pb = _analyze_single(imgB, top_k, temperature)
# If either missing, return as-is
if pa is None or pb is None:
return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, None, "Upload both images for delta metrics."
# Build aligned top-K over labels
# We need label sets and probs for both to compute aligned bars
# Recover top-K labels directly from rows (rank, label, prob, logit)
labelsA = [r[1] for r in rowsa]
probsA = [r[2] for r in rowsa]
labelsB = [r[1] for r in rowsb]
probsB = [r[2] for r in rowsb]
chosen, a, b = _align_topk(labelsA, probsA, labelsB, probsB, K=max(int(top_k), 1))
# Delta bar (A−B)
deltas = [float(x - y) for x, y in zip(a, b)]
fig_delta = _make_bar([f"{lbl} (Δ)" for lbl in chosen], deltas, title="Aligned Top-K Δ Probabilities (A − B)")
# Distribution-level differences (full softmax vectors)
# Ensure same length and normalize to prob distributions
pa = pa / (pa.sum() + 1e-12)
pb = pb / (pb.sum() + 1e-12)
H_a = _entropy(pa)
H_b = _entropy(pb)
jsd = _jsd(pa, pb)
# Top-1 labels for each side
top1_a_idx = int(np.argmax(pa))
top1_b_idx = int(np.argmax(pb))
top1_a = LABELS[top1_a_idx]
top1_b = LABELS[top1_b_idx]
diff_md = (
f"**A/B Divergence** \n"
f"- Jensen–Shannon divergence: **{jsd:.4f}** (0=same, higher=more different) \n"
f"- Entropy A / B: **{H_a:.3f} / {H_b:.3f}** nats \n"
f"- Top-1 A / B: **{top1_a} / {top1_b}** \n"
f"- Aligned Top-K shown above is ranked by max(prob_A, prob_B). \n"
f"_Tip:_ Try different crops/lighting or adjust **Temperature** to watch distributions change."
)
return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, fig_delta, diff_md
with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nVisualize probabilities, logits, entropy, and A/B deltas.\n\n"
"_Notes:_ Predictions reflect the ImageNet‑1k label space; unusual objects or logos may be misclassified. "
"Do not use for identity or sensitive inferences.")
with gr.Tab("Single Image"):
with gr.Row():
with gr.Column(scale=1):
img = gr.Image(type="pil", label="Upload image (JPEG/PNG)", height=340)
topk = gr.Slider(1, 10, value=5, step=1, label="Top‑K predictions")
temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
run = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### Quick glance")
glance = gr.Label(num_top_classes=10)
gr.Markdown("### Probabilities (Top‑K)")
plot = gr.Plot()
gr.Markdown("### Details (Top‑K)")
table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
gr.Markdown("### Metrics & Notes")
notes = gr.Markdown()
run.click(analyze_single, [img, topk, temp], [glance, plot, table, notes])
with gr.Tab("A/B Compare"):
with gr.Row():
with gr.Column(scale=1):
imgA = gr.Image(type="pil", label="Image A", height=300)
imgB = gr.Image(type="pil", label="Image B", height=300)
topkAB = gr.Slider(1, 10, value=5, step=1, label="Aligned Top‑K")
tempAB = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
runAB = gr.Button("Analyze A/B", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### A — Quick glance")
glanceA = gr.Label(num_top_classes=10)
gr.Markdown("### A — Probabilities (Top‑K)")
plotA = gr.Plot()
gr.Markdown("### A — Details (Top‑K)")
tableA = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
gr.Markdown("### A — Notes")
notesA = gr.Markdown()
with gr.Column(scale=1):
gr.Markdown("### B — Quick glance")
glanceB = gr.Label(num_top_classes=10)
gr.Markdown("### B — Probabilities (Top‑K)")
plotB = gr.Plot()
gr.Markdown("### B — Details (Top‑K)")
tableB = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
gr.Markdown("### B — Notes")
notesB = gr.Markdown()
with gr.Row():
gr.Markdown("### Aligned Top‑K Δ (A − B) & Divergence")
with gr.Row():
deltaPlot = gr.Plot()
deltaNotes = gr.Markdown()
runAB.click(analyze_pair, [imgA, imgB, topkAB, tempAB], [glanceA, plotA, tableA, notesA, glanceB, plotB, tableB, notesB, deltaPlot, deltaNotes])
if __name__ == "__main__":
demo.launch()