Spaces:
Build error
Build error
| import tensorflow as tf | |
| import numpy as np | |
| import joblib | |
| import json | |
| import pandas as pd | |
| from PIL import Image | |
| from lifelines import CoxPHFitter | |
| import gradio as gr | |
| print(f"β TensorFlow version: {tf.__version__}") | |
| # --------------------------------------------------- | |
| # CONFIG | |
| # --------------------------------------------------- | |
| CNN_MODEL_PATH = "hf://MohammedAH/BreastCancerPrediction" # Hugging Face Hub path | |
| DNN_MODEL_PATH = "survival_model.keras" | |
| SCALER_PATH = "scaler.pkl" | |
| FEATURES_PATH = "features.json" | |
| DATASET_PATH = 'processed_breast_cancer_data(1).csv' | |
| TIME_COL = "Overall_Survival_Months" | |
| EVENT_COL = "Event" | |
| # --------------------------------------------------- | |
| # GLOBAL ASSETS (loaded once at startup) | |
| # --------------------------------------------------- | |
| cnn_model = None | |
| dnn_model = None | |
| scaler = None | |
| feature_cols = None | |
| breslow_times = None | |
| breslow_H0 = None | |
| def load_all_assets(): | |
| """Load models and survival assets once at startup""" | |
| global cnn_model, dnn_model, scaler, feature_cols, breslow_times, breslow_H0 | |
| # Load CNN model (from Hugging Face Hub or local) | |
| if CNN_MODEL_PATH.startswith("hf://"): | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download( | |
| repo_id=CNN_MODEL_PATH.replace("hf://", ""), | |
| filename="best_breast_cancer_cnn.keras" | |
| ) | |
| cnn_model = tf.keras.models.load_model(model_path, compile=False) | |
| else: | |
| cnn_model = tf.keras.models.load_model(CNN_MODEL_PATH, compile=False) | |
| # Load DNN survival model | |
| dnn_model = tf.keras.models.load_model(DNN_MODEL_PATH, compile=False) | |
| # Load scaler and features | |
| scaler = joblib.load(SCALER_PATH) | |
| with open(FEATURES_PATH, 'r') as f: | |
| feature_cols = json.load(f) | |
| # Compute Breslow baseline hazard | |
| df = pd.read_csv(DATASET_PATH) | |
| feature_df = df[feature_cols].copy() | |
| feature_df["duration"] = df[TIME_COL] | |
| feature_df["event"] = df[EVENT_COL] | |
| cox = CoxPHFitter() | |
| cox.fit(feature_df, duration_col="duration", event_col="event") | |
| baseline = cox.baseline_cumulative_hazard_ | |
| breslow_times = baseline.index.values | |
| breslow_H0 = baseline.values.flatten() | |
| print("β All assets loaded successfully") | |
| # Load everything at module import | |
| load_all_assets() | |
| # --------------------------------------------------- | |
| # IMAGE PREPROCESSING | |
| # --------------------------------------------------- | |
| def preprocess_image(image: Image.Image) -> np.ndarray: | |
| """Convert PIL image to model-ready tensor""" | |
| if image.mode != "L": | |
| image = image.convert("L") | |
| image = image.resize((224, 224)) | |
| img = np.array(image) / 255.0 | |
| img = img[np.newaxis, ..., np.newaxis] # (1, 224, 224, 1) | |
| return img | |
| # --------------------------------------------------- | |
| # CNN PREDICTION | |
| # --------------------------------------------------- | |
| def predict_cancer(image: Image.Image): | |
| """Predict malignancy from histopathology image""" | |
| if image is None: | |
| return "Please upload an image", 0.0, 0.0 | |
| img = preprocess_image(image) | |
| pred = float(cnn_model.predict(img, verbose=0)[0][0]) | |
| result = "π΄ Malignant" if pred > 0.5 else "π’ Benign" | |
| confidence = max(pred, 1 - pred) | |
| return result, round(confidence * 100, 2), round(pred, 4) | |
| # --------------------------------------------------- | |
| # SURVIVAL FUNCTIONS | |
| # --------------------------------------------------- | |
| def survival_prob(risk: float, t: float) -> float: | |
| """Compute survival probability at time t using Breslow baseline""" | |
| idx = np.searchsorted(breslow_times, t, side="right") - 1 | |
| if idx < 0: | |
| return 1.0 | |
| h0 = breslow_H0[idx] | |
| return float(np.exp(-h0 * np.exp(risk))) | |
| def predict_survival(*feature_values): | |
| """Predict survival probabilities from clinical features""" | |
| if len(feature_values) != len(feature_cols): | |
| return "Error: Feature count mismatch", 0, 0, 0 | |
| row = np.array([list(feature_values)], dtype=np.float32) | |
| row_scaled = scaler.transform(row) | |
| risk = float(dnn_model.predict(row_scaled, verbose=0)[0][0]) | |
| s1 = survival_prob(risk, 12) * 100 | |
| s3 = survival_prob(risk, 36) * 100 | |
| s5 = survival_prob(risk, 60) * 100 | |
| risk_category = "π΄ High Risk" if risk > 0 else "π’ Low Risk" | |
| return ( | |
| round(risk, 4), | |
| f"{risk_category}", | |
| f"{s1:.1f}%", | |
| f"{s3:.1f}%", | |
| f"{s5:.1f}%" | |
| ) | |
| # --------------------------------------------------- | |
| # GRADIO UI | |
| # --------------------------------------------------- | |
| with gr.Blocks( | |
| title="𧬠Breast Cancer AI Diagnosis & Survival", | |
| theme=gr.themes.Soft(primary_hue="rose", secondary_hue="blue"), | |
| css=""" | |
| .main-title { text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 10px; } | |
| .subtitle { text-align: center; color: #666; margin-bottom: 30px; } | |
| .metric-box { text-align: center; padding: 10px; border-radius: 8px; background: #f9f9f9; } | |
| """ | |
| ) as demo: | |
| gr.Markdown('<p class="main-title">𧬠Breast Cancer AI Diagnosis & Survival System</p>') | |
| gr.Markdown( | |
| '<p class="subtitle">Integrates CNN tumor classification + DNN survival prediction β’ TensorFlow 2.18</p>' | |
| ) | |
| with gr.Tabs(): | |
| # ===== TAB 1: IMAGE DIAGNOSIS ===== | |
| with gr.TabItem("π¬ Image Diagnosis"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Histopathology Image", | |
| height=300 | |
| ) | |
| analyze_btn = gr.Button("π Analyze Image", variant="primary") | |
| with gr.Column(scale=1): | |
| diagnosis_out = gr.Label(label="Diagnosis") | |
| confidence_out = gr.Number(label="Confidence (%)", interactive=False) | |
| score_out = gr.Number(label="Raw Prediction Score", interactive=False) | |
| analyze_btn.click( | |
| fn=predict_cancer, | |
| inputs=image_input, | |
| outputs=[diagnosis_out, confidence_out, score_out] | |
| ) | |
| gr.Examples( | |
| examples=[["example1.jpg"], ["example2.png"]], | |
| inputs=image_input, | |
| label="Try example images (optional)" | |
| ) | |
| # ===== TAB 2: SURVIVAL ANALYSIS ===== | |
| with gr.TabItem("π Survival Analysis"): | |
| gr.Markdown("### Enter Patient Clinical Features") | |
| gr.Markdown(f"*Features expected: {', '.join(feature_cols)}*") | |
| # Dynamically create feature inputs | |
| feature_inputs = [] | |
| with gr.Row(): | |
| for i, feat in enumerate(feature_cols): | |
| with gr.Column(scale=1): | |
| inp = gr.Number( | |
| label=feat, | |
| value=0.0, | |
| step=0.1, | |
| interactive=True | |
| ) | |
| feature_inputs.append(inp) | |
| predict_btn = gr.Button("π Predict Survival", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| risk_out = gr.Number(label="Risk Score", interactive=False) | |
| risk_cat_out = gr.Markdown(label="Risk Category") | |
| with gr.Column(): | |
| gr.Markdown("### Survival Probabilities") | |
| with gr.Row(): | |
| s1_out = gr.Textbox(label="1-Year", value="--", interactive=False) | |
| s3_out = gr.Textbox(label="3-Year", value="--", interactive=False) | |
| s5_out = gr.Textbox(label="5-Year", value="--", interactive=False) | |
| predict_btn.click( | |
| fn=predict_survival, | |
| inputs=feature_inputs, | |
| outputs=[risk_out, risk_cat_out, s1_out, s3_out, s5_out] | |
| ) | |
| # ===== FOOTER ===== | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "<center>β οΈ AI-assisted clinical decision support β’ Not a substitute for professional medical advice</center>" | |
| ) | |
| # --------------------------------------------------- | |
| # LAUNCH | |
| # --------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", # Allow external access (for cloud deployment) | |
| server_port=7860, | |
| share=False, # Set True to get public link | |
| show_error=True | |
| ) |