Update app.py
Browse files
app.py
CHANGED
|
@@ -105,7 +105,7 @@ def load_model(model_name: str):
|
|
| 105 |
# ---------------------------
|
| 106 |
|
| 107 |
def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
|
| 108 |
-
"""Return
|
| 109 |
|
| 110 |
# Convert RGB (Gradio default) ➜ BGR (OpenCV default)
|
| 111 |
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
|
@@ -128,9 +128,6 @@ def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
|
|
| 128 |
finetune_has_detections = (finetune_results and len(finetune_results) > 0 and
|
| 129 |
hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0)
|
| 130 |
|
| 131 |
-
if not vanilla_has_detections and not finetune_has_detections:
|
| 132 |
-
return img_rgb # nothing detected, return original
|
| 133 |
-
|
| 134 |
# Create heatmap visualizations for both models
|
| 135 |
vanilla_heatmaps = []
|
| 136 |
finetune_heatmaps = []
|
|
@@ -145,8 +142,15 @@ def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
|
|
| 145 |
finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls]
|
| 146 |
finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned")
|
| 147 |
|
| 148 |
-
# Create
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
def create_heatmap_snippets(img_bgr, result, names, model_type):
|
|
@@ -181,38 +185,10 @@ def create_heatmap_snippets(img_bgr, result, names, model_type):
|
|
| 181 |
return snippets
|
| 182 |
|
| 183 |
|
| 184 |
-
def
|
| 185 |
-
"""Create
|
| 186 |
pad = 20
|
| 187 |
|
| 188 |
-
# Create model sections
|
| 189 |
-
vanilla_section = create_model_section(vanilla_heatmaps, "Vanilla Model", (0, 100, 0))
|
| 190 |
-
finetune_section = create_model_section(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200))
|
| 191 |
-
|
| 192 |
-
# Calculate dimensions
|
| 193 |
-
max_height = max(vanilla_section.shape[0], finetune_section.shape[0]) if vanilla_section.size > 0 and finetune_section.size > 0 else max(vanilla_section.shape[0] if vanilla_section.size > 0 else 0, finetune_section.shape[0] if finetune_section.size > 0 else 0)
|
| 194 |
-
total_width = vanilla_section.shape[1] + finetune_section.shape[1] + pad if vanilla_section.size > 0 and finetune_section.size > 0 else max(vanilla_section.shape[1] if vanilla_section.size > 0 else 0, finetune_section.shape[1] if finetune_section.size > 0 else 0)
|
| 195 |
-
|
| 196 |
-
# Create final canvas
|
| 197 |
-
canvas_h = max_height + 2 * pad
|
| 198 |
-
canvas_w = total_width + 2 * pad
|
| 199 |
-
final = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
|
| 200 |
-
|
| 201 |
-
# Place sections side by side
|
| 202 |
-
if vanilla_section.size > 0:
|
| 203 |
-
y_off = (canvas_h - vanilla_section.shape[0]) // 2
|
| 204 |
-
final[y_off : y_off + vanilla_section.shape[0], pad : pad + vanilla_section.shape[1]] = vanilla_section
|
| 205 |
-
|
| 206 |
-
if finetune_section.size > 0:
|
| 207 |
-
y_off = (canvas_h - finetune_section.shape[0]) // 2
|
| 208 |
-
x_start = pad + vanilla_section.shape[1] + pad if vanilla_section.size > 0 else pad
|
| 209 |
-
final[y_off : y_off + finetune_section.shape[0], x_start : x_start + finetune_section.shape[1]] = finetune_section
|
| 210 |
-
|
| 211 |
-
return cv2.cvtColor(final, cv2.COLOR_BGR2RGB)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
def create_model_section(heatmaps, title, color):
|
| 215 |
-
"""Create a section for one model's heatmaps."""
|
| 216 |
if not heatmaps:
|
| 217 |
# Create empty section with title
|
| 218 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
@@ -231,19 +207,22 @@ def create_model_section(heatmaps, title, color):
|
|
| 231 |
section_h = max_h + th + 40
|
| 232 |
section_w = max(total_w, tw + 20)
|
| 233 |
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
# Add title
|
| 237 |
-
cv2.putText(
|
| 238 |
|
| 239 |
# Arrange heatmaps
|
| 240 |
-
cur_x =
|
| 241 |
for h in heatmaps:
|
| 242 |
-
y_off = th + 30 + (max_h - h.shape[0]) // 2
|
| 243 |
-
|
| 244 |
cur_x += h.shape[1] + 10
|
| 245 |
|
| 246 |
-
return
|
| 247 |
|
| 248 |
|
| 249 |
# ---------------------------
|
|
@@ -267,19 +246,20 @@ def build_demo():
|
|
| 267 |
label="Click to load example"
|
| 268 |
)
|
| 269 |
|
| 270 |
-
# Right side - Output
|
| 271 |
with gr.Column(scale=2):
|
| 272 |
-
|
|
|
|
| 273 |
|
| 274 |
# Connect inputs to the function
|
| 275 |
def update_heatmap(image, confidence):
|
| 276 |
if image is None:
|
| 277 |
-
return None
|
| 278 |
return generate_heatmap_layout(image, confidence)
|
| 279 |
|
| 280 |
# Set up the interface
|
| 281 |
-
image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=
|
| 282 |
-
conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=
|
| 283 |
|
| 284 |
return demo
|
| 285 |
|
|
|
|
| 105 |
# ---------------------------
|
| 106 |
|
| 107 |
def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
|
| 108 |
+
"""Return separate XAI heatmap layouts for vanilla and fine-tuned models."""
|
| 109 |
|
| 110 |
# Convert RGB (Gradio default) ➜ BGR (OpenCV default)
|
| 111 |
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
|
|
|
| 128 |
finetune_has_detections = (finetune_results and len(finetune_results) > 0 and
|
| 129 |
hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0)
|
| 130 |
|
|
|
|
|
|
|
|
|
|
| 131 |
# Create heatmap visualizations for both models
|
| 132 |
vanilla_heatmaps = []
|
| 133 |
finetune_heatmaps = []
|
|
|
|
| 142 |
finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls]
|
| 143 |
finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned")
|
| 144 |
|
| 145 |
+
# Create separate layouts for each model
|
| 146 |
+
vanilla_layout = create_model_layout(vanilla_heatmaps, "Vanilla Model", (0, 100, 0))
|
| 147 |
+
finetune_layout = create_model_layout(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200))
|
| 148 |
+
|
| 149 |
+
# Convert BGR to RGB for display
|
| 150 |
+
vanilla_output = cv2.cvtColor(vanilla_layout, cv2.COLOR_BGR2RGB) if vanilla_layout is not None else None
|
| 151 |
+
finetune_output = cv2.cvtColor(finetune_layout, cv2.COLOR_BGR2RGB) if finetune_layout is not None else None
|
| 152 |
+
|
| 153 |
+
return vanilla_output, finetune_output
|
| 154 |
|
| 155 |
|
| 156 |
def create_heatmap_snippets(img_bgr, result, names, model_type):
|
|
|
|
| 185 |
return snippets
|
| 186 |
|
| 187 |
|
| 188 |
+
def create_model_layout(heatmaps, title, color):
|
| 189 |
+
"""Create a layout for one model's heatmaps."""
|
| 190 |
pad = 20
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
if not heatmaps:
|
| 193 |
# Create empty section with title
|
| 194 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
|
| 207 |
section_h = max_h + th + 40
|
| 208 |
section_w = max(total_w, tw + 20)
|
| 209 |
|
| 210 |
+
# Create canvas with padding
|
| 211 |
+
canvas_h = section_h + 2 * pad
|
| 212 |
+
canvas_w = section_w + 2 * pad
|
| 213 |
+
canvas = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
|
| 214 |
|
| 215 |
# Add title
|
| 216 |
+
cv2.putText(canvas, title, (pad + 10, pad + th + 20), title_font, 1.0, color, 2, cv2.LINE_AA)
|
| 217 |
|
| 218 |
# Arrange heatmaps
|
| 219 |
+
cur_x = pad
|
| 220 |
for h in heatmaps:
|
| 221 |
+
y_off = pad + th + 30 + (max_h - h.shape[0]) // 2
|
| 222 |
+
canvas[y_off : y_off + h.shape[0], cur_x : cur_x + h.shape[1]] = h
|
| 223 |
cur_x += h.shape[1] + 10
|
| 224 |
|
| 225 |
+
return canvas
|
| 226 |
|
| 227 |
|
| 228 |
# ---------------------------
|
|
|
|
| 246 |
label="Click to load example"
|
| 247 |
)
|
| 248 |
|
| 249 |
+
# Right side - Output visualizations (separated vertically)
|
| 250 |
with gr.Column(scale=2):
|
| 251 |
+
vanilla_output = gr.Image(type="numpy", label="Vanilla Model Heatmap")
|
| 252 |
+
finetune_output = gr.Image(type="numpy", label="Fine-tuned Model Heatmap")
|
| 253 |
|
| 254 |
# Connect inputs to the function
|
| 255 |
def update_heatmap(image, confidence):
|
| 256 |
if image is None:
|
| 257 |
+
return None, None
|
| 258 |
return generate_heatmap_layout(image, confidence)
|
| 259 |
|
| 260 |
# Set up the interface
|
| 261 |
+
image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
|
| 262 |
+
conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
|
| 263 |
|
| 264 |
return demo
|
| 265 |
|