Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time, math
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 7 |
+
|
| 8 |
+
# --- Model choice (CPU-friendly) ---
|
| 9 |
+
MODEL_ID = "google/vit-base-patch16-224" # alternatives: "microsoft/resnet-50", "facebook/convnext-tiny-224"
|
| 10 |
+
|
| 11 |
+
# Load once at startup
|
| 12 |
+
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
|
| 13 |
+
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
|
| 14 |
+
model.eval()
|
| 15 |
+
LABELS = model.config.id2label # {idx: "label"}
|
| 16 |
+
|
| 17 |
+
def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
| 18 |
+
"""Softmax(logits / T) with numerical stability."""
|
| 19 |
+
if temperature <= 0:
|
| 20 |
+
temperature = 1.0
|
| 21 |
+
scaled = logits / float(temperature)
|
| 22 |
+
# subtract max for numerical stability
|
| 23 |
+
scaled = scaled - torch.max(scaled)
|
| 24 |
+
exp = torch.exp(scaled)
|
| 25 |
+
return exp / torch.sum(exp)
|
| 26 |
+
|
| 27 |
+
def _entropy(probs: np.ndarray) -> float:
|
| 28 |
+
"""Shannon entropy in nats: -Σ p log p (ignore zeros)."""
|
| 29 |
+
p = probs[probs > 0]
|
| 30 |
+
return float(-(p * np.log(p)).sum())
|
| 31 |
+
|
| 32 |
+
def _make_bar(labels, probs):
|
| 33 |
+
"""Return a matplotlib horizontal bar chart of top-K probabilities."""
|
| 34 |
+
fig, ax = plt.subplots(figsize=(6, 3.2))
|
| 35 |
+
y = np.arange(len(labels))
|
| 36 |
+
ax.barh(y, probs) # do not set colors (keep default)
|
| 37 |
+
ax.set_yticks(y, labels)
|
| 38 |
+
ax.invert_yaxis()
|
| 39 |
+
ax.set_xlim(0, 1)
|
| 40 |
+
ax.set_xlabel("Probability")
|
| 41 |
+
ax.set_title("Top-K predicted classes")
|
| 42 |
+
fig.tight_layout()
|
| 43 |
+
return fig
|
| 44 |
+
|
| 45 |
+
def analyze(img, top_k=5, temperature=1.0):
|
| 46 |
+
"""
|
| 47 |
+
Run the image through the classifier and expose:
|
| 48 |
+
- Top-K probabilities (bar chart + table)
|
| 49 |
+
- Pre-softmax logits for those Top-K
|
| 50 |
+
- Uncertainty metrics (entropy; top-1 margin; cumulative top-K)
|
| 51 |
+
- Preprocessing info and inference time
|
| 52 |
+
"""
|
| 53 |
+
if img is None:
|
| 54 |
+
return (
|
| 55 |
+
{"<no image>": 1.0}, # quick glance label block
|
| 56 |
+
None, # bar plot
|
| 57 |
+
[], # table rows
|
| 58 |
+
"Please upload an image.", # metrics markdown
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
t0 = time.perf_counter()
|
| 62 |
+
inputs = processor(images=img, return_tensors="pt")
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = model(**inputs)
|
| 65 |
+
logits = outputs.logits[0] # shape [num_labels]
|
| 66 |
+
# Temperature-scaled softmax
|
| 67 |
+
probs = _softmax_with_temperature(logits, temperature)
|
| 68 |
+
# Top-K
|
| 69 |
+
k = max(1, int(top_k))
|
| 70 |
+
k = min(k, probs.shape[0])
|
| 71 |
+
top_vals, top_idx = torch.topk(probs, k=k, dim=-1)
|
| 72 |
+
top_idx = top_idx.tolist()
|
| 73 |
+
top_vals = top_vals.tolist()
|
| 74 |
+
|
| 75 |
+
labels = [LABELS[i] for i in top_idx]
|
| 76 |
+
logits_top = [float(logits[i]) for i in top_idx]
|
| 77 |
+
|
| 78 |
+
# Quick glance dict for gr.Label (keeps students oriented)
|
| 79 |
+
quick = {lab: float(p) for lab, p in zip(labels, top_vals)}
|
| 80 |
+
|
| 81 |
+
# Bar chart
|
| 82 |
+
fig = _make_bar(labels, top_vals)
|
| 83 |
+
|
| 84 |
+
# Table rows (Rank, Label, Probability, Logit)
|
| 85 |
+
rows = []
|
| 86 |
+
for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1):
|
| 87 |
+
rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)])
|
| 88 |
+
|
| 89 |
+
# Metrics
|
| 90 |
+
probs_np = probs.detach().cpu().numpy()
|
| 91 |
+
H = _entropy(probs_np)
|
| 92 |
+
top1 = float(top_vals[0])
|
| 93 |
+
top2 = float(top_vals[1]) if len(top_vals) > 1 else 0.0
|
| 94 |
+
margin = top1 - top2
|
| 95 |
+
cum_topk = float(sum(top_vals))
|
| 96 |
+
infer_ms = (time.perf_counter() - t0) * 1000.0
|
| 97 |
+
|
| 98 |
+
# Preprocessing info (resize/crop) from processor config if available
|
| 99 |
+
size = processor.size if hasattr(processor, "size") else {}
|
| 100 |
+
target_h = size.get("height", None)
|
| 101 |
+
target_w = size.get("width", None)
|
| 102 |
+
size_str = f"{target_h}×{target_w}" if (target_h and target_w) else "model default"
|
| 103 |
+
|
| 104 |
+
md = (
|
| 105 |
+
f"**Uncertainty** \n"
|
| 106 |
+
f"- Entropy (lower→more confident): **{H:.3f} nats** \n"
|
| 107 |
+
f"- Top-1 margin (Top-1 − Top-2): **{margin:.3f}** \n"
|
| 108 |
+
f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n"
|
| 109 |
+
f"**Preprocessing & Runtime** \n"
|
| 110 |
+
f"- Processor target size: **{size_str}** \n"
|
| 111 |
+
f"- Inference time: **{infer_ms:.1f} ms** (CPU) \n\n"
|
| 112 |
+
f"_Tip:_ Adjust **Temperature** to watch softmax sharpen (T<1) or soften (T>1) the distribution."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return quick, fig, rows, md
|
| 116 |
+
|
| 117 |
+
with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
|
| 118 |
+
gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nUpload an image, then watch the classifier’s probabilities, logits, and uncertainty metrics update.")
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
with gr.Column(scale=1):
|
| 122 |
+
img = gr.Image(type="pil", label="Upload an image (JPEG/PNG)", height=360)
|
| 123 |
+
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K predictions")
|
| 124 |
+
temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
|
| 125 |
+
run = gr.Button("Analyze", variant="primary")
|
| 126 |
+
with gr.Column(scale=1):
|
| 127 |
+
gr.Markdown("### Quick glance")
|
| 128 |
+
glance = gr.Label(num_top_classes=10)
|
| 129 |
+
gr.Markdown("### Probabilities (Top-K)")
|
| 130 |
+
plot = gr.Plot()
|
| 131 |
+
gr.Markdown("### Details (Top-K)")
|
| 132 |
+
table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
|
| 133 |
+
gr.Markdown("### Metrics & Notes")
|
| 134 |
+
notes = gr.Markdown()
|
| 135 |
+
|
| 136 |
+
run.click(analyze, [img, topk, temp], [glance, plot, table, notes])
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
demo.launch()
|