DrDavis commited on
Commit
bd4c84f
·
verified ·
1 Parent(s): 99aecd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -47
app.py CHANGED
@@ -19,7 +19,6 @@ def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch
19
  if temperature <= 0:
20
  temperature = 1.0
21
  scaled = logits / float(temperature)
22
- # subtract max for numerical stability
23
  scaled = scaled - torch.max(scaled)
24
  exp = torch.exp(scaled)
25
  return exp / torch.sum(exp)
@@ -29,64 +28,60 @@ def _entropy(probs: np.ndarray) -> float:
29
  p = probs[probs > 0]
30
  return float(-(p * np.log(p)).sum())
31
 
32
- def _make_bar(labels, probs):
33
- """Return a matplotlib horizontal bar chart of top-K probabilities."""
 
 
 
 
 
 
 
 
 
 
 
 
34
  fig, ax = plt.subplots(figsize=(6, 3.2))
35
  y = np.arange(len(labels))
36
- ax.barh(y, probs) # do not set colors (keep default)
37
  ax.set_yticks(y, labels)
38
  ax.invert_yaxis()
39
  ax.set_xlim(0, 1)
40
  ax.set_xlabel("Probability")
41
- ax.set_title("Top-K predicted classes")
42
  fig.tight_layout()
43
  return fig
44
 
45
- def analyze(img, top_k=5, temperature=1.0):
46
  """
47
- Run the image through the classifier and expose:
48
- - Top-K probabilities (bar chart + table)
49
- - Pre-softmax logits for those Top-K
50
- - Uncertainty metrics (entropy; top-1 margin; cumulative top-K)
51
- - Preprocessing info and inference time
52
  """
53
  if img is None:
54
- return (
55
- {"<no image>": 1.0}, # quick glance label block
56
- None, # bar plot
57
- [], # table rows
58
- "Please upload an image.", # metrics markdown
59
- )
60
 
61
  t0 = time.perf_counter()
62
  inputs = processor(images=img, return_tensors="pt")
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
  logits = outputs.logits[0] # shape [num_labels]
66
- # Temperature-scaled softmax
67
  probs = _softmax_with_temperature(logits, temperature)
68
- # Top-K
69
  k = max(1, int(top_k))
70
  k = min(k, probs.shape[0])
71
  top_vals, top_idx = torch.topk(probs, k=k, dim=-1)
72
  top_idx = top_idx.tolist()
73
  top_vals = top_vals.tolist()
74
-
75
  labels = [LABELS[i] for i in top_idx]
76
  logits_top = [float(logits[i]) for i in top_idx]
77
 
78
- # Quick glance dict for gr.Label (keeps students oriented)
79
  quick = {lab: float(p) for lab, p in zip(labels, top_vals)}
80
-
81
- # Bar chart
82
  fig = _make_bar(labels, top_vals)
83
 
84
- # Table rows (Rank, Label, Probability, Logit)
85
  rows = []
86
  for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1):
87
  rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)])
88
 
89
- # Metrics
90
  probs_np = probs.detach().cpu().numpy()
91
  H = _entropy(probs_np)
92
  top1 = float(top_vals[0])
@@ -95,7 +90,6 @@ def analyze(img, top_k=5, temperature=1.0):
95
  cum_topk = float(sum(top_vals))
96
  infer_ms = (time.perf_counter() - t0) * 1000.0
97
 
98
- # Preprocessing info (resize/crop) from processor config if available
99
  size = processor.size if hasattr(processor, "size") else {}
100
  target_h = size.get("height", None)
101
  target_w = size.get("width", None)
@@ -108,32 +102,134 @@ def analyze(img, top_k=5, temperature=1.0):
108
  f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n"
109
  f"**Preprocessing & Runtime** \n"
110
  f"- Processor target size: **{size_str}** \n"
111
- f"- Inference time: **{infer_ms:.1f} ms** (CPU) \n\n"
112
- f"_Tip:_ Adjust **Temperature** to watch softmax sharpen (T<1) or soften (T>1) the distribution."
113
  )
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return quick, fig, rows, md
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
118
- gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nUpload an image, then watch the classifier’s probabilities, logits, and uncertainty metrics update.")
119
-
120
- with gr.Row():
121
- with gr.Column(scale=1):
122
- img = gr.Image(type="pil", label="Upload an image (JPEG/PNG)", height=360)
123
- topk = gr.Slider(1, 10, value=5, step=1, label="Top-K predictions")
124
- temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
125
- run = gr.Button("Analyze", variant="primary")
126
- with gr.Column(scale=1):
127
- gr.Markdown("### Quick glance")
128
- glance = gr.Label(num_top_classes=10)
129
- gr.Markdown("### Probabilities (Top-K)")
130
- plot = gr.Plot()
131
- gr.Markdown("### Details (Top-K)")
132
- table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
133
- gr.Markdown("### Metrics & Notes")
134
- notes = gr.Markdown()
135
-
136
- run.click(analyze, [img, topk, temp], [glance, plot, table, notes])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if __name__ == "__main__":
139
  demo.launch()
 
19
  if temperature <= 0:
20
  temperature = 1.0
21
  scaled = logits / float(temperature)
 
22
  scaled = scaled - torch.max(scaled)
23
  exp = torch.exp(scaled)
24
  return exp / torch.sum(exp)
 
28
  p = probs[probs > 0]
29
  return float(-(p * np.log(p)).sum())
30
 
31
+ def _kl(p: np.ndarray, q: np.ndarray) -> float:
32
+ """KL divergence KL(p||q), with small epsilon for stability."""
33
+ eps = 1e-12
34
+ p = p + eps
35
+ q = q + eps
36
+ return float(np.sum(p * np.log(p / q)))
37
+
38
+ def _jsd(p: np.ndarray, q: np.ndarray) -> float:
39
+ """Jensen–Shannon divergence (symmetric, bounded)."""
40
+ m = 0.5 * (p + q)
41
+ return 0.5 * _kl(p, m) + 0.5 * _kl(q, m)
42
+
43
+ def _make_bar(labels, probs, title="Top-K predicted classes"):
44
+ """Return a matplotlib horizontal bar chart of probabilities."""
45
  fig, ax = plt.subplots(figsize=(6, 3.2))
46
  y = np.arange(len(labels))
47
+ ax.barh(y, probs) # default colors only (per teaching tool rules)
48
  ax.set_yticks(y, labels)
49
  ax.invert_yaxis()
50
  ax.set_xlim(0, 1)
51
  ax.set_xlabel("Probability")
52
+ ax.set_title(title)
53
  fig.tight_layout()
54
  return fig
55
 
56
+ def _analyze_single(img, top_k=5, temperature=1.0):
57
  """
58
+ Return: (quick_label_dict, bar_plot, table_rows, notes_markdown, full_probs_numpy)
 
 
 
 
59
  """
60
  if img is None:
61
+ return ({"<no image>": 1.0}, None, [], "Please upload an image.", None)
 
 
 
 
 
62
 
63
  t0 = time.perf_counter()
64
  inputs = processor(images=img, return_tensors="pt")
65
  with torch.no_grad():
66
  outputs = model(**inputs)
67
  logits = outputs.logits[0] # shape [num_labels]
 
68
  probs = _softmax_with_temperature(logits, temperature)
69
+
70
  k = max(1, int(top_k))
71
  k = min(k, probs.shape[0])
72
  top_vals, top_idx = torch.topk(probs, k=k, dim=-1)
73
  top_idx = top_idx.tolist()
74
  top_vals = top_vals.tolist()
 
75
  labels = [LABELS[i] for i in top_idx]
76
  logits_top = [float(logits[i]) for i in top_idx]
77
 
 
78
  quick = {lab: float(p) for lab, p in zip(labels, top_vals)}
 
 
79
  fig = _make_bar(labels, top_vals)
80
 
 
81
  rows = []
82
  for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1):
83
  rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)])
84
 
 
85
  probs_np = probs.detach().cpu().numpy()
86
  H = _entropy(probs_np)
87
  top1 = float(top_vals[0])
 
90
  cum_topk = float(sum(top_vals))
91
  infer_ms = (time.perf_counter() - t0) * 1000.0
92
 
 
93
  size = processor.size if hasattr(processor, "size") else {}
94
  target_h = size.get("height", None)
95
  target_w = size.get("width", None)
 
102
  f"- Cumulative Top-K probability: **{cum_topk:.3f}** \n\n"
103
  f"**Preprocessing & Runtime** \n"
104
  f"- Processor target size: **{size_str}** \n"
105
+ f"- Inference time: **{infer_ms:.1f} ms** \n"
 
106
  )
107
 
108
+ return quick, fig, rows, md, probs_np
109
+
110
+ def _align_topk(labelsA, probsA, labelsB, probsB, K=5):
111
+ """Make a unified label set of size up to K using union-of-top labels then rank by max(prob)."""
112
+ dA = dict(zip(labelsA, probsA))
113
+ dB = dict(zip(labelsB, probsB))
114
+ union = set(labelsA) | set(labelsB)
115
+ # rank by max(prob from A, prob from B)
116
+ ranked = sorted(list(union), key=lambda x: max(dA.get(x, 0.0), dB.get(x, 0.0)), reverse=True)
117
+ chosen = ranked[:K]
118
+ a = [float(dA.get(l, 0.0)) for l in chosen]
119
+ b = [float(dB.get(l, 0.0)) for l in chosen]
120
+ return chosen, a, b
121
+
122
+ def analyze_single(img, top_k=5, temperature=1.0):
123
+ quick, fig, rows, md, _ = _analyze_single(img, top_k, temperature)
124
  return quick, fig, rows, md
125
 
126
+ def analyze_pair(imgA, imgB, top_k=5, temperature=1.0):
127
+ """
128
+ A/B analysis:
129
+ - show per-image quick dict, bar chart, table, notes
130
+ - show aligned Top-K delta bar and divergence metrics
131
+ """
132
+ # Analyze each side
133
+ qa, figa, rowsa, mda, pa = _analyze_single(imgA, top_k, temperature)
134
+ qb, figb, rowsb, mdb, pb = _analyze_single(imgB, top_k, temperature)
135
+
136
+ # If either missing, return as-is
137
+ if pa is None or pb is None:
138
+ return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, None, "Upload both images for delta metrics."
139
+
140
+ # Build aligned top-K over labels
141
+ # We need label sets and probs for both to compute aligned bars
142
+ # Recover top-K labels directly from rows (rank, label, prob, logit)
143
+ labelsA = [r[1] for r in rowsa]
144
+ probsA = [r[2] for r in rowsa]
145
+ labelsB = [r[1] for r in rowsb]
146
+ probsB = [r[2] for r in rowsb]
147
+ chosen, a, b = _align_topk(labelsA, probsA, labelsB, probsB, K=max(int(top_k), 1))
148
+
149
+ # Delta bar (A−B)
150
+ deltas = [float(x - y) for x, y in zip(a, b)]
151
+ fig_delta = _make_bar([f"{lbl} (Δ)" for lbl in chosen], deltas, title="Aligned Top-K Δ Probabilities (A − B)")
152
+
153
+ # Distribution-level differences (full softmax vectors)
154
+ # Ensure same length and normalize to prob distributions
155
+ pa = pa / (pa.sum() + 1e-12)
156
+ pb = pb / (pb.sum() + 1e-12)
157
+ H_a = _entropy(pa)
158
+ H_b = _entropy(pb)
159
+ jsd = _jsd(pa, pb)
160
+
161
+ # Top-1 labels for each side
162
+ top1_a_idx = int(np.argmax(pa))
163
+ top1_b_idx = int(np.argmax(pb))
164
+ top1_a = LABELS[top1_a_idx]
165
+ top1_b = LABELS[top1_b_idx]
166
+
167
+ diff_md = (
168
+ f"**A/B Divergence** \n"
169
+ f"- Jensen–Shannon divergence: **{jsd:.4f}** (0=same, higher=more different) \n"
170
+ f"- Entropy A / B: **{H_a:.3f} / {H_b:.3f}** nats \n"
171
+ f"- Top-1 A / B: **{top1_a} / {top1_b}** \n"
172
+ f"- Aligned Top-K shown above is ranked by max(prob_A, prob_B). \n"
173
+ f"_Tip:_ Try different crops/lighting or adjust **Temperature** to watch distributions change."
174
+ )
175
+
176
+ return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, fig_delta, diff_md
177
+
178
  with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
179
+ gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nVisualize probabilities, logits, entropy, and A/B deltas.\n\n"
180
+ "_Notes:_ Predictions reflect the ImageNet‑1k label space; unusual objects or logos may be misclassified. "
181
+ "Do not use for identity or sensitive inferences.")
182
+
183
+ with gr.Tab("Single Image"):
184
+ with gr.Row():
185
+ with gr.Column(scale=1):
186
+ img = gr.Image(type="pil", label="Upload image (JPEG/PNG)", height=340)
187
+ topk = gr.Slider(1, 10, value=5, step=1, label="Top‑K predictions")
188
+ temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
189
+ run = gr.Button("Analyze", variant="primary")
190
+ with gr.Column(scale=1):
191
+ gr.Markdown("### Quick glance")
192
+ glance = gr.Label(num_top_classes=10)
193
+ gr.Markdown("### Probabilities (Top‑K)")
194
+ plot = gr.Plot()
195
+ gr.Markdown("### Details (Top‑K)")
196
+ table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
197
+ gr.Markdown("### Metrics & Notes")
198
+ notes = gr.Markdown()
199
+ run.click(analyze_single, [img, topk, temp], [glance, plot, table, notes])
200
+
201
+ with gr.Tab("A/B Compare"):
202
+ with gr.Row():
203
+ with gr.Column(scale=1):
204
+ imgA = gr.Image(type="pil", label="Image A", height=300)
205
+ imgB = gr.Image(type="pil", label="Image B", height=300)
206
+ topkAB = gr.Slider(1, 10, value=5, step=1, label="Aligned Top‑K")
207
+ tempAB = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
208
+ runAB = gr.Button("Analyze A/B", variant="primary")
209
+ with gr.Column(scale=1):
210
+ gr.Markdown("### A — Quick glance")
211
+ glanceA = gr.Label(num_top_classes=10)
212
+ gr.Markdown("### A — Probabilities (Top‑K)")
213
+ plotA = gr.Plot()
214
+ gr.Markdown("### A — Details (Top‑K)")
215
+ tableA = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
216
+ gr.Markdown("### A — Notes")
217
+ notesA = gr.Markdown()
218
+ with gr.Column(scale=1):
219
+ gr.Markdown("### B — Quick glance")
220
+ glanceB = gr.Label(num_top_classes=10)
221
+ gr.Markdown("### B — Probabilities (Top‑K)")
222
+ plotB = gr.Plot()
223
+ gr.Markdown("### B — Details (Top‑K)")
224
+ tableB = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
225
+ gr.Markdown("### B — Notes")
226
+ notesB = gr.Markdown()
227
+ with gr.Row():
228
+ gr.Markdown("### Aligned Top‑K Δ (A − B) & Divergence")
229
+ with gr.Row():
230
+ deltaPlot = gr.Plot()
231
+ deltaNotes = gr.Markdown()
232
+ runAB.click(analyze_pair, [imgA, imgB, topkAB, tempAB], [glanceA, plotA, tableA, notesA, glanceB, plotB, tableB, notesB, deltaPlot, deltaNotes])
233
 
234
  if __name__ == "__main__":
235
  demo.launch()