PraneshJs commited on
Commit
34a6738
Β·
verified Β·
1 Parent(s): 64245e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -65
app.py CHANGED
@@ -1,9 +1,13 @@
1
  # ==========================================================
2
- # YOLOv8n Visualizer β€” "Inside Object Detection"
3
  # - Uses Ultralytics YOLOv8n (small, CPU-friendly)
4
- # - Shows detections + early/mid/late feature maps
5
- # - Simple vs Technical explanation
6
- # - Gradio 5 compatible, also OK on 6 (no theme arg)
 
 
 
 
7
  # ==========================================================
8
 
9
  import gradio as gr
@@ -31,29 +35,24 @@ def load_model():
31
  if MODEL is not None:
32
  return MODEL
33
 
34
- # This will download yolov8n.pt on first run and cache it
35
  model = YOLO("yolov8n.pt")
36
 
37
- # Ensure model on CPU
38
  if hasattr(model, "to"):
39
  model.to(DEVICE)
40
  else:
41
  model.model.to(DEVICE)
42
-
43
  model.model.eval()
44
 
45
  FEATURE_MAPS = {}
46
 
47
- # Register hooks on layers in the detection model
48
- # For YOLOv8, model.model.model is a list of blocks (backbone + head)
49
  for idx, layer in enumerate(model.model.model):
50
  def make_hook(name):
51
  def hook(module, inputs, output):
52
- # Handle tensors vs lists/tuples
53
  with torch.no_grad():
54
  out = output
55
  if isinstance(out, (list, tuple)):
56
- # pick first tensor-like element
57
  out = next(
58
  (o for o in out if isinstance(o, torch.Tensor)),
59
  None
@@ -73,14 +72,11 @@ def load_model():
73
  def tensor_to_heatmap(fm, out_size):
74
  """
75
  Convert a feature map tensor (C,H,W) to a grayscale heatmap PIL image.
76
- - average over channels
77
- - normalize to 0..1
78
- - resize to out_size (W,H)
79
  """
80
  if fm.ndim != 3:
81
  return None
82
 
83
- fm_np = fm.numpy().astype(np.float32) # (C,H,W)
84
  heat = fm_np.mean(axis=0) # (H,W)
85
 
86
  if not np.any(heat):
@@ -97,6 +93,22 @@ def tensor_to_heatmap(fm, out_size):
97
  return pil
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def pick_feature_maps():
101
  """
102
  Choose three feature maps: early, middle, late.
@@ -106,7 +118,6 @@ def pick_feature_maps():
106
  if not FEATURE_MAPS:
107
  return []
108
 
109
- # sort by numeric layer index
110
  keys = sorted(FEATURE_MAPS.keys(), key=lambda x: int(x))
111
  fms = []
112
  for k in keys:
@@ -126,6 +137,50 @@ def pick_feature_maps():
126
  return chosen
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # ------------------- MAIN ANALYSIS FUNCTION -------------------
130
 
131
  def analyze_yolo(img, conf_thres, iou_thres, simple_mode):
@@ -133,103 +188,177 @@ def analyze_yolo(img, conf_thres, iou_thres, simple_mode):
133
  Run YOLOv8n on input image and produce:
134
  - detection image with boxes
135
  - early/mid/late feature map heatmaps
136
- - explanation text (simple or technical)
 
 
137
  """
138
  if img is None:
139
  return (
140
- None, # detection image
141
- None, # early heatmap
142
- None, # mid heatmap
143
- None, # late heatmap
144
- "⚠️ Please upload an image first."
 
 
 
 
 
145
  )
146
 
147
  model = load_model()
148
-
149
- # Clear old feature maps before forward
150
  FEATURE_MAPS.clear()
151
 
152
- # Gradio gives PIL image (type="pil")
153
  pil = img
154
-
155
- # Configure thresholds
156
  conf = float(conf_thres)
157
  iou = float(iou_thres)
158
 
159
  with torch.no_grad():
160
- results = model(
161
- pil,
162
- conf=conf,
163
- iou=iou,
164
- verbose=False
165
- )
166
 
167
  res = results[0]
168
-
169
- # res.plot() returns numpy array (H,W,3), BGR by default, but visually OK
170
- det_np = res.plot()
171
  det_img = Image.fromarray(det_np)
172
 
173
- # Now FEATURE_MAPS should be filled by hooks
174
  chosen = pick_feature_maps()
175
  W, H = pil.size
176
  heatmaps = [None, None, None]
 
 
177
 
178
  for idx, item in enumerate(chosen):
179
- name, fm = item
180
  hm = tensor_to_heatmap(fm, (W, H))
181
  heatmaps[idx] = hm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Build explanation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if simple_mode:
185
  explanation = (
186
  "πŸ§’ **Simple explanation of what you see:**\n\n"
187
- "**Step 0 β€” Input image**: This is your original picture.\n\n"
188
- "**Step 1 β€” Early layer heatmap**:\n"
189
- "YOLO looks for very small details like edges, corners, and simple textures.\n\n"
190
- "**Step 2 β€” Middle layer heatmap**:\n"
191
- "It starts to see groups of pixels as shapes or parts of objects (like wheels, faces, etc.).\n\n"
192
- "**Step 3 β€” Late layer heatmap**:\n"
193
- "It focuses on whole objects and regions where it thinks something important is.\n\n"
194
- "**Step 4 β€” Final detections**:\n"
195
- "YOLO draws boxes and labels around what it believes are objects in the image.\n"
196
  )
197
  else:
198
  explanation = (
199
- "πŸ”¬ **Technical explanation of the visualization:**\n\n"
200
- "- We use **YOLOv8n** (Ultralytics) running on CPU.\n"
201
- "- Forward hooks capture intermediate feature maps from backbone/head blocks.\n"
202
- "- For each selected layer, we take the tensor `(C,H,W)` and average over channels to\n"
203
- " obtain a 2D activation map `(H,W)`, then normalize it and upsample it to `(W_img,H_img)`.\n"
204
- "- Early feature map β‰ˆ low-level features (edges, corners, local textures).\n"
205
- "- Middle feature map β‰ˆ mid-level features (object parts & shapes).\n"
206
- "- Late feature map β‰ˆ high-level features (object-centric regions that drive detection head).\n"
207
- "- The detection image is produced by YOLO's standard post-processing (objectness, class\n"
208
- " scores, and Non-Maximum Suppression on bounding boxes).\n"
 
 
209
  )
210
 
211
- # Add feature map shapes
212
  if chosen:
213
  explanation += "\n**Captured feature map shapes (C,H,W):**\n"
214
  for name, fm in chosen:
215
  explanation += f"- Layer {name}: {tuple(fm.shape)}\n"
216
 
217
- return det_img, heatmaps[0], heatmaps[1], heatmaps[2], explanation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
 
220
  # ------------------- GRADIO UI -------------------
221
 
222
- with gr.Blocks(title="YOLOv8n Visualizer β€” Inside Object Detection") as demo:
223
 
224
- gr.Markdown("# 🧠 YOLOv8n Visualizer β€” Inside Object Detection")
225
  gr.Markdown(
226
- "See what happens **inside** an object detection model.\n\n"
227
  "**Steps shown:**\n"
228
  "- **Step 0** β€” Input image\n"
229
  "- **Step 1** β€” Early layer activation (edges & textures)\n"
230
  "- **Step 2** β€” Middle layer activation (parts & shapes)\n"
231
  "- **Step 3** β€” Late layer activation (objects)\n"
232
  "- **Step 4** β€” Final detections (boxes & labels)\n"
 
 
233
  )
234
 
235
  with gr.Row():
@@ -253,7 +382,7 @@ with gr.Blocks(title="YOLOv8n Visualizer β€” Inside Object Detection") as demo:
253
  label="IoU threshold (NMS)"
254
  )
255
  simple_ck = gr.Checkbox(
256
- label="Explain in simple terms (for kids/elders)",
257
  value=True
258
  )
259
  run_btn = gr.Button("Run YOLO & Visualize", variant="primary")
@@ -263,6 +392,10 @@ with gr.Blocks(title="YOLOv8n Visualizer β€” Inside Object Detection") as demo:
263
  label="Step 4 β€” Final detections (YOLOv8n)",
264
  interactive=False
265
  )
 
 
 
 
266
  explanation_md = gr.Markdown(label="Explanation")
267
 
268
  gr.Markdown("### πŸ” Steps 1–3: internal feature maps (what the network focuses on)")
@@ -281,10 +414,44 @@ with gr.Blocks(title="YOLOv8n Visualizer β€” Inside Object Detection") as demo:
281
  interactive=False
282
  )
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  run_btn.click(
285
  analyze_yolo,
286
  inputs=[in_img, conf_slider, iou_slider, simple_ck],
287
- outputs=[out_det, fm1, fm2, fm3, explanation_md]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  )
289
 
290
  demo.launch()
 
1
  # ==========================================================
2
+ # YOLOv8n Visualizer β€” Inside Object Detection (Advanced)
3
  # - Uses Ultralytics YOLOv8n (small, CPU-friendly)
4
+ # - Step 0: Input image
5
+ # - Step 1: Early feature activation (edges/textures)
6
+ # - Step 2: Middle feature activation (parts/shapes)
7
+ # - Step 3: Late feature activation (objects)
8
+ # - Step 4: Final detections (boxes + labels)
9
+ # - Activation-CAM overlay (late layer heatmap on image)
10
+ # - Channel explorer for late layer (view individual channels)
11
  # ==========================================================
12
 
13
  import gradio as gr
 
35
  if MODEL is not None:
36
  return MODEL
37
 
 
38
  model = YOLO("yolov8n.pt")
39
 
40
+ # ensure on CPU
41
  if hasattr(model, "to"):
42
  model.to(DEVICE)
43
  else:
44
  model.model.to(DEVICE)
 
45
  model.model.eval()
46
 
47
  FEATURE_MAPS = {}
48
 
49
+ # model.model.model is the list of modules (backbone + head)
 
50
  for idx, layer in enumerate(model.model.model):
51
  def make_hook(name):
52
  def hook(module, inputs, output):
 
53
  with torch.no_grad():
54
  out = output
55
  if isinstance(out, (list, tuple)):
 
56
  out = next(
57
  (o for o in out if isinstance(o, torch.Tensor)),
58
  None
 
72
  def tensor_to_heatmap(fm, out_size):
73
  """
74
  Convert a feature map tensor (C,H,W) to a grayscale heatmap PIL image.
 
 
 
75
  """
76
  if fm.ndim != 3:
77
  return None
78
 
79
+ fm_np = fm.numpy().astype(np.float32)
80
  heat = fm_np.mean(axis=0) # (H,W)
81
 
82
  if not np.any(heat):
 
93
  return pil
94
 
95
 
96
+ def heat_array_from_fm(fm):
97
+ """
98
+ Same as tensor_to_heatmap but returns 0..1 numpy array (H,W).
99
+ """
100
+ fm_np = fm.numpy().astype(np.float32)
101
+ heat = fm_np.mean(axis=0)
102
+ if not np.any(heat):
103
+ heat = np.zeros_like(heat)
104
+ else:
105
+ heat -= heat.min()
106
+ maxv = heat.max()
107
+ if maxv > 0:
108
+ heat /= maxv
109
+ return heat
110
+
111
+
112
  def pick_feature_maps():
113
  """
114
  Choose three feature maps: early, middle, late.
 
118
  if not FEATURE_MAPS:
119
  return []
120
 
 
121
  keys = sorted(FEATURE_MAPS.keys(), key=lambda x: int(x))
122
  fms = []
123
  for k in keys:
 
137
  return chosen
138
 
139
 
140
+ def make_cam_overlay(base_pil, heat_01):
141
+ """
142
+ Build a simple activation-CAM overlay (heatmap over image).
143
+ heat_01: numpy (H_fm, W_fm) in [0,1], resized to image size.
144
+ """
145
+ base = np.array(base_pil).astype(np.float32) / 255.0 # H,W,3
146
+
147
+ h, w = base.shape[:2]
148
+ heat_resized = Image.fromarray((heat_01 * 255).astype("uint8"), mode="L").resize(
149
+ (w, h), Image.BILINEAR
150
+ )
151
+ heat_resized = np.array(heat_resized).astype(np.float32) / 255.0 # H,W
152
+
153
+ # simple blue→red colormap
154
+ r = heat_resized
155
+ g = np.zeros_like(heat_resized)
156
+ b = 1.0 - heat_resized
157
+ cam = np.stack([r, g, b], axis=-1) # H,W,3
158
+
159
+ alpha = 0.45
160
+ blended = (1 - alpha) * base + alpha * cam
161
+ blended = np.clip(blended * 255.0, 0, 255).astype("uint8")
162
+ return Image.fromarray(blended)
163
+
164
+
165
+ def single_channel_heatmap(channel_2d, out_size):
166
+ """
167
+ Convert 2D channel to grayscale PIL heatmap.
168
+ """
169
+ arr = channel_2d.astype(np.float32)
170
+ if not np.any(arr):
171
+ arr = np.zeros_like(arr)
172
+ else:
173
+ arr -= arr.min()
174
+ maxv = arr.max()
175
+ if maxv > 0:
176
+ arr /= maxv
177
+
178
+ img = (arr * 255).astype("uint8")
179
+ pil = Image.fromarray(img, mode="L")
180
+ pil = pil.resize(out_size, Image.NEAREST)
181
+ return pil
182
+
183
+
184
  # ------------------- MAIN ANALYSIS FUNCTION -------------------
185
 
186
  def analyze_yolo(img, conf_thres, iou_thres, simple_mode):
 
188
  Run YOLOv8n on input image and produce:
189
  - detection image with boxes
190
  - early/mid/late feature map heatmaps
191
+ - activation-CAM overlay
192
+ - channel explorer state
193
+ - explanation markdown
194
  """
195
  if img is None:
196
  return (
197
+ None, # det img
198
+ None, # early
199
+ None, # mid
200
+ None, # late
201
+ None, # cam overlay
202
+ "⚠️ Please upload an image first.",
203
+ "", # channel info
204
+ gr.update(maximum=0, value=0),
205
+ None, # channel heatmap
206
+ {} # state
207
  )
208
 
209
  model = load_model()
 
 
210
  FEATURE_MAPS.clear()
211
 
 
212
  pil = img
 
 
213
  conf = float(conf_thres)
214
  iou = float(iou_thres)
215
 
216
  with torch.no_grad():
217
+ results = model(pil, conf=conf, iou=iou, verbose=False)
 
 
 
 
 
218
 
219
  res = results[0]
220
+ det_np = res.plot() # numpy HWC
 
 
221
  det_img = Image.fromarray(det_np)
222
 
 
223
  chosen = pick_feature_maps()
224
  W, H = pil.size
225
  heatmaps = [None, None, None]
226
+ late_fm_np = None
227
+ late_name = None
228
 
229
  for idx, item in enumerate(chosen):
230
+ name, fm = item # fm: (C,H,W)
231
  hm = tensor_to_heatmap(fm, (W, H))
232
  heatmaps[idx] = hm
233
+ if idx == len(chosen) - 1:
234
+ late_fm_np = fm.numpy().astype(np.float32) # (C,H,W)
235
+ late_name = name
236
+
237
+ # Activation-CAM overlay (using late feature map mean)
238
+ cam_overlay = None
239
+ channel_slider_update = gr.update(maximum=0, value=0)
240
+ channel_info = ""
241
+ channel_heatmap_img = None
242
+ state = {}
243
+
244
+ if late_fm_np is not None:
245
+ C, H_fm, W_fm = late_fm_np.shape
246
+ late_fm_tensor = torch.from_numpy(late_fm_np)
247
+ heat_01 = heat_array_from_fm(late_fm_tensor)
248
+ cam_overlay = make_cam_overlay(pil, heat_01)
249
+
250
+ # Channel explorer: compute mean abs activation per channel
251
+ means = np.mean(np.abs(late_fm_np), axis=(1, 2)) # (C,)
252
+ order = np.argsort(means)[::-1]
253
+ top_k = order[: min(8, C)].tolist()
254
+
255
+ channel_info = (
256
+ f"Late layer **{late_name}** feature map: {C} channels of size {H_fm}Γ—{W_fm}.\n"
257
+ f"Top active channels (by mean |activation|): {top_k}"
258
+ )
259
 
260
+ # default channel = strongest
261
+ default_ch = int(top_k[0]) if top_k else 0
262
+ channel_slider_update = gr.update(maximum=C - 1, value=default_ch)
263
+
264
+ # build heatmap for default channel
265
+ default_ch_map = late_fm_np[default_ch]
266
+ channel_heatmap_img = single_channel_heatmap(default_ch_map, (W, H))
267
+
268
+ # state for slider changes
269
+ state = {
270
+ "late_fm": late_fm_np,
271
+ "W": W,
272
+ "H": H,
273
+ }
274
+
275
+ # Explanation
276
  if simple_mode:
277
  explanation = (
278
  "πŸ§’ **Simple explanation of what you see:**\n\n"
279
+ "- **Step 0 – Input image:** your original picture.\n"
280
+ "- **Step 1 – Early layer heatmap:** the model sees edges and tiny details.\n"
281
+ "- **Step 2 – Middle layer heatmap:** it starts seeing parts of objects and shapes.\n"
282
+ "- **Step 3 – Late layer heatmap:** it focuses on full objects and important regions.\n"
283
+ "- **Activation overlay:** colored map (blue→red) over the image showing *where* the model\n"
284
+ " is looking the most in the final stage.\n"
285
+ "- **Channel explorer:** each channel is like a tiny specialist (e.g., vertical lines,\n"
286
+ " corners, or specific textures). You can slide through channels to see different patterns.\n"
 
287
  )
288
  else:
289
  explanation = (
290
+ "πŸ”¬ **Technical explanation:**\n\n"
291
+ "- We run **YOLOv8n** (Ultralytics) on CPU.\n"
292
+ "- Forward hooks capture internal feature maps from several backbone/head blocks.\n"
293
+ "- For each chosen layer, we take `(C,H,W)` and average over channels to get a 2D activation\n"
294
+ " map `(H,W)`, normalize it, and upsample it to image resolution.\n"
295
+ "- Early β‰ˆ low-level features; Middle β‰ˆ mid-level parts; Late β‰ˆ high-level object-centric\n"
296
+ " features.\n"
297
+ "- The activation overlay is a CAM-style visualization built from the **mean late-layer\n"
298
+ " activation**, colored and blended with the original image (not full gradient-based Grad-CAM,\n"
299
+ " but an activation-based approximation).\n"
300
+ "- In the channel explorer, channels are ranked by mean |activation|, and you can inspect each\n"
301
+ " channel separately as a grayscale map, revealing different spatial patterns.\n"
302
  )
303
 
304
+ # Add feature map shapes if we have them
305
  if chosen:
306
  explanation += "\n**Captured feature map shapes (C,H,W):**\n"
307
  for name, fm in chosen:
308
  explanation += f"- Layer {name}: {tuple(fm.shape)}\n"
309
 
310
+ return (
311
+ det_img,
312
+ heatmaps[0],
313
+ heatmaps[1],
314
+ heatmaps[2],
315
+ cam_overlay,
316
+ explanation,
317
+ channel_info,
318
+ channel_slider_update,
319
+ channel_heatmap_img,
320
+ state,
321
+ )
322
+
323
+
324
+ # ------------------- CHANNEL SLIDER UPDATE -------------------
325
+
326
+ def update_channel(state, ch_idx):
327
+ """
328
+ When slider moves, update the channel heatmap (late layer).
329
+ """
330
+ if not state or "late_fm" not in state:
331
+ return gr.update(value=None)
332
+
333
+ late_fm = state["late_fm"] # (C,H,W)
334
+ W = state["W"]
335
+ H = state["H"]
336
+
337
+ C = late_fm.shape[0]
338
+ idx = int(ch_idx)
339
+ if idx < 0 or idx >= C:
340
+ idx = 0
341
+
342
+ ch_map = late_fm[idx]
343
+ img = single_channel_heatmap(ch_map, (W, H))
344
+ return gr.update(value=img)
345
 
346
 
347
  # ------------------- GRADIO UI -------------------
348
 
349
+ with gr.Blocks(title="YOLOv8n Visualizer β€” Inside Object Detection (Advanced)") as demo:
350
 
351
+ gr.Markdown("# 🧠 YOLOv8n Visualizer β€” Inside Object Detection (Advanced)")
352
  gr.Markdown(
353
+ "Explore what happens **inside** an object detection model.\n\n"
354
  "**Steps shown:**\n"
355
  "- **Step 0** β€” Input image\n"
356
  "- **Step 1** β€” Early layer activation (edges & textures)\n"
357
  "- **Step 2** β€” Middle layer activation (parts & shapes)\n"
358
  "- **Step 3** β€” Late layer activation (objects)\n"
359
  "- **Step 4** β€” Final detections (boxes & labels)\n"
360
+ "- **Activation overlay** β€” CAM-style heatmap over the image\n"
361
+ "- **Channel explorer** β€” inspect individual channels in the late layer\n"
362
  )
363
 
364
  with gr.Row():
 
382
  label="IoU threshold (NMS)"
383
  )
384
  simple_ck = gr.Checkbox(
385
+ label="Explain in simple terms (kids/elders)",
386
  value=True
387
  )
388
  run_btn = gr.Button("Run YOLO & Visualize", variant="primary")
 
392
  label="Step 4 β€” Final detections (YOLOv8n)",
393
  interactive=False
394
  )
395
+ cam_img = gr.Image(
396
+ label="Activation overlay (late layer focus)",
397
+ interactive=False
398
+ )
399
  explanation_md = gr.Markdown(label="Explanation")
400
 
401
  gr.Markdown("### πŸ” Steps 1–3: internal feature maps (what the network focuses on)")
 
414
  interactive=False
415
  )
416
 
417
+ gr.Markdown("### πŸ”¬ Channel explorer (late layer)")
418
+
419
+ channel_info_md = gr.Markdown()
420
+ channel_slider = gr.Slider(
421
+ minimum=0,
422
+ maximum=0,
423
+ step=1,
424
+ value=0,
425
+ label="Channel index (late layer)"
426
+ )
427
+ channel_heatmap = gr.Image(
428
+ label="Selected channel heatmap (grayscale)",
429
+ interactive=False
430
+ )
431
+
432
+ state = gr.State()
433
+
434
  run_btn.click(
435
  analyze_yolo,
436
  inputs=[in_img, conf_slider, iou_slider, simple_ck],
437
+ outputs=[
438
+ out_det,
439
+ fm1,
440
+ fm2,
441
+ fm3,
442
+ cam_img,
443
+ explanation_md,
444
+ channel_info_md,
445
+ channel_slider,
446
+ channel_heatmap,
447
+ state,
448
+ ],
449
+ )
450
+
451
+ channel_slider.change(
452
+ update_channel,
453
+ inputs=[state, channel_slider],
454
+ outputs=[channel_heatmap],
455
  )
456
 
457
  demo.launch()