VJBharathkumar commited on
Commit
4582cbb
·
verified ·
1 Parent(s): c237f7a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +74 -29
src/streamlit_app.py CHANGED
@@ -9,10 +9,9 @@ import streamlit as st
9
  import tensorflow as tf
10
  from tensorflow import keras
11
  import pydicom
12
- import matplotlib.pyplot as plt
13
-
14
  from fpdf import FPDF
15
 
 
16
  # -----------------------------
17
  # Page config
18
  # -----------------------------
@@ -27,6 +26,7 @@ st.caption(
27
  "This tool is for decision support only and does not replace clinical judgment."
28
  )
29
 
 
30
  # -----------------------------
31
  # Paths / Model Loading
32
  # -----------------------------
@@ -34,6 +34,7 @@ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
34
  MODEL_PATH = os.path.join(REPO_ROOT, "model.keras")
35
  VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json") # optional
36
 
 
37
  @st.cache_resource
38
  def load_model():
39
  if not os.path.exists(MODEL_PATH):
@@ -47,6 +48,7 @@ def load_model():
47
  m = keras.models.load_model(MODEL_PATH, safe_mode=False)
48
  return m
49
 
 
50
  model = load_model()
51
 
52
  # model input details
@@ -54,6 +56,7 @@ input_shape = model.input_shape # (None, H, W, C)
54
  img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
55
  exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1
56
 
 
57
  def get_model_version():
58
  if os.path.exists(VERSION_PATH):
59
  try:
@@ -63,8 +66,33 @@ def get_model_version():
63
  return "ResNet50_v1"
64
  return "ResNet50_v1"
65
 
 
66
  MODEL_VERSION = get_model_version()
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # -----------------------------
69
  # Confidence interpretation
70
  # -----------------------------
@@ -76,6 +104,7 @@ def interpret_confidence(prob: float) -> str:
76
  else:
77
  return "High likelihood (>60%)"
78
 
 
79
  # -----------------------------
80
  # DICOM + preprocessing
81
  # -----------------------------
@@ -89,6 +118,7 @@ def dicom_bytes_to_img(data: bytes) -> np.ndarray:
89
 
90
  return img
91
 
 
92
  def preprocess(img_2d: np.ndarray) -> np.ndarray:
93
  # (H,W) -> (1,img_size,img_size,C) float32 0..1
94
  x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1)
@@ -104,6 +134,7 @@ def preprocess(img_2d: np.ndarray) -> np.ndarray:
104
  x = np.expand_dims(x, axis=0) # (1,img_size,img_size,C)
105
  return x.astype(np.float32)
106
 
 
107
  def predict_prob(x: np.ndarray) -> float:
108
  pred = model.predict(x, verbose=0)
109
  if isinstance(pred, (list, tuple)):
@@ -113,6 +144,40 @@ def predict_prob(x: np.ndarray) -> float:
113
  return max(0.0, min(1.0, prob))
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # -----------------------------
118
  # UI
@@ -128,8 +193,6 @@ threshold = st.slider(
128
  help="If predicted probability is greater than or equal to the threshold, output is Pneumonia. Otherwise Not Pneumonia."
129
  )
130
 
131
- show_gradcam = st.checkbox("Show Grad-CAM heatmap (explainability)", value=True)
132
-
133
  st.subheader("Upload Chest X-ray DICOM Files")
134
  uploaded_files = st.file_uploader(
135
  "Select one or multiple DICOM files (.dcm)",
@@ -152,7 +215,7 @@ if submit:
152
  if not uploaded_files:
153
  st.warning("Please upload at least one DICOM file before submitting.")
154
  else:
155
- # cache bytes once (so we can read multiple times safely)
156
  file_bytes = {f.name: f.getvalue() for f in uploaded_files}
157
 
158
  rows = []
@@ -206,23 +269,6 @@ if submit:
206
  f"'{r['prediction']}'."
207
  )
208
 
209
- # Grad-CAM section
210
- if show_gradcam:
211
- st.markdown("### Grad-CAM Heatmaps")
212
- for name, data in file_bytes.items():
213
- try:
214
- img = dicom_bytes_to_img(data)
215
- x = preprocess(img)
216
- heatmap = make_gradcam_heatmap(x)
217
- fig = overlay_heatmap(img, heatmap)
218
- st.write(f"Heatmap for: {name}")
219
- st.pyplot(fig)
220
- except Exception as e:
221
- st.warning(
222
- f"Could not generate Grad-CAM for {name}. "
223
- f"Reason: {safe_text(str(e), max_len=160)}"
224
- )
225
-
226
  # Downloads
227
  st.markdown("### Downloads")
228
 
@@ -239,15 +285,14 @@ if submit:
239
  if len(df_ok) > 0:
240
  pdf_bytes = build_pdf_report(df_ok, threshold, MODEL_VERSION)
241
  st.download_button(
242
- "Download PDF Report",
243
- data=pdf_bytes,
244
- file_name="pneumonia_report.pdf",
245
- mime="application/pdf",
246
- use_container_width=True
247
- )
248
  else:
249
  st.info("PDF report is available only when at least one file is successfully processed.")
250
-
251
 
252
  st.divider()
253
  st.caption(
 
9
  import tensorflow as tf
10
  from tensorflow import keras
11
  import pydicom
 
 
12
  from fpdf import FPDF
13
 
14
+
15
  # -----------------------------
16
  # Page config
17
  # -----------------------------
 
26
  "This tool is for decision support only and does not replace clinical judgment."
27
  )
28
 
29
+
30
  # -----------------------------
31
  # Paths / Model Loading
32
  # -----------------------------
 
34
  MODEL_PATH = os.path.join(REPO_ROOT, "model.keras")
35
  VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json") # optional
36
 
37
+
38
  @st.cache_resource
39
  def load_model():
40
  if not os.path.exists(MODEL_PATH):
 
48
  m = keras.models.load_model(MODEL_PATH, safe_mode=False)
49
  return m
50
 
51
+
52
  model = load_model()
53
 
54
  # model input details
 
56
  img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
57
  exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1
58
 
59
+
60
  def get_model_version():
61
  if os.path.exists(VERSION_PATH):
62
  try:
 
66
  return "ResNet50_v1"
67
  return "ResNet50_v1"
68
 
69
+
70
  MODEL_VERSION = get_model_version()
71
 
72
+
73
+ # -----------------------------
74
+ # Text safety (PDF + error messages)
75
+ # -----------------------------
76
+ def safe_text(s: str, max_len: int = 200) -> str:
77
+ if s is None:
78
+ return ""
79
+ s = str(s)
80
+
81
+ # replace common unicode characters that can break FPDF
82
+ s = s.replace("–", "-").replace("—", "-").replace("’", "'").replace("“", '"').replace("”", '"')
83
+
84
+ # add break opportunities for long tokens (UUIDs / filenames)
85
+ s = s.replace("-", "- ").replace("_", "_ ").replace("/", "/ ")
86
+
87
+ # keep latin-1 safe for default FPDF fonts
88
+ s = s.encode("latin-1", "replace").decode("latin-1")
89
+
90
+ # trim long strings
91
+ if len(s) > max_len:
92
+ s = s[:max_len] + "..."
93
+ return s
94
+
95
+
96
  # -----------------------------
97
  # Confidence interpretation
98
  # -----------------------------
 
104
  else:
105
  return "High likelihood (>60%)"
106
 
107
+
108
  # -----------------------------
109
  # DICOM + preprocessing
110
  # -----------------------------
 
118
 
119
  return img
120
 
121
+
122
  def preprocess(img_2d: np.ndarray) -> np.ndarray:
123
  # (H,W) -> (1,img_size,img_size,C) float32 0..1
124
  x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1)
 
134
  x = np.expand_dims(x, axis=0) # (1,img_size,img_size,C)
135
  return x.astype(np.float32)
136
 
137
+
138
  def predict_prob(x: np.ndarray) -> float:
139
  pred = model.predict(x, verbose=0)
140
  if isinstance(pred, (list, tuple)):
 
144
  return max(0.0, min(1.0, prob))
145
 
146
 
147
+ # -----------------------------
148
+ # PDF report
149
+ # -----------------------------
150
+ def build_pdf_report(df_ok: pd.DataFrame, threshold: float, model_version: str) -> bytes:
151
+ pdf = FPDF()
152
+ pdf.set_auto_page_break(auto=True, margin=12)
153
+ pdf.add_page()
154
+
155
+ pdf.set_font("Helvetica", size=12)
156
+ w = pdf.w - pdf.l_margin - pdf.r_margin # effective width
157
+
158
+ pdf.cell(0, 8, safe_text("Pneumonia Detection Report"), ln=True)
159
+ pdf.set_font("Helvetica", size=10)
160
+ pdf.cell(0, 6, safe_text(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"), ln=True)
161
+ pdf.cell(0, 6, safe_text(f"Model Version: {model_version}"), ln=True)
162
+ pdf.cell(0, 6, safe_text(f"Decision Threshold: {threshold:.2f}"), ln=True)
163
+ pdf.ln(4)
164
+
165
+ for _, row in df_ok.iterrows():
166
+ lines = [
167
+ f"File: {row['file_name']}",
168
+ f"Probability: {float(row['probability']) * 100:.2f}%",
169
+ f"Confidence: {row['confidence_level']}",
170
+ f"Prediction: {row['prediction']}",
171
+ ]
172
+ for line in lines:
173
+ pdf.multi_cell(w, 6, safe_text(line))
174
+ pdf.ln(2)
175
+
176
+ out = pdf.output(dest="S")
177
+ if isinstance(out, str):
178
+ out = out.encode("latin-1", "ignore")
179
+ return out
180
+
181
 
182
  # -----------------------------
183
  # UI
 
193
  help="If predicted probability is greater than or equal to the threshold, output is Pneumonia. Otherwise Not Pneumonia."
194
  )
195
 
 
 
196
  st.subheader("Upload Chest X-ray DICOM Files")
197
  uploaded_files = st.file_uploader(
198
  "Select one or multiple DICOM files (.dcm)",
 
215
  if not uploaded_files:
216
  st.warning("Please upload at least one DICOM file before submitting.")
217
  else:
218
+ # cache bytes once (so we can read safely)
219
  file_bytes = {f.name: f.getvalue() for f in uploaded_files}
220
 
221
  rows = []
 
269
  f"'{r['prediction']}'."
270
  )
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  # Downloads
273
  st.markdown("### Downloads")
274
 
 
285
  if len(df_ok) > 0:
286
  pdf_bytes = build_pdf_report(df_ok, threshold, MODEL_VERSION)
287
  st.download_button(
288
+ "Download PDF Report",
289
+ data=pdf_bytes,
290
+ file_name="pneumonia_report.pdf",
291
+ mime="application/pdf",
292
+ use_container_width=True
293
+ )
294
  else:
295
  st.info("PDF report is available only when at least one file is successfully processed.")
 
296
 
297
  st.divider()
298
  st.caption(