import tensorflow as tf import numpy as np import joblib import json import pandas as pd from PIL import Image from lifelines import CoxPHFitter import gradio as gr print(f"✓ TensorFlow version: {tf.__version__}") # --------------------------------------------------- # CONFIG # --------------------------------------------------- CNN_MODEL_PATH = "hf://MohammedAH/BreastCancerPrediction" # Hugging Face Hub path DNN_MODEL_PATH = "survival_model.keras" SCALER_PATH = "scaler.pkl" FEATURES_PATH = "features.json" DATASET_PATH = 'processed_breast_cancer_data(1).csv' TIME_COL = "Overall_Survival_Months" EVENT_COL = "Event" # --------------------------------------------------- # GLOBAL ASSETS (loaded once at startup) # --------------------------------------------------- cnn_model = None dnn_model = None scaler = None feature_cols = None breslow_times = None breslow_H0 = None def load_all_assets(): """Load models and survival assets once at startup""" global cnn_model, dnn_model, scaler, feature_cols, breslow_times, breslow_H0 # Load CNN model (from Hugging Face Hub or local) if CNN_MODEL_PATH.startswith("hf://"): from huggingface_hub import hf_hub_download model_path = hf_hub_download( repo_id=CNN_MODEL_PATH.replace("hf://", ""), filename="best_breast_cancer_cnn.keras" ) cnn_model = tf.keras.models.load_model(model_path, compile=False) else: cnn_model = tf.keras.models.load_model(CNN_MODEL_PATH, compile=False) # Load DNN survival model dnn_model = tf.keras.models.load_model(DNN_MODEL_PATH, compile=False) # Load scaler and features scaler = joblib.load(SCALER_PATH) with open(FEATURES_PATH, 'r') as f: feature_cols = json.load(f) # Compute Breslow baseline hazard df = pd.read_csv(DATASET_PATH) feature_df = df[feature_cols].copy() feature_df["duration"] = df[TIME_COL] feature_df["event"] = df[EVENT_COL] cox = CoxPHFitter() cox.fit(feature_df, duration_col="duration", event_col="event") baseline = cox.baseline_cumulative_hazard_ breslow_times = baseline.index.values breslow_H0 = baseline.values.flatten() print("✓ All assets loaded successfully") # Load everything at module import load_all_assets() # --------------------------------------------------- # IMAGE PREPROCESSING # --------------------------------------------------- def preprocess_image(image: Image.Image) -> np.ndarray: """Convert PIL image to model-ready tensor""" if image.mode != "L": image = image.convert("L") image = image.resize((224, 224)) img = np.array(image) / 255.0 img = img[np.newaxis, ..., np.newaxis] # (1, 224, 224, 1) return img # --------------------------------------------------- # CNN PREDICTION # --------------------------------------------------- def predict_cancer(image: Image.Image): """Predict malignancy from histopathology image""" if image is None: return "Please upload an image", 0.0, 0.0 img = preprocess_image(image) pred = float(cnn_model.predict(img, verbose=0)[0][0]) result = "🔴 Malignant" if pred > 0.5 else "🟢 Benign" confidence = max(pred, 1 - pred) return result, round(confidence * 100, 2), round(pred, 4) # --------------------------------------------------- # SURVIVAL FUNCTIONS # --------------------------------------------------- def survival_prob(risk: float, t: float) -> float: """Compute survival probability at time t using Breslow baseline""" idx = np.searchsorted(breslow_times, t, side="right") - 1 if idx < 0: return 1.0 h0 = breslow_H0[idx] return float(np.exp(-h0 * np.exp(risk))) def predict_survival(*feature_values): """Predict survival probabilities from clinical features""" if len(feature_values) != len(feature_cols): return "Error: Feature count mismatch", 0, 0, 0 row = np.array([list(feature_values)], dtype=np.float32) row_scaled = scaler.transform(row) risk = float(dnn_model.predict(row_scaled, verbose=0)[0][0]) s1 = survival_prob(risk, 12) * 100 s3 = survival_prob(risk, 36) * 100 s5 = survival_prob(risk, 60) * 100 risk_category = "🔴 High Risk" if risk > 0 else "🟢 Low Risk" return ( round(risk, 4), f"{risk_category}", f"{s1:.1f}%", f"{s3:.1f}%", f"{s5:.1f}%" ) # --------------------------------------------------- # GRADIO UI # --------------------------------------------------- with gr.Blocks( title="🧬 Breast Cancer AI Diagnosis & Survival", theme=gr.themes.Soft(primary_hue="rose", secondary_hue="blue"), css=""" .main-title { text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 10px; } .subtitle { text-align: center; color: #666; margin-bottom: 30px; } .metric-box { text-align: center; padding: 10px; border-radius: 8px; background: #f9f9f9; } """ ) as demo: gr.Markdown('
🧬 Breast Cancer AI Diagnosis & Survival System
') gr.Markdown( 'Integrates CNN tumor classification + DNN survival prediction • TensorFlow 2.18
' ) with gr.Tabs(): # ===== TAB 1: IMAGE DIAGNOSIS ===== with gr.TabItem("🔬 Image Diagnosis"): with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Upload Histopathology Image", height=300 ) analyze_btn = gr.Button("🔍 Analyze Image", variant="primary") with gr.Column(scale=1): diagnosis_out = gr.Label(label="Diagnosis") confidence_out = gr.Number(label="Confidence (%)", interactive=False) score_out = gr.Number(label="Raw Prediction Score", interactive=False) analyze_btn.click( fn=predict_cancer, inputs=image_input, outputs=[diagnosis_out, confidence_out, score_out] ) gr.Examples( examples=[["example1.jpg"], ["example2.png"]], inputs=image_input, label="Try example images (optional)" ) # ===== TAB 2: SURVIVAL ANALYSIS ===== with gr.TabItem("📈 Survival Analysis"): gr.Markdown("### Enter Patient Clinical Features") gr.Markdown(f"*Features expected: {', '.join(feature_cols)}*") # Dynamically create feature inputs feature_inputs = [] with gr.Row(): for i, feat in enumerate(feature_cols): with gr.Column(scale=1): inp = gr.Number( label=feat, value=0.0, step=0.1, interactive=True ) feature_inputs.append(inp) predict_btn = gr.Button("📊 Predict Survival", variant="primary", size="lg") with gr.Row(): with gr.Column(): risk_out = gr.Number(label="Risk Score", interactive=False) risk_cat_out = gr.Markdown(label="Risk Category") with gr.Column(): gr.Markdown("### Survival Probabilities") with gr.Row(): s1_out = gr.Textbox(label="1-Year", value="--", interactive=False) s3_out = gr.Textbox(label="3-Year", value="--", interactive=False) s5_out = gr.Textbox(label="5-Year", value="--", interactive=False) predict_btn.click( fn=predict_survival, inputs=feature_inputs, outputs=[risk_out, risk_cat_out, s1_out, s3_out, s5_out] ) # ===== FOOTER ===== gr.Markdown("---") gr.Markdown( "