# streamlit_app.py import io import os import numpy as np from PIL import Image import streamlit as st import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess import pydicom from pydicom.pixel_data_handlers.util import apply_voi_lut import matplotlib.cm as cm # -------- CONFIG -------- MODEL_FILENAME = "Model2_exact_serialized.keras" # model file expected in app folder IMG_SIZE = (224, 224) THRESHOLD = 0.62 ENABLE_GRADCAM = True # ------------------------ st.set_page_config(page_title="Pneumonia Detection (CheXNet)", layout="centered") st.title("Pneumonia detection (CheXNet)") st.write("Upload a chest X-ray (DICOM or PNG/JPG). The app predicts probability of pneumonia.") # ------- utilities ------- def dicom_to_image_array(dicom_bytes): ds = pydicom.dcmread(io.BytesIO(dicom_bytes), force=True) try: arr = ds.pixel_array except Exception as e: raise RuntimeError(f"Could not decode DICOM pixel data: {e}") if arr.ndim == 3: arr = arr[0] try: arr = apply_voi_lut(arr, ds) except Exception: pass arr = arr.astype(np.float32) if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1": arr = np.max(arr) - arr mn, mx = arr.min(), arr.max() if mx > mn: arr = (arr - mn) / (mx - mn) else: arr = arr - mn arr = (arr * 255.0).clip(0,255).astype(np.uint8) return arr def to_rgb_uint8_from_upload(uploaded_file): """Return RGB uint8 (H,W,3) array resized to IMG_SIZE.""" if uploaded_file is None: raise RuntimeError("No file") raw = uploaded_file.read() # try DICOM try: ds = pydicom.dcmread(io.BytesIO(raw), stop_before_pixels=True, force=True) if hasattr(ds, "PixelData") or getattr(ds, "Rows", None): arr = dicom_to_image_array(raw) if arr.ndim == 2: arr = np.stack([arr]*3, axis=-1) pil = Image.fromarray(arr).convert("RGB").resize(IMG_SIZE) return np.array(pil) except Exception: pass # fallback normal image try: pil = Image.open(io.BytesIO(raw)).convert("L").resize(IMG_SIZE) arr = np.stack([np.array(pil)]*3, axis=-1) return arr.astype(np.uint8) except Exception as e: raise RuntimeError("Unsupported file format. Upload a DICOM or PNG/JPG.") from e # -------- model load (cached) -------- @st.cache_resource def load_predict_model(model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") m = load_model(model_path, compile=False) return m # Grad-CAM utilities def find_last_conv_layer(m): for layer in reversed(m.layers): out_shape = getattr(layer, "output_shape", None) if out_shape and len(out_shape) == 4 and "conv" in layer.name: return layer.name return m.layers[-3].name def make_gradcam_image(rgb_uint8, model, last_conv_name=None, alpha=0.4, cmap_name="jet"): img = rgb_uint8.astype(np.float32) if last_conv_name is None: last_conv_name = find_last_conv_layer(model) grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(last_conv_name).output, model.output]) x = densenet_preprocess(np.expand_dims(img.astype(np.float32), axis=0)) with tf.GradientTape() as tape: conv_outputs, preds = grad_model(x) loss = preds[:, 0] grads = tape.gradient(loss, conv_outputs) weights = tf.reduce_mean(grads, axis=(1,2)) cam = tf.reduce_sum(tf.multiply(weights[:, tf.newaxis, tf.newaxis, :], conv_outputs), axis=-1) cam = tf.squeeze(cam).numpy() cam = np.maximum(cam, 0) cam_max = cam.max() if cam.max() != 0 else 1e-8 cam = cam / cam_max cam_img = Image.fromarray(np.uint8(cam * 255)).resize((img.shape[1], img.shape[0]), resample=Image.BILINEAR) cam_arr = np.array(cam_img).astype(np.float32)/255.0 colormap = cm.get_cmap(cmap_name) heatmap = colormap(cam_arr)[:, :, :3] heat_uint8 = np.uint8(heatmap * 255) heat_pil = Image.fromarray(heat_uint8).convert("RGBA").resize((img.shape[1], img.shape[0])) base_pil = Image.fromarray(np.uint8(img)).convert("RGBA") blended = Image.blend(base_pil, heat_pil, alpha=alpha) return blended.convert("RGB") # -------- UI elements -------- col1, col2 = st.columns([1,1]) with col1: uploaded = st.file_uploader("Upload DICOM or PNG/JPG", type=["dcm","png","jpg","jpeg","tif","tiff"]) with col2: thresh = st.number_input("Decision threshold (probability)", min_value=0.0, max_value=1.0, value=float(THRESHOLD), step=0.01) if uploaded is not None: try: rgb = to_rgb_uint8_from_upload(uploaded) except Exception as e: st.error(f"Failed to process file: {e}") st.stop() st.image(rgb, caption="Input (resized)", use_column_width=False) # load model (cached) model = load_predict_model(MODEL_FILENAME) # predict x_pre = densenet_preprocess(np.expand_dims(rgb.astype(np.float32), axis=0)) prob = float(model.predict(x_pre, verbose=0).ravel()[0]) pred = int(prob >= thresh) st.markdown(f"**Pneumonia probability:** `{prob:.4f}`") st.markdown(f"**Predicted class (binary):** `{pred}` — **{'Pneumonia' if pred==1 else 'Normal'}**") if ENABLE_GRADCAM: try: cam = make_gradcam_image(rgb, model) st.image(cam, caption="Grad-CAM overlay", use_column_width=False) except Exception as e: st.warning(f"Grad-CAM failed: {e}") else: st.info("Upload a DICOM or PNG/JPG image to run inference.")