Pushpak21 commited on
Commit
5586b97
·
verified ·
1 Parent(s): 6c0bc56

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +132 -120
streamlit_app.py CHANGED
@@ -1,140 +1,152 @@
 
1
  # streamlit_app.py
2
- import streamlit as st
 
3
  import numpy as np
 
 
4
  import tensorflow as tf
5
- from tensorflow.keras.models import load_model, Model
6
- import cv2
7
  import pydicom
8
- from PIL import Image
 
9
 
10
- st.set_page_config(page_title="Pneumonia Detection", layout="wide")
 
 
 
 
 
11
 
 
12
 
13
- # ---------------------------
14
- # Load Model
15
- # ---------------------------
16
- @st.cache_resource
17
- def load_my_model():
18
- model = load_model("model/best_model.keras", compile=False)
19
- return model
20
-
21
-
22
- model = load_my_model()
23
-
24
-
25
- # ---------------------------
26
- # Preprocess image
27
- # ---------------------------
28
- def load_image(file):
29
- """Loads PNG/JPG/DICOM and returns a grayscale 224x224 normalized array."""
30
- filename = file.name.lower()
31
-
32
- if filename.endswith(".dcm"):
33
- dcm = pydicom.dcmread(file)
34
- img = dcm.pixel_array.astype(np.float32)
35
- img = cv2.resize(img, (224, 224))
36
- img = img / np.max(img)
37
- return img
38
-
39
- # PNG / JPG / JPEG
40
- img = Image.open(file).convert("L")
41
- img = img.resize((224, 224))
42
- img = np.array(img).astype(np.float32) / 255.0
43
- return img
44
-
45
-
46
- # ---------------------------
47
- # Robust Grad-CAM
48
- # ---------------------------
49
- def grad_cam(model, img_array, layer_name=None, eps=1e-8):
50
- """
51
- img_array: (224,224) normalized grayscale → will be expanded to (1,224,224,1)
52
- """
53
-
54
- # Expand dims for model
55
- x = np.expand_dims(img_array, axis=0) # (1,224,224)
56
- x = np.expand_dims(x, axis=-1) # (1,224,224,1)
57
- x = tf.convert_to_tensor(x, dtype=tf.float32)
58
-
59
- # Auto-detect last conv layer if not provided
60
- if layer_name is None:
61
- for layer in reversed(model.layers):
62
- if hasattr(layer, "output_shape") and len(layer.output_shape) == 4:
63
- layer_name = layer.name
64
- break
65
-
66
- last_conv = model.get_layer(layer_name)
67
- grad_model = Model([model.inputs], [last_conv.output, model.output])
68
 
69
- with tf.GradientTape() as tape:
70
- conv_outputs, predictions = grad_model(x)
71
- class_idx = int(tf.argmax(predictions[0]))
72
- loss = predictions[:, class_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  grads = tape.gradient(loss, conv_outputs)
75
- if grads is None:
76
- raise RuntimeError("Gradients are None. Model may not be connected properly.")
77
-
78
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)).numpy()
79
-
80
- conv_outputs = conv_outputs[0].numpy() # (H,W,channels)
81
-
82
- for i in range(len(pooled_grads)):
83
- conv_outputs[:, :, i] *= pooled_grads[i]
84
-
85
- cam = np.mean(conv_outputs, axis=-1)
86
  cam = np.maximum(cam, 0)
87
-
88
- cam -= cam.min()
89
- if cam.max() > eps:
90
- cam /= cam.max() + eps
91
-
92
- cam = cv2.resize(cam, (224, 224))
93
- return cam
94
-
95
-
96
- # ---------------------------
97
- # UI
98
- # ---------------------------
99
- st.title("🫁 Pneumonia Detection")
100
- st.write("Upload a chest scan image (DICOM or PNG/JPG).")
101
-
102
-
103
- file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "dcm"])
104
-
105
- if file:
 
106
  try:
107
- img = load_image(file)
 
 
 
108
 
109
- st.subheader("Input Image")
110
- st.image(img, caption="Uploaded Image", use_container_width=True, clamp=True)
111
 
112
- # Model input format (1,224,224,1)
113
- x = np.expand_dims(img, axis=(0, -1))
114
 
115
- pred = model.predict(x)[0][0]
116
- label = "Pneumonia" if pred >= 0.5 else "Normal"
 
 
117
 
118
- st.subheader("Prediction")
119
- st.write(f"**Class:** {label}")
120
- st.write(f"**Probability:** {float(pred):.4f}")
121
 
122
- # Grad-CAM
123
- st.subheader("Grad-CAM Heatmap")
124
  try:
125
- cam = grad_cam(model, img)
126
- heatmap = cv2.applyColorMap(
127
- np.uint8(255 * cam), cv2.COLORMAP_JET
128
- )
129
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
130
-
131
- overlay = 0.4 * heatmap + 0.6 * np.stack([img*255]*3, axis=-1)
132
- overlay = overlay.astype(np.uint8)
133
-
134
- st.image(overlay, caption="Grad-CAM", use_container_width=True)
135
-
136
  except Exception as e:
137
- st.error(f"Grad-CAM failed: {e}")
138
 
139
- except Exception as e:
140
- st.error(f"Error loading image: {e}")
 
1
+ %%writefile pneumonia_app/streamlit_app.py
2
  # streamlit_app.py
3
+ import io
4
+ import os
5
  import numpy as np
6
+ from PIL import Image
7
+ import streamlit as st
8
  import tensorflow as tf
9
+ from tensorflow.keras.models import load_model
10
+ from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
11
  import pydicom
12
+ from pydicom.pixel_data_handlers.util import apply_voi_lut
13
+ import matplotlib.cm as cm
14
 
15
+ # -------- CONFIG --------
16
+ MODEL_FILENAME = "Model2_exact_serialized.keras" # model file expected in app folder
17
+ IMG_SIZE = (224, 224)
18
+ THRESHOLD = 0.62
19
+ ENABLE_GRADCAM = True
20
+ # ------------------------
21
 
22
+ st.set_page_config(page_title="Pneumonia Detection (CheXNet)", layout="centered")
23
 
24
+ st.title("Pneumonia detection (CheXNet)")
25
+ st.write("Upload a chest X-ray (DICOM or PNG/JPG). The app predicts probability of pneumonia.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # ------- utilities -------
28
+ def dicom_to_image_array(dicom_bytes):
29
+ ds = pydicom.dcmread(io.BytesIO(dicom_bytes), force=True)
30
+ try:
31
+ arr = ds.pixel_array
32
+ except Exception as e:
33
+ raise RuntimeError(f"Could not decode DICOM pixel data: {e}")
34
+ if arr.ndim == 3:
35
+ arr = arr[0]
36
+ try:
37
+ arr = apply_voi_lut(arr, ds)
38
+ except Exception:
39
+ pass
40
+ arr = arr.astype(np.float32)
41
+ if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1":
42
+ arr = np.max(arr) - arr
43
+ mn, mx = arr.min(), arr.max()
44
+ if mx > mn:
45
+ arr = (arr - mn) / (mx - mn)
46
+ else:
47
+ arr = arr - mn
48
+ arr = (arr * 255.0).clip(0,255).astype(np.uint8)
49
+ return arr
50
+
51
+ def to_rgb_uint8_from_upload(uploaded_file):
52
+ """Return RGB uint8 (H,W,3) array resized to IMG_SIZE."""
53
+ if uploaded_file is None:
54
+ raise RuntimeError("No file")
55
+ raw = uploaded_file.read()
56
+ # try DICOM
57
+ try:
58
+ ds = pydicom.dcmread(io.BytesIO(raw), stop_before_pixels=True, force=True)
59
+ if hasattr(ds, "PixelData") or getattr(ds, "Rows", None):
60
+ arr = dicom_to_image_array(raw)
61
+ if arr.ndim == 2:
62
+ arr = np.stack([arr]*3, axis=-1)
63
+ pil = Image.fromarray(arr).convert("RGB").resize(IMG_SIZE)
64
+ return np.array(pil)
65
+ except Exception:
66
+ pass
67
+ # fallback normal image
68
+ try:
69
+ pil = Image.open(io.BytesIO(raw)).convert("L").resize(IMG_SIZE)
70
+ arr = np.stack([np.array(pil)]*3, axis=-1)
71
+ return arr.astype(np.uint8)
72
+ except Exception as e:
73
+ raise RuntimeError("Unsupported file format. Upload a DICOM or PNG/JPG.") from e
74
 
75
+ # -------- model load (cached) --------
76
+ @st.cache_resource
77
+ def load_predict_model(model_path):
78
+ if not os.path.exists(model_path):
79
+ raise FileNotFoundError(f"Model file not found: {model_path}")
80
+ m = load_model(model_path, compile=False)
81
+ return m
82
+
83
+ # Grad-CAM utilities
84
+ def find_last_conv_layer(m):
85
+ for layer in reversed(m.layers):
86
+ out_shape = getattr(layer, "output_shape", None)
87
+ if out_shape and len(out_shape) == 4 and "conv" in layer.name:
88
+ return layer.name
89
+ return m.layers[-3].name
90
+
91
+ def make_gradcam_image(rgb_uint8, model, last_conv_name=None, alpha=0.4, cmap_name="jet"):
92
+ img = rgb_uint8.astype(np.float32)
93
+ if last_conv_name is None:
94
+ last_conv_name = find_last_conv_layer(model)
95
+ grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(last_conv_name).output, model.output])
96
+ x = densenet_preprocess(np.expand_dims(img.astype(np.float32), axis=0))
97
+ with tf.GradientTape() as tape:
98
+ conv_outputs, preds = grad_model(x)
99
+ loss = preds[:, 0]
100
  grads = tape.gradient(loss, conv_outputs)
101
+ weights = tf.reduce_mean(grads, axis=(1,2))
102
+ cam = tf.reduce_sum(tf.multiply(weights[:, tf.newaxis, tf.newaxis, :], conv_outputs), axis=-1)
103
+ cam = tf.squeeze(cam).numpy()
 
 
 
 
 
 
 
 
104
  cam = np.maximum(cam, 0)
105
+ cam_max = cam.max() if cam.max() != 0 else 1e-8
106
+ cam = cam / cam_max
107
+ cam_img = Image.fromarray(np.uint8(cam * 255)).resize((img.shape[1], img.shape[0]), resample=Image.BILINEAR)
108
+ cam_arr = np.array(cam_img).astype(np.float32)/255.0
109
+ colormap = cm.get_cmap(cmap_name)
110
+ heatmap = colormap(cam_arr)[:, :, :3]
111
+ heat_uint8 = np.uint8(heatmap * 255)
112
+ heat_pil = Image.fromarray(heat_uint8).convert("RGBA").resize((img.shape[1], img.shape[0]))
113
+ base_pil = Image.fromarray(np.uint8(img)).convert("RGBA")
114
+ blended = Image.blend(base_pil, heat_pil, alpha=alpha)
115
+ return blended.convert("RGB")
116
+
117
+ # -------- UI elements --------
118
+ col1, col2 = st.columns([1,1])
119
+ with col1:
120
+ uploaded = st.file_uploader("Upload DICOM or PNG/JPG", type=["dcm","png","jpg","jpeg","tif","tiff"])
121
+ with col2:
122
+ thresh = st.number_input("Decision threshold (probability)", min_value=0.0, max_value=1.0, value=float(THRESHOLD), step=0.01)
123
+
124
+ if uploaded is not None:
125
  try:
126
+ rgb = to_rgb_uint8_from_upload(uploaded)
127
+ except Exception as e:
128
+ st.error(f"Failed to process file: {e}")
129
+ st.stop()
130
 
131
+ st.image(rgb, caption="Input (resized)", use_container_width=True)
 
132
 
133
+ # load model (cached)
134
+ model = load_predict_model(MODEL_FILENAME)
135
 
136
+ # predict
137
+ x_pre = densenet_preprocess(np.expand_dims(rgb.astype(np.float32), axis=0))
138
+ prob = float(model.predict(x_pre, verbose=0).ravel()[0])
139
+ pred = int(prob >= thresh)
140
 
141
+ st.markdown(f"**Pneumonia probability:** `{prob:.4f}`")
142
+ st.markdown(f"**Predicted class (binary):** `{pred}` — **{'Pneumonia' if pred==1 else 'Normal'}**")
 
143
 
144
+ if ENABLE_GRADCAM:
 
145
  try:
146
+ cam = make_gradcam_image(rgb, model)
147
+ st.image(cam, caption="Grad-CAM overlay", use_container_width=True)
 
 
 
 
 
 
 
 
 
148
  except Exception as e:
149
+ st.warning(f"Grad-CAM failed: {e}")
150
 
151
+ else:
152
+ st.info("Upload a DICOM or PNG/JPG image to run inference.")