DrDavis commited on
Commit
bca9b06
·
verified ·
1 Parent(s): b8278cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py CHANGED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, math
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
7
+
8
+ # --- Model choice (CPU-friendly) ---
9
+ MODEL_ID = "google/vit-base-patch16-224" # alternatives: "microsoft/resnet-50", "facebook/convnext-tiny-224"
10
+
11
+ # Load once at startup
12
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
13
+ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
14
+ model.eval()
15
+ LABELS = model.config.id2label # {idx: "label"}
16
+
17
+ def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
18
+ """Softmax(logits / T) with numerical stability."""
19
+ if temperature <= 0:
20
+ temperature = 1.0
21
+ scaled = logits / float(temperature)
22
+ # subtract max for numerical stability
23
+ scaled = scaled - torch.max(scaled)
24
+ exp = torch.exp(scaled)
25
+ return exp / torch.sum(exp)
26
+
27
+ def _entropy(probs: np.ndarray) -> float:
28
+ """Shannon entropy in nats: -Σ p log p (ignore zeros)."""
29
+ p = probs[probs > 0]
30
+ return float(-(p * np.log(p)).sum())
31
+
32
+ def _make_bar(labels, probs):
33
+ """Return a matplotlib horizontal bar chart of top-K probabilities."""
34
+ fig, ax = plt.subplots(figsize=(6, 3.2))
35
+ y = np.arange(len(labels))
36
+ ax.barh(y, probs) # do not set colors (keep default)
37
+ ax.set_yticks(y, labels)
38
+ ax.invert_yaxis()
39
+ ax.set_xlim(0, 1)
40
+ ax.set_xlabel("Probability")
41
+ ax.set_title("Top-K predicted classes")
42
+ fig.tight_layout()
43
+ return fig
44
+
45
+ def analyze(img, top_k=5, temperature=1.0):
46
+ """
47
+ Run the image through the classifier and expose:
48
+ - Top-K probabilities (bar chart + table)
49
+ - Pre-softmax logits for those Top-K
50
+ - Uncertainty metrics (entropy; top-1 margin; cumulative top-K)
51
+ - Preprocessing info and inference time
52
+ """
53
+ if img is None:
54
+ return (
55
+ {"<no image>": 1.0}, # quick glance label block
56
+ None, # bar plot
57
+ [], # table rows
58
+ "Please upload an image.", # metrics markdown
59
+ )
60
+
61
+ t0 = time.perf_counter()
62
+ inputs = processor(images=img, return_tensors="pt")
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ logits = outputs.logits[0] # shape [num_labels]
66
+ # Temperature-scaled softmax
67
+ probs = _softmax_with_temperature(logits, temperature)
68
+ # Top-K
69
+ k = max(1, int(top_k))
70
+ k = min(k, probs.shape[0])
71
+ top_vals, top_idx = torch.topk(probs, k=k, dim=-1)
72
+ top_idx = top_idx.tolist()
73
+ top_vals = top_vals.tolist()
74
+
75
+ labels = [LABELS[i] for i in top_idx]
76
+ logits_top = [float(logits[i]) for i in top_idx]
77
+
78
+ # Quick glance dict for gr.Label (keeps students oriented)
79
+ quick = {lab: float(p) for lab, p in zip(labels, top_vals)}
80
+
81
+ # Bar chart
82
+ fig = _make_bar(labels, top_vals)
83
+
84
+ # Table rows (Rank, Label, Probability, Logit)
85
+ rows = []
86
+ for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1):
87
+ rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)])
88
+
89
+ # Metrics
90
+ probs_np = probs.detach().cpu().numpy()
91
+ H = _entropy(probs_np)
92
+ top1 = float(top_vals[0])
93
+ top2 = float(top_vals[1]) if len(top_vals) > 1 else 0.0
94
+ margin = top1 - top2
95
+ cum_topk = float(sum(top_vals))
96
+ infer_ms = (time.perf_counter() - t0) * 1000.0
97
+
98
+ # Preprocessing info (resize/crop) from processor config if available
99
+ size = processor.size if hasattr(processor, "size") else {}
100
+ target_h = size.get("height", None)
101
+ target_w = size.get("width", None)
102
+ size_str = f"{target_h}×{target_w}" if (target_h and target_w) else "model default"
103
+
104
+ md = (
105
+ f"**Uncertainty** \n"
106
+ f"- Entropy (lower→more confident): **{H:.3f} nats** \n"
107
+ f"- Top-1 margin (Top-1 − Top-2): **{margin:.3f}** \n"
108
+ f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n"
109
+ f"**Preprocessing & Runtime** \n"
110
+ f"- Processor target size: **{size_str}** \n"
111
+ f"- Inference time: **{infer_ms:.1f} ms** (CPU) \n\n"
112
+ f"_Tip:_ Adjust **Temperature** to watch softmax sharpen (T<1) or soften (T>1) the distribution."
113
+ )
114
+
115
+ return quick, fig, rows, md
116
+
117
+ with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
118
+ gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nUpload an image, then watch the classifier’s probabilities, logits, and uncertainty metrics update.")
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=1):
122
+ img = gr.Image(type="pil", label="Upload an image (JPEG/PNG)", height=360)
123
+ topk = gr.Slider(1, 10, value=5, step=1, label="Top-K predictions")
124
+ temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
125
+ run = gr.Button("Analyze", variant="primary")
126
+ with gr.Column(scale=1):
127
+ gr.Markdown("### Quick glance")
128
+ glance = gr.Label(num_top_classes=10)
129
+ gr.Markdown("### Probabilities (Top-K)")
130
+ plot = gr.Plot()
131
+ gr.Markdown("### Details (Top-K)")
132
+ table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
133
+ gr.Markdown("### Metrics & Notes")
134
+ notes = gr.Markdown()
135
+
136
+ run.click(analyze, [img, topk, temp], [glance, plot, table, notes])
137
+
138
+ if __name__ == "__main__":
139
+ demo.launch()