| import gradio as gr |
| import joblib |
| import numpy as np |
| import pandas as pd |
| import json |
| import os |
|
|
| |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), "best_surrogate_model.joblib") |
| FEATURES_PATH = os.path.join(os.path.dirname(__file__), "feature_cols.json") |
|
|
| model = joblib.load(MODEL_PATH) |
| with open(FEATURES_PATH) as f: |
| feature_cols = json.load(f) |
|
|
| RISK_BANDS = {0: "Low", 1: "Borderline", 2: "Intermediate", 3: "High"} |
| RISK_COLORS = { |
| "Low": "#27ae60", |
| "Borderline": "#f39c12", |
| "Intermediate": "#e67e22", |
| "High": "#e74c3c", |
| } |
| RISK_DESCRIPTIONS = { |
| "Low": "Less than 5% chance of a heart attack or stroke in the next 10 years.", |
| "Borderline": "5% to 7.5% chance of a heart attack or stroke in the next 10 years. Lifestyle changes recommended.", |
| "Intermediate": "7.5% to 20% chance of a heart attack or stroke in the next 10 years. Discuss treatment options with a doctor.", |
| "High": "20% or higher chance of a heart attack or stroke in the next 10 years. Medical intervention strongly recommended.", |
| } |
|
|
|
|
| def build_features(age, sex, race, total_chol, hdl, sbp, dbp, |
| bp_meds, diabetes, smoking, fasting_glucose, hba1c): |
| """ |
| Take raw clinical inputs and build all 31 features |
| in the exact order the model expects. |
| """ |
| |
| sex_code = 1 if sex == "Male" else 0 |
| race_code = 1 if race == "Black" else 0 |
| bp_meds_val = 1 if bp_meds == "Yes" else 0 |
| diabetes_val = 1 if diabetes == "Yes" else 0 |
| smoking_val = 1 if smoking == "Yes" else 0 |
|
|
| |
| fasting_glucose = fasting_glucose if fasting_glucose and fasting_glucose > 0 else 100.0 |
| hba1c = hba1c if hba1c and hba1c > 0 else 5.5 |
|
|
| |
| log_age = np.log(max(age, 1e-6)) |
| log_tc = np.log(max(total_chol, 1e-6)) |
| log_hdl = np.log(max(hdl, 1e-6)) |
| log_sbp = np.log(max(sbp, 1e-6)) |
|
|
| |
| features = { |
| "AGE": age, |
| "SEX_CODE": sex_code, |
| "RACE_CODE": race_code, |
| "TOTAL_CHOL": total_chol, |
| "HDL_UNIFIED": hdl, |
| "SBP_AVG": sbp, |
| "BP_MEDS_CURRENT": bp_meds_val, |
| "DIABETES_PCE": diabetes_val, |
| "CURRENT_SMOKER": smoking_val, |
| "DBP_AVG": dbp, |
| "FASTING_GLUCOSE": fasting_glucose, |
| "HBA1C": hba1c, |
| |
| "LOG_AGE": log_age, |
| "LOG_TOTAL_CHOL": log_tc, |
| "LOG_HDL": log_hdl, |
| "LOG_SBP": log_sbp, |
| |
| "LOG_AGE_x_LOG_TOTAL_CHOL": log_age * log_tc, |
| "LOG_AGE_x_LOG_HDL": log_age * log_hdl, |
| "LOG_AGE_x_CURRENT_SMOKER": log_age * smoking_val, |
| |
| "LOG_SBP_x_BP_MEDS_CURRENT": log_sbp * bp_meds_val, |
| "LOG_SBP_x_NO_BP_MEDS": log_sbp * (1 - bp_meds_val), |
| |
| "PULSE_PRESSURE": sbp - dbp, |
| "TOTAL_CHOL_TO_HDL_RATIO": total_chol / max(hdl, 1e-6), |
| "AGE_x_SBP": age * sbp, |
| "AGE_x_TOTAL_CHOL": age * total_chol, |
| "AGE_x_HDL": age * hdl, |
| |
| "LOG_AGE_SQUARED": log_age ** 2, |
| |
| "IS_MALE": sex_code, |
| "IS_BLACK": race_code, |
| "IS_MALE_x_LOG_AGE": sex_code * log_age, |
| "IS_BLACK_x_LOG_AGE": race_code * log_age, |
| } |
|
|
| df = pd.DataFrame([features])[feature_cols] |
| return df |
|
|
|
|
| def predict_risk(age, sex, race, total_chol, hdl, sbp, dbp, |
| bp_meds, diabetes, smoking, fasting_glucose, hba1c): |
| """Main prediction function.""" |
|
|
| |
| if age < 40 or age > 79: |
| return ( |
| format_error("Age must be between 40 and 79 (valid range for the ASCVD calculator)."), |
| None, |
| ) |
| if total_chol < 100 or total_chol > 400: |
| return ( |
| format_error("Total cholesterol should be between 100 and 400 mg/dL."), |
| None, |
| ) |
| if hdl < 10 or hdl > 150: |
| return ( |
| format_error("HDL cholesterol should be between 10 and 150 mg/dL."), |
| None, |
| ) |
| if sbp < 80 or sbp > 250: |
| return ( |
| format_error("Systolic blood pressure should be between 80 and 250 mmHg."), |
| None, |
| ) |
| if dbp < 40 or dbp > 150: |
| return ( |
| format_error("Diastolic blood pressure should be between 40 and 150 mmHg."), |
| None, |
| ) |
|
|
| |
| X = build_features(age, sex, race, total_chol, hdl, sbp, dbp, |
| bp_meds, diabetes, smoking, fasting_glucose, hba1c) |
|
|
| pred_class = int(model.predict(X)[0]) |
| proba = model.predict_proba(X)[0] |
|
|
| band = RISK_BANDS[pred_class] |
| color = RISK_COLORS[band] |
| desc = RISK_DESCRIPTIONS[band] |
|
|
| |
| result_html = f""" |
| <div style="text-align:center; padding:30px; border-radius:16px; |
| background: linear-gradient(135deg, {color}22, {color}44); |
| border: 2px solid {color};"> |
| <h1 style="color:{color}; margin:0 0 10px 0; font-size:2.5em;"> |
| {band} Risk of Heart Attack or Stroke |
| </h1> |
| <p style="font-size:1.2em; color:{color}; font-weight:bold; margin:0 0 20px 0;"> |
| {desc} |
| </p> |
| <div style="display:flex; justify-content:center; gap:20px; flex-wrap:wrap;"> |
| <div style="background:white; padding:15px 25px; border-radius:10px; |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> |
| <div style="font-size:0.85em; color:#666;">Low (<5%)</div> |
| <div style="font-size:1.4em; font-weight:bold; color:#27ae60;"> |
| {proba[0]*100:.1f}% |
| </div> |
| </div> |
| <div style="background:white; padding:15px 25px; border-radius:10px; |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> |
| <div style="font-size:0.85em; color:#666;">Borderline (5-7.5%)</div> |
| <div style="font-size:1.4em; font-weight:bold; color:#f39c12;"> |
| {proba[1]*100:.1f}% |
| </div> |
| </div> |
| <div style="background:white; padding:15px 25px; border-radius:10px; |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> |
| <div style="font-size:0.85em; color:#666;">Intermediate (7.5-20%)</div> |
| <div style="font-size:1.4em; font-weight:bold; color:#e67e22;"> |
| {proba[2]*100:.1f}% |
| </div> |
| </div> |
| <div style="background:white; padding:15px 25px; border-radius:10px; |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> |
| <div style="font-size:0.85em; color:#666;">High (β₯20%)</div> |
| <div style="font-size:1.4em; font-weight:bold; color:#e74c3c;"> |
| {proba[3]*100:.1f}% |
| </div> |
| </div> |
| </div> |
| </div> |
| """ |
|
|
| |
| confidence = {RISK_BANDS[i]: float(proba[i]) for i in range(4)} |
|
|
| return result_html, confidence |
|
|
|
|
| def format_error(msg): |
| return f""" |
| <div style="text-align:center; padding:20px; border-radius:12px; |
| background:#fee; border:2px solid #e74c3c;"> |
| <h3 style="color:#e74c3c; margin:0;">β Invalid Input</h3> |
| <p style="color:#333;">{msg}</p> |
| </div> |
| """ |
|
|
|
|
| |
|
|
| with gr.Blocks( |
| title="Cardiovascular Risk Assessment", |
| theme=gr.themes.Soft( |
| primary_hue="red", |
| secondary_hue="orange", |
| ), |
| css=""" |
| .main-header { |
| text-align: center; |
| padding: 20px 0 10px 0; |
| } |
| .main-header h1 { |
| font-size: 2em; |
| margin: 0; |
| } |
| .main-header p { |
| color: #666; |
| margin: 5px 0 0 0; |
| } |
| """ |
| ) as app: |
|
|
| gr.HTML(""" |
| <div class="main-header"> |
| <h1>π« Cardiovascular Risk Assessment</h1> |
| </div> |
| <div style="text-align:center; padding:12px 20px; margin:0 0 15px 0; |
| background:#fff3cd; border:2px solid #ffc107; border-radius:10px;"> |
| <span style="font-size:1.3em; font-weight:bold; color:#856404;"> |
| β Important Disclaimer |
| </span> |
| <br> |
| <span style="font-size:1em; color:#856404;"> |
| This application is an educational and demonstration project only. It is |
| <span style="font-weight:bold; color:#856404;">NOT</span> a certified medical device, |
| diagnostic tool, or a substitute for professional medical advice, diagnosis, or treatment. |
| The developer assumes <span style="font-weight:bold; color:#856404;">NO RESPONSIBILITY OR LIABILITY</span> for any decisions, actions, or health outcomes |
| based on the results provided by this tool. Always consult a qualified healthcare professional |
| for actual cardiovascular risk assessment and medical guidance. |
| </span> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Patient Information") |
|
|
| age = gr.Slider( |
| minimum=40, maximum=79, value=55, step=1, |
| label="Age (years)", |
| info="Valid range: 40β79" |
| ) |
| sex = gr.Radio( |
| choices=["Male", "Female"], |
| value="Male", |
| label="Gender" |
| ) |
| race = gr.Radio( |
| choices=["White", "Black"], |
| value="White", |
| label="Race" |
| ) |
|
|
| gr.Markdown("### Clinical Measurements") |
|
|
| with gr.Row(): |
| total_chol = gr.Number( |
| value=200, label="Total Cholesterol (mg/dL)", |
| info="Typical range: 125β300" |
| ) |
| hdl = gr.Number( |
| value=50, label="HDL Cholesterol (mg/dL)", |
| info="Typical range: 20β100" |
| ) |
|
|
| with gr.Row(): |
| sbp = gr.Number( |
| value=130, label="Systolic BP (mmHg)", |
| info="Top number. Normal: <120, Elevated: 120β129, High: 130+" |
| ) |
| dbp = gr.Number( |
| value=80, label="Diastolic BP (mmHg)", |
| info="Bottom number. Normal: <80, High: 80+" |
| ) |
|
|
| gr.Markdown("### Medical History") |
|
|
| with gr.Row(): |
| bp_meds = gr.Radio( |
| choices=["No", "Yes"], value="No", |
| label="Taking BP Medication?" |
| ) |
| diabetes = gr.Radio( |
| choices=["No", "Yes"], value="No", |
| label="Has Diabetes?" |
| ) |
| smoking = gr.Radio( |
| choices=["No", "Yes"], value="No", |
| label="Currently Smoking?" |
| ) |
|
|
| gr.Markdown("### Optional Lab Results") |
|
|
| with gr.Row(): |
| fasting_glucose = gr.Number( |
| value=100, label="Fasting Glucose (mg/dL)", |
| info="Typical range: 70β130. Leave as 0 if unknown" |
| ) |
| hba1c = gr.Number( |
| value=5.5, label="HbA1c (%)", |
| step=0.1, |
| info="Typical range: 4.0β6.5%. Leave as 0 if unknown" |
| ) |
|
|
| predict_btn = gr.Button( |
| "π Assess Risk", variant="primary", size="lg" |
| ) |
|
|
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Assessment Result") |
| result_html = gr.HTML( |
| value="<div style='text-align:center; padding:40px; color:#999;'>" |
| "Enter patient information and click <b>Assess Risk</b>.</div>" |
| ) |
| confidence_chart = gr.Label( |
| label="Model Confidence by Risk Category", |
| num_top_classes=4, |
| ) |
|
|
| |
| predict_btn.click( |
| fn=predict_risk, |
| inputs=[age, sex, race, total_chol, hdl, sbp, dbp, |
| bp_meds, diabetes, smoking, fasting_glucose, hba1c], |
| outputs=[result_html, confidence_chart], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| app.launch() |
|
|