Marcel0123 commited on
Commit
2766592
·
verified ·
1 Parent(s): fa7883a

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +172 -226
  2. ggz_depressie_synth_1000_modeling.csv +0 -0
app.py CHANGED
@@ -16,80 +16,67 @@ from sklearn.model_selection import train_test_split
16
  DESCRIPTION = """
17
  # Interactieve Scatter (2D/3D) — Gradio + Plotly + Live Train
18
 
19
- - **Upload** een CSV/TSV/Parquet (of gebruik de demo data)
 
 
20
  - **Kies** x, y (en z) kolommen voor de scatter
21
  - **Kleur** op cluster/categorie of continue variabele
22
  - **Hover** toont gekozen kenmerken
23
- - **Train (live)**: KMeans clustering of Logistic Regression classificatie
24
- - **Auto-train bij start**: model traint automatisch wanneer de Space start
25
  """
26
 
27
  MODEL_PATH = Path("model.joblib")
 
28
 
29
  # -----------------------------
30
- # Demo dataset
31
  # -----------------------------
32
- def make_demo_df(n=400, seed=7):
33
- rng = np.random.default_rng(seed)
34
- centers = np.array([
35
- [0, 0, 0],
36
- [5, 5, 2],
37
- [-4, 3, -3],
38
- ])
39
- labels = rng.integers(0, len(centers), size=n)
40
- points = centers[labels] + rng.normal(0, 1.1, size=(n, 3))
41
-
42
- df = pd.DataFrame(points, columns=["x", "y", "z"])
 
43
  df["cluster"] = pd.Categorical(["A" if l == 0 else ("B" if l == 1 else "C") for l in labels])
44
- df["age"] = rng.integers(20, 90, size=n)
45
- df["sex"] = pd.Categorical(rng.choice(["F", "M"], size=n))
46
- df["diagnosis"] = pd.Categorical(rng.choice(["Type I", "Type II", "Control"], size=n, p=[0.35, 0.35, 0.30]))
47
- df["patient_id"] = [f"P{1000+i}" for i in range(n)]
48
  return df
49
 
50
- DEMO_DF = make_demo_df()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # -----------------------------
53
- # Helpers
54
  # -----------------------------
55
- def parse_file(file_obj):
56
- if file_obj is None:
57
- return DEMO_DF.copy(), "Demo dataset geladen (geen upload)."
58
- name = getattr(file_obj, "name", str(file_obj))
59
- path = name
60
-
61
- if name.lower().endswith(".csv"):
62
- df = pd.read_csv(path)
63
- elif name.lower().endswith(".tsv"):
64
- df = pd.read_csv(path, sep="\t")
65
- elif name.lower().endswith(".parquet"):
66
- df = pd.read_parquet(path)
67
- else:
68
- df = pd.read_csv(path)
69
-
70
- return df, f"Bestand geladen: {Path(name).name} — {df.shape[0]} rijen, {df.shape[1]} kolommen."
71
-
72
- def detect_columns(df):
73
- numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
74
- all_cols = df.columns.tolist()
75
-
76
- x_default = next((c for c in ["x", "X", "dim1", "pc1", "tsne1", "umap1"] if c in df.columns),
77
- (numeric_cols[0] if numeric_cols else None))
78
- y_default = next((c for c in ["y", "Y", "dim2", "pc2", "tsne2", "umap2"] if c in df.columns and c != x_default),
79
- (numeric_cols[1] if len(numeric_cols) > 1 else None))
80
- z_default = next((c for c in ["z", "Z", "dim3", "pc3", "tsne3", "umap3"] if c in df.columns and c not in {x_default, y_default}),
81
- (numeric_cols[2] if len(numeric_cols) > 2 else None))
82
-
83
- cat_candidates = [c for c in df.columns if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
84
- and c not in {x_default, y_default, z_default}]
85
- color_default = next((c for c in ["cluster", "label", "group", "diagnosis", "category"] if c in df.columns),
86
- (cat_candidates[0] if cat_candidates else (numeric_cols[0] if numeric_cols else None)))
87
- return numeric_cols, all_cols, x_default, y_default, z_default, color_default
88
-
89
  def build_hovertemplate(hover_cols):
90
  if not hover_cols:
91
  return "%{x}, %{y}<extra></extra>"
92
-
93
  lines = []
94
  for i, col in enumerate(hover_cols):
95
  lines.append(f"<b>{col}</b>: %{{customdata[{i}]}}")
@@ -108,30 +95,22 @@ def make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, point_s
108
  fig = go.Figure()
109
  if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
110
  for cat_val, dsub in df.groupby(color_col):
111
- fig.add_trace(
112
- go.Scattergl(
113
- x=dsub[x_col], y=dsub[y_col], mode='markers', name=str(cat_val),
114
- marker=dict(size=point_size), opacity=opacity,
115
- customdata=(dsub[hover_cols].to_numpy() if hover_cols else None),
116
- hovertemplate=hovertemplate,
117
- )
118
- )
119
  else:
120
- fig.add_trace(
121
- go.Scattergl(
122
- x=df[x_col], y=df[y_col], mode='markers', name=color_col if color_col else "data",
123
- marker=dict(size=point_size, color=(color_series if color_series is not None else None), coloraxis='coloraxis'),
124
- opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
125
- )
126
- )
127
  fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
128
- fig.update_layout(
129
- template="plotly_white",
130
- margin=dict(l=10, r=10, t=30, b=10),
131
- legend=dict(itemsizing='trace', title=color_col if color_col else None),
132
- xaxis_title=x_col, yaxis_title=y_col,
133
- dragmode='pan',
134
- )
135
  return fig
136
 
137
  # 3D
@@ -140,66 +119,63 @@ def make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, point_s
140
  if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
141
  fig = go.Figure()
142
  for cat_val, dsub in df.groupby(color_col):
143
- fig.add_trace(
144
- go.Scatter3d(
145
- x=dsub[x_col], y=dsub[y_col], z=dsub[z_col], mode='markers', name=str(cat_val),
146
- marker=dict(size=point_size), opacity=opacity,
147
- customdata=(dsub[hover_cols].to_numpy() if hover_cols else None),
148
- hovertemplate=hovertemplate,
149
- )
150
- )
151
  else:
152
- fig = go.Figure(
153
- data=[
154
- go.Scatter3d(
155
- x=df[x_col], y=df[y_col], z=df[z_col], mode='markers', name=color_col if color_col else "data",
156
- marker=dict(size=point_size, color=(color_series if color_series is not None else None), coloraxis='coloraxis'),
157
- opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
158
- )
159
- ]
160
- )
161
  fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
162
- fig.update_layout(
163
- template="plotly_white",
164
- margin=dict(l=10, r=10, t=30, b=10),
165
- legend=dict(itemsizing='trace', title=color_col if color_col else None),
166
- scene=dict(xaxis_title=x_col, yaxis_title=y_col, zaxis_title=z_col),
167
- )
168
  return fig
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def _plot_confusion_matrix(y_true, y_pred, title="Confusion matrix"):
171
  labels = sorted(pd.Series(y_true).unique())
172
  cm = confusion_matrix(y_true, y_pred, labels=labels)
173
- fig = go.Figure(
174
- data=go.Heatmap(
175
- z=cm, x=labels, y=labels, text=cm, texttemplate="%{text}",
176
- hovertemplate="Pred=%{x}<br>True=%{y}<br>Count=%{z}<extra></extra>"
177
- )
178
- )
179
- fig.update_layout(title=title, xaxis_title="Predicted", yaxis_title="True",
180
- template="plotly_white", margin=dict(l=10,r=10,t=40,b=10))
181
  return fig
182
 
183
  def _fmt_metrics(metrics: dict) -> str:
184
  lines = []
185
  for k,v in metrics.items():
186
- if isinstance(v, float):
187
- lines.append(f"{k}: {v:.4f}")
188
- else:
189
- lines.append(f"{k}: {v}")
190
  return "\n".join(lines)
191
 
192
- # -----------------------------
193
- # Training
194
- # -----------------------------
195
  def train_live(df, task, feature_cols, label_col, k_clusters, seed):
196
- if df is None or len(df) == 0:
197
  raise gr.Error("Geen data om te trainen.")
198
  if not feature_cols:
199
  raise gr.Error("Selecteer minimaal één featurekolom.")
200
-
201
  X = df[feature_cols].select_dtypes(include=[np.number])
202
- if X.shape[1] == 0:
203
  raise gr.Error("De gekozen features bevatten geen numerieke kolommen.")
204
 
205
  log_lines = []
@@ -207,68 +183,41 @@ def train_live(df, task, feature_cols, label_col, k_clusters, seed):
207
  eval_fig = None
208
 
209
  if task == "Clustering (KMeans)":
210
- pipe = Pipeline([
211
- ("scaler", StandardScaler()),
212
- ("kmeans", KMeans(n_clusters=k_clusters, random_state=seed, n_init="auto")),
213
- ])
214
  pipe.fit(X)
215
  labels = pipe["kmeans"].labels_
216
- alpha = [chr(ord('A') + (i % 26)) for i in labels]
217
- df["cluster_model"] = pd.Categorical(alpha)
218
  color_col_suggestion = "cluster_model"
219
  dump(pipe, MODEL_PATH)
220
- log_lines.append(f"✅ KMeans getraind met k={k_clusters} op {X.shape[0]} rijen en {X.shape[1]} features.")
221
- log_lines.append(f"Model opgeslagen: {MODEL_PATH.resolve()}")
222
 
223
  elif task == "Classificatie (Logistic Regression)":
224
  if not label_col:
225
  raise gr.Error("Kies een labelkolom voor classificatie.")
226
  y = df[label_col].astype(str)
227
-
228
- # Stratified split
229
- Xtr, Xva, ytr, yva = train_test_split(
230
- X, y, test_size=0.2, random_state=seed, stratify=y if y.nunique()>1 else None
231
- )
232
-
233
- pipe = Pipeline([
234
- ("scaler", StandardScaler()),
235
- ("logreg", LogisticRegression(max_iter=1000, random_state=seed, class_weight="balanced"))
236
- ])
237
  pipe.fit(Xtr, ytr)
238
-
239
- # Validatie
240
  yhat = pipe.predict(Xva)
241
- metrics = {
242
- "accuracy": accuracy_score(yva, yhat),
243
- "f1_weighted": f1_score(yva, yhat, average="weighted"),
244
- }
245
- if y.nunique() == 2:
246
  try:
247
- proba = pipe.predict_proba(Xva)[:, 1]
248
  uniq = list(pd.Series(y).unique())
249
  mapping = {uniq[0]:0, uniq[1]:1}
250
  metrics["roc_auc"] = roc_auc_score(yva.map(mapping), proba)
251
  except Exception:
252
  pass
253
-
254
  eval_fig = _plot_confusion_matrix(yva, yhat, title="Confusion matrix (validatie)")
255
-
256
- # Train op alle data en voorspel voor visualisatie
257
  pipe.fit(X, y)
258
  preds = pipe.predict(X)
259
  df["pred_model"] = pd.Categorical(preds)
260
- if y.nunique() == 2:
261
- try:
262
- df["pred_proba"] = pipe.predict_proba(X)[:, 1]
263
- except Exception:
264
- pass
265
-
266
  color_col_suggestion = "pred_model"
267
  dump(pipe, MODEL_PATH)
268
- log_lines.append(f"✅ LogisticRegression getraind (split 80/20).")
269
- log_lines.append(f"Model opgeslagen: {MODEL_PATH.resolve()}")
270
- log_lines.append("Metrics (validatie):\n" + _fmt_metrics(metrics))
271
-
272
  else:
273
  raise gr.Error("Onbekende taak.")
274
 
@@ -286,29 +235,49 @@ def try_load_model():
286
  # -----------------------------
287
  # Gradio callbacks
288
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def init_from_file(file_obj):
290
- df, status = parse_file(file_obj)
 
 
 
 
 
291
  numeric_cols, all_cols, x_d, y_d, z_d, color_d = detect_columns(df)
292
- hover_default = [c for c in ["patient_id", "age", "sex", "diagnosis", "cluster"] if c in df.columns]
293
- feat_default = [c for c in numeric_cols]
 
294
  return (
295
  gr.update(choices=all_cols, value=x_d),
296
  gr.update(choices=all_cols, value=y_d),
297
  gr.update(choices=all_cols, value=z_d),
298
- gr.update(choices=all_cols, value=color_d),
299
  gr.update(choices=all_cols, value=hover_default),
300
  status,
301
  df,
302
- gr.update(choices=numeric_cols, value=feat_default),
303
- gr.update(choices=all_cols, value=None),
304
  )
305
 
306
  def update_plot(df, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
307
- if df is None or (isinstance(df, (list, tuple)) and len(df) == 0):
308
- df = DEMO_DF.copy()
309
  mode_3d = (mode_dim == "3D")
310
- fig = make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, size, opacity)
311
- return fig
312
 
313
  def on_train_click(df, task, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
314
  df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), task, feature_cols, label_col, k_clusters, seed)
@@ -317,8 +286,11 @@ def on_train_click(df, task, feature_cols, label_col, k_clusters, seed, x_col, y
317
  return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
318
 
319
  def startup_auto_train(df, task_default, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
 
320
  try:
321
- df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), task_default, feature_cols, label_col, k_clusters, seed)
 
 
322
  new_color = color_suggestion if color_suggestion else color_col
323
  fig = update_plot(df2, x_col, y_col, z_col, new_color, hover_cols, mode_dim, size, opacity)
324
  return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
@@ -334,87 +306,61 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px !important}") as demo:
334
  with gr.Row():
335
  with gr.Column(scale=1):
336
  data_file = gr.File(label="Upload CSV/TSV/Parquet", file_count="single", type="filepath")
337
- status_box = gr.Markdown("Gebruik de demo data of upload je eigen bestand.")
338
 
339
  with gr.Accordion("Assen & kleur", open=True):
340
- x_dd = gr.Dropdown(choices=DEMO_DF.columns.tolist(), value="x", label="X kolom")
341
- y_dd = gr.Dropdown(choices=DEMO_DF.columns.tolist(), value="y", label="Y kolom")
342
- z_dd = gr.Dropdown(choices=DEMO_DF.columns.tolist(), value="z", label="Z kolom (voor 3D)")
343
- color_dd = gr.Dropdown(choices=DEMO_DF.columns.tolist(), value="cluster", label="Kleur op kolom")
344
- hover_ms = gr.Dropdown(
345
- choices=DEMO_DF.columns.tolist(),
346
- value=["patient_id", "age", "sex", "diagnosis", "cluster"],
347
- multiselect=True,
348
- label="Hover info kolommen"
349
- )
350
 
351
  with gr.Accordion("Weergave", open=True):
352
- mode_dim = gr.Radio(["2D", "3D"], value="2D", label="Dimensie")
353
  size_slider = gr.Slider(3, 18, value=8, step=1, label="Puntgrootte")
354
  opacity_slider = gr.Slider(0.1, 1.0, value=0.8, step=0.05, label="Transparantie (opacity)")
355
 
356
  with gr.Accordion("Training (live)", open=True):
357
- task_radio = gr.Radio(
358
- ["Clustering (KMeans)", "Classificatie (Logistic Regression)"],
359
- value="Clustering (KMeans)",
360
- label="Taak"
361
- )
362
- feat_ms = gr.Dropdown(choices=DEMO_DF.select_dtypes(include=[np.number]).columns.tolist(),
363
- value=["x", "y", "z", "age"],
364
- multiselect=True,
365
- label="Feature kolommen (numeriek)")
366
- label_dd = gr.Dropdown(choices=DEMO_DF.columns.tolist(), value=None, label="Label kolom (alleen voor classificatie)")
367
  k_slider = gr.Slider(2, 12, value=3, step=1, label="K (clusters) — KMeans")
368
  seed_slider = gr.Slider(0, 10_000, value=7, step=1, label="Random seed")
369
  train_btn = gr.Button("🚀 Train (live)")
370
  train_log = gr.Textbox(label="Train log", lines=6, interactive=False)
371
 
372
- hidden_df = gr.State(DEMO_DF.copy())
373
 
374
  with gr.Column(scale=2):
375
- # Zelfde visualisatie (bolletjes) en layout behouden
376
  plot = gr.Plot(label="Scatterplot")
377
  with gr.Accordion("Evaluatie (validatie)", open=False):
378
  cm_plot = gr.Plot(label="Confusion Matrix (validatie)")
379
 
380
  # ===== Events =====
381
- data_file.change(
382
- fn=init_from_file,
383
- inputs=[data_file],
384
- outputs=[x_dd, y_dd, z_dd, color_dd, hover_ms, status_box, hidden_df, feat_ms, label_dd],
385
- show_progress=False,
386
- )
387
 
388
  for comp in [x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider]:
389
- comp.change(
390
- fn=update_plot,
391
- inputs=[hidden_df, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
392
- outputs=plot,
393
- show_progress=False,
394
- )
395
-
396
- train_btn.click(
397
- fn=on_train_click,
398
- inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider,
399
- x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
400
- outputs=[hidden_df, train_log, color_dd, plot, cm_plot],
401
- show_progress=True,
402
- )
403
 
404
- demo.load(
405
- fn=update_plot,
406
- inputs=[hidden_df, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
407
- outputs=plot,
408
- show_progress=False,
409
- )
410
 
411
- demo.load(
412
- fn=startup_auto_train,
413
- inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider,
414
- x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
415
- outputs=[hidden_df, train_log, color_dd, plot, cm_plot],
416
- show_progress=True,
417
- )
418
 
419
  if __name__ == "__main__":
420
  demo.launch()
 
16
  DESCRIPTION = """
17
  # Interactieve Scatter (2D/3D) — Gradio + Plotly + Live Train
18
 
19
+ - **Bundled dataset**: ggz_depressie_synth_1000_modeling.csv (laadt automatisch)
20
+ - **Auto-train (supervised)** bij start met label: `target_respons50`
21
+ - **Upload** een eigen CSV/TSV/Parquet indien gewenst
22
  - **Kies** x, y (en z) kolommen voor de scatter
23
  - **Kleur** op cluster/categorie of continue variabele
24
  - **Hover** toont gekozen kenmerken
25
+ - **Train (live)**: KMeans of Logistic Regression
 
26
  """
27
 
28
  MODEL_PATH = Path("model.joblib")
29
+ DATA_PATH = Path("data/ggz_depressie_synth_1000_modeling.csv")
30
 
31
  # -----------------------------
32
+ # Data loading
33
  # -----------------------------
34
+ def load_default_df():
35
+ if DATA_PATH.exists():
36
+ try:
37
+ return pd.read_csv(DATA_PATH)
38
+ except Exception as e:
39
+ print("Kon bundled dataset niet laden:", e)
40
+ # Fallback demo
41
+ rng = np.random.default_rng(7)
42
+ centers = np.array([[0,0,0],[5,5,2],[-4,3,-3]])
43
+ labels = rng.integers(0, len(centers), size=400)
44
+ points = centers[labels] + rng.normal(0, 1.1, size=(400,3))
45
+ df = pd.DataFrame(points, columns=["x","y","z"])
46
  df["cluster"] = pd.Categorical(["A" if l == 0 else ("B" if l == 1 else "C") for l in labels])
47
+ df["age"] = rng.integers(20, 90, size=400)
48
+ df["sex"] = pd.Categorical(rng.choice(["F","M"], size=400))
49
+ df["diagnosis"] = pd.Categorical(rng.choice(["Type I","Type II","Control"], size=400, p=[0.35,0.35,0.30]))
50
+ df["patient_id"] = [f"P{1000+i}" for i in range(400)]
51
  return df
52
 
53
+ BASE_DF = load_default_df()
54
+
55
+ # Heuristics for label/features
56
+ def pick_default_label(df):
57
+ # voorkeur: target_respons50 -> anders bekende varianten -> anders eerste target_
58
+ for name in ["target_respons50", "target_remissie", "target_uitval", "target_opname6mnd", "target_rtw3mnd"]:
59
+ if name in df.columns:
60
+ return name
61
+ for c in df.columns:
62
+ if str(c).lower() in ["label","target","y","class","diagnosis","outcome"]:
63
+ return c
64
+ for c in df.columns:
65
+ if str(c).startswith("target_"):
66
+ return c
67
+ return None
68
+
69
+ def default_features(df):
70
+ num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
71
+ exclude = set(["patient_id"] + [c for c in df.columns if str(c).startswith("target_")])
72
+ return [c for c in num_cols if c not in exclude]
73
 
74
  # -----------------------------
75
+ # Plot utils (identiek aan eerdere bolletjesvis)
76
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def build_hovertemplate(hover_cols):
78
  if not hover_cols:
79
  return "%{x}, %{y}<extra></extra>"
 
80
  lines = []
81
  for i, col in enumerate(hover_cols):
82
  lines.append(f"<b>{col}</b>: %{{customdata[{i}]}}")
 
95
  fig = go.Figure()
96
  if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
97
  for cat_val, dsub in df.groupby(color_col):
98
+ fig.add_trace(go.Scattergl(
99
+ x=dsub[x_col], y=dsub[y_col], mode='markers', name=str(cat_val),
100
+ marker=dict(size=point_size), opacity=opacity,
101
+ customdata=(dsub[hover_cols].to_numpy() if hover_cols else None),
102
+ hovertemplate=hovertemplate,
103
+ ))
 
 
104
  else:
105
+ fig.add_trace(go.Scattergl(
106
+ x=df[x_col], y=df[y_col], mode='markers', name=color_col if color_col else "data",
107
+ marker=dict(size=point_size, color=(color_series if color_series is not None else None), coloraxis='coloraxis'),
108
+ opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
109
+ ))
 
 
110
  fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
111
+ fig.update_layout(template="plotly_white", margin=dict(l=10,r=10,t=30,b=10),
112
+ legend=dict(itemsizing='trace', title=color_col if color_col else None),
113
+ xaxis_title=x_col, yaxis_title=y_col, dragmode='pan')
 
 
 
 
114
  return fig
115
 
116
  # 3D
 
119
  if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
120
  fig = go.Figure()
121
  for cat_val, dsub in df.groupby(color_col):
122
+ fig.add_trace(go.Scatter3d(
123
+ x=dsub[x_col], y=dsub[y_col], z=dsub[z_col], mode='markers', name=str(cat_val),
124
+ marker=dict(size=point_size), opacity=opacity,
125
+ customdata=(dsub[hover_cols].to_numpy() if hover_cols else None),
126
+ hovertemplate=hovertemplate,
127
+ ))
 
 
128
  else:
129
+ fig = go.Figure(data=[go.Scatter3d(
130
+ x=df[x_col], y=df[y_col], z=df[z_col], mode='markers', name=color_col if color_col else "data",
131
+ marker=dict(size=point_size, color=(color_series if color_series is not None else None), coloraxis='coloraxis'),
132
+ opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
133
+ )])
 
 
 
 
134
  fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
135
+ fig.update_layout(template="plotly_white", margin=dict(l=10,r=10,t=30,b=10),
136
+ legend=dict(itemsizing='trace', title=color_col if color_col else None),
137
+ scene=dict(xaxis_title=x_col, yaxis_title=y_col, zaxis_title=z_col))
 
 
 
138
  return fig
139
 
140
+ # -----------------------------
141
+ # App state & defaults
142
+ # -----------------------------
143
+ def detect_columns(df):
144
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
145
+ all_cols = df.columns.tolist()
146
+ # try common embedding columns
147
+ x_default = next((c for c in ["x","X","dim1","pc1","tsne1","umap1"] if c in df.columns), (numeric_cols[0] if numeric_cols else None))
148
+ y_default = next((c for c in ["y","Y","dim2","pc2","tsne2","umap2"] if c in df.columns and c != x_default), (numeric_cols[1] if len(numeric_cols)>1 else None))
149
+ z_default = next((c for c in ["z","Z","dim3","pc3","tsne3","umap3"] if c in df.columns and c not in {x_default,y_default}), (numeric_cols[2] if len(numeric_cols)>2 else None))
150
+ # color default
151
+ cat_candidates = [c for c in df.columns if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category')) and c not in {x_default, y_default, z_default}]
152
+ color_default = next((c for c in ["cluster","label","group","diagnosis","category","pred_model"] if c in df.columns), (cat_candidates[0] if cat_candidates else (numeric_cols[0] if numeric_cols else None)))
153
+ return numeric_cols, all_cols, x_default, y_default, z_default, color_default
154
+
155
+ # -----------------------------
156
+ # Training
157
+ # -----------------------------
158
  def _plot_confusion_matrix(y_true, y_pred, title="Confusion matrix"):
159
  labels = sorted(pd.Series(y_true).unique())
160
  cm = confusion_matrix(y_true, y_pred, labels=labels)
161
+ fig = go.Figure(data=go.Heatmap(z=cm, x=labels, y=labels, text=cm, texttemplate="%{text}",
162
+ hovertemplate="Pred=%{x}<br>True=%{y}<br>Count=%{z}<extra></extra>"))
163
+ fig.update_layout(title=title, xaxis_title="Predicted", yaxis_title="True", template="plotly_white", margin=dict(l=10,r=10,t=40,b=10))
 
 
 
 
 
164
  return fig
165
 
166
  def _fmt_metrics(metrics: dict) -> str:
167
  lines = []
168
  for k,v in metrics.items():
169
+ lines.append(f"{k}: {v:.4f}" if isinstance(v,float) else f"{k}: {v}")
 
 
 
170
  return "\n".join(lines)
171
 
 
 
 
172
  def train_live(df, task, feature_cols, label_col, k_clusters, seed):
173
+ if df is None or len(df)==0:
174
  raise gr.Error("Geen data om te trainen.")
175
  if not feature_cols:
176
  raise gr.Error("Selecteer minimaal één featurekolom.")
 
177
  X = df[feature_cols].select_dtypes(include=[np.number])
178
+ if X.shape[1]==0:
179
  raise gr.Error("De gekozen features bevatten geen numerieke kolommen.")
180
 
181
  log_lines = []
 
183
  eval_fig = None
184
 
185
  if task == "Clustering (KMeans)":
186
+ pipe = Pipeline([("scaler", StandardScaler()), ("kmeans", KMeans(n_clusters=k_clusters, random_state=seed, n_init="auto"))])
 
 
 
187
  pipe.fit(X)
188
  labels = pipe["kmeans"].labels_
189
+ df["cluster_model"] = pd.Categorical([chr(ord('A') + (i % 26)) for i in labels])
 
190
  color_col_suggestion = "cluster_model"
191
  dump(pipe, MODEL_PATH)
192
+ log_lines += [f"✅ KMeans getraind met k={k_clusters} op {X.shape[0]} rijen.", f"Model opgeslagen: {MODEL_PATH.resolve()}"]
 
193
 
194
  elif task == "Classificatie (Logistic Regression)":
195
  if not label_col:
196
  raise gr.Error("Kies een labelkolom voor classificatie.")
197
  y = df[label_col].astype(str)
198
+ Xtr, Xva, ytr, yva = train_test_split(X, y, test_size=0.2, random_state=seed, stratify=y if y.nunique()>1 else None)
199
+ pipe = Pipeline([("scaler", StandardScaler()), ("logreg", LogisticRegression(max_iter=1000, random_state=seed, class_weight="balanced"))])
 
 
 
 
 
 
 
 
200
  pipe.fit(Xtr, ytr)
 
 
201
  yhat = pipe.predict(Xva)
202
+ metrics = {"accuracy": accuracy_score(yva, yhat), "f1_weighted": f1_score(yva, yhat, average="weighted")}
203
+ if y.nunique()==2:
 
 
 
204
  try:
205
+ proba = pipe.predict_proba(Xva)[:,1]
206
  uniq = list(pd.Series(y).unique())
207
  mapping = {uniq[0]:0, uniq[1]:1}
208
  metrics["roc_auc"] = roc_auc_score(yva.map(mapping), proba)
209
  except Exception:
210
  pass
 
211
  eval_fig = _plot_confusion_matrix(yva, yhat, title="Confusion matrix (validatie)")
 
 
212
  pipe.fit(X, y)
213
  preds = pipe.predict(X)
214
  df["pred_model"] = pd.Categorical(preds)
215
+ if y.nunique()==2:
216
+ try: df["pred_proba"] = pipe.predict_proba(X)[:,1]
217
+ except Exception: pass
 
 
 
218
  color_col_suggestion = "pred_model"
219
  dump(pipe, MODEL_PATH)
220
+ log_lines += ["✅ LogisticRegression getraind (split 80/20).", f"Model opgeslagen: {MODEL_PATH.resolve()}", "Metrics (validatie):\n"+_fmt_metrics(metrics)]
 
 
 
221
  else:
222
  raise gr.Error("Onbekende taak.")
223
 
 
235
  # -----------------------------
236
  # Gradio callbacks
237
  # -----------------------------
238
+ def parse_file(file_obj):
239
+ if file_obj is None:
240
+ return BASE_DF.copy(), ("Bundled dataset geladen." if DATA_PATH.exists() else "Demo dataset geladen.")
241
+ name = getattr(file_obj, "name", str(file_obj))
242
+ path = name
243
+ if name.lower().endswith(".csv"):
244
+ df = pd.read_csv(path)
245
+ elif name.lower().endswith(".tsv"):
246
+ df = pd.read_csv(path, sep="\t")
247
+ elif name.lower().endswith(".parquet"):
248
+ df = pd.read_parquet(path)
249
+ else:
250
+ df = pd.read_csv(path)
251
+ return df, f"Bestand geladen: {Path(name).name} — {df.shape[0]} rijen, {df.shape[1]} kolommen."
252
+
253
  def init_from_file(file_obj):
254
+ if file_obj is None:
255
+ df = BASE_DF.copy()
256
+ status = "Bundled dataset geladen." if DATA_PATH.exists() else "Demo dataset geladen."
257
+ else:
258
+ df, status = parse_file(file_obj)
259
+
260
  numeric_cols, all_cols, x_d, y_d, z_d, color_d = detect_columns(df)
261
+ hover_default = [c for c in ["patient_id","age","sex","diagnosis","cluster"] if c in df.columns]
262
+ feat_default = default_features(df)
263
+ label_default = pick_default_label(df)
264
  return (
265
  gr.update(choices=all_cols, value=x_d),
266
  gr.update(choices=all_cols, value=y_d),
267
  gr.update(choices=all_cols, value=z_d),
268
+ gr.update(choices=all_cols, value="pred_model" if "pred_model" in df.columns else color_d),
269
  gr.update(choices=all_cols, value=hover_default),
270
  status,
271
  df,
272
+ gr.update(choices=df.select_dtypes(include=[np.number]).columns.tolist(), value=feat_default),
273
+ gr.update(choices=all_cols, value=label_default),
274
  )
275
 
276
  def update_plot(df, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
277
+ if df is None or (isinstance(df,(list,tuple)) and len(df)==0):
278
+ df = BASE_DF.copy()
279
  mode_3d = (mode_dim == "3D")
280
+ return make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, size, opacity)
 
281
 
282
  def on_train_click(df, task, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
283
  df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), task, feature_cols, label_col, k_clusters, seed)
 
286
  return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
287
 
288
  def startup_auto_train(df, task_default, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
289
+ # Forceer supervised classificatie bij start
290
  try:
291
+ chosen_label = label_col or pick_default_label(df)
292
+ chosen_feats = feature_cols or default_features(df)
293
+ df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), "Classificatie (Logistic Regression)", chosen_feats, chosen_label, k_clusters, seed)
294
  new_color = color_suggestion if color_suggestion else color_col
295
  fig = update_plot(df2, x_col, y_col, z_col, new_color, hover_cols, mode_dim, size, opacity)
296
  return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
 
306
  with gr.Row():
307
  with gr.Column(scale=1):
308
  data_file = gr.File(label="Upload CSV/TSV/Parquet", file_count="single", type="filepath")
309
+ status_box = gr.Markdown("Bundled dataset wordt standaard geladen en getraind bij start.")
310
 
311
  with gr.Accordion("Assen & kleur", open=True):
312
+ base_numeric = BASE_DF.select_dtypes(include=[np.number]).columns.tolist()
313
+ x_default = base_numeric[0] if len(base_numeric)>0 else None
314
+ y_default = base_numeric[1] if len(base_numeric)>1 else None
315
+ z_default = base_numeric[2] if len(base_numeric)>2 else None
316
+ x_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=x_default, label="X kolom")
317
+ y_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=y_default, label="Y kolom")
318
+ z_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=z_default, label="Z kolom (voor 3D)")
319
+ color_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value="pred_model" if "pred_model" in BASE_DF.columns else None, label="Kleur op kolom")
320
+ hover_ms = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=[c for c in ["patient_id","age","sex","diagnosis","cluster"] if c in BASE_DF.columns], multiselect=True, label="Hover info kolommen")
 
321
 
322
  with gr.Accordion("Weergave", open=True):
323
+ mode_dim = gr.Radio(["2D","3D"], value="2D", label="Dimensie")
324
  size_slider = gr.Slider(3, 18, value=8, step=1, label="Puntgrootte")
325
  opacity_slider = gr.Slider(0.1, 1.0, value=0.8, step=0.05, label="Transparantie (opacity)")
326
 
327
  with gr.Accordion("Training (live)", open=True):
328
+ task_radio = gr.Radio(["Clustering (KMeans)","Classificatie (Logistic Regression)"], value="Classificatie (Logistic Regression)", label="Taak")
329
+ feat_ms = gr.Dropdown(choices=BASE_DF.select_dtypes(include=[np.number]).columns.tolist(), value=default_features(BASE_DF), multiselect=True, label="Feature kolommen (numeriek)")
330
+ label_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=pick_default_label(BASE_DF), label="Label kolom (alleen voor classificatie)")
 
 
 
 
 
 
 
331
  k_slider = gr.Slider(2, 12, value=3, step=1, label="K (clusters) — KMeans")
332
  seed_slider = gr.Slider(0, 10_000, value=7, step=1, label="Random seed")
333
  train_btn = gr.Button("🚀 Train (live)")
334
  train_log = gr.Textbox(label="Train log", lines=6, interactive=False)
335
 
336
+ hidden_df = gr.State(BASE_DF.copy())
337
 
338
  with gr.Column(scale=2):
 
339
  plot = gr.Plot(label="Scatterplot")
340
  with gr.Accordion("Evaluatie (validatie)", open=False):
341
  cm_plot = gr.Plot(label="Confusion Matrix (validatie)")
342
 
343
  # ===== Events =====
344
+ data_file.change(fn=init_from_file, inputs=[data_file],
345
+ outputs=[x_dd, y_dd, z_dd, color_dd, hover_ms, status_box, hidden_df, feat_ms, label_dd],
346
+ show_progress=False)
 
 
 
347
 
348
  for comp in [x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider]:
349
+ comp.change(fn=update_plot,
350
+ inputs=[hidden_df, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
351
+ outputs=plot, show_progress=False)
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ train_btn.click(fn=on_train_click,
354
+ inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
355
+ outputs=[hidden_df, train_log, color_dd, plot, cm_plot], show_progress=True)
 
 
 
356
 
357
+ # Initial plot
358
+ demo.load(fn=update_plot, inputs=[hidden_df, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider], outputs=plot, show_progress=False)
359
+
360
+ # Auto-train bij start (forceer supervised classificatie met default label)
361
+ demo.load(fn=startup_auto_train,
362
+ inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
363
+ outputs=[hidden_df, train_log, color_dd, plot, cm_plot], show_progress=True)
364
 
365
  if __name__ == "__main__":
366
  demo.launch()
ggz_depressie_synth_1000_modeling.csv ADDED
The diff for this file is too large to render. See raw diff