import streamlit as st import tensorflow as tf import numpy as np import joblib import json from PIL import Image import pandas as pd import huggingface_hub from huggingface_hub import hf_hub_download # --------------------------------------------------- # CONFIG # --------------------------------------------------- st.set_page_config( page_title="Deep learning-based multi-modal data integration enhancing breast cancer disease-free survival prediction ", page_icon="🧬", layout="wide" ) # CNN_MODEL_PATH = "best_breast_cancer_cnn.keras" CNN_MODEL_PATH = "hf://MohammedAH/BreastCancerPrediction" DNN_MODEL_PATH = "src/survival_model.keras" SCALER_PATH = "src/scaler.pkl" FEATURES_PATH = "src/features.json" DATASET_PATH = 'src/processed_breast_cancer_data(1).csv' TIME_COL = "Overall_Survival_Months" EVENT_COL = "Event" ID_COL = "Patient_ID" # --------------------------------------------------- # LOAD MODELS # --------------------------------------------------- # @st.cache_resource # def load_cnn(): # return tf.keras.models.load_model(CNN_MODEL_PATH, compile=False) @st.cache_resource def load_cnn(): model_path = hf_hub_download( repo_id="MohammedAH/BreastCancerPrediction", filename="final_combined_model.keras" ) model = tf.keras.models.load_model(model_path, compile=False) return model @st.cache_resource def load_dnn(): return tf.keras.models.load_model(DNN_MODEL_PATH, compile=False) # --------------------------------------------------- # LOAD SURVIVAL ASSETS (COMPUTE BRESLOW BASELINE) # --------------------------------------------------- @st.cache_resource def load_survival_assets(): scaler = joblib.load(SCALER_PATH) features = json.load(open(FEATURES_PATH)) breslow_times = np.load("src/breslow_times.npy") breslow_H0 = np.load("src/breslow_H0.npy") median_risk = float(np.load("src/median_risk.npy")) return scaler, features, breslow_times, breslow_H0, median_risk cnn_model = load_cnn() dnn_model = load_dnn() scaler, feature_cols, breslow_times, breslow_H0, median_risk = load_survival_assets() # --------------------------------------------------- # IMAGE PREPROCESSING # --------------------------------------------------- def preprocess_image(image): if image.mode != "L": image = image.convert("L") image = image.resize((224, 224)) img = np.array(image) / 255.0 img = img[np.newaxis, ..., np.newaxis] return img # --------------------------------------------------- # CNN PREDICTION # --------------------------------------------------- def predict_cancer(image): img = preprocess_image(image) pred = cnn_model.predict(img, verbose=0)[0][0] result = "Malignant" if pred > 0.5 else "Benign" confidence = pred if pred > 0.5 else 1 - pred return result, confidence, pred # --------------------------------------------------- # SURVIVAL FUNCTION # --------------------------------------------------- def survival_prob(risk, t): idx = np.searchsorted(breslow_times, t, side="right") - 1 if idx < 0: return 1.0 h0 = breslow_H0[idx] return float(np.exp(-h0 * np.exp(risk))) # --------------------------------------------------- # SURVIVAL PREDICTION # --------------------------------------------------- def predict_survival(feature_values): row = np.array([feature_values], dtype=np.float32) row = scaler.transform(row) risk = float(dnn_model.predict(row, verbose=0)[0][0]) s1 = survival_prob(risk, 12) * 100 s3 = survival_prob(risk, 36) * 100 s5 = survival_prob(risk, 60) * 100 return risk, s1, s3, s5 # --------------------------------------------------- # FEATURE ENGINEERING # --------------------------------------------------- def build_feature_vector(inputs_dict): age = inputs_dict["Age at Diagnosis"] tumor_size = inputs_dict["Tumor Size"] nodes = inputs_dict["Lymph nodes examined positive"] stage = inputs_dict["Tumor Stage_encoded"] er = inputs_dict["ER Status_encoded"] pr = inputs_dict["PR Status_encoded"] her2 = inputs_dict["HER2 Status_encoded"] # engineered features tumor_size_log = np.log1p(tumor_size) lymph_node_ratio = nodes / (nodes + 1) age_stage_interaction = age * stage favorable_biomarker = 1 if (er == 1 and pr == 1 and her2 == 0) else 0 feature_vector = [] for col in feature_cols: if col in inputs_dict: feature_vector.append(inputs_dict[col]) elif col == "tumor_size_log": feature_vector.append(tumor_size_log) elif col == "lymph_node_ratio": feature_vector.append(lymph_node_ratio) elif col == "age_stage_interaction": feature_vector.append(age_stage_interaction) elif col == "favorable_biomarker": feature_vector.append(favorable_biomarker) else: feature_vector.append(0) return feature_vector # --------------------------------------------------- # UI # --------------------------------------------------- # --------------------------------------------------- # UI # --------------------------------------------------- st.title("🧬 Deep learning-based multi-modal data integration enhancing breast cancer disease-free survival prediction ") st.markdown( """ Workflow: 1️⃣ Upload histopathology image 2️⃣ Enter patient clinical features 3️⃣ AI predicts tumor malignancy and survival probability """ ) # --------------------------------------------------- # STEP 1 — Upload Image # --------------------------------------------------- st.header("Step 1: Upload Tumor Image") uploaded = st.file_uploader( "Upload Histopathology Image", type=["png", "jpg", "jpeg"] ) image = None if uploaded: image = Image.open(uploaded) st.image(image, width=300) # --------------------------------------------------- # STEP 2 — Enter Clinical Features # --------------------------------------------------- st.header("Step 2: Enter Patient Clinical Features") # Base clinical inputs only age = st.number_input("Age at Diagnosis", 20, 100, 50) tumor_size = st.number_input("Tumor Size (mm)", 0.0, 200.0, 20.0) nodes = st.number_input("Positive Lymph Nodes", 0, 50, 0) stage = st.selectbox( "Tumor Stage", [0,1,2,3,4] ) er = st.selectbox( "ER Status", [0,1] ) pr = st.selectbox( "PR Status", [0,1] ) her2 = st.selectbox( "HER2 Status", [0,1] ) user_inputs = { "Age at Diagnosis": age, "Tumor Size": tumor_size, "Lymph nodes examined positive": nodes, "Tumor Stage_encoded": stage, "ER Status_encoded": er, "PR Status_encoded": pr, "HER2 Status_encoded": her2 } # --------------------------------------------------- # STEP 3 — Run AI Analysis # --------------------------------------------------- st.header("Step 3: Run AI Diagnosis") if st.button("Run Full AI Analysis"): if image is None: st.error("Please upload an image first.") st.stop() # ---- CNN Prediction ---- result, conf, score = predict_cancer(image) # ---- Survival Prediction ---- # risk, s1, s3, s5 = predict_survival(inputs) features = build_feature_vector(user_inputs) risk, s1, s3, s5 = predict_survival(features) st.markdown("---") st.header("AI Analysis Results") # ---------------------------- # Diagnosis # ---------------------------- st.subheader("Tumor Diagnosis") col1, col2 = st.columns(2) col1.metric("Diagnosis", result) col2.metric("Confidence", f"{conf*100:.2f}%") st.write("Prediction Score:", round(score, 4)) # ---------------------------- # Survival # ---------------------------- st.subheader("Patient Survival Prediction") st.metric("Risk Score", round(risk, 4)) c1, c2, c3 = st.columns(3) c1.metric("1-Year Survival", f"{s1:.1f}%") c2.metric("3-Year Survival", f"{s3:.1f}%") c3.metric("5-Year Survival", f"{s5:.1f}%") if risk >= median_risk: st.error("High Risk Category") else: st.success("Low Risk Category") # --------------------------------------------------- # FOOTER # --------------------------------------------------- st.markdown("---") st.caption("AI-assisted clinical decision support system")