VJBharathkumar commited on
Commit
2958274
·
verified ·
1 Parent(s): 3372102

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +279 -78
src/streamlit_app.py CHANGED
@@ -1,66 +1,228 @@
1
  import io
2
  import os
 
 
 
3
  import numpy as np
 
4
  import streamlit as st
5
  import tensorflow as tf
6
  from tensorflow import keras
7
  import pydicom
 
 
 
8
 
9
- # ----------------------------------------------------
10
- # App Configuration
11
- # ----------------------------------------------------
12
  st.set_page_config(
13
  page_title="Pneumonia Detection (Chest X-ray) – Clinical Decision Support",
14
  layout="centered"
15
  )
16
 
17
-
18
  st.title("Pneumonia Detection (Chest X-ray) – Clinical Decision Support")
19
  st.caption(
20
- "Upload one or more Chest X-ray DICOM files (.dcm). "
21
- "Adjust the decision threshold and submit to obtain a probability-based binary prediction. "
22
- "This system is intended for clinical decision support and does not replace professional medical judgment."
23
  )
24
 
25
- # ----------------------------------------------------
26
- # Load Model
27
- # ----------------------------------------------------
28
-
29
-
30
- MODEL_PATH = os.path.join(os.path.dirname(__file__), "..", "model.keras")
31
 
 
 
32
 
33
  @st.cache_resource
34
  def load_model():
 
 
 
35
  try:
36
- return keras.models.load_model(MODEL_PATH)
37
  except Exception:
38
  keras.config.enable_unsafe_deserialization()
39
- return keras.models.load_model(MODEL_PATH, safe_mode=False)
 
40
 
41
  model = load_model()
42
 
43
- input_shape = model.input_shape
 
44
  img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
45
- expected_channels = int(input_shape[-1]) if input_shape and input_shape[-1] else 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # ----------------------------------------------------
48
- # Threshold Slider (DEFAULT = 0.37 for ResNet)
49
- # ----------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  st.subheader("Model Parameters")
51
 
52
  threshold = st.slider(
53
  "Decision Threshold",
54
  min_value=0.01,
55
  max_value=0.99,
56
- value=0.37,
57
  step=0.01,
58
- help="If predicted probability ≥ threshold → Pneumonia. Otherwise → Not Pneumonia."
59
  )
60
 
61
- # ----------------------------------------------------
62
- # File Upload
63
- # ----------------------------------------------------
64
  st.subheader("Upload Chest X-ray DICOM Files")
65
 
66
  uploaded_files = st.file_uploader(
@@ -76,74 +238,113 @@ with col2:
76
  clear = st.button("Clear", use_container_width=True)
77
 
78
  if clear:
79
- st.experimental_rerun()
80
 
81
- # ----------------------------------------------------
82
- # Helper Functions
83
- # ----------------------------------------------------
84
- def read_dicom(file):
85
- data = file.read()
86
- dcm = pydicom.dcmread(io.BytesIO(data))
87
- img = dcm.pixel_array.astype(np.float32)
88
 
89
- img = (img - img.min()) / (img.max() - img.min() + 1e-8)
90
- return img
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def preprocess(img):
93
- x = tf.convert_to_tensor(img[..., None], dtype=tf.float32)
94
- x = tf.image.resize(x, (img_size, img_size))
95
- x = tf.clip_by_value(x, 0.0, 1.0)
96
- x = x.numpy()
 
 
 
97
 
98
- # If model expects 3 channels (ResNet)
99
- if expected_channels == 3 and x.shape[-1] == 1:
100
- x = np.repeat(x, 3, axis=-1)
 
 
 
 
 
 
101
 
102
- x = np.expand_dims(x, axis=0)
103
- return x.astype(np.float32)
104
 
105
- def get_probability(x):
106
- prediction = model.predict(x, verbose=0)
 
 
 
 
 
 
107
 
108
- if isinstance(prediction, (list, tuple)):
109
- prob = float(np.ravel(prediction[-1])[0])
110
- else:
111
- prob = float(np.ravel(prediction)[0])
 
 
 
112
 
113
- return max(0.0, min(1.0, prob))
 
 
 
114
 
115
- # ----------------------------------------------------
116
- # Inference Section
117
- # ----------------------------------------------------
118
- st.subheader("Prediction Results")
119
 
120
- if submit:
121
- if not uploaded_files:
122
- st.warning("Please upload at least one DICOM file before clicking Submit.")
123
- else:
124
- with st.spinner("Processing uploaded file(s)..."):
125
- for file in uploaded_files:
126
  try:
127
- image_array = read_dicom(file)
128
- x_input = preprocess(image_array)
129
- probability = get_probability(x_input)
 
 
130
 
131
- predicted_label = "Pneumonia" if probability >= threshold else "Not Pneumonia"
 
 
 
 
 
 
132
 
133
- st.write(
134
- f"For the uploaded file '{file.name}', the model estimates a pneumonia probability of "
135
- f"{probability * 100:.2f}%. Based on the selected decision threshold of {threshold:.2f}, "
136
- f"the predicted outcome is '{predicted_label}'."
137
- )
 
 
 
 
 
138
 
139
- except Exception as e:
140
- st.error(
141
- f"For the uploaded file '{file.name}', the system could not generate a prediction. "
142
- f"Reason: {str(e)}."
143
- )
 
 
 
144
 
145
  st.divider()
146
  st.caption(
147
- "Clinical Notice: This application is designed for decision support purposes only. "
148
- "Final diagnosis and treatment decisions must be made by qualified healthcare professionals."
149
  )
 
1
  import io
2
  import os
3
+ import json
4
+ from datetime import datetime
5
+
6
  import numpy as np
7
+ import pandas as pd
8
  import streamlit as st
9
  import tensorflow as tf
10
  from tensorflow import keras
11
  import pydicom
12
+ import matplotlib.pyplot as plt
13
+
14
+ from fpdf import FPDF
15
 
16
+ # -----------------------------
17
+ # Page config
18
+ # -----------------------------
19
  st.set_page_config(
20
  page_title="Pneumonia Detection (Chest X-ray) – Clinical Decision Support",
21
  layout="centered"
22
  )
23
 
 
24
  st.title("Pneumonia Detection (Chest X-ray) – Clinical Decision Support")
25
  st.caption(
26
+ "Upload one or more Chest X-ray DICOM files (.dcm). Adjust the decision threshold and click Submit. "
27
+ "This tool is for decision support only and does not replace clinical judgment."
 
28
  )
29
 
30
+ # -----------------------------
31
+ # Paths / Model Loading
32
+ # -----------------------------
33
+ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
34
+ MODEL_PATH = os.path.join(REPO_ROOT, "model.keras")
 
35
 
36
+ # Optional: store a version tag manually in a json file in repo root if you want
37
+ VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json")
38
 
39
  @st.cache_resource
40
  def load_model():
41
+ if not os.path.exists(MODEL_PATH):
42
+ raise FileNotFoundError(f"model.keras not found at: {MODEL_PATH}")
43
+
44
  try:
45
+ m = keras.models.load_model(MODEL_PATH)
46
  except Exception:
47
  keras.config.enable_unsafe_deserialization()
48
+ m = keras.models.load_model(MODEL_PATH, safe_mode=False)
49
+ return m
50
 
51
  model = load_model()
52
 
53
+ # read model input details
54
+ input_shape = model.input_shape # (None, H, W, C)
55
  img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
56
+ exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1
57
+
58
+ # -----------------------------
59
+ # Utilities
60
+ # -----------------------------
61
+ def get_model_version():
62
+ if os.path.exists(VERSION_PATH):
63
+ try:
64
+ with open(VERSION_PATH, "r") as f:
65
+ return json.load(f).get("version", "unknown")
66
+ except Exception:
67
+ return "unknown"
68
+ return "v1"
69
+
70
+ MODEL_VERSION = get_model_version()
71
+
72
+ def read_dicom(uploaded_file) -> np.ndarray:
73
+ data = uploaded_file.read()
74
+ dcm = pydicom.dcmread(io.BytesIO(data))
75
+ img = dcm.pixel_array.astype(np.float32)
76
+
77
+ # Normalize to 0..1
78
+ img_min = float(np.min(img))
79
+ img_max = float(np.max(img))
80
+ img = (img - img_min) / (img_max - img_min + 1e-8)
81
+
82
+ return img
83
+
84
+ def preprocess(img_2d: np.ndarray) -> np.ndarray:
85
+ # (H,W) -> (1,H,W,C) float32 0..1
86
+ x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1)
87
+ x = tf.image.resize(x, (img_size, img_size))
88
+ x = tf.clip_by_value(x, 0.0, 1.0)
89
+ x = x.numpy()
90
+
91
+ if exp_ch == 3 and x.shape[-1] == 1:
92
+ x = np.repeat(x, 3, axis=-1)
93
+ elif exp_ch == 1 and x.shape[-1] == 3:
94
+ x = x[..., :1]
95
+
96
+ x = np.expand_dims(x, axis=0)
97
+ return x.astype(np.float32)
98
+
99
+ def predict_prob(x: np.ndarray) -> float:
100
+ pred = model.predict(x, verbose=0)
101
+ if isinstance(pred, (list, tuple)):
102
+ prob = float(np.ravel(pred[-1])[0])
103
+ else:
104
+ prob = float(np.ravel(pred)[0])
105
+ return max(0.0, min(1.0, prob))
106
+
107
+ def confidence_bucket(prob: float) -> str:
108
+ # Clinical-friendly interpretation (you can adjust the bands)
109
+ if prob < 0.30:
110
+ return "Low likelihood (< 0.30)"
111
+ elif prob <= 0.60:
112
+ return "Borderline suspicion (0.30 – 0.60)"
113
+ else:
114
+ return "High likelihood (> 0.60)"
115
+
116
+ # -----------------------------
117
+ # Grad-CAM (ResNet-style) helper
118
+ # -----------------------------
119
+ def find_last_conv_layer(m: keras.Model) -> str:
120
+ # picks the last Conv2D layer name
121
+ for layer in reversed(m.layers):
122
+ if isinstance(layer, keras.layers.Conv2D):
123
+ return layer.name
124
+ # If model is nested and last conv is inside base model:
125
+ for layer in reversed(m.layers):
126
+ if isinstance(layer, keras.Model):
127
+ for sub in reversed(layer.layers):
128
+ if isinstance(sub, keras.layers.Conv2D):
129
+ return sub.name
130
+ raise ValueError("Could not find a Conv2D layer for Grad-CAM.")
131
+
132
+ @st.cache_resource
133
+ def get_gradcam_model(m: keras.Model):
134
+ last_conv = find_last_conv_layer(m)
135
+ conv_layer = m.get_layer(last_conv)
136
+ grad_model = keras.Model([m.inputs], [conv_layer.output, m.output])
137
+ return grad_model, last_conv
138
+
139
+ def make_gradcam_heatmap(x_input: np.ndarray) -> np.ndarray:
140
+ grad_model, _ = get_gradcam_model(model)
141
+
142
+ x_tensor = tf.convert_to_tensor(x_input, dtype=tf.float32)
143
+ with tf.GradientTape() as tape:
144
+ conv_out, preds = grad_model(x_tensor)
145
+
146
+ if isinstance(preds, (list, tuple)):
147
+ preds = preds[-1]
148
+
149
+ # binary prob is preds[:,0]
150
+ score = preds[:, 0]
151
 
152
+ grads = tape.gradient(score, conv_out)
153
+ pooled = tf.reduce_mean(grads, axis=(0, 1, 2))
154
+ conv_out = conv_out[0]
155
+
156
+ heatmap = conv_out @ pooled[..., tf.newaxis]
157
+ heatmap = tf.squeeze(heatmap)
158
+
159
+ heatmap = tf.maximum(heatmap, 0)
160
+ denom = tf.reduce_max(heatmap) + 1e-8
161
+ heatmap = heatmap / denom
162
+ return heatmap.numpy()
163
+
164
+ def overlay_heatmap_on_image(img_2d: np.ndarray, heatmap: np.ndarray):
165
+ # Resize heatmap to img_size
166
+ heat = tf.image.resize(heatmap[..., None], (img_size, img_size)).numpy().squeeze()
167
+
168
+ fig = plt.figure(figsize=(5, 5))
169
+ plt.imshow(img_2d, cmap="gray")
170
+ plt.imshow(heat, cmap="jet", alpha=0.35)
171
+ plt.axis("off")
172
+ plt.tight_layout()
173
+ return fig
174
+
175
+ # -----------------------------
176
+ # PDF generator
177
+ # -----------------------------
178
+ def build_pdf_report(df: pd.DataFrame, threshold: float) -> bytes:
179
+ pdf = FPDF()
180
+ pdf.add_page()
181
+ pdf.set_font("Arial", size=12)
182
+
183
+ pdf.multi_cell(0, 8, f"Pneumonia Detection Report")
184
+ pdf.ln(1)
185
+ pdf.set_font("Arial", size=10)
186
+ pdf.multi_cell(0, 6, f"Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
187
+ pdf.multi_cell(0, 6, f"Model version: {MODEL_VERSION}")
188
+ pdf.multi_cell(0, 6, f"Decision threshold used: {threshold:.2f}")
189
+ pdf.ln(2)
190
+
191
+ # Table header
192
+ pdf.set_font("Arial", "B", 9)
193
+ headers = ["file_name", "probability", "prediction", "confidence_band"]
194
+ col_w = [70, 25, 35, 55]
195
+ for h, w in zip(headers, col_w):
196
+ pdf.cell(w, 7, h, border=1)
197
+ pdf.ln()
198
+
199
+ # Rows
200
+ pdf.set_font("Arial", size=9)
201
+ for _, r in df.iterrows():
202
+ pdf.cell(col_w[0], 7, str(r["file_name"])[:40], border=1)
203
+ pdf.cell(col_w[1], 7, f'{float(r["probability"]):.4f}', border=1)
204
+ pdf.cell(col_w[2], 7, str(r["prediction"])[:18], border=1)
205
+ pdf.cell(col_w[3], 7, str(r["confidence_band"])[:30], border=1)
206
+ pdf.ln()
207
+
208
+ return pdf.output(dest="S").encode("latin-1")
209
+
210
+ # -----------------------------
211
+ # UI
212
+ # -----------------------------
213
  st.subheader("Model Parameters")
214
 
215
  threshold = st.slider(
216
  "Decision Threshold",
217
  min_value=0.01,
218
  max_value=0.99,
219
+ value=0.37, # your ResNet best-thr default
220
  step=0.01,
221
+ help="If predicted probability ≥ threshold → Pneumonia, else → Not Pneumonia."
222
  )
223
 
224
+ show_gradcam = st.checkbox("Show Grad-CAM heatmap (explainability)", value=True)
225
+
 
226
  st.subheader("Upload Chest X-ray DICOM Files")
227
 
228
  uploaded_files = st.file_uploader(
 
238
  clear = st.button("Clear", use_container_width=True)
239
 
240
  if clear:
241
+ st.rerun()
242
 
243
+ st.subheader("Prediction Results")
 
 
 
 
 
 
244
 
245
+ if submit:
246
+ if not uploaded_files:
247
+ st.warning("Please upload at least one DICOM file before submitting.")
248
+ else:
249
+ rows = []
250
+ with st.spinner("Running inference..."):
251
+ for f in uploaded_files:
252
+ try:
253
+ img = read_dicom(f)
254
+ x = preprocess(img)
255
+ prob = predict_prob(x)
256
+ pred_label = "Pneumonia" if prob >= threshold else "Not Pneumonia"
257
+ band = confidence_bucket(prob)
258
 
259
+ rows.append({
260
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
261
+ "model_version": MODEL_VERSION,
262
+ "file_name": f.name,
263
+ "probability": prob,
264
+ "prediction": pred_label,
265
+ "confidence_band": band
266
+ })
267
 
268
+ except Exception as e:
269
+ rows.append({
270
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
271
+ "model_version": MODEL_VERSION,
272
+ "file_name": f.name,
273
+ "probability": np.nan,
274
+ "prediction": "Error",
275
+ "confidence_band": str(e)
276
+ })
277
 
278
+ df = pd.DataFrame(rows)
 
279
 
280
+ # Sentence-style outputs
281
+ for _, r in df.iterrows():
282
+ if r["prediction"] == "Error":
283
+ st.error(
284
+ f"For the uploaded file '{r['file_name']}', the system could not generate a prediction. "
285
+ f"Reason: {r['confidence_band']}."
286
+ )
287
+ continue
288
 
289
+ prob_pct = float(r["probability"]) * 100.0
290
+ st.write(
291
+ f"For the uploaded file '{r['file_name']}', the model estimates a pneumonia probability of "
292
+ f"{prob_pct:.2f}%. This falls under '{r['confidence_band']}'. "
293
+ f"Based on the selected decision threshold of {threshold:.2f}, the predicted outcome is "
294
+ f"'{r['prediction']}'."
295
+ )
296
 
297
+ if show_gradcam:
298
+ try:
299
+ # Use original image for display; heatmap computed from resized input
300
+ heatmap = make_gradcam_heatmap(preprocess(read_dicom(next(ff for ff in uploaded_files if ff.name == r["file_name"]))))
301
 
302
+ # We need original image again (Streamlit upload read pointer consumed; re-read by caching bytes)
303
+ # Workaround: store bytes during first loop is better; for simplicity, skip re-read failure.
304
+ except Exception:
305
+ pass
306
 
307
+ # Show Grad-CAM images in a robust way (re-read bytes by caching)
308
+ if show_gradcam:
309
+ st.markdown("### Grad-CAM Heatmaps")
310
+ for f in uploaded_files:
 
 
311
  try:
312
+ # read again safely (need cached bytes)
313
+ data = f.getvalue()
314
+ dcm = pydicom.dcmread(io.BytesIO(data))
315
+ img = dcm.pixel_array.astype(np.float32)
316
+ img = (img - img.min()) / (img.max() - img.min() + 1e-8)
317
 
318
+ x = preprocess(img)
319
+ heatmap = make_gradcam_heatmap(x)
320
+ fig = overlay_heatmap_on_image(tf.image.resize(img[..., None], (img_size, img_size)).numpy().squeeze(), heatmap)
321
+ st.write(f"Heatmap for: {f.name}")
322
+ st.pyplot(fig)
323
+ except Exception as e:
324
+ st.warning(f"Could not generate Grad-CAM for {f.name}. Reason: {e}")
325
 
326
+ # Downloads
327
+ st.markdown("### Downloads")
328
+ csv_bytes = df.to_csv(index=False).encode("utf-8")
329
+ st.download_button(
330
+ "Download CSV",
331
+ data=csv_bytes,
332
+ file_name="predictions.csv",
333
+ mime="text/csv",
334
+ use_container_width=True
335
+ )
336
 
337
+ pdf_bytes = build_pdf_report(df[df["prediction"] != "Error"], threshold)
338
+ st.download_button(
339
+ "Download PDF Report",
340
+ data=pdf_bytes,
341
+ file_name="pneumonia_report.pdf",
342
+ mime="application/pdf",
343
+ use_container_width=True
344
+ )
345
 
346
  st.divider()
347
  st.caption(
348
+ "Clinical note: This application is designed for decision support only. Final diagnosis and treatment decisions "
349
+ "must be made by qualified healthcare professionals."
350
  )