wjnwjn59 commited on
Commit
47faf49
·
1 Parent(s): 655cae9

first init

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ __MACOSX/
3
+
4
+ .DS_Store
5
+ *.csv
README.md CHANGED
@@ -1,12 +1,72 @@
1
  ---
2
- title: AIO2025M03 HEART DISEASE PREDICTION
3
- emoji: 🏃
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AIO2025M03 DEMO Decision Tree
3
+ emoji: 🌳
4
+ colorFrom: green
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # 🌳 Decision Tree Interactive Demo
13
+
14
+ An interactive web application demonstrating Decision Tree algorithms with real-time visualization and educational features.
15
+
16
+ ## ✨ Features
17
+
18
+ - **📊 Multiple Datasets**: 4 built-in datasets (Iris, Wine, Breast Cancer, Diabetes)
19
+ - **🎮 Interactive Interface**: Real-time parameter adjustment and prediction
20
+ - **🌳 Tree Visualization**: Interactive decision tree structure with zoom capabilities
21
+ - **📊 Feature Importance**: Visual representation of feature importance scores
22
+ - **🎛️ Flexible Parameters**: Adjustable max depth, split criteria, and leaf constraints
23
+ - **📱 Responsive Design**: Works on desktop and mobile devices
24
+
25
+ ## 🚀 Quick Start
26
+
27
+ ### Local Installation
28
+ ```bash
29
+ git clone <repository-url>
30
+ cd AIO2025M03_DEMO_DECISION_TREE
31
+ pip install -r requirements.txt
32
+ python app.py
33
+ ```
34
+
35
+ ### Usage
36
+ 1. **Select Dataset**: Choose from pre-loaded datasets or upload your own CSV/Excel file
37
+ 2. **Configure Target**: Select target column and problem type (classification/regression)
38
+ 3. **Set Parameters**: Adjust max depth, split criteria, and leaf constraints
39
+ 4. **Input New Point**: Enter feature values for prediction
40
+ 5. **Run Prediction**: Get results with interactive tree visualization
41
+
42
+ ## 🧠 Technical Highlights
43
+
44
+ - **Tree Structure**: Interactive visualization of decision tree nodes and splits
45
+ - **Feature Importance**: Automatic calculation and visualization of feature importance scores
46
+ - **Auto-Detection**: Automatically determines classification vs regression problems
47
+ - **Error Handling**: Robust validation and user-friendly error messages
48
+
49
+ ## 📋 Requirements
50
+
51
+ - Python 3.8+
52
+ - Gradio 5.38+
53
+ - Scikit-learn
54
+ - Pandas
55
+ - NumPy
56
+ - Plotly
57
+
58
+ ## 🎓 Educational Value
59
+
60
+ Perfect for:
61
+ - Understanding Decision Tree algorithm mechanics
62
+ - Learning about tree-based splitting criteria
63
+ - Exploring feature importance and tree pruning
64
+ - Comparing classification vs regression approaches
65
+
66
+ ## 📄 License
67
+
68
+ Educational use for AIO2025 course materials.
69
+
70
+ ---
71
+
72
+ **Live Demo**: [Decision Tree Demo](https://huggingface.co/spaces/VLAI-AIVN/AIO2025M03_DEMO_DECISION_TREE)
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ import pandas as pd
5
+
6
+ from src.heart_disease_core import (
7
+ CLEVELAND_FEATURES_ORDER, TARGET_COL, CATEGORICAL_CHOICES,
8
+ load_cleveland_dataframe, fit_all_models, predict_all, example_patient
9
+ )
10
+
11
+ APP_PRIMARY = "#0F6CBD" # medical calm blue
12
+ APP_ACCENT = "#C4314B" # medical alert red
13
+ APP_BG = "#F7FAFC"
14
+
15
+ STATE = {
16
+ "df": None,
17
+ "models": None,
18
+ "metrics": None,
19
+ }
20
+
21
+ def _ensure_models(df: pd.DataFrame):
22
+ if STATE["models"] is None:
23
+ models, metrics = fit_all_models(df)
24
+ STATE["models"] = models
25
+ STATE["metrics"] = metrics
26
+
27
+ def load_dataset(file):
28
+ try:
29
+ if file is None:
30
+ return gr.Markdown.update(value="❌ Please upload a Cleveland-format dataset (CSV/XLSX)."), gr.DataFrame.update(value=pd.DataFrame()), gr.Markdown.update(visible=False)
31
+ if file.name.endswith(".csv"):
32
+ df = pd.read_csv(file.name)
33
+ else:
34
+ df = pd.read_excel(file.name)
35
+ df = load_cleveland_dataframe(uploaded_df=df)
36
+ STATE["df"] = df
37
+ STATE["models"] = None # reset, will refit lazily
38
+ STATE["metrics"] = None
39
+ head = df.head(8)
40
+ return gr.Markdown.update(value="✅ Dataset loaded successfully."), gr.DataFrame.update(value=head, interactive=False), gr.Markdown.update(visible=False)
41
+ except Exception as e:
42
+ return gr.Markdown.update(value=f"❌ Error: {e}"), gr.DataFrame.update(value=pd.DataFrame()), gr.Markdown.update(visible=False)
43
+
44
+ def fill_example(idx):
45
+ ex = example_patient(idx)
46
+ return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
47
+
48
+ def _bar_for_models(results: dict):
49
+ names = list(results.keys())
50
+ probs = [results[n]["prob_1"] for n in names]
51
+ labels = ["Disease" if results[n]["label"] == 1 else "No disease" for n in names]
52
+
53
+ fig = go.Figure()
54
+ fig.add_bar(x=names, y=probs, text=[f"{p:.2f}" for p in probs], textposition="auto")
55
+ fig.update_layout(
56
+ title="Model Confidence (P[Heart Disease = 1])",
57
+ yaxis_title="Probability",
58
+ xaxis_title="Model",
59
+ yaxis=dict(range=[0, 1]),
60
+ plot_bgcolor="white",
61
+ height=420,
62
+ margin=dict(l=30, r=20, t=60, b=40)
63
+ )
64
+ # color emphasis for ensemble bar (last)
65
+ if len(names) >= 1:
66
+ fig.data[0].marker.color = ["#9BB8D3"] * (len(names) - 1) + [APP_ACCENT]
67
+ return fig, labels
68
+
69
+ def run_predict(*vals):
70
+ # Ensure dataset
71
+ if STATE["df"] is None:
72
+ return (
73
+ gr.Markdown.update(value="❌ No dataset yet. Please upload a Cleveland-format dataset."),
74
+ gr.Plot.update(None),
75
+ gr.Markdown.update(visible=False),
76
+ gr.DataFrame.update(visible=False)
77
+ )
78
+
79
+ # Build input row as dict with strict order
80
+ input_dict = {col: vals[i] for i, col in enumerate(CLEVELAND_FEATURES_ORDER)}
81
+
82
+ # Fit models lazily
83
+ _ensure_models(STATE["df"])
84
+
85
+ # Predict
86
+ results = predict_all(STATE["models"], input_dict)
87
+
88
+ # Compose readable summary and plot
89
+ pred_table = []
90
+ final_label = results["Ensemble (Soft Voting)"]["label"]
91
+ final_prob = results["Ensemble (Soft Voting)"]["prob_1"]
92
+ title_md = (
93
+ f"### 🫀 Cleveland Heart Disease Diagnosis\n"
94
+ f"**Ensemble Prediction**: **{'Positive' if final_label == 1 else 'Negative'}** \n"
95
+ f"**Confidence (P=1)**: `{final_prob:.3f}`"
96
+ )
97
+
98
+ for name, r in results.items():
99
+ pred_table.append({
100
+ "Model": name,
101
+ "Predicted label": "Positive" if r["label"] == 1 else "Negative",
102
+ "P(No disease)": round(r["prob_0"], 3),
103
+ "P(Heart disease)": round(r["prob_1"], 3),
104
+ })
105
+ table_df = pd.DataFrame(pred_table)
106
+
107
+ fig, labels = _bar_for_models(results)
108
+
109
+ return (
110
+ gr.Markdown.update(value=title_md),
111
+ gr.Plot.update(value=fig),
112
+ gr.Markdown.update(value="**Per-Model Predictions**", visible=True),
113
+ gr.DataFrame.update(value=table_df, visible=True, interactive=False)
114
+ )
115
+
116
+ # -----------------------------
117
+ # UI
118
+ # -----------------------------
119
+ with gr.Blocks(theme="soft", css=f"""
120
+ :root {{
121
+ --primary-600: {APP_PRIMARY};
122
+ }}
123
+ .gradio-container {{ background: {APP_BG}; }}
124
+ .footer-note a {{ color: {APP_PRIMARY}; }}
125
+ h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
126
+ """) as demo:
127
+ gr.Markdown("# 🫀 Cleveland Heart Disease Diagnosis (Ensemble Demo)")
128
+
129
+ with gr.Row(equal_height=False):
130
+ # LEFT: inputs
131
+ with gr.Column(scale=45):
132
+ with gr.Box():
133
+ gr.Markdown("### 📁 Load Dataset")
134
+ info_md = gr.Markdown("Upload a CSV/XLSX in **Cleveland** format (13 features + `target`).")
135
+ file_u = gr.File(file_count="single", file_types=[".csv", ".xlsx", ".xls"], label="Upload Cleveland Dataset")
136
+ preview = gr.DataFrame(label="Data Preview (first rows)", interactive=False)
137
+ metrics_box = gr.Markdown(visible=False)
138
+
139
+ with gr.Box():
140
+ gr.Markdown("### ✍️ Enter Patient Features")
141
+ with gr.Row():
142
+ age = gr.Number(label="age (years)", value=58)
143
+ sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0,1], value=1)
144
+ cp = gr.Dropdown(label="cp (chest pain type 0..3)", choices=[0,1,2,3], value=2)
145
+ trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130)
146
+
147
+ with gr.Row():
148
+ chol = gr.Number(label="chol (serum cholestrol mg/dl)", value=250)
149
+ fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0,1], value=0)
150
+ restecg = gr.Dropdown(label="restecg (0..2)", choices=[0,1,2], value=1)
151
+ thalach = gr.Number(label="thalach (max heart rate)", value=150)
152
+
153
+ with gr.Row():
154
+ exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0,1], value=0)
155
+ oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0)
156
+ slope = gr.Dropdown(label="slope (0..2)", choices=[0,1,2], value=1)
157
+ ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0,1,2,3], value=0)
158
+
159
+ thal = gr.Dropdown(label="thal (1=normal,2=fixed,3=reversible)", choices=[1,2,3], value=2)
160
+
161
+ with gr.Row():
162
+ ex_selector = gr.Dropdown(
163
+ label="Fill Example",
164
+ choices=["Example 1 (likely negative)", "Example 2 (borderline)", "Example 3 (likely positive)"],
165
+ value="Example 2 (borderline)"
166
+ )
167
+ fill_btn = gr.Button("🧪 Use Example", variant="secondary")
168
+ predict_btn = gr.Button("🔍 Predict", variant="primary")
169
+
170
+ # RIGHT: outputs
171
+ with gr.Column(scale=55):
172
+ with gr.Box():
173
+ title_out = gr.Markdown("### Ensemble Prediction will appear here.")
174
+ bar_out = gr.Plot(label="Model Confidence")
175
+ sub_md = gr.Markdown(visible=False)
176
+ table_out = gr.DataFrame(visible=False)
177
+
178
+ with gr.Accordion("ℹ️ Notes", open=False):
179
+ gr.Markdown(
180
+ "- This demo **fits models** on your uploaded dataset (80/20 split) the first time you predict.\n"
181
+ "- **Target** is automatically binarized (0 = no disease, >0 = disease).\n"
182
+ "- Ensemble is **soft voting** over Decision Tree, k-NN, and Naive Bayes.\n"
183
+ "- This is **for demo/education**; not medical advice."
184
+ )
185
+
186
+ # Events
187
+ file_u.upload(fn=load_dataset, inputs=[file_u], outputs=[info_md, preview, metrics_box])
188
+
189
+ def _example_index(choice: str):
190
+ return {"Example 1 (likely negative)": 0, "Example 2 (borderline)": 1, "Example 3 (likely positive)": 2}[choice]
191
+
192
+ fill_btn.click(
193
+ fn=lambda choice: tuple(fill_example(_example_index(choice))),
194
+ inputs=[ex_selector],
195
+ outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
196
+ )
197
+
198
+ predict_btn.click(
199
+ fn=run_predict,
200
+ inputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],
201
+ outputs=[title_out, bar_out, sub_md, table_out]
202
+ )
203
+
204
+ if __name__ == "__main__":
205
+ # Optional: allow GraphViz logos etc. from static if you keep them
206
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ graphviz
2
+ fonts-liberation
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==5.38.0
2
+ pandas>=1.5.0
3
+ scikit-learn>=1.3.0
4
+ numpy>=1.24.0
5
+ dtreeviz>=2.2.2
6
+ graphviz>=0.20.3
7
+ plotly>=5.15.0
8
+ supertree>=0.5.5
src/__init__.py ADDED
File without changes
src/heart_disease_core.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/heart_disease_core.py
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Dict, Tuple, Optional, List
6
+
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.preprocessing import OneHotEncoder
9
+ from sklearn.compose import ColumnTransformer
10
+ from sklearn.pipeline import Pipeline
11
+ from sklearn.impute import SimpleImputer
12
+ from sklearn.metrics import roc_auc_score
13
+ from sklearn.tree import DecisionTreeClassifier
14
+ from sklearn.neighbors import KNeighborsClassifier
15
+ from sklearn.naive_bayes import GaussianNB
16
+ from sklearn.ensemble import VotingClassifier
17
+
18
+
19
+ CLEVELAND_FEATURES_ORDER: List[str] = [
20
+ "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
21
+ "thalach", "exang", "oldpeak", "slope", "ca", "thal"
22
+ ]
23
+ TARGET_COL = "target" # 0: no disease, 1: disease (we binarize if needed)
24
+
25
+ CATEGORICAL_CHOICES = {
26
+ "sex": [0, 1], # 0: female, 1: male
27
+ "cp": [0, 1, 2, 3], # chest pain type
28
+ "fbs": [0, 1], # fasting blood sugar > 120 mg/dl (1 true, 0 false)
29
+ "restecg": [0, 1, 2], # resting ECG results
30
+ "exang": [0, 1], # exercise-induced angina
31
+ "slope": [0, 1, 2], # slope of ST
32
+ "ca": [0, 1, 2, 3], # number of major vessels (0-3) colored by fluoroscopy
33
+ "thal": [1, 2, 3], # 1: normal, 2: fixed defect, 3: reversible defect (commonly 3/6/7 variants exist; we standardize)
34
+ }
35
+
36
+ NUMERIC_COLS = ["age", "trestbps", "chol", "thalach", "oldpeak"]
37
+ CATEGORICAL_COLS = ["sex", "cp", "fbs", "restecg", "exang", "slope", "ca", "thal"]
38
+
39
+ def _coerce_and_clean(df: pd.DataFrame) -> pd.DataFrame:
40
+ """Clean '?' and cast numeric; keep only known columns if present."""
41
+ df = df.copy()
42
+ # Standardize columns if they are present with any case
43
+ colmap = {c.lower(): c for c in df.columns}
44
+ for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
45
+ if col not in df.columns and col in colmap:
46
+ df[col] = df.pop(colmap[col]) # normalize name
47
+
48
+ # Replace '?' with NaN and cast
49
+ for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
50
+ if col in df.columns:
51
+ df[col] = pd.to_numeric(df[col].replace("?", np.nan), errors="coerce")
52
+
53
+ # Binarize target if it appears as 0..4 (UCI often uses 0 vs 1..4 disease)
54
+ if TARGET_COL in df.columns:
55
+ df[TARGET_COL] = (df[TARGET_COL] > 0).astype(int)
56
+
57
+ return df
58
+
59
+ def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
60
+ """
61
+ Load the Cleveland Heart Disease dataset.
62
+ Priority: uploaded_df > file_path > raise.
63
+ Expect columns CLEVELAND_FEATURES_ORDER + TARGET_COL.
64
+ """
65
+ if uploaded_df is not None:
66
+ df = _coerce_and_clean(uploaded_df)
67
+ missing = [c for c in CLEVELAND_FEATURES_ORDER + [TARGET_COL] if c not in df.columns]
68
+ if missing:
69
+ raise ValueError(f"Uploaded data missing required columns: {missing}")
70
+ return df
71
+
72
+ if file_path is not None and os.path.exists(file_path):
73
+ if file_path.endswith(".csv"):
74
+ df = pd.read_csv(file_path)
75
+ else:
76
+ df = pd.read_excel(file_path)
77
+ df = _coerce_and_clean(df)
78
+ missing = [c for c in CLEVELAND_FEATURES_ORDER + [TARGET_COL] if c not in df.columns]
79
+ if missing:
80
+ raise ValueError(f"File missing required columns: {missing}")
81
+ return df
82
+
83
+ raise FileNotFoundError(
84
+ "No dataset found. Please upload a CSV/XLSX with columns: "
85
+ f"{CLEVELAND_FEATURES_ORDER + [TARGET_COL]}"
86
+ )
87
+
88
+ # -----------------------------
89
+ # Preprocess & Modeling
90
+ # -----------------------------
91
+ def build_preprocessor() -> ColumnTransformer:
92
+ """
93
+ - Numeric: impute median
94
+ - Categorical: impute most_frequent + one-hot
95
+ """
96
+ numeric_pipe = Pipeline(steps=[
97
+ ("imputer", SimpleImputer(strategy="median")),
98
+ ])
99
+
100
+ categorical_pipe = Pipeline(steps=[
101
+ ("imputer", SimpleImputer(strategy="most_frequent")),
102
+ ("ohe", OneHotEncoder(handle_unknown="ignore"))
103
+ ])
104
+
105
+ preprocessor = ColumnTransformer(
106
+ transformers=[
107
+ ("num", numeric_pipe, NUMERIC_COLS),
108
+ ("cat", categorical_pipe, CATEGORICAL_COLS)
109
+ ],
110
+ remainder="drop"
111
+ )
112
+ return preprocessor
113
+
114
+ def build_models() -> Dict[str, Pipeline]:
115
+ """
116
+ Create sklearn Pipelines for each model with the common preprocessor.
117
+ """
118
+ pre = build_preprocessor()
119
+
120
+ dt = Pipeline(steps=[
121
+ ("prep", pre),
122
+ ("clf", DecisionTreeClassifier(
123
+ random_state=42,
124
+ max_depth=5,
125
+ min_samples_split=2,
126
+ min_samples_leaf=1,
127
+ criterion="gini"
128
+ ))
129
+ ])
130
+
131
+ knn = Pipeline(steps=[
132
+ ("prep", pre),
133
+ ("clf", KNeighborsClassifier(n_neighbors=5))
134
+ ])
135
+
136
+ nb = Pipeline(steps=[
137
+ ("prep", pre),
138
+ ("clf", GaussianNB())
139
+ ])
140
+
141
+ # Soft Voting requires raw estimators, not Pipelines that share the same preprocessor.
142
+ # Easiest: ensemble as a single Pipeline with a VotingClassifier inside.
143
+ ensemble = Pipeline(steps=[
144
+ ("prep", pre),
145
+ ("clf", VotingClassifier(
146
+ estimators=[
147
+ ("dt", DecisionTreeClassifier(random_state=42, max_depth=5, min_samples_split=2, min_samples_leaf=1, criterion="gini")),
148
+ ("knn", KNeighborsClassifier(n_neighbors=5)),
149
+ ("nb", GaussianNB()),
150
+ ],
151
+ voting="soft",
152
+ weights=None # can tweak later
153
+ ))
154
+ ])
155
+
156
+ return {"Decision Tree": dt, "k-NN": knn, "Naive Bayes": nb, "Ensemble (Soft Voting)": ensemble}
157
+
158
+ def fit_all_models(df: pd.DataFrame, test_size: float = 0.2, random_state: int = 42) -> Tuple[Dict[str, Pipeline], pd.DataFrame]:
159
+ """
160
+ Fit all models on train split; return fitted models and metrics (AUC on holdout).
161
+ """
162
+ X = df[CLEVELAND_FEATURES_ORDER]
163
+ y = df[TARGET_COL].astype(int)
164
+
165
+ X_tr, X_te, y_tr, y_te = train_test_split(
166
+ X, y, test_size=test_size, random_state=random_state, stratify=y
167
+ )
168
+
169
+ models = build_models()
170
+ metrics = []
171
+
172
+ for name, pipe in models.items():
173
+ pipe.fit(X_tr, y_tr)
174
+ if hasattr(pipe, "predict_proba"):
175
+ proba = pipe.predict_proba(X_te)[:, 1]
176
+ auc = roc_auc_score(y_te, proba)
177
+ else:
178
+ # Fallback if any (unlikely here)
179
+ pred = pipe.predict(X_te)
180
+ auc = roc_auc_score(y_te, pred)
181
+ metrics.append({"model": name, "ROC-AUC": round(float(auc), 4)})
182
+
183
+ metrics_df = pd.DataFrame(metrics).sort_values("ROC-AUC", ascending=False, ignore_index=True)
184
+ return models, metrics_df
185
+
186
+ def predict_all(models: Dict[str, Pipeline], input_dict: Dict[str, float]) -> Dict[str, Dict[str, float]]:
187
+ """
188
+ Predict probability for positive class (heart disease) for each model.
189
+ Returns: {model_name: {"prob_1": float, "prob_0": float, "label": int}}
190
+ """
191
+ # Ensure full set & order
192
+ row = [[input_dict[c] for c in CLEVELAND_FEATURES_ORDER]]
193
+ X_new = pd.DataFrame(row, columns=CLEVELAND_FEATURES_ORDER)
194
+
195
+ out = {}
196
+ for name, pipe in models.items():
197
+ if hasattr(pipe, "predict_proba"):
198
+ proba = pipe.predict_proba(X_new)[0]
199
+ # convention: class order is [0,1]
200
+ out[name] = {
201
+ "prob_0": float(proba[0]),
202
+ "prob_1": float(proba[1]),
203
+ "label": int(np.argmax(proba))
204
+ }
205
+ else:
206
+ label = int(pipe.predict(X_new)[0])
207
+ out[name] = {"prob_0": 1.0 - label, "prob_1": float(label), "label": label}
208
+ return out
209
+
210
+
211
+ def example_patient(index: int = 0) -> Dict[str, float]:
212
+ """
213
+ A few realistic examples pulled from common Cleveland-like ranges.
214
+ You can add more patterns for quick testing.
215
+ """
216
+ examples = [
217
+ # Likely negative (no disease)
218
+ dict(age=45, sex=0, cp=0, trestbps=120, chol=230, fbs=0, restecg=1,
219
+ thalach=168, exang=0, oldpeak=0.0, slope=2, ca=0, thal=2),
220
+ # Borderline
221
+ dict(age=58, sex=1, cp=2, trestbps=138, chol=250, fbs=0, restecg=0,
222
+ thalach=150, exang=0, oldpeak=1.0, slope=1, ca=1, thal=2),
223
+ # Likely positive (disease)
224
+ dict(age=63, sex=1, cp=3, trestbps=145, chol=320, fbs=1, restecg=2,
225
+ thalach=130, exang=1, oldpeak=2.8, slope=0, ca=2, thal=3),
226
+ ]
227
+ index = max(0, min(index, len(examples) - 1))
228
+ return examples[index]
static/aivn_logo.png ADDED
static/vlai_logo.png ADDED
vlai_template.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, base64
2
+ import gradio as gr
3
+
4
+
5
+ PROJECT_NAME = "Decision Tree Demo"
6
+ AIO_YEAR = "2025"
7
+ AIO_MODULE = "03"
8
+ # END
9
+
10
+
11
+ def image_to_base64(image_path: str):
12
+ # Construct the absolute path to the image
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ full_image_path = os.path.join(current_dir, image_path)
15
+ with open(full_image_path, "rb") as f:
16
+ return base64.b64encode(f.read()).decode("utf-8")
17
+
18
+ def create_header():
19
+ with gr.Row():
20
+ with gr.Column(scale=2):
21
+ logo_base64 = image_to_base64("static/aivn_logo.png")
22
+ gr.HTML(
23
+ f"""<img src="data:image/png;base64,{logo_base64}"
24
+ alt="Logo"
25
+ style="height:120px;width:auto;margin:0 auto;margin-bottom:16px; display:block;">"""
26
+ )
27
+ with gr.Column(scale=2):
28
+ gr.HTML(f"""
29
+ <div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
30
+ <div>
31
+ <h1 style="margin-bottom:0; color: #2E7D32; font-size: 2.5em; font-weight: bold;"> {PROJECT_NAME} </h1>
32
+ <h3 style="color: #888; font-style: italic"> AIO{AIO_YEAR}: Module {AIO_MODULE}. </h3>
33
+ </div>
34
+ </div>
35
+ """)
36
+
37
+ def create_footer():
38
+ logo_base64_vlai = image_to_base64("static/vlai_logo.png")
39
+ footer_html = """
40
+ <style>
41
+ .sticky-footer{position:fixed;bottom:0px;left:0;width:100%;background:#E8F5E8;
42
+ padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;}
43
+ .content-wrap{padding-bottom:60px;}
44
+ </style>""" + f"""
45
+ <div class="sticky-footer">
46
+ <div style="text-align:center;font-size:18px; color: #888">
47
+ Created by
48
+ <a href="https://vlai.work" target="_blank" style="color:#465C88;text-decoration:none;font-weight:bold; display:inline-flex; align-items:center;"> VLAI
49
+ <img src="data:image/png;base64,{logo_base64_vlai}" alt="Logo" style="height:20px; width:auto;">
50
+ </a> from <a href="https://aivietnam.edu.vn/" target="_blank" style="color:#355724;text-decoration:none;font-weight:bold">AI VIET NAM</a>
51
+ </div>
52
+ </div>
53
+ """
54
+ return gr.HTML(footer_html)
55
+
56
+ custom_css = """
57
+
58
+ .gradio-container {
59
+ min-height: 100vh !important;
60
+ width: 100vw !important;
61
+ margin: 0 !important;
62
+ padding: 0px !important;
63
+ background: linear-gradient(135deg, #E8F5E8 0%, #D4E6D4 50%, #A8D8A8 100%);
64
+ background-size: 600% 600%;
65
+ animation: gradientBG 7s ease infinite;
66
+ }
67
+
68
+ @keyframes gradientBG {
69
+ 0% {background-position: 0% 50%;}
70
+ 50% {background-position: 100% 50%;}
71
+ 100% {background-position: 0% 50%;}
72
+ }
73
+
74
+ /* Minimize spacing and padding */
75
+ .content-wrap {
76
+ padding: 2px !important;
77
+ margin: 0 !important;
78
+ }
79
+
80
+ /* Reduce component spacing */
81
+ .gr-row {
82
+ gap: 5px !important;
83
+ margin: 2px 0 !important;
84
+ }
85
+
86
+ .gr-column {
87
+ gap: 4px !important;
88
+ padding: 4px !important;
89
+ }
90
+
91
+ /* Accordion optimization */
92
+ .gr-accordion {
93
+ margin: 4px 0 !important;
94
+ }
95
+
96
+ .gr-accordion .gr-accordion-content {
97
+ padding: 2px !important;
98
+ }
99
+
100
+ /* Form elements spacing */
101
+ .gr-form {
102
+ gap: 2px !important;
103
+ }
104
+
105
+ /* Button styling */
106
+ .gr-button {
107
+ margin: 2px 0 !important;
108
+ }
109
+
110
+ /* DataFrame optimization */
111
+ .gr-dataframe {
112
+ margin: 4px 0 !important;
113
+ }
114
+
115
+ /* Remove horizontal scroll from data preview */
116
+ .gr-dataframe .wrap {
117
+ overflow-x: auto !important;
118
+ max-width: 100% !important;
119
+ }
120
+
121
+ /* Plot optimization */
122
+ .gr-plot {
123
+ margin: 4px 0 !important;
124
+ }
125
+
126
+ /* Reduce markdown margins */
127
+ .gr-markdown {
128
+ margin: 2px 0 !important;
129
+ }
130
+
131
+ /* Footer positioning */
132
+ .sticky-footer {
133
+ position: fixed;
134
+ bottom: 0px;
135
+ left: 0;
136
+ width: 100%;
137
+ background: #E8F5E8;
138
+ padding: 6px !important;
139
+ box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
140
+ z-index: 1000;
141
+ }
142
+ """