PraneshJs commited on
Commit
64245e7
Β·
verified Β·
1 Parent(s): de061d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -108
app.py CHANGED
@@ -1,8 +1,9 @@
1
  # ==========================================================
2
- # YOLOv5n Visualizer β€” "Inside Object Detection"
3
- # - Uses small YOLOv5n (CPU-friendly)
4
  # - Shows detections + early/mid/late feature maps
5
- # - Gradio 5 compatible (theme supported)
 
6
  # ==========================================================
7
 
8
  import gradio as gr
@@ -10,11 +11,11 @@ import torch
10
  import numpy as np
11
  from PIL import Image
12
 
 
 
13
  # ------------------- GLOBALS -------------------
14
 
15
- MODEL_NAME = "yolov5n" # smallest YOLOv5 model (fast & light)
16
  DEVICE = "cpu"
17
-
18
  MODEL = None
19
  FEATURE_MAPS = {} # {layer_name: tensor(B,C,H,W)}
20
 
@@ -23,34 +24,45 @@ FEATURE_MAPS = {} # {layer_name: tensor(B,C,H,W)}
23
 
24
  def load_model():
25
  """
26
- Load YOLOv5n from torch.hub (ultralytics/yolov5) and
27
- register forward hooks to capture internal feature maps.
28
  """
29
  global MODEL, FEATURE_MAPS
30
  if MODEL is not None:
31
  return MODEL
32
 
33
- # Download and load YOLOv5n from GitHub (only on first run)
34
- # repo 'ultralytics/yolov5' must be reachable during build/first call.
35
- model = torch.hub.load("ultralytics/yolov5", MODEL_NAME, pretrained=True)
36
- model.to(DEVICE)
37
- model.eval()
38
 
39
- FEATURE_MAPS = {}
 
 
 
 
40
 
41
- def make_hook(name):
42
- def hook(module, input, output):
43
- # YOLO can run on GPU or CPU but we store CPU tensors for visualization
44
- with torch.no_grad():
45
- FEATURE_MAPS[name] = output.detach().cpu()
46
- return hook
47
 
48
- # Register hooks on some main layers in the YOLOv5 backbone/head
49
- # We choose Conv / C3 / SPPF etc. so we can show early, mid, late stages.
50
- for idx, m in enumerate(model.model):
51
- cls_name = m.__class__.__name__
52
- if cls_name in ["Conv", "C3", "Bottleneck", "BottleneckCSP", "SPPF"]:
53
- m.register_forward_hook(make_hook(str(idx)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  MODEL = model
56
  return MODEL
@@ -61,52 +73,56 @@ def load_model():
61
  def tensor_to_heatmap(fm, out_size):
62
  """
63
  Convert a feature map tensor (C,H,W) to a grayscale heatmap PIL image.
64
- Steps:
65
  - average over channels
66
  - normalize to 0..1
67
- - upscale to out_size
68
  """
69
  if fm.ndim != 3:
70
  return None
71
 
72
  fm_np = fm.numpy().astype(np.float32) # (C,H,W)
73
- # average over channels -> (H,W)
74
- heat = fm_np.mean(axis=0)
75
 
76
- if np.allclose(heat, 0):
77
  heat = np.zeros_like(heat)
78
  else:
79
- heat = heat - heat.min()
80
  maxv = heat.max()
81
  if maxv > 0:
82
- heat = heat / maxv
83
 
84
- heat_img = (heat * 255).astype("uint8")
85
- pil = Image.fromarray(heat_img, mode="L")
86
  pil = pil.resize(out_size, Image.NEAREST)
87
  return pil
88
 
89
 
90
  def pick_feature_maps():
91
  """
92
- After a forward pass, FEATURE_MAPS has many layers.
93
- We pick up to 3 layers: early, middle, late.
94
- Returns: list of (name, tensor(C,H,W))
95
  """
96
  if not FEATURE_MAPS:
97
  return []
98
 
99
- # keys are layer indices as strings: "0", "1", "4", ...
100
  keys = sorted(FEATURE_MAPS.keys(), key=lambda x: int(x))
101
- fms = [FEATURE_MAPS[k][0] for k in keys] # take batch 0
 
 
 
 
 
 
 
102
 
103
- # pick early, mid, late
104
  idxs = [0, len(fms) // 2, len(fms) - 1]
105
- idxs = sorted(list(set(idxs))) # remove duplicate indices
106
 
107
  chosen = []
108
  for i in idxs:
109
- chosen.append((keys[i], fms[i]))
110
  return chosen
111
 
112
 
@@ -114,101 +130,106 @@ def pick_feature_maps():
114
 
115
  def analyze_yolo(img, conf_thres, iou_thres, simple_mode):
116
  """
117
- Run YOLO on the input image and return:
118
- - detection overlay image
119
- - early feature map heatmap
120
- - mid feature map heatmap
121
- - late feature map heatmap
122
- - explanation markdown
123
  """
124
  if img is None:
125
  return (
126
- None, # det img
127
- None, # early fm
128
- None, # mid fm
129
- None, # late fm
130
  "⚠️ Please upload an image first."
131
  )
132
 
133
  model = load_model()
134
 
135
- # Clear old feature maps
136
  FEATURE_MAPS.clear()
137
 
138
- # In Gradio, `type="pil"` gives a PIL image already
139
  pil = img
140
 
141
  # Configure thresholds
142
- model.conf = float(conf_thres)
143
- model.iou = float(iou_thres)
144
 
145
  with torch.no_grad():
146
- results = model(pil)
 
 
 
 
 
147
 
148
- # YOLOv5 .render() draws boxes and labels on the image
149
- rendered = results.render()[0] # numpy array (H,W,C)
150
- det_img = Image.fromarray(rendered)
151
 
152
- # Collect feature maps from hooks
153
- chosen_fms = pick_feature_maps()
 
 
 
 
154
  W, H = pil.size
155
- heatmaps = [None, None, None] # early, mid, late
156
 
157
- for idx, item in enumerate(chosen_fms):
158
  name, fm = item
159
  hm = tensor_to_heatmap(fm, (W, H))
160
  heatmaps[idx] = hm
161
 
162
- # Build readable explanation
163
  if simple_mode:
164
  explanation = (
165
  "πŸ§’ **Simple explanation of what you see:**\n\n"
166
- "1. YOLO first looks at your image and tries to find basic patterns like edges and corners.\n"
167
- "2. Then it builds more complex shapes (like parts of objects: wheels, faces, etc.).\n"
168
- "3. In the last layers, it focuses on whole objects and decides **what** and **where** they are.\n\n"
169
- "**From top to bottom:**\n"
170
- "- Left: final detections (boxes + labels).\n"
171
- "- Early heatmap: where YOLO sees low-level details.\n"
172
- "- Middle heatmap: where it sees object parts.\n"
173
- "- Late heatmap: where it focuses on full objects.\n"
 
174
  )
175
  else:
176
  explanation = (
177
- "πŸ”¬ **Technical explanation:**\n\n"
178
- "- We run `yolov5n` (small YOLOv5) on CPU.\n"
179
- "- Forward hooks capture intermediate feature maps from several Conv/C3/SPPF blocks.\n"
180
- "- For each selected layer, we take the tensor `(C,H,W)`, average over channels to get a 2D\n"
181
- " activation map `(H,W)`, normalize it, and upscale it to the original image size.\n"
182
- "- Early feature map β‰ˆ low-level features (edges, textures).\n"
183
- "- Middle feature map β‰ˆ mid-level features (parts, shapes).\n"
184
- "- Late feature map β‰ˆ high-level features (object-centric regions used for detection).\n"
 
 
185
  )
186
 
187
- # Append layer shapes info if available
188
- fm_shapes_info = []
189
- for name, fm in chosen_fms:
190
- fm_shapes_info.append(f"Layer {name}: shape {tuple(fm.shape)} (C,H,W)")
191
- if fm_shapes_info:
192
- explanation += "\n**Feature map shapes captured:**\n" + "\n".join(f"- {s}" for s in fm_shapes_info)
193
 
194
  return det_img, heatmaps[0], heatmaps[1], heatmaps[2], explanation
195
 
196
 
197
- # ------------------- GRADIO UI (GRADIO 5) -------------------
198
 
199
- with gr.Blocks(
200
- title="YOLOv5n Visualizer β€” Inside Object Detection",
201
- theme=gr.themes.Soft()
202
- ) as demo:
203
 
204
- gr.Markdown("# 🧠 YOLOv5n Visualizer β€” See Inside Object Detection")
205
  gr.Markdown(
206
- "Upload an image and see YOLO work **step by step**:\n"
207
- "1. Final detections (boxes & labels)\n"
208
- "2. Early feature activations (edges/textures)\n"
209
- "3. Middle feature activations (parts/shapes)\n"
210
- "4. Late feature activations (object focus)\n"
211
- "Use the explanation toggle for simple or technical view."
 
212
  )
213
 
214
  with gr.Row():
@@ -218,29 +239,47 @@ with gr.Blocks(
218
  type="pil"
219
  )
220
  conf_slider = gr.Slider(
221
- 0.1, 0.9, step=0.05, value=0.25,
 
 
 
222
  label="Confidence threshold"
223
  )
224
  iou_slider = gr.Slider(
225
- 0.1, 0.9, step=0.05, value=0.45,
226
- label="IoU threshold (for NMS)"
 
 
 
227
  )
228
  simple_ck = gr.Checkbox(
229
- label="Explain in simple terms (kids/elders)",
230
  value=True
231
  )
232
  run_btn = gr.Button("Run YOLO & Visualize", variant="primary")
233
 
234
  with gr.Column(scale=1):
235
- out_det = gr.Image(label="Step 4 β€” Final detections (YOLOv5n)")
 
 
 
236
  explanation_md = gr.Markdown(label="Explanation")
237
 
238
- gr.Markdown("### πŸ” Steps inside the network (feature maps)")
239
 
240
  with gr.Row():
241
- fm1 = gr.Image(label="Step 1 β€” Early layer activation (edges & textures)", interactive=False)
242
- fm2 = gr.Image(label="Step 2 β€” Middle layer activation (parts & shapes)", interactive=False)
243
- fm3 = gr.Image(label="Step 3 β€” Late layer activation (objects)", interactive=False)
 
 
 
 
 
 
 
 
 
244
 
245
  run_btn.click(
246
  analyze_yolo,
 
1
  # ==========================================================
2
+ # YOLOv8n Visualizer β€” "Inside Object Detection"
3
+ # - Uses Ultralytics YOLOv8n (small, CPU-friendly)
4
  # - Shows detections + early/mid/late feature maps
5
+ # - Simple vs Technical explanation
6
+ # - Gradio 5 compatible, also OK on 6 (no theme arg)
7
  # ==========================================================
8
 
9
  import gradio as gr
 
11
  import numpy as np
12
  from PIL import Image
13
 
14
+ from ultralytics import YOLO
15
+
16
  # ------------------- GLOBALS -------------------
17
 
 
18
  DEVICE = "cpu"
 
19
  MODEL = None
20
  FEATURE_MAPS = {} # {layer_name: tensor(B,C,H,W)}
21
 
 
24
 
25
  def load_model():
26
  """
27
+ Load YOLOv8n once and register forward hooks
28
+ on backbone/head layers to capture feature maps.
29
  """
30
  global MODEL, FEATURE_MAPS
31
  if MODEL is not None:
32
  return MODEL
33
 
34
+ # This will download yolov8n.pt on first run and cache it
35
+ model = YOLO("yolov8n.pt")
 
 
 
36
 
37
+ # Ensure model on CPU
38
+ if hasattr(model, "to"):
39
+ model.to(DEVICE)
40
+ else:
41
+ model.model.to(DEVICE)
42
 
43
+ model.model.eval()
 
 
 
 
 
44
 
45
+ FEATURE_MAPS = {}
46
+
47
+ # Register hooks on layers in the detection model
48
+ # For YOLOv8, model.model.model is a list of blocks (backbone + head)
49
+ for idx, layer in enumerate(model.model.model):
50
+ def make_hook(name):
51
+ def hook(module, inputs, output):
52
+ # Handle tensors vs lists/tuples
53
+ with torch.no_grad():
54
+ out = output
55
+ if isinstance(out, (list, tuple)):
56
+ # pick first tensor-like element
57
+ out = next(
58
+ (o for o in out if isinstance(o, torch.Tensor)),
59
+ None
60
+ )
61
+ if isinstance(out, torch.Tensor):
62
+ FEATURE_MAPS[name] = out.detach().cpu()
63
+ return hook
64
+
65
+ layer.register_forward_hook(make_hook(str(idx)))
66
 
67
  MODEL = model
68
  return MODEL
 
73
  def tensor_to_heatmap(fm, out_size):
74
  """
75
  Convert a feature map tensor (C,H,W) to a grayscale heatmap PIL image.
 
76
  - average over channels
77
  - normalize to 0..1
78
+ - resize to out_size (W,H)
79
  """
80
  if fm.ndim != 3:
81
  return None
82
 
83
  fm_np = fm.numpy().astype(np.float32) # (C,H,W)
84
+ heat = fm_np.mean(axis=0) # (H,W)
 
85
 
86
+ if not np.any(heat):
87
  heat = np.zeros_like(heat)
88
  else:
89
+ heat -= heat.min()
90
  maxv = heat.max()
91
  if maxv > 0:
92
+ heat /= maxv
93
 
94
+ img = (heat * 255).astype("uint8")
95
+ pil = Image.fromarray(img, mode="L")
96
  pil = pil.resize(out_size, Image.NEAREST)
97
  return pil
98
 
99
 
100
  def pick_feature_maps():
101
  """
102
+ Choose three feature maps: early, middle, late.
103
+ FEATURE_MAPS keys are stringified indices "0", "1", ...
104
+ Returns list[(name, fm_tensor(C,H,W))]
105
  """
106
  if not FEATURE_MAPS:
107
  return []
108
 
109
+ # sort by numeric layer index
110
  keys = sorted(FEATURE_MAPS.keys(), key=lambda x: int(x))
111
+ fms = []
112
+ for k in keys:
113
+ t = FEATURE_MAPS[k]
114
+ if isinstance(t, torch.Tensor) and t.ndim == 4:
115
+ fms.append((k, t[0])) # (name, (C,H,W))
116
+
117
+ if not fms:
118
+ return []
119
 
 
120
  idxs = [0, len(fms) // 2, len(fms) - 1]
121
+ idxs = sorted(set(idxs))
122
 
123
  chosen = []
124
  for i in idxs:
125
+ chosen.append(fms[i])
126
  return chosen
127
 
128
 
 
130
 
131
  def analyze_yolo(img, conf_thres, iou_thres, simple_mode):
132
  """
133
+ Run YOLOv8n on input image and produce:
134
+ - detection image with boxes
135
+ - early/mid/late feature map heatmaps
136
+ - explanation text (simple or technical)
 
 
137
  """
138
  if img is None:
139
  return (
140
+ None, # detection image
141
+ None, # early heatmap
142
+ None, # mid heatmap
143
+ None, # late heatmap
144
  "⚠️ Please upload an image first."
145
  )
146
 
147
  model = load_model()
148
 
149
+ # Clear old feature maps before forward
150
  FEATURE_MAPS.clear()
151
 
152
+ # Gradio gives PIL image (type="pil")
153
  pil = img
154
 
155
  # Configure thresholds
156
+ conf = float(conf_thres)
157
+ iou = float(iou_thres)
158
 
159
  with torch.no_grad():
160
+ results = model(
161
+ pil,
162
+ conf=conf,
163
+ iou=iou,
164
+ verbose=False
165
+ )
166
 
167
+ res = results[0]
 
 
168
 
169
+ # res.plot() returns numpy array (H,W,3), BGR by default, but visually OK
170
+ det_np = res.plot()
171
+ det_img = Image.fromarray(det_np)
172
+
173
+ # Now FEATURE_MAPS should be filled by hooks
174
+ chosen = pick_feature_maps()
175
  W, H = pil.size
176
+ heatmaps = [None, None, None]
177
 
178
+ for idx, item in enumerate(chosen):
179
  name, fm = item
180
  hm = tensor_to_heatmap(fm, (W, H))
181
  heatmaps[idx] = hm
182
 
183
+ # Build explanation
184
  if simple_mode:
185
  explanation = (
186
  "πŸ§’ **Simple explanation of what you see:**\n\n"
187
+ "**Step 0 β€” Input image**: This is your original picture.\n\n"
188
+ "**Step 1 β€” Early layer heatmap**:\n"
189
+ "YOLO looks for very small details like edges, corners, and simple textures.\n\n"
190
+ "**Step 2 β€” Middle layer heatmap**:\n"
191
+ "It starts to see groups of pixels as shapes or parts of objects (like wheels, faces, etc.).\n\n"
192
+ "**Step 3 β€” Late layer heatmap**:\n"
193
+ "It focuses on whole objects and regions where it thinks something important is.\n\n"
194
+ "**Step 4 β€” Final detections**:\n"
195
+ "YOLO draws boxes and labels around what it believes are objects in the image.\n"
196
  )
197
  else:
198
  explanation = (
199
+ "πŸ”¬ **Technical explanation of the visualization:**\n\n"
200
+ "- We use **YOLOv8n** (Ultralytics) running on CPU.\n"
201
+ "- Forward hooks capture intermediate feature maps from backbone/head blocks.\n"
202
+ "- For each selected layer, we take the tensor `(C,H,W)` and average over channels to\n"
203
+ " obtain a 2D activation map `(H,W)`, then normalize it and upsample it to `(W_img,H_img)`.\n"
204
+ "- Early feature map β‰ˆ low-level features (edges, corners, local textures).\n"
205
+ "- Middle feature map β‰ˆ mid-level features (object parts & shapes).\n"
206
+ "- Late feature map β‰ˆ high-level features (object-centric regions that drive detection head).\n"
207
+ "- The detection image is produced by YOLO's standard post-processing (objectness, class\n"
208
+ " scores, and Non-Maximum Suppression on bounding boxes).\n"
209
  )
210
 
211
+ # Add feature map shapes
212
+ if chosen:
213
+ explanation += "\n**Captured feature map shapes (C,H,W):**\n"
214
+ for name, fm in chosen:
215
+ explanation += f"- Layer {name}: {tuple(fm.shape)}\n"
 
216
 
217
  return det_img, heatmaps[0], heatmaps[1], heatmaps[2], explanation
218
 
219
 
220
+ # ------------------- GRADIO UI -------------------
221
 
222
+ with gr.Blocks(title="YOLOv8n Visualizer β€” Inside Object Detection") as demo:
 
 
 
223
 
224
+ gr.Markdown("# 🧠 YOLOv8n Visualizer β€” Inside Object Detection")
225
  gr.Markdown(
226
+ "See what happens **inside** an object detection model.\n\n"
227
+ "**Steps shown:**\n"
228
+ "- **Step 0** β€” Input image\n"
229
+ "- **Step 1** β€” Early layer activation (edges & textures)\n"
230
+ "- **Step 2** β€” Middle layer activation (parts & shapes)\n"
231
+ "- **Step 3** β€” Late layer activation (objects)\n"
232
+ "- **Step 4** β€” Final detections (boxes & labels)\n"
233
  )
234
 
235
  with gr.Row():
 
239
  type="pil"
240
  )
241
  conf_slider = gr.Slider(
242
+ minimum=0.1,
243
+ maximum=0.9,
244
+ step=0.05,
245
+ value=0.25,
246
  label="Confidence threshold"
247
  )
248
  iou_slider = gr.Slider(
249
+ minimum=0.1,
250
+ maximum=0.9,
251
+ step=0.05,
252
+ value=0.45,
253
+ label="IoU threshold (NMS)"
254
  )
255
  simple_ck = gr.Checkbox(
256
+ label="Explain in simple terms (for kids/elders)",
257
  value=True
258
  )
259
  run_btn = gr.Button("Run YOLO & Visualize", variant="primary")
260
 
261
  with gr.Column(scale=1):
262
+ out_det = gr.Image(
263
+ label="Step 4 β€” Final detections (YOLOv8n)",
264
+ interactive=False
265
+ )
266
  explanation_md = gr.Markdown(label="Explanation")
267
 
268
+ gr.Markdown("### πŸ” Steps 1–3: internal feature maps (what the network focuses on)")
269
 
270
  with gr.Row():
271
+ fm1 = gr.Image(
272
+ label="Step 1 β€” Early layer activation (edges & textures)",
273
+ interactive=False
274
+ )
275
+ fm2 = gr.Image(
276
+ label="Step 2 β€” Middle layer activation (parts & shapes)",
277
+ interactive=False
278
+ )
279
+ fm3 = gr.Image(
280
+ label="Step 3 β€” Late layer activation (objects)",
281
+ interactive=False
282
+ )
283
 
284
  run_btn.click(
285
  analyze_yolo,