Spaces:
Sleeping
Sleeping
| # streamlit_app.py | |
| import io | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess | |
| import pydicom | |
| from pydicom.pixel_data_handlers.util import apply_voi_lut | |
| import matplotlib.cm as cm | |
| # -------- CONFIG -------- | |
| MODEL_FILENAME = "Model2_exact_serialized.keras" # model file expected in app folder | |
| IMG_SIZE = (224, 224) | |
| THRESHOLD = 0.62 | |
| ENABLE_GRADCAM = True | |
| # ------------------------ | |
| st.set_page_config(page_title="Pneumonia Detection (CheXNet)", layout="centered") | |
| st.title("Pneumonia detection (CheXNet)") | |
| st.write("Upload a chest X-ray (DICOM or PNG/JPG). The app predicts probability of pneumonia.") | |
| # ------- utilities ------- | |
| def dicom_to_image_array(dicom_bytes): | |
| ds = pydicom.dcmread(io.BytesIO(dicom_bytes), force=True) | |
| try: | |
| arr = ds.pixel_array | |
| except Exception as e: | |
| raise RuntimeError(f"Could not decode DICOM pixel data: {e}") | |
| if arr.ndim == 3: | |
| arr = arr[0] | |
| try: | |
| arr = apply_voi_lut(arr, ds) | |
| except Exception: | |
| pass | |
| arr = arr.astype(np.float32) | |
| if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1": | |
| arr = np.max(arr) - arr | |
| mn, mx = arr.min(), arr.max() | |
| if mx > mn: | |
| arr = (arr - mn) / (mx - mn) | |
| else: | |
| arr = arr - mn | |
| arr = (arr * 255.0).clip(0,255).astype(np.uint8) | |
| return arr | |
| def to_rgb_uint8_from_upload(uploaded_file): | |
| """Return RGB uint8 (H,W,3) array resized to IMG_SIZE.""" | |
| if uploaded_file is None: | |
| raise RuntimeError("No file") | |
| raw = uploaded_file.read() | |
| # try DICOM | |
| try: | |
| ds = pydicom.dcmread(io.BytesIO(raw), stop_before_pixels=True, force=True) | |
| if hasattr(ds, "PixelData") or getattr(ds, "Rows", None): | |
| arr = dicom_to_image_array(raw) | |
| if arr.ndim == 2: | |
| arr = np.stack([arr]*3, axis=-1) | |
| pil = Image.fromarray(arr).convert("RGB").resize(IMG_SIZE) | |
| return np.array(pil) | |
| except Exception: | |
| pass | |
| # fallback normal image | |
| try: | |
| pil = Image.open(io.BytesIO(raw)).convert("L").resize(IMG_SIZE) | |
| arr = np.stack([np.array(pil)]*3, axis=-1) | |
| return arr.astype(np.uint8) | |
| except Exception as e: | |
| raise RuntimeError("Unsupported file format. Upload a DICOM or PNG/JPG.") from e | |
| # -------- model load (cached) -------- | |
| def load_predict_model(model_path): | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| m = load_model(model_path, compile=False) | |
| return m | |
| # Grad-CAM utilities | |
| def find_last_conv_layer(m): | |
| for layer in reversed(m.layers): | |
| out_shape = getattr(layer, "output_shape", None) | |
| if out_shape and len(out_shape) == 4 and "conv" in layer.name: | |
| return layer.name | |
| return m.layers[-3].name | |
| def make_gradcam_image(rgb_uint8, model, last_conv_name=None, alpha=0.4, cmap_name="jet"): | |
| img = rgb_uint8.astype(np.float32) | |
| if last_conv_name is None: | |
| last_conv_name = find_last_conv_layer(model) | |
| grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(last_conv_name).output, model.output]) | |
| x = densenet_preprocess(np.expand_dims(img.astype(np.float32), axis=0)) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, preds = grad_model(x) | |
| loss = preds[:, 0] | |
| grads = tape.gradient(loss, conv_outputs) | |
| weights = tf.reduce_mean(grads, axis=(1,2)) | |
| cam = tf.reduce_sum(tf.multiply(weights[:, tf.newaxis, tf.newaxis, :], conv_outputs), axis=-1) | |
| cam = tf.squeeze(cam).numpy() | |
| cam = np.maximum(cam, 0) | |
| cam_max = cam.max() if cam.max() != 0 else 1e-8 | |
| cam = cam / cam_max | |
| cam_img = Image.fromarray(np.uint8(cam * 255)).resize((img.shape[1], img.shape[0]), resample=Image.BILINEAR) | |
| cam_arr = np.array(cam_img).astype(np.float32)/255.0 | |
| colormap = cm.get_cmap(cmap_name) | |
| heatmap = colormap(cam_arr)[:, :, :3] | |
| heat_uint8 = np.uint8(heatmap * 255) | |
| heat_pil = Image.fromarray(heat_uint8).convert("RGBA").resize((img.shape[1], img.shape[0])) | |
| base_pil = Image.fromarray(np.uint8(img)).convert("RGBA") | |
| blended = Image.blend(base_pil, heat_pil, alpha=alpha) | |
| return blended.convert("RGB") | |
| # -------- UI elements -------- | |
| col1, col2 = st.columns([1,1]) | |
| with col1: | |
| uploaded = st.file_uploader("Upload DICOM or PNG/JPG", type=["dcm","png","jpg","jpeg","tif","tiff"]) | |
| with col2: | |
| thresh = st.number_input("Decision threshold (probability)", min_value=0.0, max_value=1.0, value=float(THRESHOLD), step=0.01) | |
| if uploaded is not None: | |
| try: | |
| rgb = to_rgb_uint8_from_upload(uploaded) | |
| except Exception as e: | |
| st.error(f"Failed to process file: {e}") | |
| st.stop() | |
| st.image(rgb, caption="Input (resized)", use_column_width=False) | |
| # load model (cached) | |
| model = load_predict_model(MODEL_FILENAME) | |
| # predict | |
| x_pre = densenet_preprocess(np.expand_dims(rgb.astype(np.float32), axis=0)) | |
| prob = float(model.predict(x_pre, verbose=0).ravel()[0]) | |
| pred = int(prob >= thresh) | |
| st.markdown(f"**Pneumonia probability:** `{prob:.4f}`") | |
| st.markdown(f"**Predicted class (binary):** `{pred}` — **{'Pneumonia' if pred==1 else 'Normal'}**") | |
| if ENABLE_GRADCAM: | |
| try: | |
| cam = make_gradcam_image(rgb, model) | |
| st.image(cam, caption="Grad-CAM overlay", use_column_width=False) | |
| except Exception as e: | |
| st.warning(f"Grad-CAM failed: {e}") | |
| else: | |
| st.info("Upload a DICOM or PNG/JPG image to run inference.") | |