wjnwjn59 commited on
Commit
b9ff16f
Β·
1 Parent(s): 47faf49

fix box not found

Browse files
Files changed (1) hide show
  1. app.py +114 -103
app.py CHANGED
@@ -4,7 +4,7 @@ import plotly.graph_objects as go
4
  import pandas as pd
5
 
6
  from src.heart_disease_core import (
7
- CLEVELAND_FEATURES_ORDER, TARGET_COL, CATEGORICAL_CHOICES,
8
  load_cleveland_dataframe, fit_all_models, predict_all, example_patient
9
  )
10
 
@@ -18,37 +18,57 @@ STATE = {
18
  "metrics": None,
19
  }
20
 
21
- def _ensure_models(df: pd.DataFrame):
22
- if STATE["models"] is None:
23
- models, metrics = fit_all_models(df)
24
- STATE["models"] = models
25
- STATE["metrics"] = metrics
26
-
27
- def load_dataset(file):
28
- try:
29
- if file is None:
30
- return gr.Markdown.update(value="❌ Please upload a Cleveland-format dataset (CSV/XLSX)."), gr.DataFrame.update(value=pd.DataFrame()), gr.Markdown.update(visible=False)
31
- if file.name.endswith(".csv"):
32
- df = pd.read_csv(file.name)
33
- else:
34
- df = pd.read_excel(file.name)
35
- df = load_cleveland_dataframe(uploaded_df=df)
36
- STATE["df"] = df
37
- STATE["models"] = None # reset, will refit lazily
38
- STATE["metrics"] = None
39
- head = df.head(8)
40
- return gr.Markdown.update(value="βœ… Dataset loaded successfully."), gr.DataFrame.update(value=head, interactive=False), gr.Markdown.update(visible=False)
41
- except Exception as e:
42
- return gr.Markdown.update(value=f"❌ Error: {e}"), gr.DataFrame.update(value=pd.DataFrame()), gr.Markdown.update(visible=False)
43
-
44
- def fill_example(idx):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ex = example_patient(idx)
46
  return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
47
 
 
48
  def _bar_for_models(results: dict):
49
  names = list(results.keys())
50
  probs = [results[n]["prob_1"] for n in names]
51
- labels = ["Disease" if results[n]["label"] == 1 else "No disease" for n in names]
52
 
53
  fig = go.Figure()
54
  fig.add_bar(x=names, y=probs, text=[f"{p:.2f}" for p in probs], textposition="auto")
@@ -61,50 +81,48 @@ def _bar_for_models(results: dict):
61
  height=420,
62
  margin=dict(l=30, r=20, t=60, b=40)
63
  )
64
- # color emphasis for ensemble bar (last)
65
  if len(names) >= 1:
66
- fig.data[0].marker.color = ["#9BB8D3"] * (len(names) - 1) + [APP_ACCENT]
67
- return fig, labels
 
 
 
 
 
 
 
68
 
69
  def run_predict(*vals):
70
- # Ensure dataset
71
- if STATE["df"] is None:
72
  return (
73
- gr.Markdown.update(value="❌ No dataset yet. Please upload a Cleveland-format dataset."),
74
  gr.Plot.update(None),
75
  gr.Markdown.update(visible=False),
76
  gr.DataFrame.update(visible=False)
77
  )
78
 
79
- # Build input row as dict with strict order
80
  input_dict = {col: vals[i] for i, col in enumerate(CLEVELAND_FEATURES_ORDER)}
81
-
82
- # Fit models lazily
83
- _ensure_models(STATE["df"])
84
-
85
- # Predict
86
  results = predict_all(STATE["models"], input_dict)
87
 
88
- # Compose readable summary and plot
89
- pred_table = []
90
- final_label = results["Ensemble (Soft Voting)"]["label"]
91
- final_prob = results["Ensemble (Soft Voting)"]["prob_1"]
92
  title_md = (
93
  f"### πŸ«€ Cleveland Heart Disease Diagnosis\n"
94
- f"**Ensemble Prediction**: **{'Positive' if final_label == 1 else 'Negative'}** \n"
95
- f"**Confidence (P=1)**: `{final_prob:.3f}`"
96
  )
97
 
 
98
  for name, r in results.items():
99
- pred_table.append({
100
  "Model": name,
101
  "Predicted label": "Positive" if r["label"] == 1 else "Negative",
102
  "P(No disease)": round(r["prob_0"], 3),
103
  "P(Heart disease)": round(r["prob_1"], 3),
104
  })
105
- table_df = pd.DataFrame(pred_table)
106
 
107
- fig, labels = _bar_for_models(results)
108
 
109
  return (
110
  gr.Markdown.update(value=title_md),
@@ -113,84 +131,78 @@ def run_predict(*vals):
113
  gr.DataFrame.update(value=table_df, visible=True, interactive=False)
114
  )
115
 
 
116
  # -----------------------------
117
- # UI
118
  # -----------------------------
119
  with gr.Blocks(theme="soft", css=f"""
120
  :root {{
121
  --primary-600: {APP_PRIMARY};
122
  }}
123
  .gradio-container {{ background: {APP_BG}; }}
124
- .footer-note a {{ color: {APP_PRIMARY}; }}
125
  h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
126
  """) as demo:
127
  gr.Markdown("# πŸ«€ Cleveland Heart Disease Diagnosis (Ensemble Demo)")
128
 
129
  with gr.Row(equal_height=False):
130
- # LEFT: inputs
131
  with gr.Column(scale=45):
132
- with gr.Box():
133
- gr.Markdown("### πŸ“ Load Dataset")
134
- info_md = gr.Markdown("Upload a CSV/XLSX in **Cleveland** format (13 features + `target`).")
135
- file_u = gr.File(file_count="single", file_types=[".csv", ".xlsx", ".xls"], label="Upload Cleveland Dataset")
136
- preview = gr.DataFrame(label="Data Preview (first rows)", interactive=False)
137
- metrics_box = gr.Markdown(visible=False)
138
-
139
- with gr.Box():
140
- gr.Markdown("### ✍️ Enter Patient Features")
141
- with gr.Row():
142
- age = gr.Number(label="age (years)", value=58)
143
- sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0,1], value=1)
144
- cp = gr.Dropdown(label="cp (chest pain type 0..3)", choices=[0,1,2,3], value=2)
145
- trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130)
146
-
147
- with gr.Row():
148
- chol = gr.Number(label="chol (serum cholestrol mg/dl)", value=250)
149
- fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0,1], value=0)
150
- restecg = gr.Dropdown(label="restecg (0..2)", choices=[0,1,2], value=1)
151
- thalach = gr.Number(label="thalach (max heart rate)", value=150)
152
-
153
- with gr.Row():
154
- exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0,1], value=0)
155
- oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0)
156
- slope = gr.Dropdown(label="slope (0..2)", choices=[0,1,2], value=1)
157
- ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0,1,2,3], value=0)
158
-
159
- thal = gr.Dropdown(label="thal (1=normal,2=fixed,3=reversible)", choices=[1,2,3], value=2)
160
-
161
- with gr.Row():
162
- ex_selector = gr.Dropdown(
163
- label="Fill Example",
164
- choices=["Example 1 (likely negative)", "Example 2 (borderline)", "Example 3 (likely positive)"],
165
- value="Example 2 (borderline)"
166
- )
167
- fill_btn = gr.Button("πŸ§ͺ Use Example", variant="secondary")
168
- predict_btn = gr.Button("πŸ” Predict", variant="primary")
169
 
170
  # RIGHT: outputs
171
  with gr.Column(scale=55):
172
- with gr.Box():
173
- title_out = gr.Markdown("### Ensemble Prediction will appear here.")
174
- bar_out = gr.Plot(label="Model Confidence")
175
- sub_md = gr.Markdown(visible=False)
176
- table_out = gr.DataFrame(visible=False)
177
 
178
  with gr.Accordion("ℹ️ Notes", open=False):
179
  gr.Markdown(
180
- "- This demo **fits models** on your uploaded dataset (80/20 split) the first time you predict.\n"
181
- "- **Target** is automatically binarized (0 = no disease, >0 = disease).\n"
182
- "- Ensemble is **soft voting** over Decision Tree, k-NN, and Naive Bayes.\n"
183
- "- This is **for demo/education**; not medical advice."
184
  )
185
 
186
- # Events
187
- file_u.upload(fn=load_dataset, inputs=[file_u], outputs=[info_md, preview, metrics_box])
188
-
189
- def _example_index(choice: str):
190
- return {"Example 1 (likely negative)": 0, "Example 2 (borderline)": 1, "Example 3 (likely positive)": 2}[choice]
191
 
192
  fill_btn.click(
193
- fn=lambda choice: tuple(fill_example(_example_index(choice))),
194
  inputs=[ex_selector],
195
  outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
196
  )
@@ -202,5 +214,4 @@ h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
202
  )
203
 
204
  if __name__ == "__main__":
205
- # Optional: allow GraphViz logos etc. from static if you keep them
206
  demo.launch()
 
4
  import pandas as pd
5
 
6
  from src.heart_disease_core import (
7
+ CLEVELAND_FEATURES_ORDER, TARGET_COL,
8
  load_cleveland_dataframe, fit_all_models, predict_all, example_patient
9
  )
10
 
 
18
  "metrics": None,
19
  }
20
 
21
+ DATA_PATH = "data/cleveland.csv"
22
+
23
+
24
+ # -----------------------------
25
+ # Startup / init
26
+ # -----------------------------
27
+ def init_page():
28
+ """
29
+ Load dataset from disk, fit models once, and return preview + metrics.
30
+ """
31
+ if not os.path.exists(DATA_PATH):
32
+ msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
33
+ return (
34
+ gr.Markdown.update(value=msg),
35
+ gr.DataFrame.update(value=pd.DataFrame()),
36
+ gr.DataFrame.update(value=pd.DataFrame())
37
+ )
38
+
39
+ df = pd.read_csv(DATA_PATH)
40
+ df = load_cleveland_dataframe(uploaded_df=df) # cleans, binarizes target
41
+
42
+ models, metrics = fit_all_models(df)
43
+ STATE["df"] = df
44
+ STATE["models"] = models
45
+ STATE["metrics"] = metrics
46
+
47
+ head = df.head(8)
48
+ msg = "βœ… **Cleveland dataset loaded** from `data/cleveland.csv` and models trained (80/20 split)."
49
+ return (
50
+ gr.Markdown.update(value=msg),
51
+ gr.DataFrame.update(value=head, interactive=False),
52
+ gr.DataFrame.update(value=metrics, interactive=False)
53
+ )
54
+
55
+
56
+ # -----------------------------
57
+ # Helpers
58
+ # -----------------------------
59
+ def fill_example(idx_text: str):
60
+ idx = {
61
+ "Example 1 (likely negative)": 0,
62
+ "Example 2 (borderline)": 1,
63
+ "Example 3 (likely positive)": 2
64
+ }[idx_text]
65
  ex = example_patient(idx)
66
  return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
67
 
68
+
69
  def _bar_for_models(results: dict):
70
  names = list(results.keys())
71
  probs = [results[n]["prob_1"] for n in names]
 
72
 
73
  fig = go.Figure()
74
  fig.add_bar(x=names, y=probs, text=[f"{p:.2f}" for p in probs], textposition="auto")
 
81
  height=420,
82
  margin=dict(l=30, r=20, t=60, b=40)
83
  )
84
+ # Emphasize ensemble bar (assumes last entry named "Ensemble (Soft Voting)")
85
  if len(names) >= 1:
86
+ colors = ["#9BB8D3"] * len(names)
87
+ try:
88
+ idx = names.index("Ensemble (Soft Voting)")
89
+ colors[idx] = APP_ACCENT
90
+ except ValueError:
91
+ colors[-1] = APP_ACCENT
92
+ fig.data[0].marker.color = colors
93
+ return fig
94
+
95
 
96
  def run_predict(*vals):
97
+ if STATE["df"] is None or STATE["models"] is None:
 
98
  return (
99
+ gr.Markdown.update(value="❌ Models not initialized. Reload the app."),
100
  gr.Plot.update(None),
101
  gr.Markdown.update(visible=False),
102
  gr.DataFrame.update(visible=False)
103
  )
104
 
 
105
  input_dict = {col: vals[i] for i, col in enumerate(CLEVELAND_FEATURES_ORDER)}
 
 
 
 
 
106
  results = predict_all(STATE["models"], input_dict)
107
 
108
+ final = results["Ensemble (Soft Voting)"]
 
 
 
109
  title_md = (
110
  f"### πŸ«€ Cleveland Heart Disease Diagnosis\n"
111
+ f"**Ensemble Prediction**: **{'Positive' if final['label'] == 1 else 'Negative'}** \n"
112
+ f"**Confidence (P=1)**: `{final['prob_1']:.3f}`"
113
  )
114
 
115
+ rows = []
116
  for name, r in results.items():
117
+ rows.append({
118
  "Model": name,
119
  "Predicted label": "Positive" if r["label"] == 1 else "Negative",
120
  "P(No disease)": round(r["prob_0"], 3),
121
  "P(Heart disease)": round(r["prob_1"], 3),
122
  })
123
+ table_df = pd.DataFrame(rows)
124
 
125
+ fig = _bar_for_models(results)
126
 
127
  return (
128
  gr.Markdown.update(value=title_md),
 
131
  gr.DataFrame.update(value=table_df, visible=True, interactive=False)
132
  )
133
 
134
+
135
  # -----------------------------
136
+ # UI (no gr.Box to avoid your error)
137
  # -----------------------------
138
  with gr.Blocks(theme="soft", css=f"""
139
  :root {{
140
  --primary-600: {APP_PRIMARY};
141
  }}
142
  .gradio-container {{ background: {APP_BG}; }}
 
143
  h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
144
  """) as demo:
145
  gr.Markdown("# πŸ«€ Cleveland Heart Disease Diagnosis (Ensemble Demo)")
146
 
147
  with gr.Row(equal_height=False):
148
+ # LEFT: data preview + inputs
149
  with gr.Column(scale=45):
150
+ gr.Markdown("### πŸ“ Dataset & Model Status")
151
+ status_md = gr.Markdown("Loading dataset and training models...")
152
+ preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False)
153
+ metrics_df = gr.DataFrame(label="Validation ROC-AUC (80/20 split)", interactive=False)
154
+
155
+ gr.Markdown("### ✍️ Enter Patient Features")
156
+ with gr.Row():
157
+ age = gr.Number(label="age (years)", value=58)
158
+ sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0, 1], value=1)
159
+ cp = gr.Dropdown(label="cp (chest pain type 0..3)", choices=[0, 1, 2, 3], value=2)
160
+ trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130)
161
+
162
+ with gr.Row():
163
+ chol = gr.Number(label="chol (serum cholesterol mg/dl)", value=250)
164
+ fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0, 1], value=0)
165
+ restecg = gr.Dropdown(label="restecg (0..2)", choices=[0, 1, 2], value=1)
166
+ thalach = gr.Number(label="thalach (max heart rate)", value=150)
167
+
168
+ with gr.Row():
169
+ exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0, 1], value=0)
170
+ oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0)
171
+ slope = gr.Dropdown(label="slope (0..2)", choices=[0, 1, 2], value=1)
172
+ ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0, 1, 2, 3], value=0)
173
+
174
+ thal = gr.Dropdown(label="thal (1=normal, 2=fixed, 3=reversible)", choices=[1, 2, 3], value=2)
175
+
176
+ with gr.Row():
177
+ ex_selector = gr.Dropdown(
178
+ label="Fill Example",
179
+ choices=["Example 1 (likely negative)", "Example 2 (borderline)", "Example 3 (likely positive)"],
180
+ value="Example 2 (borderline)"
181
+ )
182
+ fill_btn = gr.Button("πŸ§ͺ Use Example")
183
+ predict_btn = gr.Button("πŸ” Predict", variant="primary")
 
 
 
184
 
185
  # RIGHT: outputs
186
  with gr.Column(scale=55):
187
+ gr.Markdown("### πŸ“ˆ Predictions")
188
+ title_out = gr.Markdown("Ensemble Prediction will appear here.")
189
+ bar_out = gr.Plot(label="Model Confidence")
190
+ sub_md = gr.Markdown(visible=False)
191
+ table_out = gr.DataFrame(visible=False)
192
 
193
  with gr.Accordion("ℹ️ Notes", open=False):
194
  gr.Markdown(
195
+ "- Models are trained once at launch on `data/cleveland.csv` (80/20 split).\n"
196
+ "- `target` is binarized automatically (0 = no disease, >0 = disease).\n"
197
+ "- Ensemble uses **soft voting** over Decision Tree, k-NN, and Naive Bayes.\n"
198
+ "- Educational demo only; **not medical advice**."
199
  )
200
 
201
+ # Bind events
202
+ demo.load(fn=init_page, inputs=None, outputs=[status_md, preview, metrics_df])
 
 
 
203
 
204
  fill_btn.click(
205
+ fn=fill_example,
206
  inputs=[ex_selector],
207
  outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
208
  )
 
214
  )
215
 
216
  if __name__ == "__main__":
 
217
  demo.launch()