Spaces:
Sleeping
Sleeping
Kasilanka Bhoopesh Siva Srikar
Complete Heart Attack Risk Prediction App - Ready for Deployment
08123aa
| """ | |
| Streamlit App for Heart Attack Risk Prediction | |
| Based on ensemble model (XGBoost + CatBoost + LightGBM) | |
| """ | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import json | |
| import os | |
| from pathlib import Path | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Predicting Heart Attack Risk: An Ensemble Modeling Approach", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for modern styling | |
| st.markdown(""" | |
| <style> | |
| /* Modern Design System */ | |
| :root { | |
| --primary: #3B82F6; | |
| --primary-dark: #2563EB; | |
| --secondary: #8B5CF6; | |
| --success: #10B981; | |
| --warning: #F59E0B; | |
| --danger: #EF4444; | |
| --bg-card: rgba(30, 41, 59, 0.4); | |
| --bg-card-hover: rgba(30, 41, 59, 0.6); | |
| --border: rgba(148, 163, 184, 0.1); | |
| --border-strong: rgba(148, 163, 184, 0.2); | |
| --text-primary: #F1F5F9; | |
| --text-secondary: #CBD5E1; | |
| --shadow-sm: 0 1px 3px 0 rgb(0 0 0 / 0.1); | |
| --shadow-md: 0 4px 6px -1px rgb(0 0 0 / 0.1); | |
| --shadow-lg: 0 10px 15px -3px rgb(0 0 0 / 0.2); | |
| --radius: 16px; | |
| --radius-sm: 12px; | |
| } | |
| /* Hide Streamlit branding */ | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| /* Main container improvements */ | |
| .main .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 2rem; | |
| max-width: 1400px; | |
| } | |
| /* Header with gradient */ | |
| .main-header { | |
| font-size: 2.5rem; | |
| font-weight: 800; | |
| text-align: center; | |
| margin: 0 0 0.5rem; | |
| letter-spacing: -0.02em; | |
| background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%); | |
| -webkit-background-clip: text; | |
| background-clip: text; | |
| color: transparent; | |
| line-height: 1.2; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| color: var(--text-secondary); | |
| font-size: 0.95rem; | |
| margin-bottom: 2rem; | |
| font-weight: 400; | |
| } | |
| /* Section divider */ | |
| .section-divider { | |
| height: 1px; | |
| background: linear-gradient(90deg, transparent, var(--border-strong), transparent); | |
| margin: 2rem 0; | |
| border: none; | |
| } | |
| /* Modern cards */ | |
| .info-card { | |
| padding: 1.5rem; | |
| border-radius: var(--radius-sm); | |
| background: var(--bg-card); | |
| border: 1px solid var(--border); | |
| backdrop-filter: blur(10px); | |
| transition: all 0.3s ease; | |
| box-shadow: var(--shadow-sm); | |
| } | |
| .info-card:hover { | |
| background: var(--bg-card-hover); | |
| border-color: var(--border-strong); | |
| box-shadow: var(--shadow-md); | |
| transform: translateY(-2px); | |
| } | |
| /* Metric cards */ | |
| div[data-testid="metric-container"] { | |
| background: var(--bg-card); | |
| padding: 1rem; | |
| border-radius: var(--radius-sm); | |
| border: 1px solid var(--border); | |
| box-shadow: var(--shadow-sm); | |
| } | |
| div[data-testid="metric-container"]:hover { | |
| background: var(--bg-card-hover); | |
| border-color: var(--border-strong); | |
| } | |
| /* Buttons */ | |
| .stButton > button { | |
| width: 100%; | |
| background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%); | |
| color: white; | |
| border: none; | |
| padding: 0.875rem 2rem; | |
| font-size: 1.05rem; | |
| font-weight: 600; | |
| border-radius: var(--radius-sm); | |
| transition: all 0.3s ease; | |
| box-shadow: var(--shadow-md); | |
| letter-spacing: 0.01em; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: var(--shadow-lg); | |
| background: linear-gradient(135deg, var(--primary-dark) 0%, var(--primary) 100%); | |
| } | |
| /* Input fields */ | |
| .stTextInput > div > div > input, | |
| .stNumberInput > div > div > input, | |
| .stSelectbox > div > div, | |
| .stRadio > div { | |
| background: var(--bg-card); | |
| border: 1px solid var(--border); | |
| border-radius: var(--radius-sm); | |
| color: var(--text-primary); | |
| } | |
| .stTextInput > div > div > input:focus, | |
| .stNumberInput > div > div > input:focus, | |
| .stSelectbox > div > div:focus-within { | |
| border-color: var(--primary); | |
| box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1); | |
| } | |
| /* Sidebar */ | |
| section[data-testid="stSidebar"] { | |
| background: linear-gradient(180deg, rgba(15, 23, 42, 0.95) 0%, rgba(30, 41, 59, 0.95) 100%); | |
| border-right: 1px solid var(--border); | |
| } | |
| section[data-testid="stSidebar"] .block-container { | |
| padding-top: 2rem; | |
| } | |
| /* Expander */ | |
| .streamlit-expanderHeader { | |
| background: var(--bg-card); | |
| border: 1px solid var(--border); | |
| border-radius: var(--radius-sm); | |
| font-weight: 600; | |
| color: var(--text-primary); | |
| } | |
| .streamlit-expanderHeader:hover { | |
| background: var(--bg-card-hover); | |
| border-color: var(--border-strong); | |
| } | |
| /* Progress bars */ | |
| .stProgress > div > div > div { | |
| background: linear-gradient(90deg, var(--primary), var(--secondary)); | |
| border-radius: 10px; | |
| } | |
| /* Tabs */ | |
| .stTabs [data-baseweb="tab-list"] { | |
| gap: 8px; | |
| background: transparent; | |
| } | |
| .stTabs [data-baseweb="tab"] { | |
| background: var(--bg-card); | |
| border: 1px solid var(--border); | |
| border-radius: var(--radius-sm); | |
| padding: 0.5rem 1.5rem; | |
| color: var(--text-secondary); | |
| } | |
| .stTabs [aria-selected="true"] { | |
| background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%); | |
| color: white; | |
| border-color: transparent; | |
| } | |
| /* Alerts */ | |
| .stAlert { | |
| border-radius: var(--radius-sm); | |
| border: 1px solid var(--border); | |
| backdrop-filter: blur(10px); | |
| } | |
| /* Success/Error states */ | |
| .risk-high { | |
| color: var(--danger); | |
| font-size: 1.5rem; | |
| font-weight: 700; | |
| } | |
| .risk-low { | |
| color: var(--success); | |
| font-size: 1.5rem; | |
| font-weight: 700; | |
| } | |
| /* Section headers */ | |
| h1, h2, h3 { | |
| color: var(--text-primary); | |
| font-weight: 700; | |
| letter-spacing: -0.01em; | |
| } | |
| h2 { | |
| font-size: 1.875rem; | |
| margin-top: 2rem; | |
| margin-bottom: 1rem; | |
| } | |
| h3 { | |
| font-size: 1.25rem; | |
| margin-top: 1.5rem; | |
| margin-bottom: 0.75rem; | |
| } | |
| /* Info boxes */ | |
| .stMarkdown p { | |
| color: var(--text-secondary); | |
| line-height: 1.6; | |
| } | |
| /* Radio buttons */ | |
| .stRadio > label { | |
| color: var(--text-primary); | |
| font-weight: 500; | |
| } | |
| /* Selectbox */ | |
| .stSelectbox > label { | |
| color: var(--text-primary); | |
| font-weight: 500; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Paths | |
| BASE_DIR = os.path.dirname(__file__) | |
| ASSETS_DIR = os.path.join(BASE_DIR, "model_assets") | |
| os.makedirs(ASSETS_DIR, exist_ok=True) | |
| def find_first_existing(names): | |
| for n in names: | |
| p = os.path.join(ASSETS_DIR, n) | |
| if os.path.exists(p): | |
| return p | |
| return None | |
| def load_performance_metrics(): | |
| """Load model and ensemble metrics from available CSVs. | |
| Returns: | |
| metrics_rows: list of dicts with keys: model, accuracy, recall, f1, roc_auc | |
| hybrid_rows: list of dicts with keys: version, accuracy, recall, f1, roc_auc | |
| """ | |
| metrics_rows = [] | |
| hybrid_rows = [] | |
| # Candidate files in order of preference | |
| candidate_model_metrics = [ | |
| os.path.join(BASE_DIR, "content", "models", "model_metrics_best.csv"), | |
| os.path.join(BASE_DIR, "model_assets", "model_metrics.csv"), | |
| os.path.join(BASE_DIR, "content", "models", "model_metrics.csv"), | |
| ] | |
| candidate_hybrid_metrics = [ | |
| os.path.join(BASE_DIR, "content", "models", "hybrid_metrics_best.csv"), | |
| os.path.join(BASE_DIR, "model_assets", "hybrid_metrics.csv"), | |
| os.path.join(BASE_DIR, "content", "models", "hybrid_metrics.csv"), | |
| ] | |
| # Load model metrics - prioritize optimized metrics | |
| candidate_model_metrics_priority = [ | |
| os.path.join(BASE_DIR, "content", "models", "model_metrics_optimized.csv"), | |
| os.path.join(BASE_DIR, "model_assets", "model_metrics_optimized.csv"), | |
| os.path.join(BASE_DIR, "content", "models", "model_metrics_best.csv"), | |
| ] + candidate_model_metrics | |
| for fp in candidate_model_metrics_priority: | |
| if os.path.exists(fp): | |
| try: | |
| df = pd.read_csv(fp) | |
| except Exception: | |
| try: | |
| df = pd.read_csv(fp, index_col=0) | |
| except Exception: | |
| continue | |
| cols = {c.lower(): c for c in df.columns} | |
| # Normalize rows | |
| for idx, row in df.iterrows(): | |
| mr = {} | |
| mr["model"] = str(row.get(cols.get("model"), idx)) | |
| for k in ["accuracy", "precision", "recall", "f1", "roc_auc"]: | |
| v = row.get(cols.get(k)) if cols.get(k) in row else None | |
| try: | |
| mr[k] = float(v) | |
| except Exception: | |
| mr[k] = None | |
| metrics_rows.append(mr) | |
| # Prefer first successful file then break | |
| if metrics_rows: | |
| break | |
| # Load hybrid/ensemble metrics - prioritize optimized metrics | |
| candidate_hybrid_metrics_priority = [ | |
| os.path.join(BASE_DIR, "content", "models", "hybrid_metrics_best.csv"), | |
| os.path.join(BASE_DIR, "model_assets", "hybrid_metrics.csv"), | |
| os.path.join(BASE_DIR, "content", "models", "hybrid_metrics.csv"), | |
| ] + candidate_hybrid_metrics | |
| for fp in candidate_hybrid_metrics_priority: | |
| if os.path.exists(fp): | |
| try: | |
| dfh = pd.read_csv(fp) | |
| except Exception: | |
| try: | |
| dfh = pd.read_csv(fp, index_col=0) | |
| except Exception: | |
| continue | |
| cols = {c.lower(): c for c in dfh.columns} | |
| for idx, row in dfh.iterrows(): | |
| hr = {} | |
| hr["version"] = str(row.get(cols.get("version", "version"), idx)) | |
| for k in ["accuracy", "precision", "recall", "f1", "roc_auc"]: | |
| v = row.get(cols.get(k)) if cols.get(k) in row else None | |
| try: | |
| hr[k] = float(v) | |
| except Exception: | |
| hr[k] = None | |
| hybrid_rows.append(hr) | |
| if hybrid_rows: | |
| break | |
| return metrics_rows, hybrid_rows | |
| def get_algo_metrics(metrics_rows, algo_name: str): | |
| """Pick metrics for a given algo ('XGBoost', 'CatBoost', 'LightGBM'). | |
| Uses heuristics to match model names in CSV. | |
| Returns best (highest accuracy) matching row or None. | |
| """ | |
| if not metrics_rows: | |
| return None | |
| name_hints = { | |
| "XGBoost": ["XGB", "XGBoost", "xgb"], | |
| "CatBoost": ["CAT", "CatBoost", "cat"], | |
| "LightGBM": ["LGBM", "LightGBM", "lgb"], | |
| "LogReg": ["LogReg", "logreg", "logistic"], | |
| "RandomForest": ["RF", "RandomForest", "random forest"], | |
| } | |
| hints = name_hints.get(algo_name, [algo_name]) | |
| best = None | |
| for row in metrics_rows: | |
| label = str(row.get("model", "")).upper() | |
| if any(hint.upper() in label for hint in hints): | |
| if best is None: | |
| best = row | |
| else: | |
| acc_best = best.get("accuracy") or -1 | |
| acc_new = row.get("accuracy") or -1 | |
| if acc_new > acc_best: | |
| best = row | |
| return best | |
| def get_ensemble_metrics(hybrid_rows, metrics_rows=None): | |
| """Return the preferred ensemble metrics row. | |
| Preference: 'Ensemble_optimized' from model_metrics -> 'Ensemble_best@0.5' -> 'Ensemble@0.5' -> first Ensemble row. | |
| """ | |
| # First, try to get Ensemble_optimized from model_metrics (most recent optimized) | |
| if metrics_rows: | |
| for row in metrics_rows: | |
| model_name = str(row.get("model", "")).upper() | |
| if "ENSEMBLE" in model_name and "OPTIMIZED" in model_name: | |
| return row | |
| # Then check hybrid_rows | |
| if not hybrid_rows: | |
| return None | |
| # Normalize | |
| rows = list(hybrid_rows) | |
| # First preference: Ensemble_best@0.5 | |
| for r in rows: | |
| ver = str(r.get("version", "")) | |
| if ver.lower() == "ensemble_best@0.5" or ("ensemble_best" in ver.lower() and "@0.5" in ver.lower()): | |
| return r | |
| # Second preference: Ensemble@0.5 | |
| for r in rows: | |
| ver = str(r.get("version", "")) | |
| if ver.lower() == "ensemble@0.5" or ("ensemble" in ver.lower() and "@0.5" in ver.lower()): | |
| return r | |
| # Any ensemble row | |
| for r in rows: | |
| ver = str(r.get("version", "")) | |
| if "ensemble" in ver.lower(): | |
| return r | |
| return None | |
| def load_models(): | |
| """Load models and preprocessor (cached for performance). Robust per-model loading.""" | |
| preprocessor = None | |
| try: | |
| preproc_path = find_first_existing(["preprocessor.joblib"]) | |
| if preproc_path: | |
| preprocessor = joblib.load(preproc_path) | |
| except Exception as e: | |
| st.warning(f"Preprocessor load skipped: {e}") | |
| models = {} | |
| # Resolve paths - prioritize optimized models | |
| xgb_path = find_first_existing([ | |
| "XGBoost_optimized.joblib", "XGB_spw.joblib", "XGBoost.joblib", "xgb_model.joblib", "xgb_full.joblib", "XGBoost_best_5cv.joblib" | |
| ]) | |
| cat_path = find_first_existing([ | |
| "CatBoost_optimized.joblib", "CAT_cw.joblib", "CatBoost.joblib", "catboost.joblib", "cat_model.joblib", "cat_full.joblib", "CatBoost_best_5cv.joblib" | |
| ]) | |
| lgb_path = find_first_existing([ | |
| "LightGBM_optimized.joblib", "LGBM_cw.joblib", "LightGBM.joblib", "lgb_model.joblib", "LightGBM_best_5cv.joblib" | |
| ]) | |
| # Load each model independently so one failure doesn't break others | |
| if xgb_path: | |
| try: | |
| models["XGBoost"] = joblib.load(xgb_path) | |
| except Exception as e: | |
| st.warning(f"XGBoost model failed to load from {os.path.basename(xgb_path)}: {e}") | |
| if cat_path: | |
| try: | |
| models["CatBoost"] = joblib.load(cat_path) | |
| except Exception as e: | |
| st.warning(f"CatBoost model failed to load from {os.path.basename(cat_path)}: {e}") | |
| if lgb_path: | |
| try: | |
| models["LightGBM"] = joblib.load(lgb_path) | |
| except Exception as e: | |
| st.warning(f"LightGBM model failed to load from {os.path.basename(lgb_path)}: {e}") | |
| # Do NOT restrict to CatBoost if preprocessor is missing; ensemble needs both. | |
| # Load metrics paths for display/selection (optional) | |
| metrics_paths = [] | |
| for mp in ["hybrid_metrics.csv", "model_metrics_summary.csv", "model_metrics.csv"]: | |
| p = find_first_existing([mp]) | |
| if p: | |
| metrics_paths.append(p) | |
| return preprocessor, models, metrics_paths | |
| def pick_best_model(models: dict, metrics_paths: list): | |
| """Pick best model based on highest accuracy then recall from available metrics CSVs.""" | |
| fallback_order = [ | |
| ("CatBoost", ["CAT", "Cat", "cat"]), | |
| ("XGBoost", ["XGB", "XGBoost", "xgb"]), | |
| ("LightGBM", ["LGBM", "LightGBM", "lgbm"]), | |
| ] | |
| best_label = None | |
| best_acc = -1.0 | |
| best_rec = -1.0 | |
| for mp in metrics_paths: | |
| try: | |
| dfm = pd.read_csv(mp) | |
| except Exception: | |
| try: | |
| dfm = pd.read_csv(mp, index_col=0) | |
| except Exception: | |
| continue | |
| cols = {c.lower(): c for c in dfm.columns} | |
| if "accuracy" in cols and "recall" in cols: | |
| acc_col = cols["accuracy"] | |
| rec_col = cols["recall"] | |
| if "model" in {c.lower() for c in dfm.columns}: | |
| name_col = [c for c in dfm.columns if c.lower() == "model"][0] | |
| iter_rows = dfm[[name_col, acc_col, rec_col]].itertuples(index=False, name=None) | |
| else: | |
| iter_rows = zip(dfm.index.astype(str).tolist(), dfm[acc_col].tolist(), dfm[rec_col].tolist()) | |
| for label, acc, rec in iter_rows: | |
| try: | |
| acc_f = float(acc) | |
| rec_f = float(rec) | |
| except Exception: | |
| continue | |
| if (acc_f > best_acc) or (np.isclose(acc_f, best_acc) and rec_f > best_rec): | |
| best_acc = acc_f | |
| best_rec = rec_f | |
| best_label = str(label) | |
| if best_label: | |
| label_u = best_label.upper() | |
| if "CAT" in label_u and "CatBoost" in models: | |
| return "CatBoost" | |
| if "XGB" in label_u and "XGBoost" in models: | |
| return "XGBoost" | |
| if ("LGBM" in label_u or "LGB" in label_u) and "LightGBM" in models: | |
| return "LightGBM" | |
| for key, hints in fallback_order: | |
| if key in models: | |
| return key | |
| return None | |
| # Load models | |
| preprocessor, models, metrics_paths = load_models() | |
| if not models: | |
| st.error("⚠️ No models found in `model_assets/`. Please add your trained model files.") | |
| st.stop() | |
| # Enforce Ensemble-only usage: require both XGBoost and CatBoost | |
| if not ("XGBoost" in models and "CatBoost" in models): | |
| st.error("⚠️ Ensemble requires both XGBoost and CatBoost models. Please ensure both artifacts are present in `model_assets/`.") | |
| st.stop() | |
| # Load ensemble configuration (weights and thresholds) | |
| ensemble_config = None | |
| ensemble_info_paths = [ | |
| os.path.join(BASE_DIR, "model_assets", "ensemble_info_optimized.json"), | |
| os.path.join(BASE_DIR, "content", "models", "ensemble_info_optimized.json"), | |
| ] | |
| for path in ensemble_info_paths: | |
| if os.path.exists(path): | |
| try: | |
| with open(path, 'r') as f: | |
| ensemble_config = json.load(f) | |
| break | |
| except Exception as e: | |
| continue | |
| # Default ensemble weights if config not found | |
| if ensemble_config: | |
| ensemble_weights_config = ensemble_config.get('weights', {}) | |
| default_xgb_weight = ensemble_weights_config.get('XGBoost', 0.5) | |
| default_cat_weight = ensemble_weights_config.get('CatBoost', 0.5) | |
| default_lgb_weight = ensemble_weights_config.get('LightGBM', 0.0) | |
| else: | |
| default_xgb_weight = 0.5 | |
| default_cat_weight = 0.5 | |
| default_lgb_weight = 0.0 | |
| # Main title | |
| st.markdown('<h1 class="main-header">Predicting Heart Attack Risk: An Ensemble Modeling Approach</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="subtitle">Advanced machine learning ensemble combining XGBoost, CatBoost, and LightGBM for accurate cardiovascular risk assessment</p>', unsafe_allow_html=True) | |
| st.markdown('<div class="section-divider"></div>', unsafe_allow_html=True) | |
| # Sidebar for model info | |
| with st.sidebar: | |
| st.header("📊 Ensemble") | |
| # Display ensemble weights | |
| if ensemble_config: | |
| weights = ensemble_config.get('weights', {}) | |
| xgb_w = weights.get('XGBoost', 0.5) * 100 | |
| cat_w = weights.get('CatBoost', 0.5) * 100 | |
| lgb_w = weights.get('LightGBM', 0.0) * 100 | |
| if lgb_w > 0: | |
| st.success(f"✅ Using Optimized Ensemble\nXGBoost: {xgb_w:.1f}% | CatBoost: {cat_w:.1f}% | LightGBM: {lgb_w:.1f}%") | |
| else: | |
| st.success(f"✅ Using Optimized Ensemble\nXGBoost: {xgb_w:.1f}% | CatBoost: {cat_w:.1f}%") | |
| else: | |
| st.success("✅ Using Ensemble (50% XGBoost + 50% CatBoost)") | |
| _model_rows, _hybrid_rows = load_performance_metrics() | |
| ens_row = get_ensemble_metrics(_hybrid_rows, _model_rows) | |
| acc_text = f"{ens_row['accuracy']*100:.2f}%" if ens_row and ens_row.get('accuracy') is not None else "n/a" | |
| rec_text = f"{ens_row['recall']*100:.2f}%" if ens_row and ens_row.get('recall') is not None else "n/a" | |
| cols_side = st.columns(2) | |
| with cols_side[0]: | |
| st.metric("Accuracy", acc_text) | |
| with cols_side[1]: | |
| st.metric("Recall", rec_text) | |
| if metrics_paths: | |
| st.markdown("**Performance Metrics:**") | |
| for mp in metrics_paths: | |
| try: | |
| dfm = pd.read_csv(mp, index_col=0) if mp.endswith('.csv') else pd.read_csv(mp) | |
| st.dataframe(dfm.head(10), use_container_width=True) | |
| except Exception: | |
| pass | |
| st.markdown("---") | |
| st.info(""" | |
| **Note:** This is a prediction tool, not a medical diagnosis. | |
| Always consult healthcare professionals for medical advice. | |
| """) | |
| # Input form with all features | |
| st.header("📝 Patient Information") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Demographics") | |
| gender = st.selectbox("Gender", options=[1, 2], format_func=lambda x: "Male" if x == 1 else "Female") | |
| height = st.number_input("Height (cm)", min_value=100, max_value=220, value=170, step=1) | |
| weight = st.number_input("Weight (kg)", min_value=30.0, max_value=200.0, value=70.0, step=0.1) | |
| # Calculate BMI with category | |
| bmi = weight / ((height / 100) ** 2) if height > 0 else 0 | |
| if bmi < 18.5: | |
| bmi_status = "⚠️ Underweight" | |
| bmi_color = "inverse" | |
| elif bmi < 25: | |
| bmi_status = "✅ Normal" | |
| bmi_color = "normal" | |
| elif bmi < 30: | |
| bmi_status = "⚠️ Overweight" | |
| bmi_color = "normal" | |
| else: | |
| bmi_status = "🔴 Obese" | |
| bmi_color = "inverse" | |
| st.metric("BMI", f"{bmi:.2f}", delta=bmi_status, delta_color=bmi_color, | |
| help="Body Mass Index - Healthy range: 18.5-24.9") | |
| with col2: | |
| st.subheader("Blood Pressure") | |
| ap_hi = st.number_input("Systolic BP (mmHg)", min_value=80, max_value=250, value=120, step=1) | |
| ap_lo = st.number_input("Diastolic BP (mmHg)", min_value=40, max_value=150, value=80, step=1) | |
| # Calculate BP_diff and category | |
| bp_diff = ap_hi - ap_lo | |
| # BP Status | |
| if ap_hi < 120 and ap_lo < 80: | |
| bp_status = "✅ Normal" | |
| bp_color = "normal" | |
| elif ap_hi < 130 and ap_lo < 80: | |
| bp_status = "⚠️ Elevated" | |
| bp_color = "normal" | |
| elif ap_hi < 140 or ap_lo < 90: | |
| bp_status = "🔴 Stage 1" | |
| bp_color = "inverse" | |
| else: | |
| bp_status = "🚨 Stage 2" | |
| bp_color = "inverse" | |
| st.metric("Pulse Pressure", f"{bp_diff} mmHg", delta=bp_status, delta_color=bp_color, | |
| help="Normal BP: <120/80 mmHg") | |
| st.markdown("---") | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| st.subheader("Medical History") | |
| cholesterol = st.selectbox("Cholesterol Level", options=[1, 2, 3], | |
| format_func=lambda x: {1: "Normal", 2: "Above Normal", 3: "Well Above Normal"}.get(x)) | |
| gluc = st.selectbox("Glucose Level", options=[1, 2, 3], | |
| format_func=lambda x: {1: "Normal", 2: "Above Normal", 3: "Well Above Normal"}.get(x)) | |
| smoke = st.radio("Smoking", options=[0, 1], format_func=lambda x: "No" if x == 0 else "Yes", horizontal=True) | |
| alco = st.radio("Alcohol Consumption", options=[0, 1], format_func=lambda x: "No" if x == 0 else "Yes", horizontal=True) | |
| with col4: | |
| st.subheader("Activity & Derived Features") | |
| active = st.radio("Physical Activity", options=[0, 1], format_func=lambda x: "No" if x == 0 else "Yes", horizontal=True) | |
| # Age in years (for display) | |
| age_years = st.number_input("Age (years)", min_value=20, max_value=100, value=50, step=1) | |
| age_days = age_years * 365 # Convert to days for model compatibility | |
| # Derived features | |
| systolic_pressure = ap_hi | |
| map_value = ap_lo + (bp_diff / 3) # Mean Arterial Pressure approximation | |
| pulse_pressure_ratio = bp_diff / ap_hi if ap_hi > 0 else 0 | |
| # Additional derived features | |
| st.markdown("---") | |
| st.subheader("Additional Health Metrics") | |
| col5, col6, col7 = st.columns(3) | |
| with col5: | |
| protein_level = st.number_input("Protein Level", min_value=0.0, max_value=200.0, value=14.0, step=0.1) | |
| with col6: | |
| ejection_fraction = st.number_input("Ejection Fraction (%)", min_value=0.0, max_value=100.0, value=60.0, step=0.1) | |
| with col7: | |
| # Calculate Lifestyle Score automatically | |
| lifestyle_score = 0 | |
| risk_factors = [] | |
| if smoke == 1: | |
| lifestyle_score += 1 | |
| risk_factors.append("Smoking") | |
| if alco == 1: | |
| lifestyle_score += 1 | |
| risk_factors.append("Alcohol") | |
| if active == 0: | |
| lifestyle_score += 1 | |
| risk_factors.append("Physical inactivity") | |
| if lifestyle_score == 0: | |
| score_label = "✅ Low Risk" | |
| delta_color = "normal" | |
| elif lifestyle_score == 1: | |
| score_label = "⚠️ Moderate Risk" | |
| delta_color = "normal" | |
| elif lifestyle_score == 2: | |
| score_label = "🔴 High Risk" | |
| delta_color = "inverse" | |
| else: | |
| score_label = "🚨 Very High Risk" | |
| delta_color = "inverse" | |
| st.metric( | |
| "Lifestyle Risk Score", | |
| f"{lifestyle_score}/3 - {score_label}", | |
| help=f"Auto-calculated from lifestyle factors. Risk factors: {', '.join(risk_factors) if risk_factors else 'None'}" | |
| ) | |
| if risk_factors: | |
| st.caption(f"⚠️ Risk factors: {', '.join(risk_factors)}") | |
| # Calculate additional derived features | |
| obesity_flag = 1 if bmi >= 30 else 0 | |
| hypertension_flag = 1 if ap_hi >= 140 or ap_lo >= 90 else 0 | |
| health_risk_score = lifestyle_score + obesity_flag + hypertension_flag | |
| smoker_alcoholic = 1 if (smoke == 1 or alco == 1) else 0 | |
| # Age group and BMI category | |
| if age_years < 30: | |
| age_group = "20-29" | |
| elif age_years < 40: | |
| age_group = "30-39" | |
| elif age_years < 50: | |
| age_group = "40-49" | |
| elif age_years < 60: | |
| age_group = "50-59" | |
| else: | |
| age_group = "60+" | |
| if bmi < 18.5: | |
| bmi_category = "Underweight" | |
| elif bmi < 25: | |
| bmi_category = "Normal" | |
| elif bmi < 30: | |
| bmi_category = "Overweight" | |
| else: | |
| bmi_category = "Obese" | |
| # BP Category | |
| if ap_hi < 120 and ap_lo < 80: | |
| bp_category = "Normal" | |
| elif ap_hi < 130 and ap_lo < 80: | |
| bp_category = "Elevated" | |
| elif ap_hi < 140 or ap_lo < 90: | |
| bp_category = "Stage 1" | |
| else: | |
| bp_category = "Stage 2" | |
| # Risk Level (Note: data uses "Moderate" not "Medium") | |
| if health_risk_score <= 2: | |
| risk_level = "Low" | |
| elif health_risk_score <= 4: | |
| risk_level = "Moderate" # Changed from "Medium" to match training data | |
| else: | |
| risk_level = "High" | |
| # Risk Age (derived) | |
| risk_age = age_years + (health_risk_score * 5) | |
| # Generate Reason based on risk factors | |
| reasons = [] | |
| if obesity_flag == 1: | |
| reasons.append("High BMI (>30)") | |
| if hypertension_flag == 1: | |
| reasons.append("High BP") | |
| if cholesterol == 3: | |
| reasons.append("High cholesterol") | |
| if gluc == 3: | |
| reasons.append("High glucose") | |
| if lifestyle_score > 0: | |
| if smoke == 1: | |
| reasons.append("Smoking") | |
| if alco == 1: | |
| reasons.append("Alcohol consumption") | |
| if active == 0: | |
| reasons.append("Physical inactivity") | |
| if not reasons: | |
| reasons.append("Healthy indicators") | |
| reason = ", ".join(reasons) | |
| # Create feature dictionary matching the dataset structure | |
| feature_dict = { | |
| 'age': age_days, | |
| 'gender': gender, | |
| 'height': height, | |
| 'weight': weight, | |
| 'ap_hi': ap_hi, | |
| 'ap_lo': ap_lo, | |
| 'cholesterol': cholesterol, | |
| 'gluc': gluc, | |
| 'smoke': smoke, | |
| 'alco': alco, | |
| 'active': active, | |
| 'BMI': bmi, | |
| 'BP_diff': bp_diff, | |
| 'Systolic_Pressure': systolic_pressure, | |
| 'age_years': age_years, | |
| 'Age_Group': age_group, | |
| 'Lifestyle_Score': lifestyle_score, | |
| 'Obesity_Flag': obesity_flag, | |
| 'Hypertension_Flag': hypertension_flag, | |
| 'Health_Risk_Score': health_risk_score, | |
| 'Reason': reason, | |
| 'Pulse_Pressure_Ratio': pulse_pressure_ratio, | |
| 'MAP': map_value, | |
| 'BMI_Category': bmi_category, | |
| 'Smoker_Alcoholic': smoker_alcoholic, | |
| 'BP_Category': bp_category, | |
| 'Risk_Age': risk_age, | |
| 'Risk_Level': risk_level, | |
| 'Protein_Level': protein_level, | |
| 'Ejection_Fraction': ejection_fraction | |
| } | |
| # Health Summary Card (before prediction) | |
| st.markdown("---") | |
| st.subheader("📊 Health Summary") | |
| summary_col1, summary_col2, summary_col3, summary_col4 = st.columns(4) | |
| with summary_col1: | |
| if obesity_flag == 1: | |
| st.error("🔴 Obesity Risk") | |
| else: | |
| st.success("✅ Healthy Weight") | |
| with summary_col2: | |
| if hypertension_flag == 1: | |
| st.error("🔴 Hypertension") | |
| else: | |
| st.success("✅ Normal BP") | |
| with summary_col3: | |
| if lifestyle_score >= 2: | |
| st.error(f"🔴 High Lifestyle Risk ({lifestyle_score}/3)") | |
| elif lifestyle_score == 1: | |
| st.warning(f"⚠️ Moderate Risk ({lifestyle_score}/3)") | |
| else: | |
| st.success("✅ Low Risk (0/3)") | |
| with summary_col4: | |
| if cholesterol == 3 or gluc == 3: | |
| st.error("🔴 Elevated Levels") | |
| elif cholesterol == 2 or gluc == 2: | |
| st.warning("⚠️ Above Normal") | |
| else: | |
| st.success("✅ Normal Levels") | |
| # Prediction button | |
| st.markdown("---") | |
| predict_button = st.button("🔮 Predict Heart Attack Risk", type="primary", use_container_width=True) | |
| if predict_button: | |
| try: | |
| # Create DataFrame matching EXACT training data structure (excluding id, cardio, Reason) | |
| feature_cols = ['age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo', 'cholesterol', 'gluc', | |
| 'smoke', 'alco', 'active', 'BMI', 'BP_diff', 'Systolic_Pressure', 'age_years', | |
| 'Age_Group', 'Lifestyle_Score', 'Obesity_Flag', 'Hypertension_Flag', 'Health_Risk_Score', | |
| 'Pulse_Pressure_Ratio', 'MAP', 'BMI_Category', 'Smoker_Alcoholic', 'BP_Category', | |
| 'Risk_Age', 'Risk_Level', 'Protein_Level', 'Ejection_Fraction'] | |
| # Build input row with exact feature order | |
| input_row = { | |
| 'age': age_days, | |
| 'gender': gender, | |
| 'height': height, | |
| 'weight': weight, | |
| 'ap_hi': ap_hi, | |
| 'ap_lo': ap_lo, | |
| 'cholesterol': cholesterol, | |
| 'gluc': gluc, | |
| 'smoke': smoke, | |
| 'alco': alco, | |
| 'active': active, | |
| 'BMI': bmi, | |
| 'BP_diff': bp_diff, | |
| 'Systolic_Pressure': systolic_pressure, | |
| 'age_years': age_years, | |
| 'Age_Group': age_group, | |
| 'Lifestyle_Score': lifestyle_score, | |
| 'Obesity_Flag': obesity_flag, | |
| 'Hypertension_Flag': hypertension_flag, | |
| 'Health_Risk_Score': health_risk_score, | |
| 'Pulse_Pressure_Ratio': pulse_pressure_ratio, | |
| 'MAP': map_value, | |
| 'BMI_Category': bmi_category, | |
| 'Smoker_Alcoholic': smoker_alcoholic, | |
| 'BP_Category': bp_category, | |
| 'Risk_Age': risk_age, | |
| 'Risk_Level': risk_level, | |
| 'Protein_Level': protein_level, | |
| 'Ejection_Fraction': ejection_fraction | |
| } | |
| # Create DataFrame with exact column order | |
| X_input = pd.DataFrame([input_row])[feature_cols] | |
| # The model expects numeric features - categorical columns were one-hot encoded during training | |
| # Load FULL dataset to get ALL possible categorical values (matching training) | |
| sample_csv = os.path.join(BASE_DIR, "content", "cardio_train_extended.csv") | |
| cat_cols = ['Age_Group', 'BMI_Category', 'BP_Category', 'Risk_Level'] | |
| # Get all categorical values from FULL dataset (not just sample) | |
| if os.path.exists(sample_csv): | |
| # Load full dataset to get ALL unique values (matching training) | |
| full_df = pd.read_csv(sample_csv) | |
| cat_values = {} | |
| for col in cat_cols: | |
| if col in full_df.columns: | |
| # Get all unique values and sort them (matching pandas get_dummies behavior) | |
| cat_values[col] = sorted(full_df[col].unique().tolist()) | |
| else: | |
| # Fallback to known values (matching actual data) | |
| cat_values = { | |
| 'Age_Group': ['20-29', '30-39', '40-49', '50-59', '60+'], | |
| 'BMI_Category': ['Normal', 'Obese', 'Overweight', 'Underweight'], # Sorted order from data | |
| 'BP_Category': ['Elevated', 'Normal', 'Stage 1', 'Stage 2'], # Sorted order from data | |
| 'Risk_Level': ['High', 'Low', 'Moderate'] # Note: "Moderate" not "Medium" | |
| } | |
| # Separate numeric and categorical columns | |
| numeric_cols = [col for col in X_input.columns if col not in cat_cols] | |
| X_numeric = X_input[numeric_cols].copy() | |
| # One-hot encode categorical columns with all possible categories in sorted order | |
| # This matches pandas get_dummies behavior during training | |
| X_cat_encoded_list = [] | |
| for col in cat_cols: | |
| if col in X_input.columns: | |
| # Create one-hot columns for all possible values in sorted order | |
| for val in cat_values.get(col, []): | |
| col_name = f"{col}_{val}" | |
| X_cat_encoded_list.append(pd.Series([1 if X_input[col].iloc[0] == val else 0], name=col_name)) | |
| if X_cat_encoded_list: | |
| X_cat_encoded = pd.concat(X_cat_encoded_list, axis=1) | |
| # Combine numeric and encoded categorical features | |
| X_processed = pd.concat([X_numeric, X_cat_encoded], axis=1) | |
| else: | |
| X_processed = X_numeric.copy() | |
| # Ensure all columns are numeric (float) | |
| X_processed = X_processed.astype(float) | |
| # Use ensemble model with optimized weights | |
| predictions = {} | |
| ensemble_probs = [] | |
| ensemble_weights = [] | |
| # Get ensemble weights from config or use defaults | |
| xgb_weight = default_xgb_weight if ensemble_config else 0.5 | |
| cat_weight = default_cat_weight if ensemble_config else 0.5 | |
| lgb_weight = default_lgb_weight if ensemble_config else 0.0 | |
| # Normalize weights to sum to 1.0 | |
| total_weight = xgb_weight + cat_weight + lgb_weight | |
| if total_weight > 0: | |
| xgb_weight = xgb_weight / total_weight | |
| cat_weight = cat_weight / total_weight | |
| lgb_weight = lgb_weight / total_weight | |
| # Try ensemble: XGBoost + CatBoost + LightGBM (if available) | |
| if "XGBoost" in models and "CatBoost" in models: | |
| try: | |
| # Predict with XGBoost | |
| xgb_model = models["XGBoost"] | |
| # Get expected features from XGBoost model | |
| if hasattr(xgb_model, 'feature_names_in_'): | |
| expected_features = list(xgb_model.feature_names_in_) | |
| elif hasattr(xgb_model, 'get_booster'): | |
| try: | |
| booster = xgb_model.get_booster() | |
| if hasattr(booster, 'feature_names') and booster.feature_names: | |
| expected_features = list(booster.feature_names) | |
| else: | |
| # Check n_features_in_ to create placeholder columns | |
| if hasattr(xgb_model, 'n_features_in_'): | |
| n_features = xgb_model.n_features_in_ | |
| expected_features = [f"f{i}" for i in range(n_features)] | |
| else: | |
| expected_features = None | |
| except: | |
| expected_features = None | |
| else: | |
| expected_features = None | |
| if expected_features: | |
| # Align features exactly as XGBoost expects | |
| X_aligned = pd.DataFrame(0.0, index=X_processed.index, columns=expected_features, dtype=float) | |
| # Match columns by name | |
| for col in X_processed.columns: | |
| if col in X_aligned.columns: | |
| X_aligned[col] = X_processed[col].values | |
| X_xgb = X_aligned[expected_features] # Ensure exact order | |
| else: | |
| X_xgb = X_processed | |
| if hasattr(xgb_model, 'predict_proba'): | |
| xgb_prob = float(xgb_model.predict_proba(X_xgb)[0, 1]) | |
| if xgb_weight > 0: | |
| ensemble_probs.append(xgb_prob) | |
| ensemble_weights.append(xgb_weight) | |
| predictions["XGBoost"] = xgb_prob | |
| except Exception as e: | |
| st.warning(f"⚠️ XGBoost prediction failed (using CatBoost only): {str(e)}") | |
| # Don't add to predictions, but continue with CatBoost | |
| # Predict with CatBoost | |
| if "CatBoost" in models: | |
| try: | |
| cat_model = models["CatBoost"] | |
| # CatBoost is very strict about feature order and names | |
| if hasattr(cat_model, 'feature_names_'): | |
| # CatBoost uses feature_names_ (with underscore) | |
| expected_features = list(cat_model.feature_names_) | |
| elif hasattr(cat_model, 'feature_names_in_'): | |
| expected_features = list(cat_model.feature_names_in_) | |
| else: | |
| expected_features = None | |
| if expected_features: | |
| # Create DataFrame with exact feature order and names expected by CatBoost | |
| X_aligned = pd.DataFrame(0.0, index=X_processed.index, columns=expected_features, dtype=float) | |
| # Match columns by name | |
| for col in X_processed.columns: | |
| if col in X_aligned.columns: | |
| X_aligned[col] = X_processed[col].values | |
| X_cat = X_aligned[expected_features] # Ensure exact order | |
| else: | |
| X_cat = X_processed | |
| if hasattr(cat_model, 'predict_proba'): | |
| cat_prob = float(cat_model.predict_proba(X_cat)[0, 1]) | |
| if cat_weight > 0: | |
| ensemble_probs.append(cat_prob) | |
| ensemble_weights.append(cat_weight) | |
| predictions["CatBoost"] = cat_prob | |
| except Exception as e: | |
| st.warning(f"CatBoost prediction failed: {e}") | |
| # Predict with LightGBM (if included in ensemble) | |
| if "LightGBM" in models and lgb_weight > 0: | |
| try: | |
| lgb_model = models["LightGBM"] | |
| # LightGBM is strict about feature order and names | |
| if hasattr(lgb_model, 'feature_name_'): | |
| # LightGBM uses feature_name_ (with underscore, singular) | |
| expected_features = list(lgb_model.feature_name_) | |
| elif hasattr(lgb_model, 'feature_names_in_'): | |
| expected_features = list(lgb_model.feature_names_in_) | |
| else: | |
| expected_features = None | |
| if expected_features: | |
| # Create DataFrame with exact feature order and names expected by LightGBM | |
| X_aligned = pd.DataFrame(0.0, index=X_processed.index, columns=expected_features, dtype=float) | |
| # Match columns by name | |
| for col in X_processed.columns: | |
| if col in X_aligned.columns: | |
| X_aligned[col] = X_processed[col].values | |
| X_lgb = X_aligned[expected_features] # Ensure exact order | |
| else: | |
| X_lgb = X_processed | |
| if hasattr(lgb_model, 'predict_proba'): | |
| lgb_prob = float(lgb_model.predict_proba(X_lgb)[0, 1]) | |
| ensemble_probs.append(lgb_prob) | |
| ensemble_weights.append(lgb_weight) | |
| predictions["LightGBM"] = lgb_prob | |
| except Exception as e: | |
| st.warning(f"LightGBM prediction failed: {e}") | |
| # Ensemble: require at least XGBoost and CatBoost probabilities | |
| if len(ensemble_probs) >= 2: | |
| # Normalize weights to sum to 1.0 | |
| total_weight = sum(ensemble_weights) | |
| if total_weight > 0: | |
| ensemble_weights = [w / total_weight for w in ensemble_weights] | |
| # Ensemble prediction (weighted average) | |
| ensemble_prob = np.average(ensemble_probs, weights=ensemble_weights) | |
| predictions["Ensemble"] = ensemble_prob | |
| else: | |
| st.error("Ensemble prediction requires at least XGBoost and CatBoost probabilities.") | |
| with st.expander("Debug Info"): | |
| st.write("XGBoost available:", "XGBoost" in models) | |
| st.write("CatBoost available:", "CatBoost" in models) | |
| st.write("LightGBM available:", "LightGBM" in models) | |
| st.write("Ensemble probs count:", len(ensemble_probs)) | |
| st.write("Ensemble weights:", ensemble_weights) | |
| st.stop() | |
| if not predictions: | |
| st.error("No models with predict_proba available.") | |
| st.stop() | |
| # Use ensemble prediction only | |
| if "Ensemble" in predictions: | |
| ensemble_prob = predictions["Ensemble"] | |
| else: | |
| st.error("Ensemble prediction missing.") | |
| st.stop() | |
| # Binary prediction | |
| prediction = 1 if ensemble_prob >= 0.5 else 0 | |
| risk_percentage = ensemble_prob * 100 | |
| # Display results | |
| st.markdown("---") | |
| st.header("🎯 Prediction Results") | |
| # Main result with visual indicator | |
| if prediction == 1: | |
| st.error(f"⚠️ **HIGH RISK DETECTED** - {risk_percentage:.1f}% probability of heart disease") | |
| else: | |
| st.success(f"✅ **LOW RISK** - {risk_percentage:.1f}% probability of heart disease") | |
| col_result1, col_result2, col_result3 = st.columns(3) | |
| with col_result1: | |
| st.metric("Risk Probability", f"{risk_percentage:.2f}%", | |
| delta=f"{'High' if risk_percentage >= 70 else 'Moderate' if risk_percentage >= 50 else 'Low'} Risk", | |
| delta_color="inverse" if risk_percentage >= 70 else "normal") | |
| with col_result2: | |
| if risk_percentage >= 70: | |
| risk_level_display = "🚨 Very High" | |
| elif risk_percentage >= 50: | |
| risk_level_display = "🔴 High" | |
| elif risk_percentage >= 30: | |
| risk_level_display = "⚠️ Moderate" | |
| else: | |
| risk_level_display = "✅ Low" | |
| st.metric("Risk Level", risk_level_display) | |
| with col_result3: | |
| st.metric("Prediction", "Heart Disease Detected" if prediction == 1 else "No Heart Disease", | |
| delta="Consult Doctor" if prediction == 1 else "Continue Monitoring", | |
| delta_color="inverse" if prediction == 1 else "normal") | |
| # Enhanced progress bar with color coding | |
| risk_bar_color = "#FF1744" if risk_percentage >= 70 else "#FF9800" if risk_percentage >= 50 else "#4CAF50" | |
| st.markdown(f""" | |
| <div style="background-color: #f0f0f0; border-radius: 5px; padding: 10px; margin: 10px 0;"> | |
| <div style="background-color: {risk_bar_color}; width: {risk_percentage}%; height: 30px; border-radius: 5px; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold;"> | |
| {risk_percentage:.1f}% | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Display Reason with better formatting | |
| if reason and reason != "Healthy indicators": | |
| # Check if only "Physical inactivity" is the risk factor (less severe) | |
| if reason == "Physical inactivity": | |
| st.info(f"**ℹ️ Lifestyle Note:** {reason} - Consider adding regular physical activity to reduce risk.") | |
| else: | |
| st.warning(f"**⚠️ Key Risk Factors Identified:** {reason}") | |
| else: | |
| st.success(f"**✅ Health Status:** {reason}") | |
| # Detailed breakdown with visual bars | |
| with st.expander("📊 Model Details & Breakdown"): | |
| # Load accuracy/recall metrics for display under each model | |
| _model_rows_all, _hybrid_rows_all = load_performance_metrics() | |
| xgb_m_all = get_algo_metrics(_model_rows_all, "XGBoost") | |
| cat_m_all = get_algo_metrics(_model_rows_all, "CatBoost") | |
| lgb_m_all = get_algo_metrics(_model_rows_all, "LightGBM") | |
| # Get optimized ensemble metrics | |
| ens_opt_all = None | |
| for row in _model_rows_all or []: | |
| model_name = str(row.get("model", "")).upper() | |
| if "ENSEMBLE" in model_name and "OPTIMIZED" in model_name: | |
| ens_opt_all = row | |
| break | |
| # Explicit ensemble header with models and weights | |
| if ensemble_config: | |
| weights = ensemble_config.get('weights', {}) | |
| xgb_w = weights.get('XGBoost', 0.5) * 100 | |
| cat_w = weights.get('CatBoost', 0.5) * 100 | |
| lgb_w = weights.get('LightGBM', 0.0) * 100 | |
| if lgb_w > 0: | |
| header_text = f"Ensemble uses: XGBoost ({xgb_w:.1f}%) + CatBoost ({cat_w:.1f}%) + LightGBM ({lgb_w:.1f}%)" | |
| else: | |
| header_text = f"Ensemble uses: XGBoost ({xgb_w:.1f}%) + CatBoost ({cat_w:.1f}%)" | |
| else: | |
| header_text = "Ensemble uses: XGBoost + CatBoost" | |
| if ens_opt_all and ens_opt_all.get("accuracy") is not None: | |
| st.markdown(f"**{header_text}** · Accuracy: {ens_opt_all['accuracy']*100:.2f}% | Recall: {ens_opt_all['recall']*100:.2f}%") | |
| else: | |
| st.markdown(f"**{header_text}**") | |
| # Helper function to create risk bar with percentage inside | |
| def create_risk_bar(risk_pct, model_name): | |
| # Use teal/green color for low risk, orange for moderate, red for high | |
| if risk_pct >= 50: | |
| color = '#EF4444' # Red | |
| elif risk_pct >= 30: | |
| color = '#F59E0B' # Orange | |
| else: | |
| color = '#14B8A6' # Teal/Green | |
| # Ensure bar width doesn't exceed 100% | |
| bar_width = min(risk_pct, 100) | |
| return f""" | |
| <div style="background: rgba(148, 163, 184, 0.15); border-radius: 8px; height: 36px; width: 100%; position: relative; overflow: hidden; border: 1px solid rgba(148, 163, 184, 0.3); margin: 8px 0;"> | |
| <div style="background: {color}; width: {bar_width}%; height: 100%; border-radius: 8px; display: flex; align-items: center; justify-content: flex-start; padding-left: 8px; color: white; font-weight: 600; font-size: 0.85rem; transition: width 0.3s ease;"> | |
| {risk_pct:.2f}% | |
| </div> | |
| </div> | |
| """ | |
| # Display all models horizontally on the same line (4 columns) | |
| models_to_show = [] | |
| # Collect all available models in order | |
| if "XGBoost" in predictions: | |
| models_to_show.append(("XGBoost Model", "XGBoost")) | |
| if "CatBoost" in predictions: | |
| models_to_show.append(("CatBoost Model", "CatBoost")) | |
| if "LightGBM" in predictions: | |
| models_to_show.append(("LightGBM Model", "LightGBM")) | |
| if "Ensemble" in predictions: | |
| models_to_show.append(("🎯 Ensemble (Final)", "Ensemble")) | |
| # Create columns for all models - equal width | |
| if models_to_show: | |
| num_cols = len(models_to_show) | |
| model_cols = st.columns(num_cols) | |
| for idx, (display_name, model_key) in enumerate(models_to_show): | |
| with model_cols[idx]: | |
| # Model title | |
| st.markdown(f"**{display_name}**", unsafe_allow_html=True) | |
| # Calculate risk percentage | |
| risk_pct = float(predictions[model_key]) * 100 | |
| # Display progress bar | |
| st.markdown(create_risk_bar(risk_pct, model_key), unsafe_allow_html=True) | |
| # Risk percentage below bar | |
| st.markdown(f"<div style='text-align: center; margin-top: -8px; font-size: 0.85rem; color: #666;'>{risk_pct:.2f}% risk</div>", unsafe_allow_html=True) | |
| # Show ensemble info | |
| if "Ensemble" in predictions: | |
| if ensemble_config: | |
| weights = ensemble_config.get('weights', {}) | |
| xgb_w = weights.get('XGBoost', 0.5) * 100 | |
| cat_w = weights.get('CatBoost', 0.5) * 100 | |
| lgb_w = weights.get('LightGBM', 0.0) * 100 | |
| if lgb_w > 0: | |
| st.info(f"💡 **Ensemble Method**: Weighted average (XGBoost: {xgb_w:.1f}%, CatBoost: {cat_w:.1f}%, LightGBM: {lgb_w:.1f}%). Final decision uses the Ensemble output.") | |
| else: | |
| st.info(f"💡 **Ensemble Method**: Weighted average (XGBoost: {xgb_w:.1f}%, CatBoost: {cat_w:.1f}%). Final decision uses the Ensemble output.") | |
| else: | |
| st.info("💡 **Ensemble Method**: Weighted average (50% XGBoost + 50% CatBoost). Final decision uses the Ensemble output.") | |
| # Metrics breakdown: show per-model accuracy and averaged accuracy (concise) | |
| st.markdown("---") | |
| st.subheader("Ensemble Metrics") | |
| ens_row_bd = get_ensemble_metrics(_hybrid_rows_all, _model_rows_all) | |
| acc_bd = f"{ens_row_bd['accuracy']*100:.2f}%" if ens_row_bd and ens_row_bd.get('accuracy') is not None else "n/a" | |
| rec_bd = f"{ens_row_bd['recall']*100:.2f}%" if ens_row_bd and ens_row_bd.get('recall') is not None else "n/a" | |
| cols_acc = st.columns(2) | |
| with cols_acc[0]: | |
| st.metric("Accuracy", acc_bd) | |
| with cols_acc[1]: | |
| st.metric("Recall", rec_bd) | |
| # Recommendations | |
| st.markdown("---") | |
| if prediction == 1 or risk_percentage > 70: | |
| st.warning("⚠️ **High Risk Detected!** Please consult with a healthcare professional immediately.") | |
| st.info(""" | |
| **Recommendations:** | |
| - Schedule an appointment with a cardiologist | |
| - Monitor blood pressure regularly | |
| - Maintain a healthy diet and exercise routine | |
| - Avoid smoking and limit alcohol consumption | |
| - Follow up with regular health checkups | |
| """) | |
| elif risk_percentage > 50: | |
| st.warning("⚠️ **Moderate Risk** - Consider consulting a healthcare professional.") | |
| else: | |
| st.success("✅ **Low Risk** - Continue maintaining a healthy lifestyle!") | |
| except Exception as e: | |
| st.error(f"Error making prediction: {str(e)}") | |
| with st.expander("Error Details"): | |
| st.exception(e) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align: center; color: #666; padding: 2rem;'> | |
| <p>⚠️ <strong>Disclaimer:</strong> This tool is for educational purposes only and should not be used as a substitute for professional medical advice, diagnosis, or treatment.</p> | |
| <p>Always seek the advice of qualified health providers with any questions you may have regarding a medical condition.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |