Mohammad-El-Majzoub's picture
Upload 2 files
91f4c21 verified
import gradio as gr
import joblib
import numpy as np
import pandas as pd
import json
import os
# ── Load model + feature list ──────────────────────────────────
MODEL_PATH = os.path.join(os.path.dirname(__file__), "best_surrogate_model.joblib")
FEATURES_PATH = os.path.join(os.path.dirname(__file__), "feature_cols.json")
model = joblib.load(MODEL_PATH)
with open(FEATURES_PATH) as f:
feature_cols = json.load(f)
RISK_BANDS = {0: "Low", 1: "Borderline", 2: "Intermediate", 3: "High"}
RISK_COLORS = {
"Low": "#27ae60",
"Borderline": "#f39c12",
"Intermediate": "#e67e22",
"High": "#e74c3c",
}
RISK_DESCRIPTIONS = {
"Low": "Less than 5% chance of a heart attack or stroke in the next 10 years.",
"Borderline": "5% to 7.5% chance of a heart attack or stroke in the next 10 years. Lifestyle changes recommended.",
"Intermediate": "7.5% to 20% chance of a heart attack or stroke in the next 10 years. Discuss treatment options with a doctor.",
"High": "20% or higher chance of a heart attack or stroke in the next 10 years. Medical intervention strongly recommended.",
}
def build_features(age, sex, race, total_chol, hdl, sbp, dbp,
bp_meds, diabetes, smoking, fasting_glucose, hba1c):
"""
Take raw clinical inputs and build all 31 features
in the exact order the model expects.
"""
# ── Base features ──
sex_code = 1 if sex == "Male" else 0
race_code = 1 if race == "Black" else 0
bp_meds_val = 1 if bp_meds == "Yes" else 0
diabetes_val = 1 if diabetes == "Yes" else 0
smoking_val = 1 if smoking == "Yes" else 0
# Handle optional fields
fasting_glucose = fasting_glucose if fasting_glucose and fasting_glucose > 0 else 100.0
hba1c = hba1c if hba1c and hba1c > 0 else 5.5
# ── Log transforms (safe clip) ──
log_age = np.log(max(age, 1e-6))
log_tc = np.log(max(total_chol, 1e-6))
log_hdl = np.log(max(hdl, 1e-6))
log_sbp = np.log(max(sbp, 1e-6))
# ── Build feature dict in exact order ──
features = {
"AGE": age,
"SEX_CODE": sex_code,
"RACE_CODE": race_code,
"TOTAL_CHOL": total_chol,
"HDL_UNIFIED": hdl,
"SBP_AVG": sbp,
"BP_MEDS_CURRENT": bp_meds_val,
"DIABETES_PCE": diabetes_val,
"CURRENT_SMOKER": smoking_val,
"DBP_AVG": dbp,
"FASTING_GLUCOSE": fasting_glucose,
"HBA1C": hba1c,
# Log transforms
"LOG_AGE": log_age,
"LOG_TOTAL_CHOL": log_tc,
"LOG_HDL": log_hdl,
"LOG_SBP": log_sbp,
# PCE-like interactions
"LOG_AGE_x_LOG_TOTAL_CHOL": log_age * log_tc,
"LOG_AGE_x_LOG_HDL": log_age * log_hdl,
"LOG_AGE_x_CURRENT_SMOKER": log_age * smoking_val,
# BP treatment interactions
"LOG_SBP_x_BP_MEDS_CURRENT": log_sbp * bp_meds_val,
"LOG_SBP_x_NO_BP_MEDS": log_sbp * (1 - bp_meds_val),
# Clinical helpers
"PULSE_PRESSURE": sbp - dbp,
"TOTAL_CHOL_TO_HDL_RATIO": total_chol / max(hdl, 1e-6),
"AGE_x_SBP": age * sbp,
"AGE_x_TOTAL_CHOL": age * total_chol,
"AGE_x_HDL": age * hdl,
# Age squared
"LOG_AGE_SQUARED": log_age ** 2,
# Subgroup indicators
"IS_MALE": sex_code,
"IS_BLACK": race_code,
"IS_MALE_x_LOG_AGE": sex_code * log_age,
"IS_BLACK_x_LOG_AGE": race_code * log_age,
}
df = pd.DataFrame([features])[feature_cols]
return df
def predict_risk(age, sex, race, total_chol, hdl, sbp, dbp,
bp_meds, diabetes, smoking, fasting_glucose, hba1c):
"""Main prediction function."""
# ── Validation ──
if age < 40 or age > 79:
return (
format_error("Age must be between 40 and 79 (valid range for the ASCVD calculator)."),
None,
)
if total_chol < 100 or total_chol > 400:
return (
format_error("Total cholesterol should be between 100 and 400 mg/dL."),
None,
)
if hdl < 10 or hdl > 150:
return (
format_error("HDL cholesterol should be between 10 and 150 mg/dL."),
None,
)
if sbp < 80 or sbp > 250:
return (
format_error("Systolic blood pressure should be between 80 and 250 mmHg."),
None,
)
if dbp < 40 or dbp > 150:
return (
format_error("Diastolic blood pressure should be between 40 and 150 mmHg."),
None,
)
# ── Build features + predict ──
X = build_features(age, sex, race, total_chol, hdl, sbp, dbp,
bp_meds, diabetes, smoking, fasting_glucose, hba1c)
pred_class = int(model.predict(X)[0])
proba = model.predict_proba(X)[0]
band = RISK_BANDS[pred_class]
color = RISK_COLORS[band]
desc = RISK_DESCRIPTIONS[band]
# ── Format result ──
result_html = f"""
<div style="text-align:center; padding:30px; border-radius:16px;
background: linear-gradient(135deg, {color}22, {color}44);
border: 2px solid {color};">
<h1 style="color:{color}; margin:0 0 10px 0; font-size:2.5em;">
{band} Risk of Heart Attack or Stroke
</h1>
<p style="font-size:1.2em; color:{color}; font-weight:bold; margin:0 0 20px 0;">
{desc}
</p>
<div style="display:flex; justify-content:center; gap:20px; flex-wrap:wrap;">
<div style="background:white; padding:15px 25px; border-radius:10px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<div style="font-size:0.85em; color:#666;">Low (&lt;5%)</div>
<div style="font-size:1.4em; font-weight:bold; color:#27ae60;">
{proba[0]*100:.1f}%
</div>
</div>
<div style="background:white; padding:15px 25px; border-radius:10px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<div style="font-size:0.85em; color:#666;">Borderline (5-7.5%)</div>
<div style="font-size:1.4em; font-weight:bold; color:#f39c12;">
{proba[1]*100:.1f}%
</div>
</div>
<div style="background:white; padding:15px 25px; border-radius:10px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<div style="font-size:0.85em; color:#666;">Intermediate (7.5-20%)</div>
<div style="font-size:1.4em; font-weight:bold; color:#e67e22;">
{proba[2]*100:.1f}%
</div>
</div>
<div style="background:white; padding:15px 25px; border-radius:10px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<div style="font-size:0.85em; color:#666;">High (β‰₯20%)</div>
<div style="font-size:1.4em; font-weight:bold; color:#e74c3c;">
{proba[3]*100:.1f}%
</div>
</div>
</div>
</div>
"""
# Confidence bar chart
confidence = {RISK_BANDS[i]: float(proba[i]) for i in range(4)}
return result_html, confidence
def format_error(msg):
return f"""
<div style="text-align:center; padding:20px; border-radius:12px;
background:#fee; border:2px solid #e74c3c;">
<h3 style="color:#e74c3c; margin:0;">⚠ Invalid Input</h3>
<p style="color:#333;">{msg}</p>
</div>
"""
# ── Gradio Interface ──────────────────────────────────────────
with gr.Blocks(
title="Cardiovascular Risk Assessment",
theme=gr.themes.Soft(
primary_hue="red",
secondary_hue="orange",
),
css="""
.main-header {
text-align: center;
padding: 20px 0 10px 0;
}
.main-header h1 {
font-size: 2em;
margin: 0;
}
.main-header p {
color: #666;
margin: 5px 0 0 0;
}
"""
) as app:
gr.HTML("""
<div class="main-header">
<h1>πŸ«€ Cardiovascular Risk Assessment</h1>
</div>
<div style="text-align:center; padding:12px 20px; margin:0 0 15px 0;
background:#fff3cd; border:2px solid #ffc107; border-radius:10px;">
<span style="font-size:1.3em; font-weight:bold; color:#856404;">
⚠ Important Disclaimer
</span>
<br>
<span style="font-size:1em; color:#856404;">
This application is an educational and demonstration project only. It is
<span style="font-weight:bold; color:#856404;">NOT</span> a certified medical device,
diagnostic tool, or a substitute for professional medical advice, diagnosis, or treatment.
The developer assumes <span style="font-weight:bold; color:#856404;">NO RESPONSIBILITY OR LIABILITY</span> for any decisions, actions, or health outcomes
based on the results provided by this tool. Always consult a qualified healthcare professional
for actual cardiovascular risk assessment and medical guidance.
</span>
</div>
""")
with gr.Row():
# ── Left column: Inputs ──
with gr.Column(scale=1):
gr.Markdown("### Patient Information")
age = gr.Slider(
minimum=40, maximum=79, value=55, step=1,
label="Age (years)",
info="Valid range: 40–79"
)
sex = gr.Radio(
choices=["Male", "Female"],
value="Male",
label="Gender"
)
race = gr.Radio(
choices=["White", "Black"],
value="White",
label="Race"
)
gr.Markdown("### Clinical Measurements")
with gr.Row():
total_chol = gr.Number(
value=200, label="Total Cholesterol (mg/dL)",
info="Typical range: 125–300"
)
hdl = gr.Number(
value=50, label="HDL Cholesterol (mg/dL)",
info="Typical range: 20–100"
)
with gr.Row():
sbp = gr.Number(
value=130, label="Systolic BP (mmHg)",
info="Top number. Normal: <120, Elevated: 120–129, High: 130+"
)
dbp = gr.Number(
value=80, label="Diastolic BP (mmHg)",
info="Bottom number. Normal: <80, High: 80+"
)
gr.Markdown("### Medical History")
with gr.Row():
bp_meds = gr.Radio(
choices=["No", "Yes"], value="No",
label="Taking BP Medication?"
)
diabetes = gr.Radio(
choices=["No", "Yes"], value="No",
label="Has Diabetes?"
)
smoking = gr.Radio(
choices=["No", "Yes"], value="No",
label="Currently Smoking?"
)
gr.Markdown("### Optional Lab Results")
with gr.Row():
fasting_glucose = gr.Number(
value=100, label="Fasting Glucose (mg/dL)",
info="Typical range: 70–130. Leave as 0 if unknown"
)
hba1c = gr.Number(
value=5.5, label="HbA1c (%)",
step=0.1,
info="Typical range: 4.0–6.5%. Leave as 0 if unknown"
)
predict_btn = gr.Button(
"πŸ” Assess Risk", variant="primary", size="lg"
)
# ── Right column: Results ──
with gr.Column(scale=1):
gr.Markdown("### Assessment Result")
result_html = gr.HTML(
value="<div style='text-align:center; padding:40px; color:#999;'>"
"Enter patient information and click <b>Assess Risk</b>.</div>"
)
confidence_chart = gr.Label(
label="Model Confidence by Risk Category",
num_top_classes=4,
)
# ── Connect ──
predict_btn.click(
fn=predict_risk,
inputs=[age, sex, race, total_chol, hdl, sbp, dbp,
bp_meds, diabetes, smoking, fasting_glucose, hba1c],
outputs=[result_html, confidence_chart],
)
if __name__ == "__main__":
app.launch()