Marcel0123 commited on
Commit
32f22b3
·
verified ·
1 Parent(s): 0dbdb69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -24
app.py CHANGED
@@ -1,22 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def make_base_fig(coords, y, title):
2
- # Helder kleurpalet per klasse
3
  palette = ["#2563eb", "#ef4444", "#10b981", "#f59e0b", "#a855f7", "#06b6d4", "#f97316", "#22c55e"]
4
  fig = go.Figure()
5
-
6
- # Eerst het canvas vormgeven (wit, duidelijke assen)
7
  fig.update_layout(
8
- title=title,
9
- xaxis_title="PC1",
10
- yaxis_title="PC2",
11
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
12
  margin=dict(l=10, r=10, t=60, b=10),
13
- template=None, # geen donker thema
14
- plot_bgcolor="#ffffff", # wit
15
- paper_bgcolor="#ffffff",
16
- height=520
17
  )
18
-
19
- # Daarna de klassen als markers erbovenop
20
  labels = pd.Series(y).astype(str).values
21
  uniq = list(np.unique(labels))
22
  for i, lbl in enumerate(uniq):
@@ -24,35 +69,25 @@ def make_base_fig(coords, y, title):
24
  color = palette[i % len(palette)]
25
  fig.add_trace(go.Scatter(
26
  x=coords[mask, 0], y=coords[mask, 1],
27
- mode="markers",
28
- name=f"Klasse {lbl}",
29
  marker=dict(size=10, opacity=0.95, color=color, line=dict(width=1, color="#111")),
30
  hovertemplate="PC1: %{x:.2f}<br>PC2: %{y:.2f}<extra>" + f"Klasse {lbl}</extra>"
31
  ))
32
  return fig
33
 
34
-
35
  def draw_decision_boundary(fig, clf2d, scaler2d, pca2d, X_scaled):
36
- # Maak mesh in PCA-ruimte
37
  coords = pca2d.transform(X_scaled)
38
  x_min, x_max = coords[:, 0].min() - 0.5, coords[:, 0].max() + 0.5
39
  y_min, y_max = coords[:, 1].min() - 0.5, coords[:, 1].max() + 0.5
40
- xx, yy = np.meshgrid(
41
- np.linspace(x_min, x_max, 200),
42
- np.linspace(y_min, y_max, 200)
43
- )
44
  grid_2d = np.c_[xx.ravel(), yy.ravel()]
45
  coords_grid_s = scaler2d.transform(grid_2d)
46
-
47
- # Score voor contour
48
  if hasattr(clf2d, "predict_proba"):
49
  Z = clf2d.predict_proba(coords_grid_s)[:, -1]
50
  else:
51
  dec = clf2d.decision_function(coords_grid_s)
52
  Z = (dec - np.nanmin(dec)) / (np.nanmax(dec) - np.nanmin(dec) + 1e-9)
53
  Z = np.nan_to_num(Z, nan=0.5, posinf=1.0, neginf=0.0).reshape(xx.shape)
54
-
55
- # Contour als LIJNEN (geen vulling) zodat markers zichtbaar blijven
56
  fig.add_trace(go.Contour(
57
  x=np.linspace(x_min, x_max, 200),
58
  y=np.linspace(y_min, y_max, 200),
@@ -64,3 +99,224 @@ def draw_decision_boundary(fig, clf2d, scaler2d, pca2d, X_scaled):
64
  name="Beslissingslijnen"
65
  ))
66
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+
7
+ import gradio as gr
8
+ from sklearn.preprocessing import StandardScaler
9
+ from sklearn.decomposition import PCA
10
+ from sklearn.linear_model import SGDClassifier, LogisticRegression
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ from sklearn.svm import SVC
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
15
+
16
+
17
+ # ---------- Ingebouwde dataset ----------
18
+ def load_builtin_dataset(n=1000, seed=42):
19
+ rng = np.random.default_rng(seed)
20
+ age = rng.integers(18, 75, size=n)
21
+ gender = rng.choice([0, 1], size=n) # dummy feature
22
+ sleep_quality = np.clip(rng.normal(6.5, 1.5, size=n), 1, 10)
23
+ energy = np.clip(rng.normal(6.0, 1.7, size=n), 1, 10)
24
+ anhedonia = np.clip(rng.normal(3.5, 1.8, size=n), 1, 10)
25
+ stress = np.clip(rng.normal(4.5, 2.0, size=n), 1, 10)
26
+ social_support = np.clip(rng.normal(6.0, 1.8, size=n), 1, 10)
27
+ activity = np.clip(rng.normal(3.0 + 0.4*energy - 0.2*stress, 1.5, size=n), 0, 10)
28
+ phq9 = np.clip(
29
+ 0.8*anhedonia + 0.7*stress - 0.5*sleep_quality - 0.4*energy
30
+ + rng.normal(0, 1.2, size=n) + 5, 0, 27
31
+ )
32
+ logit = (
33
+ + 0.65*anhedonia + 0.55*stress
34
+ - 0.45*sleep_quality - 0.40*energy
35
+ - 0.30*social_support - 0.20*activity
36
+ + 0.01*(age - 40) + 0.05*gender
37
+ + rng.normal(0, 0.6, size=n)
38
+ )
39
+ logit -= np.median(logit)
40
+ prob = 1 / (1 + np.exp(-logit))
41
+ depressed = (prob > 0.5).astype(int)
42
+ df = pd.DataFrame({
43
+ "age": age, "gender": gender, "sleep_quality": sleep_quality, "energy": energy,
44
+ "anhedonia": anhedonia, "stress": stress, "social_support": social_support,
45
+ "activity": activity, "phq9": phq9, "depressed": depressed
46
+ })
47
+ return df, "depressed"
48
+
49
+
50
+ # ---------- Helpers ----------
51
+ def ensure_min_classes(y):
52
+ if len(np.unique(y)) < 2:
53
+ raise gr.Error("Label heeft minder dan 2 unieke klassen.")
54
+
55
  def make_base_fig(coords, y, title):
56
+ # Helder palet + wit canvas
57
  palette = ["#2563eb", "#ef4444", "#10b981", "#f59e0b", "#a855f7", "#06b6d4", "#f97316", "#22c55e"]
58
  fig = go.Figure()
 
 
59
  fig.update_layout(
60
+ title=title, xaxis_title="PC1", yaxis_title="PC2",
 
 
61
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
62
  margin=dict(l=10, r=10, t=60, b=10),
63
+ template=None, plot_bgcolor="#ffffff", paper_bgcolor="#ffffff", height=520
 
 
 
64
  )
 
 
65
  labels = pd.Series(y).astype(str).values
66
  uniq = list(np.unique(labels))
67
  for i, lbl in enumerate(uniq):
 
69
  color = palette[i % len(palette)]
70
  fig.add_trace(go.Scatter(
71
  x=coords[mask, 0], y=coords[mask, 1],
72
+ mode="markers", name=f"Klasse {lbl}",
 
73
  marker=dict(size=10, opacity=0.95, color=color, line=dict(width=1, color="#111")),
74
  hovertemplate="PC1: %{x:.2f}<br>PC2: %{y:.2f}<extra>" + f"Klasse {lbl}</extra>"
75
  ))
76
  return fig
77
 
 
78
  def draw_decision_boundary(fig, clf2d, scaler2d, pca2d, X_scaled):
 
79
  coords = pca2d.transform(X_scaled)
80
  x_min, x_max = coords[:, 0].min() - 0.5, coords[:, 0].max() + 0.5
81
  y_min, y_max = coords[:, 1].min() - 0.5, coords[:, 1].max() + 0.5
82
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
 
 
 
83
  grid_2d = np.c_[xx.ravel(), yy.ravel()]
84
  coords_grid_s = scaler2d.transform(grid_2d)
 
 
85
  if hasattr(clf2d, "predict_proba"):
86
  Z = clf2d.predict_proba(coords_grid_s)[:, -1]
87
  else:
88
  dec = clf2d.decision_function(coords_grid_s)
89
  Z = (dec - np.nanmin(dec)) / (np.nanmax(dec) - np.nanmin(dec) + 1e-9)
90
  Z = np.nan_to_num(Z, nan=0.5, posinf=1.0, neginf=0.0).reshape(xx.shape)
 
 
91
  fig.add_trace(go.Contour(
92
  x=np.linspace(x_min, x_max, 200),
93
  y=np.linspace(y_min, y_max, 200),
 
99
  name="Beslissingslijnen"
100
  ))
101
  return fig
102
+
103
+ def get_model(model_name, params):
104
+ if model_name == "SGDClassifier (realtime)":
105
+ return SGDClassifier(
106
+ loss=params.get("sgd_loss", "log_loss"),
107
+ alpha=params.get("sgd_alpha", 1e-4),
108
+ learning_rate=params.get("sgd_lr", "optimal"),
109
+ max_iter=1, random_state=42
110
+ )
111
+ elif model_name == "Logistic Regression":
112
+ return LogisticRegression(max_iter=300)
113
+ elif model_name == "Random Forest":
114
+ return RandomForestClassifier(
115
+ n_estimators=int(params.get("rf_n", 250)),
116
+ max_depth=int(params.get("rf_depth", 8)) if params.get("rf_depth", None) else None,
117
+ random_state=42
118
+ )
119
+ elif model_name == "SVM (RBF)":
120
+ return SVC(probability=True, gamma="scale", C=params.get("svm_c", 1.0), random_state=42)
121
+ return LogisticRegression(max_iter=300)
122
+
123
+
124
+ # ---------- Train & Stream ----------
125
+ def train_and_stream(test_size, model_name, params, epochs, pause_s):
126
+ df, ycol = load_builtin_dataset()
127
+ X = df.drop(columns=[ycol]).values
128
+ y = df[ycol].values
129
+ ensure_min_classes(y)
130
+
131
+ X_train, X_test, y_train, y_test = train_test_split(
132
+ X, y, test_size=test_size, random_state=42, stratify=y
133
+ )
134
+ scaler = StandardScaler().fit(X_train)
135
+ X_train_s = scaler.transform(X_train)
136
+ X_test_s = scaler.transform(X_test)
137
+ pca = PCA(n_components=2, random_state=42).fit(X_train_s)
138
+ coords_train = pca.transform(X_train_s)
139
+ coords_test = pca.transform(X_test_s)
140
+
141
+ clf = get_model(model_name, params)
142
+
143
+ if model_name == "SGDClassifier (realtime)":
144
+ classes = np.unique(y_train)
145
+ for e in range(1, int(epochs) + 1):
146
+ clf.partial_fit(X_train_s, y_train, classes=classes)
147
+
148
+ y_pred = clf.predict(X_test_s)
149
+ acc = accuracy_score(y_test, y_pred)
150
+ f1 = f1_score(y_test, y_pred, average="weighted")
151
+ try:
152
+ y_proba = clf.predict_proba(X_test_s)[:, -1]
153
+ auc = roc_auc_score(y_test, y_proba)
154
+ except Exception:
155
+ auc = np.nan
156
+
157
+ scaler2d = StandardScaler().fit(coords_train)
158
+ coords_train_s = scaler2d.transform(coords_train)
159
+ clf2d = LogisticRegression(max_iter=200).fit(coords_train_s, y_train)
160
+
161
+ title = f"Epoch {e}/{epochs} • Acc {acc:.2f} • F1 {f1:.2f}"
162
+ fig_epoch = make_base_fig(coords_train, y_train, title=title)
163
+ fig_epoch = draw_decision_boundary(fig_epoch, clf2d, scaler2d, pca, X_train_s)
164
+ fig_epoch.add_trace(go.Scatter(
165
+ x=coords_test[:, 0], y=coords_test[:, 1],
166
+ mode="markers", name="Test set",
167
+ marker=dict(size=10, symbol="circle-open", line=dict(width=2, color="#111")),
168
+ hovertemplate="PC1: %{x:.2f}<br>PC2: %{y:.2f}<extra>Test set</extra>"
169
+ ))
170
+
171
+ metrics_md = (
172
+ f"### Metrieken (testset)\n"
173
+ f"**Accuracy:** {acc:.3f} \n"
174
+ f"**F1 (gewogen):** {f1:.3f} \n"
175
+ f"**ROC AUC:** {auc:.3f}\n"
176
+ )
177
+
178
+ # Belangrijk: retourneer een échte Plotly Figure
179
+ yield fig_epoch, metrics_md
180
+
181
+ if pause_s and float(pause_s) > 0:
182
+ time.sleep(float(pause_s))
183
+ return
184
+ else:
185
+ clf.fit(X_train_s, y_train)
186
+ y_pred = clf.predict(X_test_s)
187
+ acc = accuracy_score(y_test, y_pred)
188
+ f1 = f1_score(y_test, y_pred, average="weighted")
189
+ try:
190
+ y_proba = clf.predict_proba(X_test_s)[:, -1]
191
+ auc = roc_auc_score(y_test, y_proba)
192
+ except Exception:
193
+ auc = np.nan
194
+
195
+ fig = make_base_fig(coords_train, y_train, title=f"Model: {model_name}")
196
+ scaler2d = StandardScaler().fit(coords_train)
197
+ coords_train_s = scaler2d.transform(coords_train)
198
+ clf2d = LogisticRegression(max_iter=200).fit(coords_train_s, y_train)
199
+ fig = draw_decision_boundary(fig, clf2d, scaler2d, pca, X_train_s)
200
+ fig.add_trace(go.Scatter(
201
+ x=coords_test[:, 0], y=coords_test[:, 1],
202
+ mode="markers", name="Test set",
203
+ marker=dict(size=10, symbol="circle-open", line=dict(width=2, color="#111")),
204
+ ))
205
+
206
+ metrics_md = (
207
+ f"### Metrieken (testset)\n"
208
+ f"**Accuracy:** {acc:.3f} \n"
209
+ f"**F1 (gewogen):** {f1:.3f} \n"
210
+ f"**ROC AUC:** {auc:.3f}\n"
211
+ )
212
+ return fig, metrics_md
213
+
214
+
215
+ # ---------- UI ----------
216
+ DESCRIPTION = """
217
+ # 🧠 Supervised Leren – Depressie (synthetisch, ingebouwd)
218
+ - **Realtime** training (SGD) met **PCA-scatter** (elk bolletje = patiënt) en **beslissingslijnen**.
219
+ - Eén pagina, helder wit canvas. Geen uploads nodig.
220
+ """
221
+
222
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", neutral_hue="slate")) as demo:
223
+ gr.Markdown(DESCRIPTION)
224
+
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
+ ds_preview = gr.Dataframe(label="Voorbeeld van de data (eerste 10 rijen)")
228
+ btn_preview = gr.Button("📄 Dataset preview vernieuwen", variant="secondary")
229
+ with gr.Column(scale=1):
230
+ model_choice = gr.Radio(
231
+ label="Model",
232
+ choices=["SGDClassifier (realtime)", "Logistic Regression", "Random Forest", "SVM (RBF)"],
233
+ value="SGDClassifier (realtime)"
234
+ )
235
+ with gr.Accordion("Hyperparameters", open=False):
236
+ sgd_loss = gr.Dropdown(["log_loss", "hinge", "modified_huber"], value="log_loss", label="SGD loss")
237
+ sgd_alpha = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="SGD alpha (L2)")
238
+ sgd_lr = gr.Dropdown(["optimal", "invscaling", "constant", "adaptive"], value="optimal", label="SGD learning rate")
239
+
240
+ rf_n = gr.Slider(50, 500, value=250, step=10, label="RandomForest n_estimators")
241
+ rf_depth = gr.Slider(0, 20, value=8, step=1, label="RandomForest max_depth (0 = None)")
242
+ svm_c = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="SVM C")
243
+
244
+ test_size = gr.Slider(0.1, 0.5, value=0.25, step=0.05, label="Testset proportie")
245
+ with gr.Row():
246
+ epochs = gr.Slider(1, 30, value=12, step=1, label="Epochs (alleen realtime SGD)")
247
+ pause_s = gr.Slider(0.0, 1.0, value=0.15, step=0.05, label="Pauze per epoch (s)")
248
+
249
+ btn_train = gr.Button("🚀 Train & Visualiseer", variant="primary")
250
+
251
+ with gr.Row():
252
+ fig_out = gr.Plot(label="Visualisatie (PCA 2D) met beslissingslijnen)")
253
+ metrics_out = gr.Markdown(label="Metrieken")
254
+
255
+ with gr.Row():
256
+ with gr.Column():
257
+ row_index = gr.Slider(0, 999, value=0, step=1, label="Kies een patiënt (rij-index) voor voorspelling")
258
+ btn_predict = gr.Button("🔮 Voorspel voor gekozen patiënt", variant="secondary")
259
+ pred_md = gr.Markdown(label="Voorspelling")
260
+
261
+ # Preload: preview, dan direct trainen
262
+ demo.load(lambda: load_builtin_dataset()[0].head(10), inputs=None, outputs=[ds_preview])
263
+
264
+ def _proxy_train(test_size_v, model_name_v,
265
+ sgd_loss_v, sgd_alpha_v, sgd_lr_v, rf_n_v, rf_depth_v, svm_c_v,
266
+ epochs_v, pause_v):
267
+ params = dict(
268
+ sgd_loss=sgd_loss_v,
269
+ sgd_alpha=float(sgd_alpha_v),
270
+ sgd_lr=sgd_lr_v,
271
+ rf_n=int(rf_n_v),
272
+ rf_depth=None if int(rf_depth_v) == 0 else int(rf_depth_v),
273
+ svm_c=float(svm_c_v),
274
+ )
275
+ yield from train_and_stream(test_size_v, model_name_v, params, epochs_v, pause_v)
276
+
277
+ demo.load(
278
+ _proxy_train,
279
+ inputs=[test_size, model_choice, sgd_loss, sgd_alpha, sgd_lr, rf_n, rf_depth, svm_c, epochs, pause_s],
280
+ outputs=[fig_out, metrics_out]
281
+ )
282
+
283
+ btn_preview.click(lambda: load_builtin_dataset()[0].head(10), inputs=None, outputs=[ds_preview])
284
+
285
+ btn_train.click(
286
+ _proxy_train,
287
+ inputs=[test_size, model_choice, sgd_loss, sgd_alpha, sgd_lr, rf_n, rf_depth, svm_c, epochs, pause_s],
288
+ outputs=[fig_out, metrics_out]
289
+ )
290
+
291
+ btn_predict.click(
292
+ lambda model_name_v, sgd_loss_v, sgd_alpha_v, sgd_lr_v, rf_n_v, rf_depth_v, svm_c_v, row_idx:
293
+ (lambda df, ycol: (
294
+ (lambda scaler, Xs, y, idx:
295
+ (lambda clf:
296
+ (lambda x_row, pred, proba, pretty:
297
+ f"### Gekozen patiënt (rij {idx})\n```json\n{pretty}\n```\n**Voorspelling:** {pred} \n"
298
+ + (f"**Zekerheid (max. klasse-prob):** {proba:.3f}" if proba is not None else "")
299
+ )(
300
+ Xs[idx].reshape(1, -1),
301
+ clf.predict(Xs[idx].reshape(1, -1))[0],
302
+ (clf.predict_proba(Xs[idx].reshape(1, -1))[0].max() if hasattr(clf, 'predict_proba') else None),
303
+ json.dumps(df.iloc[[idx]].to_dict(orient='records')[0], ensure_ascii=False, indent=2)
304
+ )
305
+ )(
306
+ (lambda base_clf:
307
+ LogisticRegression(max_iter=300) if isinstance(base_clf, SGDClassifier) else base_clf
308
+ )(get_model(model_name_v, dict(
309
+ sgd_loss=sgd_loss_v, sgd_alpha=float(sgd_alpha_v), sgd_lr=sgd_lr_v,
310
+ rf_n=int(rf_n_v), rf_depth=None if int(rf_depth_v)==0 else int(rf_depth_v), svm_c=float(svm_c_v)
311
+ ))).fit(Xs, y.values)
312
+ )
313
+ )(StandardScaler().fit(df.drop(columns=[ycol]).values),
314
+ StandardScaler().fit(df.drop(columns=[ycol]).values).transform(df.drop(columns=[ycol]).values),
315
+ df[ycol], int(row_idx))
316
+ ))(*load_builtin_dataset()),
317
+ inputs=[model_choice, sgd_loss, sgd_alpha, sgd_lr, rf_n, rf_depth, svm_c, row_index],
318
+ outputs=[pred_md]
319
+ )
320
+
321
+ if __name__ == "__main__":
322
+ demo.launch()