Pneumonia_Detection / src /streamlit_app.py
VJBharathkumar's picture
Update src/streamlit_app.py
e74e049 verified
import io
import os
import json
from datetime import datetime
import numpy as np
import pandas as pd
import streamlit as st
import tensorflow as tf
from tensorflow import keras
import pydicom
from fpdf import FPDF
# -----------------------------
# Page config
# -----------------------------
st.set_page_config(
page_title="Pneumonia Detection (Chest X-ray) - Clinical Decision Support",
layout="centered"
)
st.title("Pneumonia Detection (Chest X-ray) - Clinical Decision Support")
st.caption(
"Upload one or more Chest X-ray DICOM files (.dcm). Adjust the decision threshold and click Submit. "
"This tool is for decision support only and does not replace clinical judgment."
)
# -----------------------------
# Paths / Model Loading
# -----------------------------
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
MODEL_PATH = os.path.join(REPO_ROOT, "model.keras")
VERSION_PATH = os.path.join(REPO_ROOT, "model_version.json") # optional
@st.cache_resource
def load_model():
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"model.keras not found at: {MODEL_PATH}")
try:
m = keras.models.load_model(MODEL_PATH)
except Exception:
# If you trained it, it's safe to allow deserialization
keras.config.enable_unsafe_deserialization()
m = keras.models.load_model(MODEL_PATH, safe_mode=False)
return m
model = load_model()
# model input details
input_shape = model.input_shape # (None, H, W, C)
img_size = int(input_shape[1]) if input_shape and input_shape[1] else 256
exp_ch = int(input_shape[-1]) if input_shape and input_shape[-1] else 1
def get_model_version():
if os.path.exists(VERSION_PATH):
try:
with open(VERSION_PATH, "r") as f:
return json.load(f).get("version", "ResNet50_v1")
except Exception:
return "ResNet50_v1"
return "ResNet50_v1"
MODEL_VERSION = get_model_version()
# -----------------------------
# Text safety (PDF + error messages)
# -----------------------------
def safe_text(s: str, max_len: int = 200) -> str:
if s is None:
return ""
s = str(s)
# replace common unicode characters that can break FPDF
s = s.replace("–", "-").replace("—", "-").replace("’", "'").replace("“", '"').replace("”", '"')
# add break opportunities for long tokens (UUIDs / filenames)
s = s.replace("-", "- ").replace("_", "_ ").replace("/", "/ ")
# keep latin-1 safe for default FPDF fonts
s = s.encode("latin-1", "replace").decode("latin-1")
# trim long strings
if len(s) > max_len:
s = s[:max_len] + "..."
return s
# -----------------------------
# Confidence interpretation
# -----------------------------
def interpret_confidence(prob: float) -> str:
if prob < 0.30:
return "Low likelihood (<30%)"
elif prob <= 0.60:
return "Borderline suspicion (30-60%)"
else:
return "High likelihood (>60%)"
# -----------------------------
# DICOM + preprocessing
# -----------------------------
def dicom_bytes_to_img(data: bytes) -> np.ndarray:
dcm = pydicom.dcmread(io.BytesIO(data))
img = dcm.pixel_array.astype(np.float32)
img_min = float(np.min(img))
img_max = float(np.max(img))
img = (img - img_min) / (img_max - img_min + 1e-8) # 0..1
return img
def preprocess(img_2d: np.ndarray) -> np.ndarray:
# (H,W) -> (1,img_size,img_size,C) float32 0..1
x = tf.convert_to_tensor(img_2d[..., np.newaxis], dtype=tf.float32) # (H,W,1)
x = tf.image.resize(x, (img_size, img_size))
x = tf.clip_by_value(x, 0.0, 1.0)
x = x.numpy() # (img_size,img_size,1)
if exp_ch == 3 and x.shape[-1] == 1:
x = np.repeat(x, 3, axis=-1) # (img_size,img_size,3)
elif exp_ch == 1 and x.shape[-1] == 3:
x = x[..., :1] # (img_size,img_size,1)
x = np.expand_dims(x, axis=0) # (1,img_size,img_size,C)
return x.astype(np.float32)
def predict_prob(x: np.ndarray) -> float:
pred = model.predict(x, verbose=0)
if isinstance(pred, (list, tuple)):
prob = float(np.ravel(pred[-1])[0])
else:
prob = float(np.ravel(pred)[0])
return max(0.0, min(1.0, prob))
# -----------------------------
# UI
# -----------------------------
st.subheader("Model Parameters")
threshold = st.slider(
"Decision Threshold",
min_value=0.01,
max_value=0.99,
value=0.37, # your ResNet best threshold default
step=0.01,
help="If predicted probability is greater than or equal to the threshold, output is Pneumonia. Otherwise Not Pneumonia."
)
st.subheader("Upload Chest X-ray DICOM Files")
uploaded_files = st.file_uploader(
"Select one or multiple DICOM files (.dcm)",
type=["dcm"],
accept_multiple_files=True
)
col1, col2 = st.columns(2)
with col1:
submit = st.button("Submit", type="primary", use_container_width=True)
with col2:
clear = st.button("Clear", use_container_width=True)
if clear:
st.rerun()
st.subheader("Prediction Results")
if submit:
if not uploaded_files:
st.warning("Please upload at least one DICOM file before submitting.")
else:
# cache bytes once (so we can read safely)
file_bytes = {f.name: f.getvalue() for f in uploaded_files}
rows = []
with st.spinner("Running inference..."):
for name, data in file_bytes.items():
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
img = dicom_bytes_to_img(data)
x = preprocess(img)
prob = predict_prob(x)
pred_label = "Pneumonia" if prob >= threshold else "Not Pneumonia"
conf_level = interpret_confidence(prob)
rows.append({
"timestamp": ts,
"model_version": MODEL_VERSION,
"file_name": name,
"probability": prob,
"prediction": pred_label,
"confidence_level": conf_level,
"error": ""
})
except Exception as e:
rows.append({
"timestamp": ts,
"model_version": MODEL_VERSION,
"file_name": name,
"probability": np.nan,
"prediction": "Error",
"confidence_level": "",
"error": safe_text(str(e), max_len=140)
})
df = pd.DataFrame(rows)
# Sentence-style outputs
for _, r in df.iterrows():
if r["prediction"] == "Error":
st.error(
f"For the uploaded file '{r['file_name']}', the system could not generate a prediction. "
f"Reason: {r['error']}."
)
continue
prob_pct = float(r["probability"]) * 100.0
st.write(
f"For the uploaded file '{r['file_name']}', the model estimates a pneumonia probability of "
f"{prob_pct:.2f}%. This falls under '{r['confidence_level']}'. "
f"Based on the selected decision threshold of {threshold:.2f}, the predicted outcome is "
f"'{r['prediction']}'."
)
st.divider()
st.caption(
"Clinical note: This application is designed for decision support only. Final diagnosis and treatment decisions "
"must be made by qualified healthcare professionals."
)