import os import gradio as gr import plotly.graph_objects as go import pandas as pd import vlai_template from src.heart_disease_core import ( CLEVELAND_FEATURES_ORDER, load_cleveland_dataframe, fit_all_models, predict_all, example_patient, get_example_labels ) APP_PRIMARY = vlai_template.PRIMARY_COLOR APP_ACCENT = vlai_template.ACCENT_COLOR APP_BG = "#F7FAFC" STATE = { "df": None, "models": None, "metrics": None, } DATA_PATH = "data/cleveland.csv" vlai_template.set_meta( project_name="Heart Disease Diagnosis Project", year="2025", module="03", description="Predict heart disease risk from patient data with optimized ML models trained on the Cleveland dataset.", meta_items=[ ("Dataset", "Cleveland Heart Disease"), ("Models", "Decision Tree, k-NN, Naive Bayes, Random Forest, AdaBoost, Gradient Boosting, XGBoost"), ], ) force_light_theme_js = """ () => { const params = new URLSearchParams(window.location.search); if (!params.has('__theme')) { params.set('__theme', 'light'); window.location.search = params.toString(); } } """ def init_page(train_split): """Load dataset, train models, and return status, preview, metrics.""" if not os.path.exists(DATA_PATH): msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there." return msg, pd.DataFrame(), pd.DataFrame() df = load_cleveland_dataframe(file_path=DATA_PATH) # Convert train_split percentage to test_size for sklearn test_size = (100 - train_split) / 100 models, metrics = fit_all_models(df, test_size=test_size) STATE["df"] = df STATE["models"] = models STATE["metrics"] = metrics head = df.head(8) msg = f"✅ Cleveland dataset loaded from `data/cleveland.csv` and models trained ({train_split}/{100-train_split} split)." return msg, head, metrics def fill_example(idx_text: str): import re match = re.search(r'Example (\d+)', idx_text) if match: idx = int(match.group(1)) - 1 else: idx = 1 ex = example_patient(idx) return [ex[c] for c in CLEVELAND_FEATURES_ORDER] def _bar_for_models(results: dict): names = list(results.keys()) confidences = [] predictions_text = [] bar_colors = [] line_colors = [] line_widths = [] for n in names: r = results[n] if r["label"] == 1: confidences.append(r["prob_1"]) predictions_text.append("🫀 Heart Disease") bar_colors.append("#C4314B") else: confidences.append(r["prob_0"]) predictions_text.append("✅ No Heart Disease") bar_colors.append("#2E7D32") line_colors.append("rgba(0,0,0,0.15)") line_widths.append(1.0) if "Ensemble (Soft Voting)" in names: idx = names.index("Ensemble (Soft Voting)") line_colors[idx] = "#000000" line_widths[idx] = 2.5 fig = go.Figure() fig.add_bar(x=names, y=confidences, text=predictions_text, textposition="auto") fig.update_layout( title="Model Predictions", yaxis_title="Prediction Confidence", xaxis_title="Model", yaxis=dict(range=[0, 1]), plot_bgcolor="white", paper_bgcolor="white", font=dict(color="black", size=12), height=420, margin=dict(l=30, r=20, t=60, b=40) ) fig.data[0].marker.color = bar_colors fig.data[0].marker.line.color = line_colors fig.data[0].marker.line.width = line_widths return fig def run_predict(*vals): if STATE["df"] is None or STATE["models"] is None: return None, "❌ Models not initialized. Reload the app.", pd.DataFrame() input_dict = {col: vals[i] for i, col in enumerate(CLEVELAND_FEATURES_ORDER)} results = predict_all(STATE["models"], input_dict) final = results["Ensemble (Soft Voting)"] ensemble_color = "#C4314B" if final["label"] == 1 else "#2E7D32" ensemble_prediction = "🫀 **Heart Disease Detected**" if final["label"] == 1 else "✅ **No Heart Disease**" ensemble_md = f"""

🎯 Ensemble Prediction (Final Result)

{ensemble_prediction}

Confidence: {final['prob_1']:.1%}

""" model_predictions = [] for name, r in results.items(): prediction_text = "🫀 **Heart Disease Detected**" if r["label"] == 1 else "✅ **No Heart Disease**" confidence = r["prob_1"] if r["label"] == 1 else r["prob_0"] color = "#C4314B" if r["label"] == 1 else "#2E7D32" model_predictions.append(f"""

{name}

Prediction: {prediction_text}

Confidence: {confidence:.1%}

P(No disease): {r['prob_0']:.3f} | P(Heart disease): {r['prob_1']:.3f}

""") all_predictions = "\n".join(model_predictions) rows = [] for name, r in results.items(): confidence = r["prob_1"] if r["label"] == 1 else r["prob_0"] rows.append({ "Model": name, "Prediction": "Heart Disease" if r["label"] == 1 else "No Heart Disease", "Confidence": f"{confidence:.1%}", "P(No disease)": round(r["prob_0"], 3), "P(Heart disease)": round(r["prob_1"], 3), }) table_df = pd.DataFrame(rows) fig = _bar_for_models(results) return fig, "\n".join(model_predictions), table_df with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo: vlai_template.create_header() gr.HTML(vlai_template.render_info_card(icon="🫀", title="About this demo")) gr.HTML(vlai_template.render_disclaimer( text=( "This interactive heart disease prediction demo is provided strictly for educational purposes. " "It is not intended for clinical use and must not be relied upon for medical advice, diagnosis, " "treatment, or decision-making. Always consult a qualified healthcare professional." ) )) gr.Markdown("### 🫀 **How to Use**: Enter patient features → Run prediction → View ensemble results!") with gr.Row(equal_height=False, variant="panel"): # LEFT: data preview + inputs with gr.Column(scale=45): with gr.Accordion("📁 Dataset & Model Status", open=True): with gr.Row(): train_split = gr.Slider( minimum=60, maximum=90, value=80, step=5, label="Training Split (%)", info="Percentage of data used for training (remaining for validation)" ) retrain_btn = gr.Button("🔄 Retrain Models", variant="secondary") status_md = gr.Markdown("Loading dataset and training models...") preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False) metrics_df = gr.DataFrame(label="Model Performance Comparison (Validation Set Results)", interactive=False) with gr.Accordion("✍️ Enter Patient Features", open=True): with gr.Row(): age = gr.Number(label="age (years)", value=58) sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0, 1], value=1) cp = gr.Dropdown(label="cp (chest pain type 1..4)", choices=[1, 2, 3, 4], value=2) trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130) with gr.Row(): chol = gr.Number(label="chol (serum cholesterol mg/dl)", value=250) fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0, 1], value=0) restecg = gr.Dropdown(label="restecg (0..2)", choices=[0, 1, 2], value=1) thalach = gr.Number(label="thalach (max heart rate)", value=150) with gr.Row(): exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0, 1], value=0) oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0) slope = gr.Dropdown(label="slope (1..3)", choices=[1, 2, 3], value=1) ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0, 1, 2, 3], value=0) thal = gr.Dropdown(label="thal (3=normal, 6=fixed, 7=reversible)", choices=[3, 6, 7], value=3) with gr.Row(): # Get actual labels from the dataset - only 2 examples try: labels = get_example_labels() choices = [] # Only use first two examples: one no disease, one disease for i in range(min(2, len(labels))): label_text = "No Heart Disease" if labels[i] == 0 else "Heart Disease" choices.append(f"Example {i+1} ({label_text})") default_choice = choices[0] if choices else "Example 1" except: choices = ["Example 1 (No Heart Disease)", "Example 2 (Heart Disease)"] default_choice = "Example 1 (No Heart Disease)" ex_selector = gr.Dropdown( label="Select Example Patient", choices=choices, value=default_choice ) predict_btn = gr.Button("🔍 Predict", variant="primary") # RIGHT: outputs with gr.Column(scale=55): gr.Markdown("### 📈 Model Predictions") bar_out = gr.Plot(label="Model Predictions Overview") sub_md = gr.Markdown("**Individual Model Results**") table_out = gr.DataFrame(label="All Model Predictions", interactive=False) gr.Markdown(""" ## 📋 **Notes** - **Models are trained at launch** on `data/cleveland.csv` with customizable train/validation split (default 80/20). - **Target is binarized automatically** (0 = no disease, >0 = disease). - **Retrain functionality**: Adjust the split ratio and click "🔄 Retrain Models" to see how data size affects performance. - **Seven optimized models are compared**: Decision Tree, k-NN, Naive Bayes, Random Forest, AdaBoost, Gradient Boosting, and XGBoost. - **Hyperparameters are optimized** for heart disease prediction tasks using best practices. - **Ensemble uses weighted soft voting** with optimized weights based on model performance. - **Best performing model** on test set is highlighted with 🏆 in the validation metrics table. - **Optimization highlights**: - Decision Tree: entropy criterion, balanced classes, optimal depth - k-NN: distance weighting, Manhattan metric, optimized neighbors - Random Forest: 200 trees, class balancing, feature sampling - Gradient Boosting: regularization, subsampling, lower learning rate - AdaBoost: SAMME algorithm, increased estimators - XGBoost: L1/L2 regularization, optimal depth and learning rate - **Feature descriptions**: - `age`: Patient age in years - `sex`: Gender (0=female, 1=male) - `cp`: Chest pain type (1-4) - `trestbps`: Resting blood pressure (mmHg) - `chol`: Serum cholesterol (mg/dl) - `fbs`: Fasting blood sugar >120 mg/dl (1=true, 0=false) - `restecg`: Resting ECG results (0-2) - `thalach`: Maximum heart rate achieved - `exang`: Exercise induced angina (1=yes, 0=no) - `oldpeak`: ST depression induced by exercise - `slope`: Slope of peak exercise ST segment (1-3) - `ca`: Number of major vessels colored by fluoroscopy (0-3) - `thal`: Thalassemia (3=normal, 6=fixed defect, 7=reversible defect) """) vlai_template.create_footer() # Bind events demo.load(fn=init_page, inputs=[train_split], outputs=[status_md, preview, metrics_df]) # Retrain models when split changes or button is clicked retrain_btn.click( fn=init_page, inputs=[train_split], outputs=[status_md, preview, metrics_df] ) # Auto-fill when example is selected ex_selector.change( fn=fill_example, inputs=[ex_selector], outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal] ) predict_btn.click( fn=run_predict, inputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal], outputs=[bar_out, sub_md, table_out] ) if __name__ == "__main__": demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])