BreastCancerSurvival / src /streamlit_app.py
MohammedAH's picture
Update src/streamlit_app.py
830657c verified
import streamlit as st
import tensorflow as tf
import numpy as np
import joblib
import json
from PIL import Image
import pandas as pd
import huggingface_hub
from huggingface_hub import hf_hub_download
# ---------------------------------------------------
# CONFIG
# ---------------------------------------------------
st.set_page_config(
page_title="Deep learning-based multi-modal data integration enhancing breast cancer disease-free survival prediction ",
page_icon="🧬",
layout="wide"
)
# CNN_MODEL_PATH = "best_breast_cancer_cnn.keras"
CNN_MODEL_PATH = "hf://MohammedAH/BreastCancerPrediction"
DNN_MODEL_PATH = "src/survival_model.keras"
SCALER_PATH = "src/scaler.pkl"
FEATURES_PATH = "src/features.json"
DATASET_PATH = 'src/processed_breast_cancer_data(1).csv'
TIME_COL = "Overall_Survival_Months"
EVENT_COL = "Event"
ID_COL = "Patient_ID"
# ---------------------------------------------------
# LOAD MODELS
# ---------------------------------------------------
# @st.cache_resource
# def load_cnn():
# return tf.keras.models.load_model(CNN_MODEL_PATH, compile=False)
@st.cache_resource
def load_cnn():
model_path = hf_hub_download(
repo_id="MohammedAH/BreastCancerPrediction",
filename="final_combined_model.keras"
)
model = tf.keras.models.load_model(model_path, compile=False)
return model
@st.cache_resource
def load_dnn():
return tf.keras.models.load_model(DNN_MODEL_PATH, compile=False)
# ---------------------------------------------------
# LOAD SURVIVAL ASSETS (COMPUTE BRESLOW BASELINE)
# ---------------------------------------------------
@st.cache_resource
def load_survival_assets():
scaler = joblib.load(SCALER_PATH)
features = json.load(open(FEATURES_PATH))
breslow_times = np.load("src/breslow_times.npy")
breslow_H0 = np.load("src/breslow_H0.npy")
median_risk = float(np.load("src/median_risk.npy"))
return scaler, features, breslow_times, breslow_H0, median_risk
cnn_model = load_cnn()
dnn_model = load_dnn()
scaler, feature_cols, breslow_times, breslow_H0, median_risk = load_survival_assets()
# ---------------------------------------------------
# IMAGE PREPROCESSING
# ---------------------------------------------------
def preprocess_image(image):
if image.mode != "L":
image = image.convert("L")
image = image.resize((224, 224))
img = np.array(image) / 255.0
img = img[np.newaxis, ..., np.newaxis]
return img
# ---------------------------------------------------
# CNN PREDICTION
# ---------------------------------------------------
def predict_cancer(image):
img = preprocess_image(image)
pred = cnn_model.predict(img, verbose=0)[0][0]
result = "Malignant" if pred > 0.5 else "Benign"
confidence = pred if pred > 0.5 else 1 - pred
return result, confidence, pred
# ---------------------------------------------------
# SURVIVAL FUNCTION
# ---------------------------------------------------
def survival_prob(risk, t):
idx = np.searchsorted(breslow_times, t, side="right") - 1
if idx < 0:
return 1.0
h0 = breslow_H0[idx]
return float(np.exp(-h0 * np.exp(risk)))
# ---------------------------------------------------
# SURVIVAL PREDICTION
# ---------------------------------------------------
def predict_survival(feature_values):
row = np.array([feature_values], dtype=np.float32)
row = scaler.transform(row)
risk = float(dnn_model.predict(row, verbose=0)[0][0])
s1 = survival_prob(risk, 12) * 100
s3 = survival_prob(risk, 36) * 100
s5 = survival_prob(risk, 60) * 100
return risk, s1, s3, s5
# ---------------------------------------------------
# FEATURE ENGINEERING
# ---------------------------------------------------
def build_feature_vector(inputs_dict):
age = inputs_dict["Age at Diagnosis"]
tumor_size = inputs_dict["Tumor Size"]
nodes = inputs_dict["Lymph nodes examined positive"]
stage = inputs_dict["Tumor Stage_encoded"]
er = inputs_dict["ER Status_encoded"]
pr = inputs_dict["PR Status_encoded"]
her2 = inputs_dict["HER2 Status_encoded"]
# engineered features
tumor_size_log = np.log1p(tumor_size)
lymph_node_ratio = nodes / (nodes + 1)
age_stage_interaction = age * stage
favorable_biomarker = 1 if (er == 1 and pr == 1 and her2 == 0) else 0
feature_vector = []
for col in feature_cols:
if col in inputs_dict:
feature_vector.append(inputs_dict[col])
elif col == "tumor_size_log":
feature_vector.append(tumor_size_log)
elif col == "lymph_node_ratio":
feature_vector.append(lymph_node_ratio)
elif col == "age_stage_interaction":
feature_vector.append(age_stage_interaction)
elif col == "favorable_biomarker":
feature_vector.append(favorable_biomarker)
else:
feature_vector.append(0)
return feature_vector
# ---------------------------------------------------
# UI
# ---------------------------------------------------
# ---------------------------------------------------
# UI
# ---------------------------------------------------
st.title("🧬 Deep learning-based multi-modal data integration enhancing breast cancer disease-free survival prediction ")
st.markdown(
"""
Workflow:
1️⃣ Upload histopathology image
2️⃣ Enter patient clinical features
3️⃣ AI predicts tumor malignancy and survival probability
"""
)
# ---------------------------------------------------
# STEP 1 — Upload Image
# ---------------------------------------------------
st.header("Step 1: Upload Tumor Image")
uploaded = st.file_uploader(
"Upload Histopathology Image",
type=["png", "jpg", "jpeg"]
)
image = None
if uploaded:
image = Image.open(uploaded)
st.image(image, width=300)
# ---------------------------------------------------
# STEP 2 — Enter Clinical Features
# ---------------------------------------------------
st.header("Step 2: Enter Patient Clinical Features")
# Base clinical inputs only
age = st.number_input("Age at Diagnosis", 20, 100, 50)
tumor_size = st.number_input("Tumor Size (mm)", 0.0, 200.0, 20.0)
nodes = st.number_input("Positive Lymph Nodes", 0, 50, 0)
stage = st.selectbox(
"Tumor Stage",
[0,1,2,3,4]
)
er = st.selectbox(
"ER Status",
[0,1]
)
pr = st.selectbox(
"PR Status",
[0,1]
)
her2 = st.selectbox(
"HER2 Status",
[0,1]
)
user_inputs = {
"Age at Diagnosis": age,
"Tumor Size": tumor_size,
"Lymph nodes examined positive": nodes,
"Tumor Stage_encoded": stage,
"ER Status_encoded": er,
"PR Status_encoded": pr,
"HER2 Status_encoded": her2
}
# ---------------------------------------------------
# STEP 3 — Run AI Analysis
# ---------------------------------------------------
st.header("Step 3: Run AI Diagnosis")
if st.button("Run Full AI Analysis"):
if image is None:
st.error("Please upload an image first.")
st.stop()
# ---- CNN Prediction ----
result, conf, score = predict_cancer(image)
# ---- Survival Prediction ----
# risk, s1, s3, s5 = predict_survival(inputs)
features = build_feature_vector(user_inputs)
risk, s1, s3, s5 = predict_survival(features)
st.markdown("---")
st.header("AI Analysis Results")
# ----------------------------
# Diagnosis
# ----------------------------
st.subheader("Tumor Diagnosis")
col1, col2 = st.columns(2)
col1.metric("Diagnosis", result)
col2.metric("Confidence", f"{conf*100:.2f}%")
st.write("Prediction Score:", round(score, 4))
# ----------------------------
# Survival
# ----------------------------
st.subheader("Patient Survival Prediction")
st.metric("Risk Score", round(risk, 4))
c1, c2, c3 = st.columns(3)
c1.metric("1-Year Survival", f"{s1:.1f}%")
c2.metric("3-Year Survival", f"{s3:.1f}%")
c3.metric("5-Year Survival", f"{s5:.1f}%")
if risk >= median_risk:
st.error("High Risk Category")
else:
st.success("Low Risk Category")
# ---------------------------------------------------
# FOOTER
# ---------------------------------------------------
st.markdown("---")
st.caption("AI-assisted clinical decision support system")