wjnwjn59 commited on
Commit
b99dcff
Β·
1 Parent(s): a9c2222

update dataset

Browse files
Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -42,7 +42,7 @@ force_light_theme_js = """
42
  }
43
  """
44
 
45
- def init_page():
46
  """Load dataset, train models, and return status, preview, metrics."""
47
  if not os.path.exists(DATA_PATH):
48
  msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
@@ -50,13 +50,15 @@ def init_page():
50
 
51
  df = load_cleveland_dataframe(file_path=DATA_PATH)
52
 
53
- models, metrics = fit_all_models(df)
 
 
54
  STATE["df"] = df
55
  STATE["models"] = models
56
  STATE["metrics"] = metrics
57
 
58
  head = df.head(8)
59
- msg = "βœ… Cleveland dataset loaded from `data/cleveland.csv` and models trained (80/20 split)."
60
  return msg, head, metrics
61
 
62
 
@@ -188,9 +190,20 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
188
  # LEFT: data preview + inputs
189
  with gr.Column(scale=45):
190
  with gr.Accordion("πŸ“ Dataset & Model Status", open=True):
 
 
 
 
 
 
 
 
 
 
 
191
  status_md = gr.Markdown("Loading dataset and training models...")
192
  preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False)
193
- metrics_df = gr.DataFrame(label="Validation Metrics (80/20 split)", interactive=False)
194
 
195
  with gr.Accordion("✍️ Enter Patient Features", open=True):
196
  with gr.Row():
@@ -244,8 +257,9 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
244
  gr.Markdown("""
245
  ## πŸ“‹ **Notes**
246
 
247
- - **Models are trained once at launch** on `data/cleveland.csv` (80/20 split).
248
  - **Target is binarized automatically** (0 = no disease, >0 = disease).
 
249
  - **Seven optimized models are compared**: Decision Tree, k-NN, Naive Bayes, Random Forest, AdaBoost, Gradient Boosting, and XGBoost.
250
  - **Hyperparameters are optimized** for heart disease prediction tasks using best practices.
251
  - **Ensemble uses weighted soft voting** with optimized weights based on model performance.
@@ -276,7 +290,14 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
276
  vlai_template.create_footer()
277
 
278
  # Bind events
279
- demo.load(fn=init_page, inputs=None, outputs=[status_md, preview, metrics_df])
 
 
 
 
 
 
 
280
 
281
  # Auto-fill when example is selected
282
  ex_selector.change(
 
42
  }
43
  """
44
 
45
+ def init_page(train_split):
46
  """Load dataset, train models, and return status, preview, metrics."""
47
  if not os.path.exists(DATA_PATH):
48
  msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
 
50
 
51
  df = load_cleveland_dataframe(file_path=DATA_PATH)
52
 
53
+ # Convert train_split percentage to test_size for sklearn
54
+ test_size = (100 - train_split) / 100
55
+ models, metrics = fit_all_models(df, test_size=test_size)
56
  STATE["df"] = df
57
  STATE["models"] = models
58
  STATE["metrics"] = metrics
59
 
60
  head = df.head(8)
61
+ msg = f"βœ… Cleveland dataset loaded from `data/cleveland.csv` and models trained ({train_split}/{100-train_split} split)."
62
  return msg, head, metrics
63
 
64
 
 
190
  # LEFT: data preview + inputs
191
  with gr.Column(scale=45):
192
  with gr.Accordion("πŸ“ Dataset & Model Status", open=True):
193
+ with gr.Row():
194
+ train_split = gr.Slider(
195
+ minimum=60,
196
+ maximum=90,
197
+ value=80,
198
+ step=5,
199
+ label="Training Split (%)",
200
+ info="Percentage of data used for training (remaining for validation)"
201
+ )
202
+ retrain_btn = gr.Button("πŸ”„ Retrain Models", variant="secondary")
203
+
204
  status_md = gr.Markdown("Loading dataset and training models...")
205
  preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False)
206
+ metrics_df = gr.DataFrame(label="Validation Metrics", interactive=False)
207
 
208
  with gr.Accordion("✍️ Enter Patient Features", open=True):
209
  with gr.Row():
 
257
  gr.Markdown("""
258
  ## πŸ“‹ **Notes**
259
 
260
+ - **Models are trained at launch** on `data/cleveland.csv` with customizable train/validation split (default 80/20).
261
  - **Target is binarized automatically** (0 = no disease, >0 = disease).
262
+ - **Retrain functionality**: Adjust the split ratio and click "πŸ”„ Retrain Models" to see how data size affects performance.
263
  - **Seven optimized models are compared**: Decision Tree, k-NN, Naive Bayes, Random Forest, AdaBoost, Gradient Boosting, and XGBoost.
264
  - **Hyperparameters are optimized** for heart disease prediction tasks using best practices.
265
  - **Ensemble uses weighted soft voting** with optimized weights based on model performance.
 
290
  vlai_template.create_footer()
291
 
292
  # Bind events
293
+ demo.load(fn=init_page, inputs=[train_split], outputs=[status_md, preview, metrics_df])
294
+
295
+ # Retrain models when split changes or button is clicked
296
+ retrain_btn.click(
297
+ fn=init_page,
298
+ inputs=[train_split],
299
+ outputs=[status_md, preview, metrics_df]
300
+ )
301
 
302
  # Auto-fill when example is selected
303
  ex_selector.change(