Marcel0123 commited on
Commit
460491a
Β·
verified Β·
1 Parent(s): f8f46b8

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +6 -59
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  from sklearn import datasets
6
  from sklearn.utils import shuffle
7
- import csv, os, tempfile
8
 
9
  EXPLAIN_MD = """
10
  ### Wat testen we hier?
@@ -79,8 +78,6 @@ def sgd_train_generator(lr, epochs, batch_size, seed, split_seed):
79
 
80
  rng = np.random.RandomState(int(seed))
81
 
82
- state_for_download = None
83
-
84
  for epoch in range(1, int(epochs) + 1):
85
  # shuffle train set
86
  x_tr, y_tr = shuffle(x_tr, y_tr, random_state=rng)
@@ -147,51 +144,9 @@ def sgd_train_generator(lr, epochs, batch_size, seed, split_seed):
147
  f"{CONCLUSION_MD}"
148
  )
149
 
150
- # Bewaar state voor download (laatste epoch)
151
- state_for_download = {
152
- "x_test": x_te,
153
- "y_test": y_te,
154
- "y_pred": y_te_pred,
155
- "w": w,
156
- "b": b,
157
- "mse_train": mse_tr,
158
- "mse_test": mse_te,
159
- "r2_test": r2_te,
160
- }
161
-
162
- yield fig_main, fig_loss, summary, state_for_download
163
-
164
- def prepare_download(state):
165
- if not state:
166
- return None
167
- # Schrijf CSV met test set en voorspellingen
168
- fd, path = tempfile.mkstemp(suffix="_diabetes_bmi_results.csv")
169
- os.close(fd)
170
- with open(path, "w", newline="", encoding="utf-8") as f:
171
- writer = csv.writer(f)
172
- writer.writerow(["bmi_normalized", "y_true", "y_pred", "residual"])
173
- for x, yt, yp in zip(state["x_test"], state["y_test"], state["y_pred"]):
174
- writer.writerow([float(x), float(yt), float(yp), float(yt-yp)])
175
- # Voeg onderaan een lege regel + metrics toe
176
- writer.writerow([])
177
- writer.writerow(["w", state["w"]])
178
- writer.writerow(["b", state["b"]])
179
- writer.writerow(["mse_train", state["mse_train"]])
180
- writer.writerow(["mse_test", state["mse_test"]])
181
- writer.writerow(["r2_test", state["r2_test"]])
182
- return path
183
 
184
  with gr.Blocks(title="Diabetes: BMI β†’ Progressiescore (Live Regressie)") as demo:
185
- # Custom CSS to color the buttons (using elem_id selectors)
186
- gr.HTML("""
187
- <style>
188
- #train-btn button { background:#2563eb; color:white; border:none; }
189
- #train-btn button:hover { filter: brightness(0.95); }
190
- #download-btn button { background:#059669; color:white; border:none; }
191
- #download-btn button:hover { filter: brightness(0.95); }
192
- </style>
193
- """)
194
-
195
  gr.Markdown("# Diabetes: BMI β†’ Progressiescore (Live Lineaire Regressie)")
196
  gr.Markdown(EXPLAIN_MD)
197
 
@@ -202,35 +157,27 @@ with gr.Blocks(title="Diabetes: BMI β†’ Progressiescore (Live Regressie)") as de
202
  batch = gr.Slider(8, 256, value=64, step=1, label="Batchgrootte")
203
  seed = gr.Slider(0, 9999, value=42, step=1, label="Training seed")
204
  split_seed = gr.Slider(0, 9999, value=7, step=1, label="Train/test split seed")
205
- train_btn = gr.Button("Train live", elem_id="train-btn", variant="primary")
206
- download_btn = gr.DownloadButton(
207
- label="Download resultaten (CSV)", elem_id="download-btn", file_name="diabetes_bmi_results.csv"
208
- )
209
- # Story direct onder de knoppen
210
  gr.Markdown(STORY_MD)
211
  with gr.Column(scale=2):
212
  plot_main = gr.Plot(label="Data (train/test) & regressielijn (live)")
213
  plot_loss = gr.Plot(label="Loss-curve (MSE per epoch) β€” train vs test")
214
  results = gr.Markdown()
215
 
216
- results_state = gr.State()
217
-
218
- # Training starten via knop (streaming)
219
  train_btn.click(
220
  fn=sgd_train_generator,
221
  inputs=[lr, epochs, batch, seed, split_seed],
222
- outputs=[plot_main, plot_loss, results, results_state]
223
  )
224
 
225
  # Auto-train bij laden met default-waarden
226
  demo.load(
227
  fn=sgd_train_generator,
228
  inputs=[lr, epochs, batch, seed, split_seed],
229
- outputs=[plot_main, plot_loss, results, results_state]
230
  )
231
 
232
- # Download-button: maak CSV vanuit state
233
- download_btn.click(fn=prepare_download, inputs=[results_state], outputs=download_btn)
234
-
235
  if __name__ == "__main__":
236
  demo.launch()
 
4
  import matplotlib.pyplot as plt
5
  from sklearn import datasets
6
  from sklearn.utils import shuffle
 
7
 
8
  EXPLAIN_MD = """
9
  ### Wat testen we hier?
 
78
 
79
  rng = np.random.RandomState(int(seed))
80
 
 
 
81
  for epoch in range(1, int(epochs) + 1):
82
  # shuffle train set
83
  x_tr, y_tr = shuffle(x_tr, y_tr, random_state=rng)
 
144
  f"{CONCLUSION_MD}"
145
  )
146
 
147
+ yield fig_main, fig_loss, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  with gr.Blocks(title="Diabetes: BMI β†’ Progressiescore (Live Regressie)") as demo:
 
 
 
 
 
 
 
 
 
 
150
  gr.Markdown("# Diabetes: BMI β†’ Progressiescore (Live Lineaire Regressie)")
151
  gr.Markdown(EXPLAIN_MD)
152
 
 
157
  batch = gr.Slider(8, 256, value=64, step=1, label="Batchgrootte")
158
  seed = gr.Slider(0, 9999, value=42, step=1, label="Training seed")
159
  split_seed = gr.Slider(0, 9999, value=7, step=1, label="Train/test split seed")
160
+ train_btn = gr.Button("Train live")
161
+ # Story direct onder de knop
 
 
 
162
  gr.Markdown(STORY_MD)
163
  with gr.Column(scale=2):
164
  plot_main = gr.Plot(label="Data (train/test) & regressielijn (live)")
165
  plot_loss = gr.Plot(label="Loss-curve (MSE per epoch) β€” train vs test")
166
  results = gr.Markdown()
167
 
168
+ # Training starten via knop
 
 
169
  train_btn.click(
170
  fn=sgd_train_generator,
171
  inputs=[lr, epochs, batch, seed, split_seed],
172
+ outputs=[plot_main, plot_loss, results]
173
  )
174
 
175
  # Auto-train bij laden met default-waarden
176
  demo.load(
177
  fn=sgd_train_generator,
178
  inputs=[lr, epochs, batch, seed, split_seed],
179
+ outputs=[plot_main, plot_loss, results]
180
  )
181
 
 
 
 
182
  if __name__ == "__main__":
183
  demo.launch()