Arviano commited on
Commit
b23df0d
·
1 Parent(s): 216b040

Split app into launcher, UI, and inference modules

Browse files
Files changed (3) hide show
  1. app.py +2 -64
  2. src/inference.py +258 -0
  3. src/ui.py +44 -0
app.py CHANGED
@@ -1,69 +1,7 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn.functional as F
4
- from torchvision import transforms
5
- from PIL import Image
6
 
7
- # ---------------------------------------------------------
8
- # 1. CONFIGURATION (Edit these!)
9
- # ---------------------------------------------------------
10
- MODEL_PATH = "models/multi_smoothing1.2_reducelr_epoch116.pth" # Your model file name
11
- # IMPORTANT: These must match the order of folders/classes used during training!
12
- LABELS = ["uc", "infeksi", "crohn", "tb"] # Example for your Colitis demo
13
 
14
- # ---------------------------------------------------------
15
- # 2. LOAD MODEL
16
- # ---------------------------------------------------------
17
- # We load to CPU since Hugging Face basic tier is CPU-only
18
- from src.densenet import DenseNet121
19
- model = DenseNet121(num_classes=len(LABELS))
20
-
21
- try:
22
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
23
- model.eval() # Set to evaluation mode
24
- except Exception as e:
25
- print(f"Error loading model: {e}")
26
- print("Make sure your model file is uploaded and matches MODEL_PATH.")
27
- # Create a dummy model just so the app doesn't crash immediately for testing
28
-
29
- # ---------------------------------------------------------
30
- # 3. PREPROCESSING
31
- # ---------------------------------------------------------
32
- # Standard preprocessing. Adjust "Resize" if your model expects 299x299 (Inception) or others.
33
- from src.preprocessing import preprocess
34
-
35
- target_input_size = tuple([3, 299, 299])
36
-
37
- val_transform = preprocess(target_input_size=target_input_size)
38
-
39
- # ---------------------------------------------------------
40
- # 4. PREDICTION FUNCTION
41
- # ---------------------------------------------------------
42
- def predict(image):
43
- if model is None:
44
- return {"Error": 1.0}
45
-
46
- input_tensor = val_transform(image).unsqueeze(0)
47
-
48
- with torch.no_grad():
49
- output = model(input_tensor)
50
-
51
- probabilities = F.softmax(output[0], dim=0)
52
-
53
- # 4. Map to labels
54
- # Returns a dictionary: {"Class A": 0.9, "Class B": 0.1}
55
- return {LABELS[i]: float(probabilities[i]) for i in range(len(LABELS))}
56
-
57
- # ---------------------------------------------------------
58
- # 5. GRADIO INTERFACE
59
- # ---------------------------------------------------------
60
- demo = gr.Interface(
61
- fn=predict,
62
- inputs=gr.Image(type="pil"),
63
- outputs=gr.Label(num_top_classes=len(LABELS)),
64
- title="Medical Image Classification Demo",
65
- description="Upload a scan to classify it."
66
- )
67
 
68
  if __name__ == "__main__":
69
  demo.launch()
 
1
+ from src.ui import build_demo
 
 
 
 
2
 
 
 
 
 
 
 
3
 
4
+ demo = build_demo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  if __name__ == "__main__":
7
  demo.launch()
src/inference.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, defaultdict
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+
11
+ from src.densenet import DenseNet121
12
+ from src.preprocessing import preprocess
13
+
14
+ MODEL_PATH = "models/multi_smoothing1.2_reducelr_epoch116.pth"
15
+ LABELS = ["uc", "infeksi", "crohn", "tb"]
16
+ NUM_CLASSES = len(LABELS)
17
+ SUPPORTED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
18
+
19
+ model = DenseNet121(num_classes=NUM_CLASSES)
20
+ _model_load_error: Optional[str] = None
21
+
22
+ try:
23
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu")))
24
+ model.eval()
25
+ except Exception as exc:
26
+ _model_load_error = str(exc)
27
+
28
+ val_transform = preprocess(target_input_size=(3, 299, 299))
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class ClassifiedPrediction:
33
+ path: Path
34
+ pred_idx: int
35
+ confidence: float
36
+
37
+
38
+ def _model_error_message() -> Optional[str]:
39
+ if _model_load_error is None:
40
+ return None
41
+ return f"Model failed to load from `{MODEL_PATH}`: {_model_load_error}"
42
+
43
+
44
+ def _load_xdl_modules():
45
+ """Lazy-load optional XDL dependencies."""
46
+ try:
47
+ from pytorch_grad_cam import GradCAM
48
+ from pytorch_grad_cam.utils.image import show_cam_on_image
49
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
50
+ except Exception as exc:
51
+ raise RuntimeError(f"GradCAM import failed: {exc}")
52
+
53
+ try:
54
+ import cv2
55
+ except Exception as exc:
56
+ raise RuntimeError(f"OpenCV import failed: {exc}")
57
+
58
+ try:
59
+ from src.xdl import (
60
+ _get_target_layer,
61
+ _preprocess_image,
62
+ _process_smoothgrad_map,
63
+ smoothgrad,
64
+ )
65
+ except Exception as exc:
66
+ raise RuntimeError(f"Failed to load XDL utilities from src/xdl.py: {exc}")
67
+
68
+ return {
69
+ "cv2": cv2,
70
+ "GradCAM": GradCAM,
71
+ "show_cam_on_image": show_cam_on_image,
72
+ "ClassifierOutputTarget": ClassifierOutputTarget,
73
+ "_get_target_layer": _get_target_layer,
74
+ "_preprocess_image": _preprocess_image,
75
+ "_process_smoothgrad_map": _process_smoothgrad_map,
76
+ "smoothgrad": smoothgrad,
77
+ }
78
+
79
+
80
+ def _iter_image_paths(folder: Path) -> List[Path]:
81
+ return sorted(
82
+ p
83
+ for p in folder.rglob("*")
84
+ if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS
85
+ )
86
+
87
+
88
+ def _aggregate_classification(classified: List[ClassifiedPrediction]) -> Tuple[str, float]:
89
+ class_counter = Counter(item.pred_idx for item in classified)
90
+ top_count = max(class_counter.values())
91
+ tied_classes = [idx for idx, count in class_counter.items() if count == top_count]
92
+
93
+ if len(tied_classes) == 1:
94
+ final_idx = tied_classes[0]
95
+ else:
96
+ class_conf = defaultdict(list)
97
+ for item in classified:
98
+ class_conf[item.pred_idx].append(item.confidence)
99
+ final_idx = max(tied_classes, key=lambda idx: float(np.mean(class_conf[idx])))
100
+
101
+ final_conf = float(np.mean([item.confidence for item in classified if item.pred_idx == final_idx]))
102
+ return LABELS[final_idx], final_conf
103
+
104
+
105
+ def _build_visual_row(
106
+ original: np.ndarray,
107
+ gradcam_img: np.ndarray,
108
+ smoothgrad_img: np.ndarray,
109
+ ) -> np.ndarray:
110
+ return np.concatenate([original, gradcam_img, smoothgrad_img], axis=1)
111
+
112
+
113
+ def _predict_top1(image: Image.Image) -> Tuple[int, float, torch.Tensor]:
114
+ input_tensor = val_transform(image).unsqueeze(0)
115
+
116
+ with torch.no_grad():
117
+ logits = model(input_tensor)[0]
118
+ probs = F.softmax(logits, dim=0)
119
+
120
+ pred_idx = int(torch.argmax(probs).item())
121
+ confidence = float(probs[pred_idx].item())
122
+ return pred_idx, confidence, input_tensor
123
+
124
+
125
+ def predict_single(image: Image.Image) -> Dict[str, float]:
126
+ model_error = _model_error_message()
127
+ if model_error:
128
+ return {model_error: 1.0}
129
+
130
+ input_tensor = val_transform(image).unsqueeze(0)
131
+
132
+ with torch.no_grad():
133
+ output = model(input_tensor)
134
+ probabilities = F.softmax(output[0], dim=0)
135
+
136
+ return {LABELS[i]: float(probabilities[i]) for i in range(NUM_CLASSES)}
137
+
138
+
139
+ def batch_predict_with_xdl(
140
+ folder_path: str,
141
+ confidence_threshold: float,
142
+ smoothgrad_samples: int,
143
+ smoothgrad_noise: float,
144
+ ):
145
+ model_error = _model_error_message()
146
+ if model_error:
147
+ return model_error, [], []
148
+
149
+ if not folder_path:
150
+ return "Provide a folder path.", [], []
151
+
152
+ folder = Path(folder_path).expanduser().resolve()
153
+ if not folder.exists() or not folder.is_dir():
154
+ return f"Invalid folder: `{folder}`", [], []
155
+
156
+ threshold = float(np.clip(confidence_threshold, 0.0, 1.0))
157
+ smoothgrad_samples = int(max(1, smoothgrad_samples))
158
+ smoothgrad_noise = float(max(0.0, smoothgrad_noise))
159
+
160
+ image_paths = _iter_image_paths(folder)
161
+ if not image_paths:
162
+ return f"No supported image files found in `{folder}`.", [], []
163
+
164
+ xdl = None
165
+ xdl_error = ""
166
+ try:
167
+ xdl = _load_xdl_modules()
168
+ except RuntimeError as exc:
169
+ xdl_error = str(exc)
170
+
171
+ cam = None
172
+ if xdl:
173
+ target_layer = xdl["_get_target_layer"](model)
174
+ cam = xdl["GradCAM"](model=model, target_layers=[target_layer])
175
+
176
+ classified: List[ClassifiedPrediction] = []
177
+ rows: List[List[str]] = []
178
+ gallery_items = []
179
+
180
+ for img_path in image_paths:
181
+ try:
182
+ image = Image.open(img_path).convert("RGB")
183
+ except Exception as exc:
184
+ rows.append([img_path.name, "error", "-", str(exc)])
185
+ continue
186
+
187
+ pred_idx, confidence, input_tensor = _predict_top1(image)
188
+
189
+ if confidence < threshold:
190
+ rows.append([img_path.name, "below_threshold", LABELS[pred_idx], f"{confidence:.4f}"])
191
+ continue
192
+
193
+ prediction = ClassifiedPrediction(path=img_path, pred_idx=pred_idx, confidence=confidence)
194
+ classified.append(prediction)
195
+ rows.append([img_path.name, "classified", LABELS[pred_idx], f"{confidence:.4f}"])
196
+
197
+ if xdl and cam is not None:
198
+ try:
199
+ base_img_float, base_img_uint8 = xdl["_preprocess_image"](input_tensor[0])
200
+ h, w = base_img_uint8.shape[:2]
201
+
202
+ grayscale_cam = cam(
203
+ input_tensor=input_tensor,
204
+ targets=[xdl["ClassifierOutputTarget"](pred_idx)],
205
+ )[0, :]
206
+ gradcam_overlay = xdl["show_cam_on_image"](base_img_float, grayscale_cam, use_rgb=True)
207
+
208
+ smooth_raw = xdl["smoothgrad"](
209
+ model,
210
+ input_tensor,
211
+ pred_idx,
212
+ n_samples=smoothgrad_samples,
213
+ noise_level=smoothgrad_noise,
214
+ )
215
+ _, smooth_heatmap = xdl["_process_smoothgrad_map"](
216
+ smooth_raw,
217
+ img_shape=(h, w),
218
+ percentile=95,
219
+ colormap="hot",
220
+ )
221
+ smooth_heatmap_rgb = xdl["cv2"].cvtColor(smooth_heatmap, xdl["cv2"].COLOR_BGR2RGB)
222
+ smooth_overlay = xdl["cv2"].addWeighted(base_img_uint8, 0.6, smooth_heatmap_rgb, 0.4, 0)
223
+
224
+ panel = _build_visual_row(base_img_uint8, gradcam_overlay, smooth_overlay)
225
+ caption = (
226
+ f"{img_path.name} | pred={LABELS[pred_idx]} | conf={confidence:.4f} "
227
+ "| left=Original middle=GradCAM right=SmoothGrad"
228
+ )
229
+ gallery_items.append((panel, caption))
230
+ except Exception as exc:
231
+ rows.append([img_path.name, "xdl_error", LABELS[pred_idx], str(exc)])
232
+
233
+ if not classified:
234
+ summary = (
235
+ f"Processed {len(image_paths)} images. "
236
+ f"0 images met threshold `{threshold:.2f}`."
237
+ )
238
+ if xdl_error:
239
+ summary += f"\n\nXDL status: {xdl_error}"
240
+ return summary, rows, []
241
+
242
+ final_class, mean_conf = _aggregate_classification(classified)
243
+ class_counter = Counter(item.pred_idx for item in classified)
244
+ class_stats = ", ".join(f"{LABELS[idx]}: {count}" for idx, count in class_counter.items())
245
+
246
+ summary = (
247
+ f"Processed {len(image_paths)} images. "
248
+ f"Classified: {len(classified)} (threshold >= {threshold:.2f}), "
249
+ f"Skipped: {len(image_paths) - len(classified)}.\n"
250
+ f"Final class (from per-image max class vote): **{final_class}**\n"
251
+ f"Mean confidence for final class: **{mean_conf:.4f}**\n"
252
+ f"Classified image distribution: {class_stats}"
253
+ )
254
+
255
+ if xdl_error:
256
+ summary += f"\n\nXDL status: {xdl_error}"
257
+
258
+ return summary, rows, gallery_items
src/ui.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from src.inference import NUM_CLASSES, batch_predict_with_xdl, predict_single
4
+
5
+
6
+ def build_demo() -> gr.Blocks:
7
+ with gr.Blocks(title="Medical Image Classification + XDL") as demo:
8
+ gr.Markdown("## Medical Image Classification Demo")
9
+
10
+ with gr.Tab("Single Image"):
11
+ single_in = gr.Image(type="pil", label="Image")
12
+ single_out = gr.Label(num_top_classes=NUM_CLASSES, label="Prediction")
13
+ single_btn = gr.Button("Predict")
14
+ single_btn.click(fn=predict_single, inputs=single_in, outputs=single_out)
15
+
16
+ with gr.Tab("Folder Batch + XDL"):
17
+ folder_path = gr.Textbox(
18
+ label="Folder Path",
19
+ placeholder="/absolute/path/to/folder/with/images",
20
+ )
21
+ threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.01, label="Confidence Threshold")
22
+ smoothgrad_samples = gr.Slider(10, 200, value=50, step=10, label="SmoothGrad Samples")
23
+ smoothgrad_noise = gr.Slider(0.01, 0.2, value=0.05, step=0.01, label="SmoothGrad Noise Level")
24
+ run_btn = gr.Button("Run Batch Inference")
25
+
26
+ summary_out = gr.Markdown()
27
+ table_out = gr.Dataframe(
28
+ headers=["filename", "status", "predicted_label", "confidence_or_error"],
29
+ datatype=["str", "str", "str", "str"],
30
+ interactive=False,
31
+ label="Per-image Results",
32
+ )
33
+ gallery_out = gr.Gallery(
34
+ label="XDL Overlays (Original | GradCAM | SmoothGrad)",
35
+ columns=1,
36
+ )
37
+
38
+ run_btn.click(
39
+ fn=batch_predict_with_xdl,
40
+ inputs=[folder_path, threshold, smoothgrad_samples, smoothgrad_noise],
41
+ outputs=[summary_out, table_out, gallery_out],
42
+ )
43
+
44
+ return demo