import gradio as gr import joblib import numpy as np import pandas as pd import json import os # ── Load model + feature list ────────────────────────────────── MODEL_PATH = os.path.join(os.path.dirname(__file__), "best_surrogate_model.joblib") FEATURES_PATH = os.path.join(os.path.dirname(__file__), "feature_cols.json") model = joblib.load(MODEL_PATH) with open(FEATURES_PATH) as f: feature_cols = json.load(f) RISK_BANDS = {0: "Low", 1: "Borderline", 2: "Intermediate", 3: "High"} RISK_COLORS = { "Low": "#27ae60", "Borderline": "#f39c12", "Intermediate": "#e67e22", "High": "#e74c3c", } RISK_DESCRIPTIONS = { "Low": "Less than 5% chance of a heart attack or stroke in the next 10 years.", "Borderline": "5% to 7.5% chance of a heart attack or stroke in the next 10 years. Lifestyle changes recommended.", "Intermediate": "7.5% to 20% chance of a heart attack or stroke in the next 10 years. Discuss treatment options with a doctor.", "High": "20% or higher chance of a heart attack or stroke in the next 10 years. Medical intervention strongly recommended.", } def build_features(age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c): """ Take raw clinical inputs and build all 31 features in the exact order the model expects. """ # ── Base features ── sex_code = 1 if sex == "Male" else 0 race_code = 1 if race == "Black" else 0 bp_meds_val = 1 if bp_meds == "Yes" else 0 diabetes_val = 1 if diabetes == "Yes" else 0 smoking_val = 1 if smoking == "Yes" else 0 # Handle optional fields fasting_glucose = fasting_glucose if fasting_glucose and fasting_glucose > 0 else 100.0 hba1c = hba1c if hba1c and hba1c > 0 else 5.5 # ── Log transforms (safe clip) ── log_age = np.log(max(age, 1e-6)) log_tc = np.log(max(total_chol, 1e-6)) log_hdl = np.log(max(hdl, 1e-6)) log_sbp = np.log(max(sbp, 1e-6)) # ── Build feature dict in exact order ── features = { "AGE": age, "SEX_CODE": sex_code, "RACE_CODE": race_code, "TOTAL_CHOL": total_chol, "HDL_UNIFIED": hdl, "SBP_AVG": sbp, "BP_MEDS_CURRENT": bp_meds_val, "DIABETES_PCE": diabetes_val, "CURRENT_SMOKER": smoking_val, "DBP_AVG": dbp, "FASTING_GLUCOSE": fasting_glucose, "HBA1C": hba1c, # Log transforms "LOG_AGE": log_age, "LOG_TOTAL_CHOL": log_tc, "LOG_HDL": log_hdl, "LOG_SBP": log_sbp, # PCE-like interactions "LOG_AGE_x_LOG_TOTAL_CHOL": log_age * log_tc, "LOG_AGE_x_LOG_HDL": log_age * log_hdl, "LOG_AGE_x_CURRENT_SMOKER": log_age * smoking_val, # BP treatment interactions "LOG_SBP_x_BP_MEDS_CURRENT": log_sbp * bp_meds_val, "LOG_SBP_x_NO_BP_MEDS": log_sbp * (1 - bp_meds_val), # Clinical helpers "PULSE_PRESSURE": sbp - dbp, "TOTAL_CHOL_TO_HDL_RATIO": total_chol / max(hdl, 1e-6), "AGE_x_SBP": age * sbp, "AGE_x_TOTAL_CHOL": age * total_chol, "AGE_x_HDL": age * hdl, # Age squared "LOG_AGE_SQUARED": log_age ** 2, # Subgroup indicators "IS_MALE": sex_code, "IS_BLACK": race_code, "IS_MALE_x_LOG_AGE": sex_code * log_age, "IS_BLACK_x_LOG_AGE": race_code * log_age, } df = pd.DataFrame([features])[feature_cols] return df def predict_risk(age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c): """Main prediction function.""" # ── Validation ── if age < 40 or age > 79: return ( format_error("Age must be between 40 and 79 (valid range for the ASCVD calculator)."), None, ) if total_chol < 100 or total_chol > 400: return ( format_error("Total cholesterol should be between 100 and 400 mg/dL."), None, ) if hdl < 10 or hdl > 150: return ( format_error("HDL cholesterol should be between 10 and 150 mg/dL."), None, ) if sbp < 80 or sbp > 250: return ( format_error("Systolic blood pressure should be between 80 and 250 mmHg."), None, ) if dbp < 40 or dbp > 150: return ( format_error("Diastolic blood pressure should be between 40 and 150 mmHg."), None, ) # ── Build features + predict ── X = build_features(age, sex, race, total_chol, hdl, sbp, dbp, bp_meds, diabetes, smoking, fasting_glucose, hba1c) pred_class = int(model.predict(X)[0]) proba = model.predict_proba(X)[0] band = RISK_BANDS[pred_class] color = RISK_COLORS[band] desc = RISK_DESCRIPTIONS[band] # ── Format result ── result_html = f"""
{desc}
{msg}