import gradio as gr import joblib import numpy as np import pandas as pd import json import os # ── Load model + feature list ────────────────────────────────── MODEL_PATH = os.path.join(os.path.dirname(__file__), "best_surrogate_model.joblib") FEATURES_PATH = os.path.join(os.path.dirname(__file__), "feature_cols.json") model = joblib.load(MODEL_PATH) with open(FEATURES_PATH) as f: feature_cols = json.load(f) RISK_BANDS = {0: "Low", 1: "Borderline", 2: "Intermediate", 3: "High"} RISK_COLORS = { "Low": "#27ae60", "Borderline": "#f39c12", "Intermediate": "#e67e22", "High": "#e74c3c", } RISK_DESCRIPTIONS = { "Low": "Less than 5% chance of a heart attack or stroke in the next 10 years.", "Borderline": "5% to 7.5% chance of a heart attack or stroke in the next 10 years. Lifestyle changes recommended.", "Intermediate": "7.5% to 20% chance of a heart attack or stroke in the next 10 years. Discuss treatment options with a doctor.", "High": "20% or higher chance of a heart attack or stroke in the next 10 years. Medical intervention strongly recommended.", } def build_features(age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c): """ Take raw clinical inputs and build all 31 features in the exact order the model expects. """ # ── Base features ── sex_code = 1 if sex == "Male" else 0 race_code = 1 if race == "Black" else 0 bp_meds_val = 1 if bp_meds == "Yes" else 0 diabetes_val = 1 if diabetes == "Yes" else 0 smoking_val = 1 if smoking == "Yes" else 0 # Handle optional fields fasting_glucose = fasting_glucose if fasting_glucose and fasting_glucose > 0 else 100.0 hba1c = hba1c if hba1c and hba1c > 0 else 5.5 # ── Log transforms (safe clip) ── log_age = np.log(max(age, 1e-6)) log_tc = np.log(max(total_chol, 1e-6)) log_hdl = np.log(max(hdl, 1e-6)) log_sbp = np.log(max(sbp, 1e-6)) # ── Build feature dict in exact order ── features = { "AGE": age, "SEX_CODE": sex_code, "RACE_CODE": race_code, "TOTAL_CHOL": total_chol, "HDL_UNIFIED": hdl, "SBP_AVG": sbp, "BP_MEDS_CURRENT": bp_meds_val, "DIABETES_PCE": diabetes_val, "CURRENT_SMOKER": smoking_val, "DBP_AVG": dbp, "FASTING_GLUCOSE": fasting_glucose, "HBA1C": hba1c, # Log transforms "LOG_AGE": log_age, "LOG_TOTAL_CHOL": log_tc, "LOG_HDL": log_hdl, "LOG_SBP": log_sbp, # PCE-like interactions "LOG_AGE_x_LOG_TOTAL_CHOL": log_age * log_tc, "LOG_AGE_x_LOG_HDL": log_age * log_hdl, "LOG_AGE_x_CURRENT_SMOKER": log_age * smoking_val, # BP treatment interactions "LOG_SBP_x_BP_MEDS_CURRENT": log_sbp * bp_meds_val, "LOG_SBP_x_NO_BP_MEDS": log_sbp * (1 - bp_meds_val), # Clinical helpers "PULSE_PRESSURE": sbp - dbp, "TOTAL_CHOL_TO_HDL_RATIO": total_chol / max(hdl, 1e-6), "AGE_x_SBP": age * sbp, "AGE_x_TOTAL_CHOL": age * total_chol, "AGE_x_HDL": age * hdl, # Age squared "LOG_AGE_SQUARED": log_age ** 2, # Subgroup indicators "IS_MALE": sex_code, "IS_BLACK": race_code, "IS_MALE_x_LOG_AGE": sex_code * log_age, "IS_BLACK_x_LOG_AGE": race_code * log_age, } df = pd.DataFrame([features])[feature_cols] return df def predict_risk(age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c): """Main prediction function.""" # ── Validation ── if age < 40 or age > 79: return ( format_error("Age must be between 40 and 79 (valid range for the ASCVD calculator)."), None, ) if total_chol < 100 or total_chol > 400: return ( format_error("Total cholesterol should be between 100 and 400 mg/dL."), None, ) if hdl < 10 or hdl > 150: return ( format_error("HDL cholesterol should be between 10 and 150 mg/dL."), None, ) if sbp < 80 or sbp > 250: return ( format_error("Systolic blood pressure should be between 80 and 250 mmHg."), None, ) if dbp < 40 or dbp > 150: return ( format_error("Diastolic blood pressure should be between 40 and 150 mmHg."), None, ) # ── Build features + predict ── X = build_features(age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c) pred_class = int(model.predict(X)[0]) proba = model.predict_proba(X)[0] band = RISK_BANDS[pred_class] color = RISK_COLORS[band] desc = RISK_DESCRIPTIONS[band] # ── Format result ── result_html = f"""

{band} Risk of Heart Attack or Stroke

{desc}

Low (<5%)
{proba[0]*100:.1f}%
Borderline (5-7.5%)
{proba[1]*100:.1f}%
Intermediate (7.5-20%)
{proba[2]*100:.1f}%
High (≥20%)
{proba[3]*100:.1f}%
""" # Confidence bar chart confidence = {RISK_BANDS[i]: float(proba[i]) for i in range(4)} return result_html, confidence def format_error(msg): return f"""

⚠ Invalid Input

{msg}

""" # ── Gradio Interface ────────────────────────────────────────── with gr.Blocks( title="Cardiovascular Risk Assessment", theme=gr.themes.Soft( primary_hue="red", secondary_hue="orange", ), css=""" .main-header { text-align: center; padding: 20px 0 10px 0; } .main-header h1 { font-size: 2em; margin: 0; } .main-header p { color: #666; margin: 5px 0 0 0; } """ ) as app: gr.HTML("""

🫀 Cardiovascular Risk Assessment

⚠ Important Disclaimer
This application is an educational and demonstration project only. It is NOT a certified medical device, diagnostic tool, or a substitute for professional medical advice, diagnosis, or treatment. The developer assumes NO RESPONSIBILITY OR LIABILITY for any decisions, actions, or health outcomes based on the results provided by this tool. Always consult a qualified healthcare professional for actual cardiovascular risk assessment and medical guidance.
""") with gr.Row(): # ── Left column: Inputs ── with gr.Column(scale=1): gr.Markdown("### Patient Information") age = gr.Slider( minimum=40, maximum=79, value=55, step=1, label="Age (years)", info="Valid range: 40–79" ) sex = gr.Radio( choices=["Male", "Female"], value="Male", label="Gender" ) race = gr.Radio( choices=["White", "Black"], value="White", label="Race" ) gr.Markdown("### Clinical Measurements") with gr.Row(): total_chol = gr.Number( value=200, label="Total Cholesterol (mg/dL)", info="Typical range: 125–300" ) hdl = gr.Number( value=50, label="HDL Cholesterol (mg/dL)", info="Typical range: 20–100" ) with gr.Row(): sbp = gr.Number( value=130, label="Systolic BP (mmHg)", info="Top number. Normal: <120, Elevated: 120–129, High: 130+" ) dbp = gr.Number( value=80, label="Diastolic BP (mmHg)", info="Bottom number. Normal: <80, High: 80+" ) gr.Markdown("### Medical History") with gr.Row(): bp_meds = gr.Radio( choices=["No", "Yes"], value="No", label="Taking BP Medication?" ) diabetes = gr.Radio( choices=["No", "Yes"], value="No", label="Has Diabetes?" ) smoking = gr.Radio( choices=["No", "Yes"], value="No", label="Currently Smoking?" ) gr.Markdown("### Optional Lab Results") with gr.Row(): fasting_glucose = gr.Number( value=100, label="Fasting Glucose (mg/dL)", info="Typical range: 70–130. Leave as 0 if unknown" ) hba1c = gr.Number( value=5.5, label="HbA1c (%)", step=0.1, info="Typical range: 4.0–6.5%. Leave as 0 if unknown" ) predict_btn = gr.Button( "🔍 Assess Risk", variant="primary", size="lg" ) # ── Right column: Results ── with gr.Column(scale=1): gr.Markdown("### Assessment Result") result_html = gr.HTML( value="
" "Enter patient information and click Assess Risk.
" ) confidence_chart = gr.Label( label="Model Confidence by Risk Category", num_top_classes=4, ) # ── Connect ── predict_btn.click( fn=predict_risk, inputs=[age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c], outputs=[result_html, confidence_chart], ) if __name__ == "__main__": app.launch()