HugoHE commited on
Commit
cb187fe
·
verified ·
1 Parent(s): 5b4cb04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -47
app.py CHANGED
@@ -105,7 +105,7 @@ def load_model(model_name: str):
105
  # ---------------------------
106
 
107
  def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
108
- """Return a composite XAI heatmap layout comparing vanilla and fine-tuned models."""
109
 
110
  # Convert RGB (Gradio default) ➜ BGR (OpenCV default)
111
  img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
@@ -128,9 +128,6 @@ def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
128
  finetune_has_detections = (finetune_results and len(finetune_results) > 0 and
129
  hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0)
130
 
131
- if not vanilla_has_detections and not finetune_has_detections:
132
- return img_rgb # nothing detected, return original
133
-
134
  # Create heatmap visualizations for both models
135
  vanilla_heatmaps = []
136
  finetune_heatmaps = []
@@ -145,8 +142,15 @@ def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
145
  finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls]
146
  finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned")
147
 
148
- # Create comparison layout
149
- return create_comparison_layout(img_bgr, vanilla_heatmaps, finetune_heatmaps)
 
 
 
 
 
 
 
150
 
151
 
152
  def create_heatmap_snippets(img_bgr, result, names, model_type):
@@ -181,38 +185,10 @@ def create_heatmap_snippets(img_bgr, result, names, model_type):
181
  return snippets
182
 
183
 
184
- def create_comparison_layout(img_bgr, vanilla_heatmaps, finetune_heatmaps):
185
- """Create side-by-side comparison layout for both models."""
186
  pad = 20
187
 
188
- # Create model sections
189
- vanilla_section = create_model_section(vanilla_heatmaps, "Vanilla Model", (0, 100, 0))
190
- finetune_section = create_model_section(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200))
191
-
192
- # Calculate dimensions
193
- max_height = max(vanilla_section.shape[0], finetune_section.shape[0]) if vanilla_section.size > 0 and finetune_section.size > 0 else max(vanilla_section.shape[0] if vanilla_section.size > 0 else 0, finetune_section.shape[0] if finetune_section.size > 0 else 0)
194
- total_width = vanilla_section.shape[1] + finetune_section.shape[1] + pad if vanilla_section.size > 0 and finetune_section.size > 0 else max(vanilla_section.shape[1] if vanilla_section.size > 0 else 0, finetune_section.shape[1] if finetune_section.size > 0 else 0)
195
-
196
- # Create final canvas
197
- canvas_h = max_height + 2 * pad
198
- canvas_w = total_width + 2 * pad
199
- final = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
200
-
201
- # Place sections side by side
202
- if vanilla_section.size > 0:
203
- y_off = (canvas_h - vanilla_section.shape[0]) // 2
204
- final[y_off : y_off + vanilla_section.shape[0], pad : pad + vanilla_section.shape[1]] = vanilla_section
205
-
206
- if finetune_section.size > 0:
207
- y_off = (canvas_h - finetune_section.shape[0]) // 2
208
- x_start = pad + vanilla_section.shape[1] + pad if vanilla_section.size > 0 else pad
209
- final[y_off : y_off + finetune_section.shape[0], x_start : x_start + finetune_section.shape[1]] = finetune_section
210
-
211
- return cv2.cvtColor(final, cv2.COLOR_BGR2RGB)
212
-
213
-
214
- def create_model_section(heatmaps, title, color):
215
- """Create a section for one model's heatmaps."""
216
  if not heatmaps:
217
  # Create empty section with title
218
  font = cv2.FONT_HERSHEY_SIMPLEX
@@ -231,19 +207,22 @@ def create_model_section(heatmaps, title, color):
231
  section_h = max_h + th + 40
232
  section_w = max(total_w, tw + 20)
233
 
234
- section = np.full((section_h, section_w, 3), 255, np.uint8)
 
 
 
235
 
236
  # Add title
237
- cv2.putText(section, title, (10, th + 20), title_font, 1.0, color, 2, cv2.LINE_AA)
238
 
239
  # Arrange heatmaps
240
- cur_x = 0
241
  for h in heatmaps:
242
- y_off = th + 30 + (max_h - h.shape[0]) // 2
243
- section[y_off : y_off + h.shape[0], cur_x : cur_x + h.shape[1]] = h
244
  cur_x += h.shape[1] + 10
245
 
246
- return section
247
 
248
 
249
  # ---------------------------
@@ -267,19 +246,20 @@ def build_demo():
267
  label="Click to load example"
268
  )
269
 
270
- # Right side - Output visualization
271
  with gr.Column(scale=2):
272
- outputs = gr.Image(type="numpy", label="XAI Heatmap Comparison")
 
273
 
274
  # Connect inputs to the function
275
  def update_heatmap(image, confidence):
276
  if image is None:
277
- return None
278
  return generate_heatmap_layout(image, confidence)
279
 
280
  # Set up the interface
281
- image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=outputs)
282
- conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=outputs)
283
 
284
  return demo
285
 
 
105
  # ---------------------------
106
 
107
  def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
108
+ """Return separate XAI heatmap layouts for vanilla and fine-tuned models."""
109
 
110
  # Convert RGB (Gradio default) ➜ BGR (OpenCV default)
111
  img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
 
128
  finetune_has_detections = (finetune_results and len(finetune_results) > 0 and
129
  hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0)
130
 
 
 
 
131
  # Create heatmap visualizations for both models
132
  vanilla_heatmaps = []
133
  finetune_heatmaps = []
 
142
  finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls]
143
  finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned")
144
 
145
+ # Create separate layouts for each model
146
+ vanilla_layout = create_model_layout(vanilla_heatmaps, "Vanilla Model", (0, 100, 0))
147
+ finetune_layout = create_model_layout(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200))
148
+
149
+ # Convert BGR to RGB for display
150
+ vanilla_output = cv2.cvtColor(vanilla_layout, cv2.COLOR_BGR2RGB) if vanilla_layout is not None else None
151
+ finetune_output = cv2.cvtColor(finetune_layout, cv2.COLOR_BGR2RGB) if finetune_layout is not None else None
152
+
153
+ return vanilla_output, finetune_output
154
 
155
 
156
  def create_heatmap_snippets(img_bgr, result, names, model_type):
 
185
  return snippets
186
 
187
 
188
+ def create_model_layout(heatmaps, title, color):
189
+ """Create a layout for one model's heatmaps."""
190
  pad = 20
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  if not heatmaps:
193
  # Create empty section with title
194
  font = cv2.FONT_HERSHEY_SIMPLEX
 
207
  section_h = max_h + th + 40
208
  section_w = max(total_w, tw + 20)
209
 
210
+ # Create canvas with padding
211
+ canvas_h = section_h + 2 * pad
212
+ canvas_w = section_w + 2 * pad
213
+ canvas = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
214
 
215
  # Add title
216
+ cv2.putText(canvas, title, (pad + 10, pad + th + 20), title_font, 1.0, color, 2, cv2.LINE_AA)
217
 
218
  # Arrange heatmaps
219
+ cur_x = pad
220
  for h in heatmaps:
221
+ y_off = pad + th + 30 + (max_h - h.shape[0]) // 2
222
+ canvas[y_off : y_off + h.shape[0], cur_x : cur_x + h.shape[1]] = h
223
  cur_x += h.shape[1] + 10
224
 
225
+ return canvas
226
 
227
 
228
  # ---------------------------
 
246
  label="Click to load example"
247
  )
248
 
249
+ # Right side - Output visualizations (separated vertically)
250
  with gr.Column(scale=2):
251
+ vanilla_output = gr.Image(type="numpy", label="Vanilla Model Heatmap")
252
+ finetune_output = gr.Image(type="numpy", label="Fine-tuned Model Heatmap")
253
 
254
  # Connect inputs to the function
255
  def update_heatmap(image, confidence):
256
  if image is None:
257
+ return None, None
258
  return generate_heatmap_layout(image, confidence)
259
 
260
  # Set up the interface
261
+ image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
262
+ conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
263
 
264
  return demo
265