Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import numpy as np
|
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
from sklearn import datasets
|
| 6 |
from sklearn.utils import shuffle
|
| 7 |
-
import csv, os, tempfile
|
| 8 |
|
| 9 |
EXPLAIN_MD = """
|
| 10 |
### Wat testen we hier?
|
|
@@ -79,8 +78,6 @@ def sgd_train_generator(lr, epochs, batch_size, seed, split_seed):
|
|
| 79 |
|
| 80 |
rng = np.random.RandomState(int(seed))
|
| 81 |
|
| 82 |
-
state_for_download = None
|
| 83 |
-
|
| 84 |
for epoch in range(1, int(epochs) + 1):
|
| 85 |
# shuffle train set
|
| 86 |
x_tr, y_tr = shuffle(x_tr, y_tr, random_state=rng)
|
|
@@ -147,51 +144,9 @@ def sgd_train_generator(lr, epochs, batch_size, seed, split_seed):
|
|
| 147 |
f"{CONCLUSION_MD}"
|
| 148 |
)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
state_for_download = {
|
| 152 |
-
"x_test": x_te,
|
| 153 |
-
"y_test": y_te,
|
| 154 |
-
"y_pred": y_te_pred,
|
| 155 |
-
"w": w,
|
| 156 |
-
"b": b,
|
| 157 |
-
"mse_train": mse_tr,
|
| 158 |
-
"mse_test": mse_te,
|
| 159 |
-
"r2_test": r2_te,
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
yield fig_main, fig_loss, summary, state_for_download
|
| 163 |
-
|
| 164 |
-
def prepare_download(state):
|
| 165 |
-
if not state:
|
| 166 |
-
return None
|
| 167 |
-
# Schrijf CSV met test set en voorspellingen
|
| 168 |
-
fd, path = tempfile.mkstemp(suffix="_diabetes_bmi_results.csv")
|
| 169 |
-
os.close(fd)
|
| 170 |
-
with open(path, "w", newline="", encoding="utf-8") as f:
|
| 171 |
-
writer = csv.writer(f)
|
| 172 |
-
writer.writerow(["bmi_normalized", "y_true", "y_pred", "residual"])
|
| 173 |
-
for x, yt, yp in zip(state["x_test"], state["y_test"], state["y_pred"]):
|
| 174 |
-
writer.writerow([float(x), float(yt), float(yp), float(yt-yp)])
|
| 175 |
-
# Voeg onderaan een lege regel + metrics toe
|
| 176 |
-
writer.writerow([])
|
| 177 |
-
writer.writerow(["w", state["w"]])
|
| 178 |
-
writer.writerow(["b", state["b"]])
|
| 179 |
-
writer.writerow(["mse_train", state["mse_train"]])
|
| 180 |
-
writer.writerow(["mse_test", state["mse_test"]])
|
| 181 |
-
writer.writerow(["r2_test", state["r2_test"]])
|
| 182 |
-
return path
|
| 183 |
|
| 184 |
with gr.Blocks(title="Diabetes: BMI β Progressiescore (Live Regressie)") as demo:
|
| 185 |
-
# Custom CSS to color the buttons (using elem_id selectors)
|
| 186 |
-
gr.HTML("""
|
| 187 |
-
<style>
|
| 188 |
-
#train-btn button { background:#2563eb; color:white; border:none; }
|
| 189 |
-
#train-btn button:hover { filter: brightness(0.95); }
|
| 190 |
-
#download-btn button { background:#059669; color:white; border:none; }
|
| 191 |
-
#download-btn button:hover { filter: brightness(0.95); }
|
| 192 |
-
</style>
|
| 193 |
-
""")
|
| 194 |
-
|
| 195 |
gr.Markdown("# Diabetes: BMI β Progressiescore (Live Lineaire Regressie)")
|
| 196 |
gr.Markdown(EXPLAIN_MD)
|
| 197 |
|
|
@@ -202,35 +157,27 @@ with gr.Blocks(title="Diabetes: BMI β Progressiescore (Live Regressie)") as de
|
|
| 202 |
batch = gr.Slider(8, 256, value=64, step=1, label="Batchgrootte")
|
| 203 |
seed = gr.Slider(0, 9999, value=42, step=1, label="Training seed")
|
| 204 |
split_seed = gr.Slider(0, 9999, value=7, step=1, label="Train/test split seed")
|
| 205 |
-
train_btn = gr.Button("Train live"
|
| 206 |
-
|
| 207 |
-
label="Download resultaten (CSV)", elem_id="download-btn", file_name="diabetes_bmi_results.csv"
|
| 208 |
-
)
|
| 209 |
-
# Story direct onder de knoppen
|
| 210 |
gr.Markdown(STORY_MD)
|
| 211 |
with gr.Column(scale=2):
|
| 212 |
plot_main = gr.Plot(label="Data (train/test) & regressielijn (live)")
|
| 213 |
plot_loss = gr.Plot(label="Loss-curve (MSE per epoch) β train vs test")
|
| 214 |
results = gr.Markdown()
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# Training starten via knop (streaming)
|
| 219 |
train_btn.click(
|
| 220 |
fn=sgd_train_generator,
|
| 221 |
inputs=[lr, epochs, batch, seed, split_seed],
|
| 222 |
-
outputs=[plot_main, plot_loss, results
|
| 223 |
)
|
| 224 |
|
| 225 |
# Auto-train bij laden met default-waarden
|
| 226 |
demo.load(
|
| 227 |
fn=sgd_train_generator,
|
| 228 |
inputs=[lr, epochs, batch, seed, split_seed],
|
| 229 |
-
outputs=[plot_main, plot_loss, results
|
| 230 |
)
|
| 231 |
|
| 232 |
-
# Download-button: maak CSV vanuit state
|
| 233 |
-
download_btn.click(fn=prepare_download, inputs=[results_state], outputs=download_btn)
|
| 234 |
-
|
| 235 |
if __name__ == "__main__":
|
| 236 |
demo.launch()
|
|
|
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
from sklearn import datasets
|
| 6 |
from sklearn.utils import shuffle
|
|
|
|
| 7 |
|
| 8 |
EXPLAIN_MD = """
|
| 9 |
### Wat testen we hier?
|
|
|
|
| 78 |
|
| 79 |
rng = np.random.RandomState(int(seed))
|
| 80 |
|
|
|
|
|
|
|
| 81 |
for epoch in range(1, int(epochs) + 1):
|
| 82 |
# shuffle train set
|
| 83 |
x_tr, y_tr = shuffle(x_tr, y_tr, random_state=rng)
|
|
|
|
| 144 |
f"{CONCLUSION_MD}"
|
| 145 |
)
|
| 146 |
|
| 147 |
+
yield fig_main, fig_loss, summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
with gr.Blocks(title="Diabetes: BMI β Progressiescore (Live Regressie)") as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
gr.Markdown("# Diabetes: BMI β Progressiescore (Live Lineaire Regressie)")
|
| 151 |
gr.Markdown(EXPLAIN_MD)
|
| 152 |
|
|
|
|
| 157 |
batch = gr.Slider(8, 256, value=64, step=1, label="Batchgrootte")
|
| 158 |
seed = gr.Slider(0, 9999, value=42, step=1, label="Training seed")
|
| 159 |
split_seed = gr.Slider(0, 9999, value=7, step=1, label="Train/test split seed")
|
| 160 |
+
train_btn = gr.Button("Train live")
|
| 161 |
+
# Story direct onder de knop
|
|
|
|
|
|
|
|
|
|
| 162 |
gr.Markdown(STORY_MD)
|
| 163 |
with gr.Column(scale=2):
|
| 164 |
plot_main = gr.Plot(label="Data (train/test) & regressielijn (live)")
|
| 165 |
plot_loss = gr.Plot(label="Loss-curve (MSE per epoch) β train vs test")
|
| 166 |
results = gr.Markdown()
|
| 167 |
|
| 168 |
+
# Training starten via knop
|
|
|
|
|
|
|
| 169 |
train_btn.click(
|
| 170 |
fn=sgd_train_generator,
|
| 171 |
inputs=[lr, epochs, batch, seed, split_seed],
|
| 172 |
+
outputs=[plot_main, plot_loss, results]
|
| 173 |
)
|
| 174 |
|
| 175 |
# Auto-train bij laden met default-waarden
|
| 176 |
demo.load(
|
| 177 |
fn=sgd_train_generator,
|
| 178 |
inputs=[lr, epochs, batch, seed, split_seed],
|
| 179 |
+
outputs=[plot_main, plot_loss, results]
|
| 180 |
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
if __name__ == "__main__":
|
| 183 |
demo.launch()
|