VJBharathkumar commited on
Commit
f1c43c2
·
verified ·
1 Parent(s): 1dc8d47

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +167 -207
src/streamlit_app.py CHANGED
@@ -13,25 +13,23 @@ import matplotlib.pyplot as plt
13
 
14
  from fpdf import FPDF
15
 
16
-
17
- # ============================================================
18
  # Page config
19
- # ============================================================
20
  st.set_page_config(
21
- page_title="Pneumonia Detection (Chest X-ray) Clinical Decision Support",
22
- layout="centered",
23
  )
24
 
25
- st.title("Pneumonia Detection (Chest X-ray) Clinical Decision Support")
26
  st.caption(
27
  "Upload one or more Chest X-ray DICOM files (.dcm). Adjust the decision threshold and click Submit. "
28
  "This tool is for decision support only and does not replace clinical judgment."
29
  )
30
 
31
-
32
- # ============================================================
33
  # Paths / Model Loading
34
- # ============================================================
35
  REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
36
  MODEL_PATH = os.path.join(REPO_ROOT, "model.keras")
37
  VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json") # optional
@@ -44,240 +42,213 @@ def load_model():
44
  try:
45
  m = keras.models.load_model(MODEL_PATH)
46
  except Exception:
47
- # if Lambda layers / unsafe deserialization exists
48
  keras.config.enable_unsafe_deserialization()
49
  m = keras.models.load_model(MODEL_PATH, safe_mode=False)
50
  return m
51
 
 
 
 
 
 
 
 
52
  def get_model_version():
53
  if os.path.exists(VERSION_PATH):
54
  try:
55
  with open(VERSION_PATH, "r") as f:
56
- return json.load(f).get("version", "unknown")
57
  except Exception:
58
- return "unknown"
59
- return "v1"
60
 
61
  MODEL_VERSION = get_model_version()
62
- model = load_model()
63
-
64
- # model input details: (None, H, W, C)
65
- input_shape = model.input_shape
66
- img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
67
- exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1
68
 
69
-
70
- # ============================================================
71
- # Helpers
72
- # ============================================================
73
  def interpret_confidence(prob: float) -> str:
74
  if prob < 0.30:
75
  return "Low likelihood (<30%)"
76
  elif prob <= 0.60:
77
- return "Borderline suspicion (3060%)"
78
  else:
79
  return "High likelihood (>60%)"
80
 
81
- def read_dicom_bytes(file_bytes: bytes) -> np.ndarray:
82
- dcm = pydicom.dcmread(io.BytesIO(file_bytes))
 
 
 
83
  img = dcm.pixel_array.astype(np.float32)
84
 
85
- # Normalize to 0..1
86
  img_min = float(np.min(img))
87
  img_max = float(np.max(img))
88
- img = (img - img_min) / (img_max - img_min + 1e-8)
89
 
90
  return img
91
 
92
  def preprocess(img_2d: np.ndarray) -> np.ndarray:
93
- """
94
- (H,W) -> (1,img_size,img_size,C) float32 in 0..1
95
- """
96
  x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1)
97
  x = tf.image.resize(x, (img_size, img_size))
98
  x = tf.clip_by_value(x, 0.0, 1.0)
99
  x = x.numpy() # (img_size,img_size,1)
100
 
101
  if exp_ch == 3 and x.shape[-1] == 1:
102
- x = np.repeat(x, 3, axis=-1)
103
  elif exp_ch == 1 and x.shape[-1] == 3:
104
- x = x[..., :1]
105
 
106
- x = np.expand_dims(x, axis=0) # (1,img_size,img_size,C)
107
  return x.astype(np.float32)
108
 
109
  def predict_prob(x: np.ndarray) -> float:
110
- """
111
- Supports single-head and multi-head models.
112
- Uses last output as probability when outputs are list/tuple.
113
- """
114
  pred = model.predict(x, verbose=0)
115
  if isinstance(pred, (list, tuple)):
116
  prob = float(np.ravel(pred[-1])[0])
117
  else:
118
  prob = float(np.ravel(pred)[0])
119
-
120
  return max(0.0, min(1.0, prob))
121
 
122
-
123
- # ============================================================
124
- # Grad-CAM (robust layer selection)
125
- # ============================================================
126
- def _find_last_conv2d_layer_name(m: keras.Model) -> str:
127
- # Prefer backbone conv layers if a common backbone layer exists
128
- backbone_names = ["resnet50", "ResNet50", "backbone"]
129
- for nm in backbone_names:
130
- try:
131
- bb = m.get_layer(nm)
132
- # walk backwards in backbone
133
- for layer in reversed(bb.layers):
134
- if isinstance(layer, tf.keras.layers.Conv2D):
135
- return f"{nm}/{layer.name}"
136
- except Exception:
137
- pass
138
-
139
- # Fallback: scan the whole model
140
- for layer in reversed(m.layers):
141
- if isinstance(layer, tf.keras.layers.Conv2D):
142
- return layer.name
143
-
144
- raise ValueError("No Conv2D layer found for Grad-CAM.")
145
-
146
  @st.cache_resource
147
- def get_gradcam_model_and_layername():
148
- last_conv_name = _find_last_conv2d_layer_name(model)
 
 
149
 
150
- # If name includes "backbone/layer", resolve it
151
- if "/" in last_conv_name:
152
- parent, child = last_conv_name.split("/", 1)
153
- conv_layer = model.get_layer(parent).get_layer(child)
154
- else:
155
- conv_layer = model.get_layer(last_conv_name)
 
 
156
 
157
- grad_model = keras.Model(
158
- inputs=model.inputs,
159
- outputs=[conv_layer.output, model.output],
160
- )
161
- return grad_model, last_conv_name
162
 
163
- def make_gradcam_heatmap(img_array: np.ndarray) -> np.ndarray:
164
- """
165
- img_array: (1,H,W,C)
166
- returns heatmap: (Hc, Wc) normalized 0..1
167
- """
168
- grad_model, _ = get_gradcam_model_and_layername()
169
-
170
- with tf.GradientTape() as tape:
171
- conv_outputs, preds = grad_model(img_array)
172
 
173
- # multi-head -> take last output for probability
174
- if isinstance(preds, (list, tuple)):
175
- preds = preds[-1]
176
 
177
- # preds shape could be (1,1) or (1,)
178
- loss = preds[:, 0] if preds.ndim == 2 else preds
179
 
180
- grads = tape.gradient(loss, conv_outputs)
 
 
181
 
182
- # safety in case grads is None
183
  if grads is None:
184
- raise ValueError("Gradients are None. Grad-CAM cannot be computed for this model output.")
185
 
186
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) # (channels,)
187
- conv_outputs = conv_outputs[0] # (Hc,Wc,channels)
188
 
189
- heatmap = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1) # (Hc,Wc)
190
  heatmap = tf.maximum(heatmap, 0)
191
 
192
- denom = tf.reduce_max(heatmap)
193
- heatmap = heatmap / (denom + 1e-8)
194
-
195
  return heatmap.numpy()
196
 
197
- def overlay_heatmap_on_image(img_2d_resized: np.ndarray, heatmap: np.ndarray):
198
- """
199
- img_2d_resized: (img_size,img_size) in 0..1
200
- heatmap: (Hc,Wc) in 0..1
201
- """
202
  heat = tf.image.resize(heatmap[..., None], (img_size, img_size)).numpy().squeeze()
203
 
 
 
 
204
  fig = plt.figure(figsize=(5, 5))
205
- plt.imshow(img_2d_resized, cmap="gray")
206
  plt.imshow(heat, cmap="jet", alpha=0.35)
207
  plt.axis("off")
208
  plt.tight_layout()
209
  return fig
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- # ============================================================
213
- # PDF generator (stable)
214
- # ============================================================
215
- class PDF(FPDF):
216
- pass
217
 
218
- def build_pdf_report(df: pd.DataFrame, threshold: float, model_version: str) -> bytes:
219
- pdf = PDF()
220
  pdf.set_auto_page_break(auto=True, margin=12)
221
  pdf.add_page()
222
- pdf.set_margins(12, 12, 12)
223
-
224
- pdf.set_font("Arial", size=12)
225
- pdf.cell(0, 8, "Pneumonia Detection Report", ln=True)
226
-
227
- pdf.set_font("Arial", size=10)
228
- pdf.cell(0, 7, f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True)
229
- pdf.cell(0, 7, f"Model Version: {model_version}", ln=True)
230
- pdf.cell(0, 7, f"Decision Threshold: {threshold:.2f}", ln=True)
231
- pdf.ln(3)
232
-
233
- pdf.set_font("Arial", size=10)
234
-
235
- # Write each row safely
236
- for _, row in df.iterrows():
237
- pdf.set_x(pdf.l_margin)
238
- prob = row.get("probability", np.nan)
239
- prob_pct = "NA" if pd.isna(prob) else f"{float(prob)*100:.2f}%"
240
-
241
- lines = [
242
- f"File: {row.get('file_name','')}",
243
- f"Probability: {prob_pct}",
244
- f"Confidence Level: {row.get('confidence_level','')}",
245
- f"Prediction: {row.get('prediction','')}",
246
- f"Timestamp: {row.get('timestamp','')}",
247
- ]
248
-
249
- for line in lines:
250
- # Use multi_cell for wrapping and reset x each time
251
- pdf.set_x(pdf.l_margin)
252
- pdf.multi_cell(0, 6, line)
253
 
254
- pdf.ln(2)
 
 
 
 
 
255
 
256
- return pdf.output(dest="S").encode("latin-1")
 
 
 
 
 
257
 
 
 
258
 
259
- # ============================================================
 
 
 
 
 
 
 
 
260
  # UI
261
- # ============================================================
262
  st.subheader("Model Parameters")
263
 
264
  threshold = st.slider(
265
  "Decision Threshold",
266
  min_value=0.01,
267
  max_value=0.99,
268
- value=0.37, # ResNet default (your best thr)
269
  step=0.01,
270
- help="If predicted probability threshold Pneumonia, else Not Pneumonia.",
271
  )
272
 
273
  show_gradcam = st.checkbox("Show Grad-CAM heatmap (explainability)", value=True)
274
 
275
  st.subheader("Upload Chest X-ray DICOM Files")
276
-
277
  uploaded_files = st.file_uploader(
278
  "Select one or multiple DICOM files (.dcm)",
279
  type=["dcm"],
280
- accept_multiple_files=True,
281
  )
282
 
283
  col1, col2 = st.columns(2)
@@ -291,54 +262,44 @@ if clear:
291
 
292
  st.subheader("Prediction Results")
293
 
294
-
295
- # ============================================================
296
- # Inference
297
- # ============================================================
298
  if submit:
299
  if not uploaded_files:
300
  st.warning("Please upload at least one DICOM file before submitting.")
301
  else:
302
- rows = []
303
- file_cache = [] # (filename, bytes, img_2d_norm)
304
 
 
305
  with st.spinner("Running inference..."):
306
- for f in uploaded_files:
307
  ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
308
  try:
309
- b = f.getvalue()
310
- img = read_dicom_bytes(b) # (H,W) 0..1
311
  x = preprocess(img)
312
  prob = predict_prob(x)
313
 
314
  pred_label = "Pneumonia" if prob >= threshold else "Not Pneumonia"
315
  conf_level = interpret_confidence(prob)
316
 
317
- rows.append(
318
- {
319
- "timestamp": ts,
320
- "model_version": MODEL_VERSION,
321
- "file_name": f.name,
322
- "probability": prob,
323
- "confidence_level": conf_level,
324
- "prediction": pred_label,
325
- }
326
- )
327
-
328
- file_cache.append((f.name, b, img))
329
-
330
  except Exception as e:
331
- rows.append(
332
- {
333
- "timestamp": ts,
334
- "model_version": MODEL_VERSION,
335
- "file_name": f.name,
336
- "probability": np.nan,
337
- "confidence_level": "NA",
338
- "prediction": "Error",
339
- "error": str(e),
340
- }
341
- )
342
 
343
  df = pd.DataFrame(rows)
344
 
@@ -347,34 +308,34 @@ if submit:
347
  if r["prediction"] == "Error":
348
  st.error(
349
  f"For the uploaded file '{r['file_name']}', the system could not generate a prediction. "
350
- f"Reason: {r.get('error','Unknown error')}."
351
- )
352
- else:
353
- prob_pct = float(r["probability"]) * 100.0
354
- st.write(
355
- f"For the uploaded file '{r['file_name']}', the model estimates a pneumonia probability of "
356
- f"{prob_pct:.2f}%. This falls under '{r['confidence_level']}'. "
357
- f"Based on the selected decision threshold of {threshold:.2f}, the predicted outcome is "
358
- f"'{r['prediction']}'."
359
  )
 
 
 
 
 
 
 
 
 
360
 
361
  # Grad-CAM section
362
  if show_gradcam:
363
  st.markdown("### Grad-CAM Heatmaps")
364
-
365
- # Resize original image for display
366
- for (fname, _, img_2d) in file_cache:
367
  try:
368
- img_resized = tf.image.resize(img_2d[..., None], (img_size, img_size)).numpy().squeeze()
369
- x = preprocess(img_2d)
370
  heatmap = make_gradcam_heatmap(x)
371
- fig = overlay_heatmap_on_image(img_resized, heatmap)
372
- st.write(f"Heatmap for: {fname}")
373
  st.pyplot(fig)
374
  except Exception as e:
375
- # This will now show the *actual* layer list + reason,
376
- # instead of failing with a wrong hard-coded layer name.
377
- st.warning(f"Could not generate Grad-CAM for {fname}. Reason: {e}")
 
378
 
379
  # Downloads
380
  st.markdown("### Downloads")
@@ -385,7 +346,7 @@ if submit:
385
  data=csv_bytes,
386
  file_name="predictions.csv",
387
  mime="text/csv",
388
- use_container_width=True,
389
  )
390
 
391
  df_ok = df[df["prediction"] != "Error"].copy()
@@ -396,11 +357,10 @@ if submit:
396
  data=pdf_bytes,
397
  file_name="pneumonia_report.pdf",
398
  mime="application/pdf",
399
- use_container_width=True,
400
  )
401
  else:
402
- st.info("PDF report is available once at least one file is successfully processed.")
403
-
404
 
405
  st.divider()
406
  st.caption(
 
13
 
14
  from fpdf import FPDF
15
 
16
+ # -----------------------------
 
17
  # Page config
18
+ # -----------------------------
19
  st.set_page_config(
20
+ page_title="Pneumonia Detection (Chest X-ray) - Clinical Decision Support",
21
+ layout="centered"
22
  )
23
 
24
+ st.title("Pneumonia Detection (Chest X-ray) - Clinical Decision Support")
25
  st.caption(
26
  "Upload one or more Chest X-ray DICOM files (.dcm). Adjust the decision threshold and click Submit. "
27
  "This tool is for decision support only and does not replace clinical judgment."
28
  )
29
 
30
+ # -----------------------------
 
31
  # Paths / Model Loading
32
+ # -----------------------------
33
  REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
34
  MODEL_PATH = os.path.join(REPO_ROOT, "model.keras")
35
  VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json") # optional
 
42
  try:
43
  m = keras.models.load_model(MODEL_PATH)
44
  except Exception:
45
+ # If you trained it, it's safe to allow deserialization
46
  keras.config.enable_unsafe_deserialization()
47
  m = keras.models.load_model(MODEL_PATH, safe_mode=False)
48
  return m
49
 
50
+ model = load_model()
51
+
52
+ # model input details
53
+ input_shape = model.input_shape # (None, H, W, C)
54
+ img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
55
+ exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1
56
+
57
  def get_model_version():
58
  if os.path.exists(VERSION_PATH):
59
  try:
60
  with open(VERSION_PATH, "r") as f:
61
+ return json.load(f).get("version", "ResNet50_v1")
62
  except Exception:
63
+ return "ResNet50_v1"
64
+ return "ResNet50_v1"
65
 
66
  MODEL_VERSION = get_model_version()
 
 
 
 
 
 
67
 
68
+ # -----------------------------
69
+ # Confidence interpretation
70
+ # -----------------------------
 
71
  def interpret_confidence(prob: float) -> str:
72
  if prob < 0.30:
73
  return "Low likelihood (<30%)"
74
  elif prob <= 0.60:
75
+ return "Borderline suspicion (30-60%)"
76
  else:
77
  return "High likelihood (>60%)"
78
 
79
+ # -----------------------------
80
+ # DICOM + preprocessing
81
+ # -----------------------------
82
+ def dicom_bytes_to_img(data: bytes) -> np.ndarray:
83
+ dcm = pydicom.dcmread(io.BytesIO(data))
84
  img = dcm.pixel_array.astype(np.float32)
85
 
 
86
  img_min = float(np.min(img))
87
  img_max = float(np.max(img))
88
+ img = (img - img_min) / (img_max - img_min + 1e-8) # 0..1
89
 
90
  return img
91
 
92
  def preprocess(img_2d: np.ndarray) -> np.ndarray:
93
+ # (H,W) -> (1,img_size,img_size,C) float32 0..1
 
 
94
  x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1)
95
  x = tf.image.resize(x, (img_size, img_size))
96
  x = tf.clip_by_value(x, 0.0, 1.0)
97
  x = x.numpy() # (img_size,img_size,1)
98
 
99
  if exp_ch == 3 and x.shape[-1] == 1:
100
+ x = np.repeat(x, 3, axis=-1) # (img_size,img_size,3)
101
  elif exp_ch == 1 and x.shape[-1] == 3:
102
+ x = x[..., :1] # (img_size,img_size,1)
103
 
104
+ x = np.expand_dims(x, axis=0) # (1,img_size,img_size,C)
105
  return x.astype(np.float32)
106
 
107
  def predict_prob(x: np.ndarray) -> float:
 
 
 
 
108
  pred = model.predict(x, verbose=0)
109
  if isinstance(pred, (list, tuple)):
110
  prob = float(np.ravel(pred[-1])[0])
111
  else:
112
  prob = float(np.ravel(pred)[0])
 
113
  return max(0.0, min(1.0, prob))
114
 
115
+ # -----------------------------
116
+ # Grad-CAM (robust for nested "resnet50" submodel)
117
+ # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  @st.cache_resource
119
+ def build_gradcam_tools():
120
+ # If your backbone layer name is different, change here:
121
+ backbone_name = "resnet50"
122
+ backbone = model.get_layer(backbone_name)
123
 
124
+ # pick the last Conv2D inside the backbone
125
+ last_conv = None
126
+ for lyr in reversed(backbone.layers):
127
+ if isinstance(lyr, tf.keras.layers.Conv2D):
128
+ last_conv = lyr.name
129
+ break
130
+ if last_conv is None:
131
+ raise ValueError("Could not find a Conv2D layer inside the ResNet backbone.")
132
 
133
+ # outputs: last conv feature map + probability head
134
+ conv_out = backbone.get_layer(last_conv).output
135
+ prob_out = model.outputs[-1] if isinstance(model.outputs, (list, tuple)) else model.output
 
 
136
 
137
+ grad_model = keras.Model(inputs=model.inputs, outputs=[conv_out, prob_out])
138
+ return grad_model, last_conv
 
 
 
 
 
 
 
139
 
140
+ def make_gradcam_heatmap(img_batch: np.ndarray) -> np.ndarray:
141
+ grad_model, _ = build_gradcam_tools()
 
142
 
143
+ x = tf.convert_to_tensor(img_batch, dtype=tf.float32)
 
144
 
145
+ with tf.GradientTape() as tape:
146
+ conv_out, preds = grad_model(x, training=False)
147
+ loss = preds[:, 0] # binary sigmoid prob
148
 
149
+ grads = tape.gradient(loss, conv_out)
150
  if grads is None:
151
+ raise ValueError("Gradients are None. Grad-CAM cannot be computed for this model setup.")
152
 
153
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
154
+ conv_out = conv_out[0] # (H,W,channels)
155
 
156
+ heatmap = tf.reduce_sum(conv_out * pooled_grads, axis=-1)
157
  heatmap = tf.maximum(heatmap, 0)
158
 
159
+ denom = tf.reduce_max(heatmap) + 1e-8
160
+ heatmap = heatmap / denom
 
161
  return heatmap.numpy()
162
 
163
+ def overlay_heatmap(img_2d: np.ndarray, heatmap: np.ndarray):
164
+ # resize heatmap to img_size
 
 
 
165
  heat = tf.image.resize(heatmap[..., None], (img_size, img_size)).numpy().squeeze()
166
 
167
+ # ensure base image displayed at img_size
168
+ base = tf.image.resize(img_2d[..., None], (img_size, img_size)).numpy().squeeze()
169
+
170
  fig = plt.figure(figsize=(5, 5))
171
+ plt.imshow(base, cmap="gray")
172
  plt.imshow(heat, cmap="jet", alpha=0.35)
173
  plt.axis("off")
174
  plt.tight_layout()
175
  return fig
176
 
177
+ # -----------------------------
178
+ # PDF generation (fix unicode issues)
179
+ # -----------------------------
180
+ def safe_text(s: str, max_len: int = 180) -> str:
181
+ """
182
+ Convert to something FPDF core fonts can render.
183
+ Also trims very long strings to avoid layout failures.
184
+ """
185
+ if s is None:
186
+ return ""
187
+ s = str(s)
188
+
189
+ # Replace common unicode dashes/quotes with ascii
190
+ s = s.replace("–", "-").replace("—", "-").replace("’", "'").replace("“", '"').replace("”", '"')
191
+
192
+ # Remove / replace any remaining unsupported chars (latin-1 fallback)
193
+ s = s.encode("latin-1", "replace").decode("latin-1")
194
 
195
+ # Avoid extremely long unbroken lines
196
+ if len(s) > max_len:
197
+ s = s[:max_len] + "..."
198
+ return s
 
199
 
200
+ def build_pdf_report(df_ok: pd.DataFrame, threshold: float, model_version: str) -> bytes:
201
+ pdf = FPDF()
202
  pdf.set_auto_page_break(auto=True, margin=12)
203
  pdf.add_page()
204
+ pdf.set_font("Helvetica", size=12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ pdf.cell(0, 8, safe_text("Pneumonia Detection Report"), ln=True)
207
+ pdf.set_font("Helvetica", size=10)
208
+ pdf.cell(0, 6, safe_text(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"), ln=True)
209
+ pdf.cell(0, 6, safe_text(f"Model Version: {model_version}"), ln=True)
210
+ pdf.cell(0, 6, safe_text(f"Decision Threshold: {threshold:.2f}"), ln=True)
211
+ pdf.ln(4)
212
 
213
+ pdf.set_font("Helvetica", size=10)
214
+ for _, row in df_ok.iterrows():
215
+ line1 = f"File: {row['file_name']}"
216
+ line2 = f"Probability: {row['probability']*100:.2f}%"
217
+ line3 = f"Confidence: {row['confidence_level']}"
218
+ line4 = f"Prediction: {row['prediction']}"
219
 
220
+ for line in [line1, line2, line3, line4]:
221
+ pdf.multi_cell(0, 6, safe_text(line))
222
 
223
+ pdf.ln(2)
224
+
225
+ out = pdf.output(dest="S")
226
+ # fpdf may return str in some versions
227
+ if isinstance(out, str):
228
+ out = out.encode("latin-1", "ignore")
229
+ return out
230
+
231
+ # -----------------------------
232
  # UI
233
+ # -----------------------------
234
  st.subheader("Model Parameters")
235
 
236
  threshold = st.slider(
237
  "Decision Threshold",
238
  min_value=0.01,
239
  max_value=0.99,
240
+ value=0.37, # your ResNet best threshold default
241
  step=0.01,
242
+ help="If predicted probability is greater than or equal to the threshold, output is Pneumonia. Otherwise Not Pneumonia."
243
  )
244
 
245
  show_gradcam = st.checkbox("Show Grad-CAM heatmap (explainability)", value=True)
246
 
247
  st.subheader("Upload Chest X-ray DICOM Files")
 
248
  uploaded_files = st.file_uploader(
249
  "Select one or multiple DICOM files (.dcm)",
250
  type=["dcm"],
251
+ accept_multiple_files=True
252
  )
253
 
254
  col1, col2 = st.columns(2)
 
262
 
263
  st.subheader("Prediction Results")
264
 
 
 
 
 
265
  if submit:
266
  if not uploaded_files:
267
  st.warning("Please upload at least one DICOM file before submitting.")
268
  else:
269
+ # cache bytes once (so we can read multiple times safely)
270
+ file_bytes = {f.name: f.getvalue() for f in uploaded_files}
271
 
272
+ rows = []
273
  with st.spinner("Running inference..."):
274
+ for name, data in file_bytes.items():
275
  ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
276
  try:
277
+ img = dicom_bytes_to_img(data)
 
278
  x = preprocess(img)
279
  prob = predict_prob(x)
280
 
281
  pred_label = "Pneumonia" if prob >= threshold else "Not Pneumonia"
282
  conf_level = interpret_confidence(prob)
283
 
284
+ rows.append({
285
+ "timestamp": ts,
286
+ "model_version": MODEL_VERSION,
287
+ "file_name": name,
288
+ "probability": prob,
289
+ "prediction": pred_label,
290
+ "confidence_level": conf_level,
291
+ "error": ""
292
+ })
 
 
 
 
293
  except Exception as e:
294
+ rows.append({
295
+ "timestamp": ts,
296
+ "model_version": MODEL_VERSION,
297
+ "file_name": name,
298
+ "probability": np.nan,
299
+ "prediction": "Error",
300
+ "confidence_level": "",
301
+ "error": safe_text(str(e), max_len=140)
302
+ })
 
 
303
 
304
  df = pd.DataFrame(rows)
305
 
 
308
  if r["prediction"] == "Error":
309
  st.error(
310
  f"For the uploaded file '{r['file_name']}', the system could not generate a prediction. "
311
+ f"Reason: {r['error']}."
 
 
 
 
 
 
 
 
312
  )
313
+ continue
314
+
315
+ prob_pct = float(r["probability"]) * 100.0
316
+ st.write(
317
+ f"For the uploaded file '{r['file_name']}', the model estimates a pneumonia probability of "
318
+ f"{prob_pct:.2f}%. This falls under '{r['confidence_level']}'. "
319
+ f"Based on the selected decision threshold of {threshold:.2f}, the predicted outcome is "
320
+ f"'{r['prediction']}'."
321
+ )
322
 
323
  # Grad-CAM section
324
  if show_gradcam:
325
  st.markdown("### Grad-CAM Heatmaps")
326
+ for name, data in file_bytes.items():
 
 
327
  try:
328
+ img = dicom_bytes_to_img(data)
329
+ x = preprocess(img)
330
  heatmap = make_gradcam_heatmap(x)
331
+ fig = overlay_heatmap(img, heatmap)
332
+ st.write(f"Heatmap for: {name}")
333
  st.pyplot(fig)
334
  except Exception as e:
335
+ st.warning(
336
+ f"Could not generate Grad-CAM for {name}. "
337
+ f"Reason: {safe_text(str(e), max_len=160)}"
338
+ )
339
 
340
  # Downloads
341
  st.markdown("### Downloads")
 
346
  data=csv_bytes,
347
  file_name="predictions.csv",
348
  mime="text/csv",
349
+ use_container_width=True
350
  )
351
 
352
  df_ok = df[df["prediction"] != "Error"].copy()
 
357
  data=pdf_bytes,
358
  file_name="pneumonia_report.pdf",
359
  mime="application/pdf",
360
+ use_container_width=True
361
  )
362
  else:
363
+ st.info("PDF report is available only when at least one file is successfully processed.")
 
364
 
365
  st.divider()
366
  st.caption(