Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tensorflow as tf | |
| import numpy as np | |
| import joblib | |
| import json | |
| from PIL import Image | |
| import pandas as pd | |
| import huggingface_hub | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------- | |
| # CONFIG | |
| # --------------------------------------------------- | |
| st.set_page_config( | |
| page_title="Deep learning-based multi-modal data integration enhancing breast cancer disease-free survival prediction ", | |
| page_icon="🧬", | |
| layout="wide" | |
| ) | |
| # CNN_MODEL_PATH = "best_breast_cancer_cnn.keras" | |
| CNN_MODEL_PATH = "hf://MohammedAH/BreastCancerPrediction" | |
| DNN_MODEL_PATH = "src/survival_model.keras" | |
| SCALER_PATH = "src/scaler.pkl" | |
| FEATURES_PATH = "src/features.json" | |
| DATASET_PATH = 'src/processed_breast_cancer_data(1).csv' | |
| TIME_COL = "Overall_Survival_Months" | |
| EVENT_COL = "Event" | |
| ID_COL = "Patient_ID" | |
| # --------------------------------------------------- | |
| # LOAD MODELS | |
| # --------------------------------------------------- | |
| # @st.cache_resource | |
| # def load_cnn(): | |
| # return tf.keras.models.load_model(CNN_MODEL_PATH, compile=False) | |
| def load_cnn(): | |
| model_path = hf_hub_download( | |
| repo_id="MohammedAH/BreastCancerPrediction", | |
| filename="final_combined_model.keras" | |
| ) | |
| model = tf.keras.models.load_model(model_path, compile=False) | |
| return model | |
| def load_dnn(): | |
| return tf.keras.models.load_model(DNN_MODEL_PATH, compile=False) | |
| # --------------------------------------------------- | |
| # LOAD SURVIVAL ASSETS (COMPUTE BRESLOW BASELINE) | |
| # --------------------------------------------------- | |
| def load_survival_assets(): | |
| scaler = joblib.load(SCALER_PATH) | |
| features = json.load(open(FEATURES_PATH)) | |
| breslow_times = np.load("src/breslow_times.npy") | |
| breslow_H0 = np.load("src/breslow_H0.npy") | |
| median_risk = float(np.load("src/median_risk.npy")) | |
| return scaler, features, breslow_times, breslow_H0, median_risk | |
| cnn_model = load_cnn() | |
| dnn_model = load_dnn() | |
| scaler, feature_cols, breslow_times, breslow_H0, median_risk = load_survival_assets() | |
| # --------------------------------------------------- | |
| # IMAGE PREPROCESSING | |
| # --------------------------------------------------- | |
| def preprocess_image(image): | |
| if image.mode != "L": | |
| image = image.convert("L") | |
| image = image.resize((224, 224)) | |
| img = np.array(image) / 255.0 | |
| img = img[np.newaxis, ..., np.newaxis] | |
| return img | |
| # --------------------------------------------------- | |
| # CNN PREDICTION | |
| # --------------------------------------------------- | |
| def predict_cancer(image): | |
| img = preprocess_image(image) | |
| pred = cnn_model.predict(img, verbose=0)[0][0] | |
| result = "Malignant" if pred > 0.5 else "Benign" | |
| confidence = pred if pred > 0.5 else 1 - pred | |
| return result, confidence, pred | |
| # --------------------------------------------------- | |
| # SURVIVAL FUNCTION | |
| # --------------------------------------------------- | |
| def survival_prob(risk, t): | |
| idx = np.searchsorted(breslow_times, t, side="right") - 1 | |
| if idx < 0: | |
| return 1.0 | |
| h0 = breslow_H0[idx] | |
| return float(np.exp(-h0 * np.exp(risk))) | |
| # --------------------------------------------------- | |
| # SURVIVAL PREDICTION | |
| # --------------------------------------------------- | |
| def predict_survival(feature_values): | |
| row = np.array([feature_values], dtype=np.float32) | |
| row = scaler.transform(row) | |
| risk = float(dnn_model.predict(row, verbose=0)[0][0]) | |
| s1 = survival_prob(risk, 12) * 100 | |
| s3 = survival_prob(risk, 36) * 100 | |
| s5 = survival_prob(risk, 60) * 100 | |
| return risk, s1, s3, s5 | |
| # --------------------------------------------------- | |
| # FEATURE ENGINEERING | |
| # --------------------------------------------------- | |
| def build_feature_vector(inputs_dict): | |
| age = inputs_dict["Age at Diagnosis"] | |
| tumor_size = inputs_dict["Tumor Size"] | |
| nodes = inputs_dict["Lymph nodes examined positive"] | |
| stage = inputs_dict["Tumor Stage_encoded"] | |
| er = inputs_dict["ER Status_encoded"] | |
| pr = inputs_dict["PR Status_encoded"] | |
| her2 = inputs_dict["HER2 Status_encoded"] | |
| # engineered features | |
| tumor_size_log = np.log1p(tumor_size) | |
| lymph_node_ratio = nodes / (nodes + 1) | |
| age_stage_interaction = age * stage | |
| favorable_biomarker = 1 if (er == 1 and pr == 1 and her2 == 0) else 0 | |
| feature_vector = [] | |
| for col in feature_cols: | |
| if col in inputs_dict: | |
| feature_vector.append(inputs_dict[col]) | |
| elif col == "tumor_size_log": | |
| feature_vector.append(tumor_size_log) | |
| elif col == "lymph_node_ratio": | |
| feature_vector.append(lymph_node_ratio) | |
| elif col == "age_stage_interaction": | |
| feature_vector.append(age_stage_interaction) | |
| elif col == "favorable_biomarker": | |
| feature_vector.append(favorable_biomarker) | |
| else: | |
| feature_vector.append(0) | |
| return feature_vector | |
| # --------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------- | |
| # --------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------- | |
| st.title("🧬 Deep learning-based multi-modal data integration enhancing breast cancer disease-free survival prediction ") | |
| st.markdown( | |
| """ | |
| Workflow: | |
| 1️⃣ Upload histopathology image | |
| 2️⃣ Enter patient clinical features | |
| 3️⃣ AI predicts tumor malignancy and survival probability | |
| """ | |
| ) | |
| # --------------------------------------------------- | |
| # STEP 1 — Upload Image | |
| # --------------------------------------------------- | |
| st.header("Step 1: Upload Tumor Image") | |
| uploaded = st.file_uploader( | |
| "Upload Histopathology Image", | |
| type=["png", "jpg", "jpeg"] | |
| ) | |
| image = None | |
| if uploaded: | |
| image = Image.open(uploaded) | |
| st.image(image, width=300) | |
| # --------------------------------------------------- | |
| # STEP 2 — Enter Clinical Features | |
| # --------------------------------------------------- | |
| st.header("Step 2: Enter Patient Clinical Features") | |
| # Base clinical inputs only | |
| age = st.number_input("Age at Diagnosis", 20, 100, 50) | |
| tumor_size = st.number_input("Tumor Size (mm)", 0.0, 200.0, 20.0) | |
| nodes = st.number_input("Positive Lymph Nodes", 0, 50, 0) | |
| stage = st.selectbox( | |
| "Tumor Stage", | |
| [0,1,2,3,4] | |
| ) | |
| er = st.selectbox( | |
| "ER Status", | |
| [0,1] | |
| ) | |
| pr = st.selectbox( | |
| "PR Status", | |
| [0,1] | |
| ) | |
| her2 = st.selectbox( | |
| "HER2 Status", | |
| [0,1] | |
| ) | |
| user_inputs = { | |
| "Age at Diagnosis": age, | |
| "Tumor Size": tumor_size, | |
| "Lymph nodes examined positive": nodes, | |
| "Tumor Stage_encoded": stage, | |
| "ER Status_encoded": er, | |
| "PR Status_encoded": pr, | |
| "HER2 Status_encoded": her2 | |
| } | |
| # --------------------------------------------------- | |
| # STEP 3 — Run AI Analysis | |
| # --------------------------------------------------- | |
| st.header("Step 3: Run AI Diagnosis") | |
| if st.button("Run Full AI Analysis"): | |
| if image is None: | |
| st.error("Please upload an image first.") | |
| st.stop() | |
| # ---- CNN Prediction ---- | |
| result, conf, score = predict_cancer(image) | |
| # ---- Survival Prediction ---- | |
| # risk, s1, s3, s5 = predict_survival(inputs) | |
| features = build_feature_vector(user_inputs) | |
| risk, s1, s3, s5 = predict_survival(features) | |
| st.markdown("---") | |
| st.header("AI Analysis Results") | |
| # ---------------------------- | |
| # Diagnosis | |
| # ---------------------------- | |
| st.subheader("Tumor Diagnosis") | |
| col1, col2 = st.columns(2) | |
| col1.metric("Diagnosis", result) | |
| col2.metric("Confidence", f"{conf*100:.2f}%") | |
| st.write("Prediction Score:", round(score, 4)) | |
| # ---------------------------- | |
| # Survival | |
| # ---------------------------- | |
| st.subheader("Patient Survival Prediction") | |
| st.metric("Risk Score", round(risk, 4)) | |
| c1, c2, c3 = st.columns(3) | |
| c1.metric("1-Year Survival", f"{s1:.1f}%") | |
| c2.metric("3-Year Survival", f"{s3:.1f}%") | |
| c3.metric("5-Year Survival", f"{s5:.1f}%") | |
| if risk >= median_risk: | |
| st.error("High Risk Category") | |
| else: | |
| st.success("Low Risk Category") | |
| # --------------------------------------------------- | |
| # FOOTER | |
| # --------------------------------------------------- | |
| st.markdown("---") | |
| st.caption("AI-assisted clinical decision support system") | |