Arviano commited on
Commit
67f1c25
·
1 Parent(s): a990302

Add GradCAM target layer selection and UI defaults integration

Browse files
Files changed (5) hide show
  1. config/ui_defaults.json +18 -0
  2. requirements.txt +1 -1
  3. src/inference.py +30 -7
  4. src/ui.py +227 -48
  5. src/xdl.py +85 -13
config/ui_defaults.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "selected_case": "Multiclass (4 Classes)",
3
+ "confidence_threshold": 0.6,
4
+ "smoothgrad_samples": 50,
5
+ "smoothgrad_noise": 0.05,
6
+ "gradcam_target_layer": "denseblock3",
7
+ "save_xdl_results": false,
8
+ "save_xdl_dir": "xdl_results",
9
+ "_comment_xdl_target_layer": "UI dropdown reads this default. You can still override with XDL_TARGET_LAYER env var.",
10
+ "_supported_xdl_target_layer": {
11
+ "denseblock3": "Default. Usually less center-biased and more spatially varied.",
12
+ "transition2": "Good alternative; often broad and stable localization.",
13
+ "transition1": "Earlier layer; more detail but can be noisy.",
14
+ "denseblock4": "Late layer; stronger class semantics, can be center-heavy.",
15
+ "transition3": "Late transition; similar tradeoff to denseblock4.",
16
+ "norm5_last": "Original last layer behavior (legacy setting)."
17
+ }
18
+ }
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- # uv export --format requirements-txt
3
  aiofiles==24.1.0 \
4
  --hash=sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c \
5
  --hash=sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5
 
1
  # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --output-file requirements.txt
3
  aiofiles==24.1.0 \
4
  --hash=sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c \
5
  --hash=sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5
src/inference.py CHANGED
@@ -39,6 +39,15 @@ DEFAULT_CASE_NAME = "Multiclass (4 Classes)"
39
  CASE_OPTIONS = list(CASE_CONFIGS.keys())
40
  SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
41
  DEFAULT_SAVE_DIR = "xdl_results"
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  def _detect_device() -> torch.device:
@@ -201,7 +210,8 @@ def _aggregate_classification(classified: List[ClassifiedPrediction], labels: Li
201
 
202
 
203
  def _predict_top1(model: DenseNet121, image: Image.Image) -> Tuple[int, float, torch.Tensor]:
204
- input_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
 
205
 
206
  with torch.no_grad():
207
  logits = model(input_tensor)[0]
@@ -368,6 +378,7 @@ def batch_predict_with_xdl(
368
  smoothgrad_noise: float,
369
  save_xdl_results: bool,
370
  save_xdl_dir: str,
 
371
  ):
372
  last_output = None
373
  for payload in batch_predict_with_xdl_stream(
@@ -379,6 +390,7 @@ def batch_predict_with_xdl(
379
  smoothgrad_noise=smoothgrad_noise,
380
  save_xdl_results=save_xdl_results,
381
  save_xdl_dir=save_xdl_dir,
 
382
  ):
383
  last_output = payload
384
 
@@ -396,6 +408,7 @@ def batch_predict_with_xdl_stream(
396
  smoothgrad_noise: float,
397
  save_xdl_results: bool,
398
  save_xdl_dir: str,
 
399
  ) -> Iterator[Tuple[str, List[List[str]], List[Tuple[np.ndarray, str]]]]:
400
  case_state = _get_case_state(selected_case)
401
  model: DenseNet121 = case_state["model"]
@@ -408,9 +421,14 @@ def batch_predict_with_xdl_stream(
408
  yield _render_error_html(model_error), [], []
409
  return
410
 
 
 
411
  threshold = float(np.clip(confidence_threshold, 0.0, 1.0))
412
  smoothgrad_samples = int(max(1, smoothgrad_samples))
413
  smoothgrad_noise = float(max(0.0, smoothgrad_noise))
 
 
 
414
 
415
  image_paths, input_error = _resolve_input_images(uploaded_files, folder_path)
416
  if input_error:
@@ -445,7 +463,7 @@ def batch_predict_with_xdl_stream(
445
  final_class, mean_conf = _aggregate_classification(classified, labels)
446
  class_counter = Counter(item.pred_idx for item in classified)
447
  class_stats = ", ".join(f"{labels[idx]}: {count}" for idx, count in class_counter.items())
448
- initial_xdl_status = "Processing overlays..."
449
  else:
450
  final_class = "N/A"
451
  mean_conf = None
@@ -455,7 +473,7 @@ def batch_predict_with_xdl_stream(
455
  summary_initial = _render_summary_html(
456
  case_name=case_name,
457
  model_path=model_path,
458
- device_name=DEVICE.type,
459
  processed=len(image_paths),
460
  classified=len(classified),
461
  threshold=threshold,
@@ -483,13 +501,17 @@ def batch_predict_with_xdl_stream(
483
  xdl_error_count = 0
484
 
485
  if xdl is not None:
486
- target_layer = xdl["_get_target_layer"](model)
 
 
 
 
487
  cam = xdl["GradCAM"](model=model, target_layers=[target_layer])
488
 
489
  for item in classified:
490
  try:
491
  image = Image.open(item.path).convert("RGB")
492
- input_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
493
 
494
  base_img_float, base_img_uint8 = xdl["_preprocess_image"](input_tensor[0])
495
  h, w = base_img_uint8.shape[:2]
@@ -506,6 +528,7 @@ def batch_predict_with_xdl_stream(
506
  item.pred_idx,
507
  n_samples=smoothgrad_samples,
508
  noise_level=smoothgrad_noise,
 
509
  )
510
  _, smooth_heatmap = xdl["_process_smoothgrad_map"](
511
  smooth_raw,
@@ -532,7 +555,7 @@ def batch_predict_with_xdl_stream(
532
  elif xdl_error_count:
533
  xdl_status = f"Completed with {xdl_error_count} overlay errors"
534
  else:
535
- xdl_status = f"Completed ({len(gallery_items)} overlays)"
536
 
537
  if save_xdl_results:
538
  if save_error:
@@ -547,7 +570,7 @@ def batch_predict_with_xdl_stream(
547
  summary_final = _render_summary_html(
548
  case_name=case_name,
549
  model_path=model_path,
550
- device_name=DEVICE.type,
551
  processed=len(image_paths),
552
  classified=len(classified),
553
  threshold=threshold,
 
39
  CASE_OPTIONS = list(CASE_CONFIGS.keys())
40
  SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
41
  DEFAULT_SAVE_DIR = "xdl_results"
42
+ GRADCAM_TARGET_LAYER_OPTIONS = (
43
+ "denseblock3",
44
+ "transition2",
45
+ "transition1",
46
+ "denseblock4",
47
+ "transition3",
48
+ "norm5_last",
49
+ )
50
+ DEFAULT_GRADCAM_TARGET_LAYER = "denseblock3"
51
 
52
 
53
  def _detect_device() -> torch.device:
 
210
 
211
 
212
  def _predict_top1(model: DenseNet121, image: Image.Image) -> Tuple[int, float, torch.Tensor]:
213
+ model_device = next(model.parameters()).device
214
+ input_tensor = val_transform(image).unsqueeze(0).to(model_device)
215
 
216
  with torch.no_grad():
217
  logits = model(input_tensor)[0]
 
378
  smoothgrad_noise: float,
379
  save_xdl_results: bool,
380
  save_xdl_dir: str,
381
+ gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER,
382
  ):
383
  last_output = None
384
  for payload in batch_predict_with_xdl_stream(
 
390
  smoothgrad_noise=smoothgrad_noise,
391
  save_xdl_results=save_xdl_results,
392
  save_xdl_dir=save_xdl_dir,
393
+ gradcam_target_layer=gradcam_target_layer,
394
  ):
395
  last_output = payload
396
 
 
408
  smoothgrad_noise: float,
409
  save_xdl_results: bool,
410
  save_xdl_dir: str,
411
+ gradcam_target_layer: str = DEFAULT_GRADCAM_TARGET_LAYER,
412
  ) -> Iterator[Tuple[str, List[List[str]], List[Tuple[np.ndarray, str]]]]:
413
  case_state = _get_case_state(selected_case)
414
  model: DenseNet121 = case_state["model"]
 
421
  yield _render_error_html(model_error), [], []
422
  return
423
 
424
+ model_device = next(model.parameters()).device
425
+
426
  threshold = float(np.clip(confidence_threshold, 0.0, 1.0))
427
  smoothgrad_samples = int(max(1, smoothgrad_samples))
428
  smoothgrad_noise = float(max(0.0, smoothgrad_noise))
429
+ gradcam_target_layer = str(gradcam_target_layer or DEFAULT_GRADCAM_TARGET_LAYER).strip().lower()
430
+ if gradcam_target_layer not in GRADCAM_TARGET_LAYER_OPTIONS:
431
+ gradcam_target_layer = DEFAULT_GRADCAM_TARGET_LAYER
432
 
433
  image_paths, input_error = _resolve_input_images(uploaded_files, folder_path)
434
  if input_error:
 
463
  final_class, mean_conf = _aggregate_classification(classified, labels)
464
  class_counter = Counter(item.pred_idx for item in classified)
465
  class_stats = ", ".join(f"{labels[idx]}: {count}" for idx, count in class_counter.items())
466
+ initial_xdl_status = f"Processing overlays... (GradCAM layer: {gradcam_target_layer})"
467
  else:
468
  final_class = "N/A"
469
  mean_conf = None
 
473
  summary_initial = _render_summary_html(
474
  case_name=case_name,
475
  model_path=model_path,
476
+ device_name=model_device.type,
477
  processed=len(image_paths),
478
  classified=len(classified),
479
  threshold=threshold,
 
501
  xdl_error_count = 0
502
 
503
  if xdl is not None:
504
+ try:
505
+ target_layer = xdl["_get_target_layer"](model, layer_name=gradcam_target_layer)
506
+ except TypeError:
507
+ # Backward compatibility for older helper signature: _get_target_layer(model)
508
+ target_layer = xdl["_get_target_layer"](model)
509
  cam = xdl["GradCAM"](model=model, target_layers=[target_layer])
510
 
511
  for item in classified:
512
  try:
513
  image = Image.open(item.path).convert("RGB")
514
+ input_tensor = val_transform(image).unsqueeze(0).to(model_device)
515
 
516
  base_img_float, base_img_uint8 = xdl["_preprocess_image"](input_tensor[0])
517
  h, w = base_img_uint8.shape[:2]
 
528
  item.pred_idx,
529
  n_samples=smoothgrad_samples,
530
  noise_level=smoothgrad_noise,
531
+ use_amp=(model_device.type == "cuda"),
532
  )
533
  _, smooth_heatmap = xdl["_process_smoothgrad_map"](
534
  smooth_raw,
 
555
  elif xdl_error_count:
556
  xdl_status = f"Completed with {xdl_error_count} overlay errors"
557
  else:
558
+ xdl_status = f"Completed ({len(gallery_items)} overlays, layer: {gradcam_target_layer})"
559
 
560
  if save_xdl_results:
561
  if save_error:
 
570
  summary_final = _render_summary_html(
571
  case_name=case_name,
572
  model_path=model_path,
573
+ device_name=model_device.type,
574
  processed=len(image_paths),
575
  classified=len(classified),
576
  threshold=threshold,
src/ui.py CHANGED
@@ -1,59 +1,237 @@
 
 
 
 
1
  import gradio as gr
2
 
3
- from src.inference import CASE_OPTIONS, DEFAULT_CASE_NAME, DEVICE, batch_predict_with_xdl_stream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def build_demo() -> gr.Blocks:
7
- with gr.Blocks(title="Medical Batch Classification + XDL") as demo:
8
- gr.Markdown("## Medical Batch Classification + XDL")
9
- gr.Markdown(
10
- f"Detected device: **{DEVICE.type}**. "
11
- "Upload a folder (preferred) or provide a local folder path."
12
- )
13
 
14
- selected_case = gr.Dropdown(
15
- choices=CASE_OPTIONS,
16
- value=DEFAULT_CASE_NAME,
17
- label="Problem Case",
 
 
 
 
 
 
 
18
  )
19
 
20
- with gr.Row():
21
- upload_input = gr.File(
22
- file_count="directory",
23
- file_types=["image"],
24
- type="filepath",
25
- label="Upload Image Folder",
26
- )
27
- folder_path = gr.Textbox(
28
- label="Local Folder Path (Optional)",
29
- placeholder="/absolute/path/to/folder/with/images",
30
- )
31
-
32
- with gr.Row():
33
- threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.01, label="Confidence Threshold")
34
- smoothgrad_samples = gr.Slider(10, 200, value=50, step=10, label="SmoothGrad Samples")
35
- smoothgrad_noise = gr.Slider(0.01, 0.2, value=0.05, step=0.01, label="SmoothGrad Noise Level")
36
-
37
- with gr.Row():
38
- save_xdl_results = gr.Checkbox(label="Save XDL Results Locally", value=False)
39
- save_xdl_dir = gr.Textbox(
40
- label="Save Folder",
41
- value="xdl_results",
42
- placeholder="xdl_results",
43
- )
44
-
45
- run_btn = gr.Button("Run Batch Inference")
46
-
47
- summary_out = gr.HTML(label="Summary")
48
- table_out = gr.Dataframe(
49
- headers=["filename", "status", "predicted_label", "confidence_or_error"],
50
- datatype=["str", "str", "str", "str"],
51
- interactive=False,
52
- label="Per-image Results",
53
- )
54
- gallery_out = gr.Gallery(
55
- label="Compact XDL Results (Original | GradCAM | SmoothGrad)",
56
- columns=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
 
59
  run_btn.click(
@@ -67,6 +245,7 @@ def build_demo() -> gr.Blocks:
67
  smoothgrad_noise,
68
  save_xdl_results,
69
  save_xdl_dir,
 
70
  ],
71
  outputs=[summary_out, table_out, gallery_out],
72
  )
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
  import gradio as gr
6
 
7
+ from src.inference import (
8
+ CASE_OPTIONS,
9
+ DEFAULT_CASE_NAME,
10
+ DEFAULT_GRADCAM_TARGET_LAYER,
11
+ DEVICE,
12
+ GRADCAM_TARGET_LAYER_OPTIONS,
13
+ batch_predict_with_xdl_stream,
14
+ )
15
+
16
+ UI_DEFAULTS_PATH = Path(__file__).resolve().parent.parent / "config" / "ui_defaults.json"
17
+ UI_DEFAULTS_FALLBACK = {
18
+ "selected_case": DEFAULT_CASE_NAME,
19
+ "confidence_threshold": 0.60,
20
+ "smoothgrad_samples": 50,
21
+ "smoothgrad_noise": 0.05,
22
+ "gradcam_target_layer": DEFAULT_GRADCAM_TARGET_LAYER,
23
+ "save_xdl_results": False,
24
+ "save_xdl_dir": "xdl_results",
25
+ }
26
+
27
+ GRADCAM_TARGET_LAYER_DROPDOWN_CHOICES = [
28
+ ("DenseBlock 3 (Default, balanced)", "denseblock3"),
29
+ ("Transition 2 (Broad, stable)", "transition2"),
30
+ ("Transition 1 (Earlier, detailed/noisier)", "transition1"),
31
+ ("DenseBlock 4 (Late, center-heavy)", "denseblock4"),
32
+ ("Transition 3 (Late, center-heavy)", "transition3"),
33
+ ("Norm5 Last (Legacy behavior)", "norm5_last"),
34
+ ]
35
+
36
+ CUSTOM_CSS = """
37
+ .app-shell {
38
+ max-width: 1120px;
39
+ margin: 0 auto;
40
+ }
41
+ .hero {
42
+ border: 1px solid #d1d5db;
43
+ background: linear-gradient(135deg, #f0fdfa 0%, #ecfeff 45%, #f8fafc 100%);
44
+ border-radius: 14px;
45
+ padding: 16px 18px;
46
+ margin-bottom: 12px;
47
+ }
48
+ .hero h1 {
49
+ margin: 0;
50
+ font-size: 24px;
51
+ color: #0f172a;
52
+ }
53
+ .hero p {
54
+ margin: 6px 0 0 0;
55
+ color: #334155;
56
+ font-size: 14px;
57
+ }
58
+ .panel {
59
+ border: 1px solid #e2e8f0;
60
+ border-radius: 12px;
61
+ background: #ffffff;
62
+ padding: 12px;
63
+ }
64
+ """
65
+
66
+
67
+ def _as_float(value: Any, fallback: float) -> float:
68
+ try:
69
+ return float(value)
70
+ except (TypeError, ValueError):
71
+ return float(fallback)
72
+
73
+
74
+ def _as_int(value: Any, fallback: int) -> int:
75
+ try:
76
+ return int(value)
77
+ except (TypeError, ValueError):
78
+ return int(fallback)
79
+
80
+
81
+ def _as_bool(value: Any, fallback: bool) -> bool:
82
+ if isinstance(value, bool):
83
+ return value
84
+ if isinstance(value, str):
85
+ return value.strip().lower() in {"1", "true", "yes", "y", "on"}
86
+ if value is None:
87
+ return fallback
88
+ return bool(value)
89
+
90
+
91
+ def _load_ui_defaults() -> dict[str, Any]:
92
+ defaults = dict(UI_DEFAULTS_FALLBACK)
93
+
94
+ try:
95
+ raw_text = UI_DEFAULTS_PATH.read_text(encoding="utf-8")
96
+ raw = json.loads(raw_text)
97
+ if isinstance(raw, dict):
98
+ for key in defaults:
99
+ if key in raw:
100
+ defaults[key] = raw[key]
101
+ except Exception:
102
+ pass
103
+
104
+ selected_case = str(defaults.get("selected_case", DEFAULT_CASE_NAME))
105
+ defaults["selected_case"] = selected_case if selected_case in CASE_OPTIONS else DEFAULT_CASE_NAME
106
+ defaults["confidence_threshold"] = min(
107
+ 1.0,
108
+ max(0.0, _as_float(defaults.get("confidence_threshold"), UI_DEFAULTS_FALLBACK["confidence_threshold"])),
109
+ )
110
+ defaults["smoothgrad_samples"] = max(
111
+ 1,
112
+ _as_int(defaults.get("smoothgrad_samples"), UI_DEFAULTS_FALLBACK["smoothgrad_samples"]),
113
+ )
114
+ defaults["smoothgrad_noise"] = min(
115
+ 1.0,
116
+ max(0.0, _as_float(defaults.get("smoothgrad_noise"), UI_DEFAULTS_FALLBACK["smoothgrad_noise"])),
117
+ )
118
+ gradcam_target_layer = str(defaults.get("gradcam_target_layer", DEFAULT_GRADCAM_TARGET_LAYER)).strip().lower()
119
+ defaults["gradcam_target_layer"] = (
120
+ gradcam_target_layer if gradcam_target_layer in GRADCAM_TARGET_LAYER_OPTIONS else DEFAULT_GRADCAM_TARGET_LAYER
121
+ )
122
+ defaults["save_xdl_results"] = _as_bool(defaults.get("save_xdl_results"), UI_DEFAULTS_FALLBACK["save_xdl_results"])
123
+ defaults["save_xdl_dir"] = str(defaults.get("save_xdl_dir") or UI_DEFAULTS_FALLBACK["save_xdl_dir"])
124
+ return defaults
125
+
126
+
127
+ def _toggle_save_dir(enabled: bool):
128
+ is_enabled = bool(enabled)
129
+ return gr.update(visible=is_enabled, interactive=is_enabled)
130
 
131
 
132
  def build_demo() -> gr.Blocks:
133
+ ui_defaults = _load_ui_defaults()
 
 
 
 
 
134
 
135
+ with gr.Blocks(title="XDL Colitis Demo") as demo:
136
+ gr.HTML(
137
+ f"""
138
+ <style>{CUSTOM_CSS}</style>
139
+ <div class="app-shell">
140
+ <div class="hero">
141
+ <h1>XDL Colitis Workbench</h1>
142
+ <p>Detected device: <b>{DEVICE.type}</b>. Upload a directory or enter a local folder path, then run batch inference.</p>
143
+ </div>
144
+ </div>
145
+ """
146
  )
147
 
148
+ with gr.Row(elem_classes=["app-shell"]):
149
+ with gr.Column(scale=2, elem_classes=["panel"]):
150
+ gr.Markdown("### 1) Image Input")
151
+ selected_case = gr.Dropdown(
152
+ choices=CASE_OPTIONS,
153
+ value=ui_defaults["selected_case"],
154
+ label="Problem Case",
155
+ info="Choose the model group that matches your diagnosis scenario.",
156
+ )
157
+ upload_input = gr.File(
158
+ file_count="directory",
159
+ file_types=["image"],
160
+ type="filepath",
161
+ label="Upload Image Folder",
162
+ )
163
+ folder_path = gr.Textbox(
164
+ label="Local Folder Path (Optional)",
165
+ placeholder="/absolute/path/to/folder/with/images",
166
+ )
167
+
168
+ with gr.Column(scale=1, elem_classes=["panel"]):
169
+ gr.Markdown("### 2) Inference Settings")
170
+ threshold = gr.Number(
171
+ value=ui_defaults["confidence_threshold"],
172
+ minimum=0.0,
173
+ maximum=1.0,
174
+ step=0.01,
175
+ precision=2,
176
+ label="Confidence Threshold",
177
+ info="Range: 0.00 to 1.00",
178
+ )
179
+ smoothgrad_samples = gr.Number(
180
+ value=ui_defaults["smoothgrad_samples"],
181
+ minimum=1,
182
+ maximum=1000,
183
+ step=1,
184
+ precision=0,
185
+ label="SmoothGrad Samples",
186
+ info="Higher values improve stability but increase runtime.",
187
+ )
188
+ smoothgrad_noise = gr.Number(
189
+ value=ui_defaults["smoothgrad_noise"],
190
+ minimum=0.0,
191
+ maximum=1.0,
192
+ step=0.01,
193
+ precision=2,
194
+ label="SmoothGrad Noise Level",
195
+ info="Typical range: 0.01 to 0.20",
196
+ )
197
+ gradcam_target_layer = gr.Dropdown(
198
+ choices=GRADCAM_TARGET_LAYER_DROPDOWN_CHOICES,
199
+ value=ui_defaults["gradcam_target_layer"],
200
+ label="GradCAM Target Layer",
201
+ info="Try `transition2` or `denseblock3` if CAM looks too centered.",
202
+ )
203
+ save_xdl_results = gr.Checkbox(
204
+ label="Save XDL Results Locally",
205
+ value=ui_defaults["save_xdl_results"],
206
+ )
207
+ save_xdl_dir = gr.Textbox(
208
+ label="Save Folder",
209
+ value=ui_defaults["save_xdl_dir"],
210
+ placeholder="xdl_results",
211
+ visible=bool(ui_defaults["save_xdl_results"]),
212
+ interactive=bool(ui_defaults["save_xdl_results"]),
213
+ )
214
+ run_btn = gr.Button("Run Batch Inference", variant="primary")
215
+
216
+ with gr.Row(elem_classes=["app-shell"]):
217
+ with gr.Column(elem_classes=["panel"]):
218
+ gr.Markdown("### 3) Results")
219
+ summary_out = gr.HTML(label="Summary")
220
+ table_out = gr.Dataframe(
221
+ headers=["filename", "status", "predicted_label", "confidence_or_error"],
222
+ datatype=["str", "str", "str", "str"],
223
+ interactive=False,
224
+ label="Per-image Results",
225
+ )
226
+ gallery_out = gr.Gallery(
227
+ label="Compact XDL Results (Original | GradCAM | SmoothGrad)",
228
+ columns=2,
229
+ )
230
+
231
+ save_xdl_results.change(
232
+ fn=_toggle_save_dir,
233
+ inputs=[save_xdl_results],
234
+ outputs=[save_xdl_dir],
235
  )
236
 
237
  run_btn.click(
 
245
  smoothgrad_noise,
246
  save_xdl_results,
247
  save_xdl_dir,
248
+ gradcam_target_layer,
249
  ],
250
  outputs=[summary_out, table_out, gallery_out],
251
  )
src/xdl.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Tuple
 
2
 
3
  import cv2
4
  import numpy as np
@@ -47,13 +48,50 @@ def _process_smoothgrad_map(
47
  return smoothgrad_map, heatmap
48
 
49
 
50
- def _get_target_layer(model: nn.Module):
51
- """Return DenseNet feature layer used for GradCAM."""
 
 
 
 
 
 
 
 
 
 
 
 
52
  if not isinstance(model, DenseNet121):
53
  raise TypeError(
54
  f"Unsupported model type for this demo: {type(model).__name__}. Expected DenseNet121."
55
  )
56
- return model.densenet_model.features[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def smoothgrad(
@@ -62,27 +100,61 @@ def smoothgrad(
62
  target_class: int,
63
  n_samples: int = 100,
64
  noise_level: float = 0.05,
 
 
65
  ) -> np.ndarray:
66
- """Compute SmoothGrad saliency map for one input tensor of shape (1, C, H, W)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  model.eval()
68
  accumulated_gradients = torch.zeros_like(input_tensor)
69
 
70
  input_range = torch.max(input_tensor) - torch.min(input_tensor)
71
  scaled_noise = noise_level * input_range
72
 
73
- for _ in range(n_samples):
74
- noise = torch.randn_like(input_tensor) * scaled_noise
75
- noisy_input = (input_tensor + noise).clone().detach().requires_grad_(True)
76
-
77
- output = model(noisy_input)
 
 
 
 
 
 
 
78
  if output.ndim == 1:
79
  output = output.unsqueeze(0)
80
 
81
- model.zero_grad()
82
- output[0, target_class].backward()
83
 
84
  if noisy_input.grad is not None:
85
- accumulated_gradients += noisy_input.grad.data
86
 
87
  smooth = accumulated_gradients / n_samples
88
  smooth = torch.abs(smooth)
 
1
+ import os
2
+ from typing import Optional, Tuple
3
 
4
  import cv2
5
  import numpy as np
 
48
  return smoothgrad_map, heatmap
49
 
50
 
51
+ def _get_target_layer(model: nn.Module, layer_name: Optional[str] = None):
52
+ """
53
+ Return DenseNet feature layer used for GradCAM.
54
+
55
+ Default layer is `denseblock3` to reduce center-biased CAMs compared with
56
+ the original `norm5_last`.
57
+
58
+ You can override with:
59
+ - function arg: `layer_name`
60
+ - env var: `XDL_TARGET_LAYER`
61
+
62
+ Supported layer names:
63
+ `denseblock3`, `transition2`, `transition1`, `denseblock4`, `transition3`, `norm5_last`.
64
+ """
65
  if not isinstance(model, DenseNet121):
66
  raise TypeError(
67
  f"Unsupported model type for this demo: {type(model).__name__}. Expected DenseNet121."
68
  )
69
+
70
+ requested = (layer_name or os.getenv("XDL_TARGET_LAYER") or "denseblock3").strip().lower()
71
+ aliases = {
72
+ "default": "denseblock3",
73
+ "last": "norm5_last",
74
+ "norm5": "norm5_last",
75
+ }
76
+ selected = aliases.get(requested, requested)
77
+
78
+ features = model.densenet_model.features
79
+ layer_map = {
80
+ "denseblock3": features.denseblock3,
81
+ "transition2": features.transition2,
82
+ "transition1": features.transition1,
83
+ "denseblock4": features.denseblock4,
84
+ "transition3": features.transition3,
85
+ "norm5_last": features[-1],
86
+ }
87
+
88
+ if selected not in layer_map:
89
+ supported = ", ".join(layer_map.keys())
90
+ raise ValueError(
91
+ f"Unsupported XDL target layer '{selected}'. Supported layers: {supported}"
92
+ )
93
+
94
+ return layer_map[selected]
95
 
96
 
97
  def smoothgrad(
 
100
  target_class: int,
101
  n_samples: int = 100,
102
  noise_level: float = 0.05,
103
+ batch_size: Optional[int] = None,
104
+ use_amp: bool = False,
105
  ) -> np.ndarray:
106
+ """
107
+ Compute SmoothGrad saliency map for one input tensor of shape (1, C, H, W).
108
+
109
+ Notes:
110
+ - This implementation batches noisy samples to reduce per-step overhead.
111
+ - If `input_tensor` and `model` are on CUDA, computation runs on GPU.
112
+ """
113
+ if n_samples <= 0:
114
+ raise ValueError(f"n_samples must be > 0, got {n_samples}")
115
+
116
+ if input_tensor.ndim != 4 or input_tensor.shape[0] != 1:
117
+ raise ValueError(
118
+ f"input_tensor must have shape (1, C, H, W), got {tuple(input_tensor.shape)}"
119
+ )
120
+
121
+ if batch_size is None:
122
+ batch_size = min(16, n_samples)
123
+ if batch_size <= 0:
124
+ raise ValueError(f"batch_size must be > 0, got {batch_size}")
125
+
126
+ model_param = next(model.parameters(), None)
127
+ if model_param is not None and model_param.device != input_tensor.device:
128
+ raise ValueError(
129
+ f"Model device ({model_param.device}) and input device ({input_tensor.device}) must match."
130
+ )
131
+
132
  model.eval()
133
  accumulated_gradients = torch.zeros_like(input_tensor)
134
 
135
  input_range = torch.max(input_tensor) - torch.min(input_tensor)
136
  scaled_noise = noise_level * input_range
137
 
138
+ use_cuda_amp = use_amp and input_tensor.device.type == "cuda"
139
+ for start in range(0, n_samples, batch_size):
140
+ current_batch = min(batch_size, n_samples - start)
141
+ expanded_input = input_tensor.expand(current_batch, -1, -1, -1)
142
+ noise = torch.randn_like(expanded_input) * scaled_noise
143
+ noisy_input = (expanded_input + noise).clone().detach().requires_grad_(True)
144
+
145
+ if use_cuda_amp:
146
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
147
+ output = model(noisy_input)
148
+ else:
149
+ output = model(noisy_input)
150
  if output.ndim == 1:
151
  output = output.unsqueeze(0)
152
 
153
+ model.zero_grad(set_to_none=True)
154
+ output[:, target_class].sum().backward()
155
 
156
  if noisy_input.grad is not None:
157
+ accumulated_gradients += noisy_input.grad.data.sum(dim=0, keepdim=True)
158
 
159
  smooth = accumulated_gradients / n_samples
160
  smooth = torch.abs(smooth)