Pushpak21 commited on
Commit
6c0bc56
·
verified ·
1 Parent(s): 7e27315

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +120 -131
streamlit_app.py CHANGED
@@ -1,151 +1,140 @@
1
  # streamlit_app.py
2
- import io
3
- import os
4
- import numpy as np
5
- from PIL import Image
6
  import streamlit as st
 
7
  import tensorflow as tf
8
- from tensorflow.keras.models import load_model
9
- from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
10
  import pydicom
11
- from pydicom.pixel_data_handlers.util import apply_voi_lut
12
- import matplotlib.cm as cm
13
-
14
- # -------- CONFIG --------
15
- MODEL_FILENAME = "Model2_exact_serialized.keras" # model file expected in app folder
16
- IMG_SIZE = (224, 224)
17
- THRESHOLD = 0.62
18
- ENABLE_GRADCAM = True
19
- # ------------------------
20
 
21
- st.set_page_config(page_title="Pneumonia Detection (CheXNet)", layout="centered")
22
 
23
- st.title("Pneumonia detection (CheXNet)")
24
- st.write("Upload a chest X-ray (DICOM or PNG/JPG). The app predicts probability of pneumonia.")
25
 
26
- # ------- utilities -------
27
- def dicom_to_image_array(dicom_bytes):
28
- ds = pydicom.dcmread(io.BytesIO(dicom_bytes), force=True)
29
- try:
30
- arr = ds.pixel_array
31
- except Exception as e:
32
- raise RuntimeError(f"Could not decode DICOM pixel data: {e}")
33
- if arr.ndim == 3:
34
- arr = arr[0]
35
- try:
36
- arr = apply_voi_lut(arr, ds)
37
- except Exception:
38
- pass
39
- arr = arr.astype(np.float32)
40
- if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1":
41
- arr = np.max(arr) - arr
42
- mn, mx = arr.min(), arr.max()
43
- if mx > mn:
44
- arr = (arr - mn) / (mx - mn)
45
- else:
46
- arr = arr - mn
47
- arr = (arr * 255.0).clip(0,255).astype(np.uint8)
48
- return arr
49
-
50
- def to_rgb_uint8_from_upload(uploaded_file):
51
- """Return RGB uint8 (H,W,3) array resized to IMG_SIZE."""
52
- if uploaded_file is None:
53
- raise RuntimeError("No file")
54
- raw = uploaded_file.read()
55
- # try DICOM
56
- try:
57
- ds = pydicom.dcmread(io.BytesIO(raw), stop_before_pixels=True, force=True)
58
- if hasattr(ds, "PixelData") or getattr(ds, "Rows", None):
59
- arr = dicom_to_image_array(raw)
60
- if arr.ndim == 2:
61
- arr = np.stack([arr]*3, axis=-1)
62
- pil = Image.fromarray(arr).convert("RGB").resize(IMG_SIZE)
63
- return np.array(pil)
64
- except Exception:
65
- pass
66
- # fallback normal image
67
- try:
68
- pil = Image.open(io.BytesIO(raw)).convert("L").resize(IMG_SIZE)
69
- arr = np.stack([np.array(pil)]*3, axis=-1)
70
- return arr.astype(np.uint8)
71
- except Exception as e:
72
- raise RuntimeError("Unsupported file format. Upload a DICOM or PNG/JPG.") from e
73
-
74
- # -------- model load (cached) --------
75
  @st.cache_resource
76
- def load_predict_model(model_path):
77
- if not os.path.exists(model_path):
78
- raise FileNotFoundError(f"Model file not found: {model_path}")
79
- m = load_model(model_path, compile=False)
80
- return m
81
-
82
- # Grad-CAM utilities
83
- def find_last_conv_layer(m):
84
- for layer in reversed(m.layers):
85
- out_shape = getattr(layer, "output_shape", None)
86
- if out_shape and len(out_shape) == 4 and "conv" in layer.name:
87
- return layer.name
88
- return m.layers[-3].name
89
-
90
- def make_gradcam_image(rgb_uint8, model, last_conv_name=None, alpha=0.4, cmap_name="jet"):
91
- img = rgb_uint8.astype(np.float32)
92
- if last_conv_name is None:
93
- last_conv_name = find_last_conv_layer(model)
94
- grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(last_conv_name).output, model.output])
95
- x = densenet_preprocess(np.expand_dims(img.astype(np.float32), axis=0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  with tf.GradientTape() as tape:
97
- conv_outputs, preds = grad_model(x)
98
- loss = preds[:, 0]
 
 
99
  grads = tape.gradient(loss, conv_outputs)
100
- weights = tf.reduce_mean(grads, axis=(1,2))
101
- cam = tf.reduce_sum(tf.multiply(weights[:, tf.newaxis, tf.newaxis, :], conv_outputs), axis=-1)
102
- cam = tf.squeeze(cam).numpy()
 
 
 
 
 
 
 
 
103
  cam = np.maximum(cam, 0)
104
- cam_max = cam.max() if cam.max() != 0 else 1e-8
105
- cam = cam / cam_max
106
- cam_img = Image.fromarray(np.uint8(cam * 255)).resize((img.shape[1], img.shape[0]), resample=Image.BILINEAR)
107
- cam_arr = np.array(cam_img).astype(np.float32)/255.0
108
- colormap = cm.get_cmap(cmap_name)
109
- heatmap = colormap(cam_arr)[:, :, :3]
110
- heat_uint8 = np.uint8(heatmap * 255)
111
- heat_pil = Image.fromarray(heat_uint8).convert("RGBA").resize((img.shape[1], img.shape[0]))
112
- base_pil = Image.fromarray(np.uint8(img)).convert("RGBA")
113
- blended = Image.blend(base_pil, heat_pil, alpha=alpha)
114
- return blended.convert("RGB")
115
-
116
- # -------- UI elements --------
117
- col1, col2 = st.columns([1,1])
118
- with col1:
119
- uploaded = st.file_uploader("Upload DICOM or PNG/JPG", type=["dcm","png","jpg","jpeg","tif","tiff"])
120
- with col2:
121
- thresh = st.number_input("Decision threshold (probability)", min_value=0.0, max_value=1.0, value=float(THRESHOLD), step=0.01)
122
-
123
- if uploaded is not None:
124
  try:
125
- rgb = to_rgb_uint8_from_upload(uploaded)
126
- except Exception as e:
127
- st.error(f"Failed to process file: {e}")
128
- st.stop()
129
 
130
- st.image(rgb, caption="Input (resized)", use_container_width=True)
 
131
 
132
- # load model (cached)
133
- model = load_predict_model(MODEL_FILENAME)
134
 
135
- # predict
136
- x_pre = densenet_preprocess(np.expand_dims(rgb.astype(np.float32), axis=0))
137
- prob = float(model.predict(x_pre, verbose=0).ravel()[0])
138
- pred = int(prob >= thresh)
139
 
140
- st.markdown(f"**Pneumonia probability:** `{prob:.4f}`")
141
- st.markdown(f"**Predicted class (binary):** `{pred}` — **{'Pneumonia' if pred==1 else 'Normal'}**")
 
142
 
143
- if ENABLE_GRADCAM:
 
144
  try:
145
- cam = make_gradcam_image(rgb, model)
146
- st.image(cam, caption="Grad-CAM overlay", use_container_width=True)
 
 
 
 
 
 
 
 
 
147
  except Exception as e:
148
- st.warning(f"Grad-CAM failed: {e}")
149
 
150
- else:
151
- st.info("Upload a DICOM or PNG/JPG image to run inference.")
 
1
  # streamlit_app.py
 
 
 
 
2
  import streamlit as st
3
+ import numpy as np
4
  import tensorflow as tf
5
+ from tensorflow.keras.models import load_model, Model
6
+ import cv2
7
  import pydicom
8
+ from PIL import Image
 
 
 
 
 
 
 
 
9
 
10
+ st.set_page_config(page_title="Pneumonia Detection", layout="wide")
11
 
 
 
12
 
13
+ # ---------------------------
14
+ # Load Model
15
+ # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @st.cache_resource
17
+ def load_my_model():
18
+ model = load_model("model/best_model.keras", compile=False)
19
+ return model
20
+
21
+
22
+ model = load_my_model()
23
+
24
+
25
+ # ---------------------------
26
+ # Preprocess image
27
+ # ---------------------------
28
+ def load_image(file):
29
+ """Loads PNG/JPG/DICOM and returns a grayscale 224x224 normalized array."""
30
+ filename = file.name.lower()
31
+
32
+ if filename.endswith(".dcm"):
33
+ dcm = pydicom.dcmread(file)
34
+ img = dcm.pixel_array.astype(np.float32)
35
+ img = cv2.resize(img, (224, 224))
36
+ img = img / np.max(img)
37
+ return img
38
+
39
+ # PNG / JPG / JPEG
40
+ img = Image.open(file).convert("L")
41
+ img = img.resize((224, 224))
42
+ img = np.array(img).astype(np.float32) / 255.0
43
+ return img
44
+
45
+
46
+ # ---------------------------
47
+ # Robust Grad-CAM
48
+ # ---------------------------
49
+ def grad_cam(model, img_array, layer_name=None, eps=1e-8):
50
+ """
51
+ img_array: (224,224) normalized grayscale → will be expanded to (1,224,224,1)
52
+ """
53
+
54
+ # Expand dims for model
55
+ x = np.expand_dims(img_array, axis=0) # (1,224,224)
56
+ x = np.expand_dims(x, axis=-1) # (1,224,224,1)
57
+ x = tf.convert_to_tensor(x, dtype=tf.float32)
58
+
59
+ # Auto-detect last conv layer if not provided
60
+ if layer_name is None:
61
+ for layer in reversed(model.layers):
62
+ if hasattr(layer, "output_shape") and len(layer.output_shape) == 4:
63
+ layer_name = layer.name
64
+ break
65
+
66
+ last_conv = model.get_layer(layer_name)
67
+ grad_model = Model([model.inputs], [last_conv.output, model.output])
68
+
69
  with tf.GradientTape() as tape:
70
+ conv_outputs, predictions = grad_model(x)
71
+ class_idx = int(tf.argmax(predictions[0]))
72
+ loss = predictions[:, class_idx]
73
+
74
  grads = tape.gradient(loss, conv_outputs)
75
+ if grads is None:
76
+ raise RuntimeError("Gradients are None. Model may not be connected properly.")
77
+
78
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)).numpy()
79
+
80
+ conv_outputs = conv_outputs[0].numpy() # (H,W,channels)
81
+
82
+ for i in range(len(pooled_grads)):
83
+ conv_outputs[:, :, i] *= pooled_grads[i]
84
+
85
+ cam = np.mean(conv_outputs, axis=-1)
86
  cam = np.maximum(cam, 0)
87
+
88
+ cam -= cam.min()
89
+ if cam.max() > eps:
90
+ cam /= cam.max() + eps
91
+
92
+ cam = cv2.resize(cam, (224, 224))
93
+ return cam
94
+
95
+
96
+ # ---------------------------
97
+ # UI
98
+ # ---------------------------
99
+ st.title("🫁 Pneumonia Detection")
100
+ st.write("Upload a chest scan image (DICOM or PNG/JPG).")
101
+
102
+
103
+ file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "dcm"])
104
+
105
+ if file:
 
106
  try:
107
+ img = load_image(file)
 
 
 
108
 
109
+ st.subheader("Input Image")
110
+ st.image(img, caption="Uploaded Image", use_container_width=True, clamp=True)
111
 
112
+ # Model input format (1,224,224,1)
113
+ x = np.expand_dims(img, axis=(0, -1))
114
 
115
+ pred = model.predict(x)[0][0]
116
+ label = "Pneumonia" if pred >= 0.5 else "Normal"
 
 
117
 
118
+ st.subheader("Prediction")
119
+ st.write(f"**Class:** {label}")
120
+ st.write(f"**Probability:** {float(pred):.4f}")
121
 
122
+ # Grad-CAM
123
+ st.subheader("Grad-CAM Heatmap")
124
  try:
125
+ cam = grad_cam(model, img)
126
+ heatmap = cv2.applyColorMap(
127
+ np.uint8(255 * cam), cv2.COLORMAP_JET
128
+ )
129
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
130
+
131
+ overlay = 0.4 * heatmap + 0.6 * np.stack([img*255]*3, axis=-1)
132
+ overlay = overlay.astype(np.uint8)
133
+
134
+ st.image(overlay, caption="Grad-CAM", use_container_width=True)
135
+
136
  except Exception as e:
137
+ st.error(f"Grad-CAM failed: {e}")
138
 
139
+ except Exception as e:
140
+ st.error(f"Error loading image: {e}")