Pneumonia_Detection / streamlit_app.py
Pushpak21's picture
Upload folder using huggingface_hub
2443ce1 verified
# streamlit_app.py
import io
import os
import numpy as np
from PIL import Image
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.cm as cm
# -------- CONFIG --------
MODEL_FILENAME = "Model2_exact_serialized.keras" # model file expected in app folder
IMG_SIZE = (224, 224)
THRESHOLD = 0.62
ENABLE_GRADCAM = True
# ------------------------
st.set_page_config(page_title="Pneumonia Detection (CheXNet)", layout="centered")
st.title("Pneumonia detection (CheXNet)")
st.write("Upload a chest X-ray (DICOM or PNG/JPG). The app predicts probability of pneumonia.")
# ------- utilities -------
def dicom_to_image_array(dicom_bytes):
ds = pydicom.dcmread(io.BytesIO(dicom_bytes), force=True)
try:
arr = ds.pixel_array
except Exception as e:
raise RuntimeError(f"Could not decode DICOM pixel data: {e}")
if arr.ndim == 3:
arr = arr[0]
try:
arr = apply_voi_lut(arr, ds)
except Exception:
pass
arr = arr.astype(np.float32)
if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1":
arr = np.max(arr) - arr
mn, mx = arr.min(), arr.max()
if mx > mn:
arr = (arr - mn) / (mx - mn)
else:
arr = arr - mn
arr = (arr * 255.0).clip(0,255).astype(np.uint8)
return arr
def to_rgb_uint8_from_upload(uploaded_file):
"""Return RGB uint8 (H,W,3) array resized to IMG_SIZE."""
if uploaded_file is None:
raise RuntimeError("No file")
raw = uploaded_file.read()
# try DICOM
try:
ds = pydicom.dcmread(io.BytesIO(raw), stop_before_pixels=True, force=True)
if hasattr(ds, "PixelData") or getattr(ds, "Rows", None):
arr = dicom_to_image_array(raw)
if arr.ndim == 2:
arr = np.stack([arr]*3, axis=-1)
pil = Image.fromarray(arr).convert("RGB").resize(IMG_SIZE)
return np.array(pil)
except Exception:
pass
# fallback normal image
try:
pil = Image.open(io.BytesIO(raw)).convert("L").resize(IMG_SIZE)
arr = np.stack([np.array(pil)]*3, axis=-1)
return arr.astype(np.uint8)
except Exception as e:
raise RuntimeError("Unsupported file format. Upload a DICOM or PNG/JPG.") from e
# -------- model load (cached) --------
@st.cache_resource
def load_predict_model(model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
m = load_model(model_path, compile=False)
return m
# Grad-CAM utilities
def find_last_conv_layer(m):
for layer in reversed(m.layers):
out_shape = getattr(layer, "output_shape", None)
if out_shape and len(out_shape) == 4 and "conv" in layer.name:
return layer.name
return m.layers[-3].name
def make_gradcam_image(rgb_uint8, model, last_conv_name=None, alpha=0.4, cmap_name="jet"):
img = rgb_uint8.astype(np.float32)
if last_conv_name is None:
last_conv_name = find_last_conv_layer(model)
grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(last_conv_name).output, model.output])
x = densenet_preprocess(np.expand_dims(img.astype(np.float32), axis=0))
with tf.GradientTape() as tape:
conv_outputs, preds = grad_model(x)
loss = preds[:, 0]
grads = tape.gradient(loss, conv_outputs)
weights = tf.reduce_mean(grads, axis=(1,2))
cam = tf.reduce_sum(tf.multiply(weights[:, tf.newaxis, tf.newaxis, :], conv_outputs), axis=-1)
cam = tf.squeeze(cam).numpy()
cam = np.maximum(cam, 0)
cam_max = cam.max() if cam.max() != 0 else 1e-8
cam = cam / cam_max
cam_img = Image.fromarray(np.uint8(cam * 255)).resize((img.shape[1], img.shape[0]), resample=Image.BILINEAR)
cam_arr = np.array(cam_img).astype(np.float32)/255.0
colormap = cm.get_cmap(cmap_name)
heatmap = colormap(cam_arr)[:, :, :3]
heat_uint8 = np.uint8(heatmap * 255)
heat_pil = Image.fromarray(heat_uint8).convert("RGBA").resize((img.shape[1], img.shape[0]))
base_pil = Image.fromarray(np.uint8(img)).convert("RGBA")
blended = Image.blend(base_pil, heat_pil, alpha=alpha)
return blended.convert("RGB")
# -------- UI elements --------
col1, col2 = st.columns([1,1])
with col1:
uploaded = st.file_uploader("Upload DICOM or PNG/JPG", type=["dcm","png","jpg","jpeg","tif","tiff"])
with col2:
thresh = st.number_input("Decision threshold (probability)", min_value=0.0, max_value=1.0, value=float(THRESHOLD), step=0.01)
if uploaded is not None:
try:
rgb = to_rgb_uint8_from_upload(uploaded)
except Exception as e:
st.error(f"Failed to process file: {e}")
st.stop()
st.image(rgb, caption="Input (resized)", use_column_width=False)
# load model (cached)
model = load_predict_model(MODEL_FILENAME)
# predict
x_pre = densenet_preprocess(np.expand_dims(rgb.astype(np.float32), axis=0))
prob = float(model.predict(x_pre, verbose=0).ravel()[0])
pred = int(prob >= thresh)
st.markdown(f"**Pneumonia probability:** `{prob:.4f}`")
st.markdown(f"**Predicted class (binary):** `{pred}` — **{'Pneumonia' if pred==1 else 'Normal'}**")
if ENABLE_GRADCAM:
try:
cam = make_gradcam_image(rgb, model)
st.image(cam, caption="Grad-CAM overlay", use_column_width=False)
except Exception as e:
st.warning(f"Grad-CAM failed: {e}")
else:
st.info("Upload a DICOM or PNG/JPG image to run inference.")