Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import json | |
| from datetime import datetime | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import pydicom | |
| from fpdf import FPDF | |
| # ----------------------------- | |
| # Page config | |
| # ----------------------------- | |
| st.set_page_config( | |
| page_title="Pneumonia Detection (Chest X-ray) - Clinical Decision Support", | |
| layout="centered" | |
| ) | |
| st.title("Pneumonia Detection (Chest X-ray) - Clinical Decision Support") | |
| st.caption( | |
| "Upload one or more Chest X-ray DICOM files (.dcm). Adjust the decision threshold and click Submit. " | |
| "This tool is for decision support only and does not replace clinical judgment." | |
| ) | |
| # ----------------------------- | |
| # Paths / Model Loading | |
| # ----------------------------- | |
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| MODEL_PATH = os.path.join(REPO_ROOT, "model.keras") | |
| VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json") # optional | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"model.keras not found at: {MODEL_PATH}") | |
| try: | |
| m = keras.models.load_model(MODEL_PATH) | |
| except Exception: | |
| # If you trained it, it's safe to allow deserialization | |
| keras.config.enable_unsafe_deserialization() | |
| m = keras.models.load_model(MODEL_PATH, safe_mode=False) | |
| return m | |
| model = load_model() | |
| # model input details | |
| input_shape = model.input_shape # (None, H, W, C) | |
| img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256 | |
| exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1 | |
| def get_model_version(): | |
| if os.path.exists(VERSION_PATH): | |
| try: | |
| with open(VERSION_PATH, "r") as f: | |
| return json.load(f).get("version", "ResNet50_v1") | |
| except Exception: | |
| return "ResNet50_v1" | |
| return "ResNet50_v1" | |
| MODEL_VERSION = get_model_version() | |
| # ----------------------------- | |
| # Text safety (PDF + error messages) | |
| # ----------------------------- | |
| def safe_text(s: str, max_len: int = 200) -> str: | |
| if s is None: | |
| return "" | |
| s = str(s) | |
| # replace common unicode characters that can break FPDF | |
| s = s.replace("–", "-").replace("—", "-").replace("’", "'").replace("“", '"').replace("”", '"') | |
| # add break opportunities for long tokens (UUIDs / filenames) | |
| s = s.replace("-", "- ").replace("_", "_ ").replace("/", "/ ") | |
| # keep latin-1 safe for default FPDF fonts | |
| s = s.encode("latin-1", "replace").decode("latin-1") | |
| # trim long strings | |
| if len(s) > max_len: | |
| s = s[:max_len] + "..." | |
| return s | |
| # ----------------------------- | |
| # Confidence interpretation | |
| # ----------------------------- | |
| def interpret_confidence(prob: float) -> str: | |
| if prob < 0.30: | |
| return "Low likelihood (<30%)" | |
| elif prob <= 0.60: | |
| return "Borderline suspicion (30-60%)" | |
| else: | |
| return "High likelihood (>60%)" | |
| # ----------------------------- | |
| # DICOM + preprocessing | |
| # ----------------------------- | |
| def dicom_bytes_to_img(data: bytes) -> np.ndarray: | |
| dcm = pydicom.dcmread(io.BytesIO(data)) | |
| img = dcm.pixel_array.astype(np.float32) | |
| img_min = float(np.min(img)) | |
| img_max = float(np.max(img)) | |
| img = (img - img_min) / (img_max - img_min + 1e-8) # 0..1 | |
| return img | |
| def preprocess(img_2d: np.ndarray) -> np.ndarray: | |
| # (H,W) -> (1,img_size,img_size,C) float32 0..1 | |
| x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1) | |
| x = tf.image.resize(x, (img_size, img_size)) | |
| x = tf.clip_by_value(x, 0.0, 1.0) | |
| x = x.numpy() # (img_size,img_size,1) | |
| if exp_ch == 3 and x.shape[-1] == 1: | |
| x = np.repeat(x, 3, axis=-1) # (img_size,img_size,3) | |
| elif exp_ch == 1 and x.shape[-1] == 3: | |
| x = x[..., :1] # (img_size,img_size,1) | |
| x = np.expand_dims(x, axis=0) # (1,img_size,img_size,C) | |
| return x.astype(np.float32) | |
| def predict_prob(x: np.ndarray) -> float: | |
| pred = model.predict(x, verbose=0) | |
| if isinstance(pred, (list, tuple)): | |
| prob = float(np.ravel(pred[-1])[0]) | |
| else: | |
| prob = float(np.ravel(pred)[0]) | |
| return max(0.0, min(1.0, prob)) | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| st.subheader("Model Parameters") | |
| threshold = st.slider( | |
| "Decision Threshold", | |
| min_value=0.01, | |
| max_value=0.99, | |
| value=0.37, # your ResNet best threshold default | |
| step=0.01, | |
| help="If predicted probability is greater than or equal to the threshold, output is Pneumonia. Otherwise Not Pneumonia." | |
| ) | |
| st.subheader("Upload Chest X-ray DICOM Files") | |
| uploaded_files = st.file_uploader( | |
| "Select one or multiple DICOM files (.dcm)", | |
| type=["dcm"], | |
| accept_multiple_files=True | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| submit = st.button("Submit", type="primary", use_container_width=True) | |
| with col2: | |
| clear = st.button("Clear", use_container_width=True) | |
| if clear: | |
| st.rerun() | |
| st.subheader("Prediction Results") | |
| if submit: | |
| if not uploaded_files: | |
| st.warning("Please upload at least one DICOM file before submitting.") | |
| else: | |
| # cache bytes once (so we can read safely) | |
| file_bytes = {f.name: f.getvalue() for f in uploaded_files} | |
| rows = [] | |
| with st.spinner("Running inference..."): | |
| for name, data in file_bytes.items(): | |
| ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| try: | |
| img = dicom_bytes_to_img(data) | |
| x = preprocess(img) | |
| prob = predict_prob(x) | |
| pred_label = "Pneumonia" if prob >= threshold else "Not Pneumonia" | |
| conf_level = interpret_confidence(prob) | |
| rows.append({ | |
| "timestamp": ts, | |
| "model_version": MODEL_VERSION, | |
| "file_name": name, | |
| "probability": prob, | |
| "prediction": pred_label, | |
| "confidence_level": conf_level, | |
| "error": "" | |
| }) | |
| except Exception as e: | |
| rows.append({ | |
| "timestamp": ts, | |
| "model_version": MODEL_VERSION, | |
| "file_name": name, | |
| "probability": np.nan, | |
| "prediction": "Error", | |
| "confidence_level": "", | |
| "error": safe_text(str(e), max_len=140) | |
| }) | |
| df = pd.DataFrame(rows) | |
| # Sentence-style outputs | |
| for _, r in df.iterrows(): | |
| if r["prediction"] == "Error": | |
| st.error( | |
| f"For the uploaded file '{r['file_name']}', the system could not generate a prediction. " | |
| f"Reason: {r['error']}." | |
| ) | |
| continue | |
| prob_pct = float(r["probability"]) * 100.0 | |
| st.write( | |
| f"For the uploaded file '{r['file_name']}', the model estimates a pneumonia probability of " | |
| f"{prob_pct:.2f}%. This falls under '{r['confidence_level']}'. " | |
| f"Based on the selected decision threshold of {threshold:.2f}, the predicted outcome is " | |
| f"'{r['prediction']}'." | |
| ) | |
| st.divider() | |
| st.caption( | |
| "Clinical note: This application is designed for decision support only. Final diagnosis and treatment decisions " | |
| "must be made by qualified healthcare professionals." | |
| ) | |