Marcel0123 commited on
Commit
9eb3654
·
verified ·
1 Parent(s): 81acd21

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +0 -16
  2. app.py +1 -391
  3. requirements.txt +1 -8
README.md CHANGED
@@ -1,17 +1 @@
1
- ---
2
- title: "Synthetische depressiedata – Supervised ML demo"
3
- emoji: "🧠"
4
- colorFrom: "blue"
5
- colorTo: "purple"
6
- sdk: gradio
7
- sdk_version: "4.0.0"
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- # Supervised ML demo – synthetische depressiedata
13
-
14
- Volledig synthetische data. Niet voor klinisch gebruik.
15
-
16
- ## Gebruik
17
- Upload `app.py`, `requirements.txt` (en desgewenst `runtime.txt`) naar een nieuwe **Gradio** Space.
 
 
 
 
 
 
 
 
 
 
 
1
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,391 +1 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import numpy as np
4
- from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
5
- from sklearn.preprocessing import OneHotEncoder, StandardScaler
6
- from sklearn.compose import ColumnTransformer
7
- from sklearn.pipeline import Pipeline
8
- from sklearn.linear_model import LogisticRegression
9
- from sklearn.ensemble import RandomForestClassifier
10
- from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, RocCurveDisplay, precision_recall_curve, average_precision_score
11
- from sklearn.inspection import permutation_importance
12
- from sklearn.calibration import CalibratedClassifierCV
13
- import matplotlib.pyplot as plt
14
- import io
15
- import joblib
16
-
17
- # Optionele afhankelijkheid
18
- try:
19
- import shap
20
- SHAP_AVAILABLE = True
21
- except Exception:
22
- SHAP_AVAILABLE = False
23
-
24
- # -----------------------------
25
- # 1) Synthetische datageneratie
26
- # -----------------------------
27
-
28
- def generate_synthetic_dataset(n_samples=1000, seed=42):
29
- rng = np.random.default_rng(seed)
30
-
31
- age = rng.integers(18, 81, size=n_samples)
32
- sex = rng.choice(["man", "vrouw"], size=n_samples, p=[0.48, 0.52])
33
- bmi = np.clip(rng.normal(26, 5, size=n_samples), 16, 45)
34
- sleep_hours = np.clip(rng.normal(7, 1.5, size=n_samples), 3, 12)
35
- activity_min = np.clip(rng.normal(30, 25, size=n_samples), 0, 180)
36
- phq9 = np.clip(np.round(rng.normal(9, 6, size=n_samples)), 0, 27)
37
- gad7 = np.clip(np.round(rng.normal(7, 5, size=n_samples)), 0, 21)
38
- prior_depr = rng.integers(0, 2, size=n_samples)
39
- family_hist = rng.integers(0, 2, size=n_samples)
40
- chronic_ill = rng.integers(0, 2, size=n_samples)
41
- substance_use = rng.integers(0, 2, size=n_samples)
42
- stressful_events = np.clip(rng.poisson(1.2, size=n_samples), 0, 6)
43
- social_support = rng.integers(1, 6, size=n_samples)
44
- employment = rng.choice(["werkend", "student", "werkloos", "ziekverlof"], size=n_samples, p=[0.56, 0.16, 0.18, 0.10])
45
-
46
- z = (
47
- 0.35 * (phq9 / 27) +
48
- 0.12 * (gad7 / 21) +
49
- 0.18 * (1 - (sleep_hours - 3) / 9) +
50
- 0.10 * (1 - np.sqrt(np.maximum(activity_min,1e-6) / 180)) +
51
- 0.10 * (stressful_events / 6) +
52
- 0.08 * (1 - (social_support - 1) / 4) +
53
- 0.10 * prior_depr +
54
- 0.05 * family_hist +
55
- 0.03 * chronic_ill +
56
- 0.02 * (bmi - 25) / 20 +
57
- 0.03 * substance_use
58
- )
59
- z = z + rng.normal(0, 0.05, size=n_samples)
60
- p = 1 / (1 + np.exp(-(z * 4 - 2)))
61
- label = (rng.random(n_samples) < p).astype(int)
62
-
63
- df = pd.DataFrame({
64
- "age": age,
65
- "sex": sex,
66
- "bmi": np.round(bmi, 1),
67
- "sleep_hours": np.round(sleep_hours, 1),
68
- "activity_minutes": np.round(activity_min, 0).astype(int),
69
- "phq9": phq9.astype(int),
70
- "gad7": gad7.astype(int),
71
- "prior_depression": prior_depr,
72
- "family_history": family_hist,
73
- "chronic_illness": chronic_ill,
74
- "substance_use": substance_use,
75
- "stressful_events": stressful_events,
76
- "social_support": social_support,
77
- "employment_status": employment,
78
- "current_depression": label
79
- })
80
- return df
81
-
82
- # -----------------------------
83
- # 2) Pipeline helpers
84
- # -----------------------------
85
-
86
- def make_preprocessor():
87
- numeric_cols = [
88
- "age","bmi","sleep_hours","activity_minutes","phq9","gad7",
89
- "prior_depression","family_history","chronic_illness","substance_use",
90
- "stressful_events","social_support"
91
- ]
92
- cat_cols = ["sex", "employment_status"]
93
- pre = ColumnTransformer([
94
- ("num", StandardScaler(), numeric_cols),
95
- ("cat", OneHotEncoder(handle_unknown="ignore"), cat_cols)
96
- ])
97
- return pre, numeric_cols, cat_cols
98
-
99
- def build_pipeline(model_type="Logistic Regression", seed=42, calibration=None):
100
- pre, *_ = make_preprocessor()
101
- if model_type == "Random Forest":
102
- base_model = RandomForestClassifier(n_estimators=300, random_state=seed)
103
- else:
104
- base_model = LogisticRegression(max_iter=300)
105
-
106
- if calibration in ("Platt (sigmoid)", "Isotonic"):
107
- method = "sigmoid" if calibration.startswith("Platt") else "isotonic"
108
- model = CalibratedClassifierCV(base_model, cv=3, method=method)
109
- else:
110
- model = base_model
111
-
112
- return Pipeline([("prep", pre), ("clf", model)])
113
-
114
- def train_model(df, model_type="Logistic Regression", test_size=0.2, seed=42, threshold=0.5, calibration=None):
115
- y = df["current_depression"]
116
- X = df.drop(columns=["current_depression"])
117
-
118
- pipe = build_pipeline(model_type, seed, calibration=calibration)
119
- X_train, X_test, y_train, y_test = train_test_split(
120
- X, y, test_size=test_size, random_state=seed, stratify=y
121
- )
122
- pipe.fit(X_train, y_train)
123
-
124
- y_proba = pipe.predict_proba(X_test)[:, 1]
125
- y_pred = (y_proba >= threshold).astype(int)
126
-
127
- acc = float(accuracy_score(y_test, y_pred))
128
- auc = float(roc_auc_score(y_test, y_proba))
129
- ap = float(average_precision_score(y_test, y_proba))
130
- cm = confusion_matrix(y_test, y_pred)
131
-
132
- # ROC
133
- fig, ax = plt.subplots()
134
- RocCurveDisplay.from_predictions(y_test, y_proba, ax=ax)
135
- ax.set_title("ROC-curve")
136
- buf = io.BytesIO(); fig.savefig(buf, format="png", bbox_inches="tight"); plt.close(fig)
137
- roc_png = buf.getvalue()
138
-
139
- # PR
140
- precision, recall, _ = precision_recall_curve(y_test, y_proba)
141
- fig3, ax3 = plt.subplots()
142
- ax3.plot(recall, precision)
143
- ax3.set_xlabel("Recall"); ax3.set_ylabel("Precision"); ax3.set_title("Precision–Recall curve")
144
- buf3 = io.BytesIO(); fig3.savefig(buf3, format="png", bbox_inches="tight"); plt.close(fig3)
145
- pr_png = buf3.getvalue()
146
-
147
- # Confusion matrix
148
- fig2, ax2 = plt.subplots()
149
- _ = ax2.imshow(cm, interpolation="nearest")
150
- ax2.set_title(f"Confusion matrix (thr={threshold:.2f})")
151
- ax2.set_xlabel("Voorspeld"); ax2.set_ylabel("Werkelijk")
152
- for (i, j), v in np.ndenumerate(cm):
153
- ax2.text(j, i, str(v), ha="center", va="center")
154
- buf2 = io.BytesIO(); fig2.savefig(buf2, format="png", bbox_inches="tight"); plt.close(fig2)
155
- cm_png = buf2.getvalue()
156
-
157
- # Permutation importance
158
- try:
159
- r = permutation_importance(pipe, X_test, y_test, n_repeats=10, random_state=seed)
160
- importances = r.importances_mean
161
- feat_names = pipe.named_steps["prep"].get_feature_names_out()
162
- imp_df = pd.DataFrame({"feature": feat_names, "importance": importances}).sort_values("importance", ascending=False).head(20)
163
- figi, axi = plt.subplots(figsize=(6,4))
164
- axi.barh(imp_df["feature"][::-1], imp_df["importance"][::-1])
165
- axi.set_title("Permutation importance (top 20)")
166
- figbuf = io.BytesIO(); figi.savefig(figbuf, format="png", bbox_inches="tight"); plt.close(figi)
167
- imp_png = figbuf.getvalue()
168
- except Exception:
169
- imp_png = None
170
-
171
- shap_png = None
172
- if SHAP_AVAILABLE:
173
- try:
174
- sample_idx = np.random.choice(len(X_test), size=min(200, len(X_test)), replace=False)
175
- X_sample = X_test.iloc[sample_idx]
176
- f = lambda data: pipe.predict_proba(pd.DataFrame(data, columns=X_test.columns))[:,1]
177
- explainer = shap.KernelExplainer(f, shap.sample(X_train, 50, random_state=seed))
178
- shap_values = explainer.shap_values(X_sample, nsamples=100)
179
- figshap = plt.figure()
180
- shap.summary_plot(shap_values, X_sample, show=False)
181
- bufshap = io.BytesIO(); figshap.savefig(bufshap, format="png", bbox_inches="tight"); plt.close(figshap)
182
- shap_png = bufshap.getvalue()
183
- except Exception:
184
- shap_png = None
185
-
186
- metrics = {"accuracy": round(acc,3), "roc_auc": round(auc,3), "avg_precision": round(ap,3)}
187
- return pipe, metrics, cm, roc_png, cm_png, pr_png, imp_png, shap_png
188
-
189
- def cross_validate(df, model_type="Logistic Regression", seed=42, k=5, calibration=None):
190
- y = df["current_depression"]
191
- X = df.drop(columns=["current_depression"])
192
- pipe = build_pipeline(model_type, seed, calibration=calibration)
193
- cv = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
194
- aucs = cross_val_score(pipe, X, y, scoring="roc_auc", cv=cv)
195
- accs = cross_val_score(pipe, X, y, scoring="accuracy", cv=cv)
196
- return {"cv_auc_mean": float(np.mean(aucs)), "cv_auc_std": float(np.std(aucs)),
197
- "cv_acc_mean": float(np.mean(accs)), "cv_acc_std": float(np.std(accs))}
198
-
199
- # -----------------------------
200
- # 3) Gradio UI
201
- # -----------------------------
202
-
203
- def build_app():
204
- with gr.Blocks(title="Synthetische depressiedata – Supervised ML demo") as demo:
205
- gr.Markdown(
206
- "# Supervised ML demo (synthetische depressiedata)\n"
207
- "**Let op:** Deze app gebruikt *volledig synthetische* data en is alleen voor onderwijs/demonstratie. "
208
- "Niet gebruiken voor klinische beslissingen."
209
- )
210
-
211
- state_df = gr.State()
212
- model_state = gr.State()
213
-
214
- # Tab 0: Exploratie
215
- with gr.Tab("0) Data Exploratie"):
216
- gr.Markdown("Genereer eerst een dataset of laad de standaard. Bekijk distributies en correlaties.")
217
- init_btn = gr.Button("(Re)genereer standaarddataset")
218
- stats_json = gr.JSON(label="Samenvatting")
219
- dist_img = gr.Image(label="Histogrammen (kernvariabelen)")
220
- corr_img = gr.Image(label="Correlatie heatmap (numeriek)")
221
-
222
- def init_and_explore():
223
- df = generate_synthetic_dataset(1000, 42)
224
- desc = df.describe().to_dict()
225
-
226
- fig, ax = plt.subplots(figsize=(8,6))
227
- cols = ["phq9","gad7","sleep_hours","activity_minutes","stressful_events","social_support"]
228
- for c in cols:
229
- df[c].plot(kind="hist", alpha=0.5)
230
- ax.set_title("Distributies kernvariabelen")
231
- buf = io.BytesIO(); fig.savefig(buf, format="png", bbox_inches="tight"); plt.close(fig)
232
- hist_png = buf.getvalue()
233
-
234
- num = df.select_dtypes(include=[np.number]).corr()
235
- fig2, ax2 = plt.subplots(figsize=(6,5))
236
- _ = ax2.imshow(num, aspect='auto')
237
- ax2.set_title("Correlatie (Pearson)")
238
- ax2.set_xticks(range(len(num.columns))); ax2.set_xticklabels(num.columns, rotation=90)
239
- ax2.set_yticks(range(len(num.index))); ax2.set_yticklabels(num.index)
240
- buf2 = io.BytesIO(); fig2.savefig(buf2, format="png", bbox_inches="tight"); plt.close(fig2)
241
- corr_png = buf2.getvalue()
242
- return df, desc, hist_png, corr_png
243
-
244
- init_btn.click(init_and_explore, inputs=None, outputs=[state_df, stats_json, dist_img, corr_img])
245
-
246
- # Tab 1: Data
247
- with gr.Tab("1) Data"):
248
- n = gr.Slider(200, 5000, value=1000, step=50, label="Aantal voorbeelden")
249
- seed = gr.Slider(0, 9999, value=42, step=1, label="Random seed")
250
- gen_btn = gr.Button("Genereer dataset")
251
- df_out = gr.Dataframe(interactive=False, wrap=True, height=300)
252
- csv = gr.File(label="Download CSV", interactive=False)
253
-
254
- def on_generate(n, seed):
255
- df = generate_synthetic_dataset(int(n), int(seed))
256
- path = "synthetic_depression.csv"; df.to_csv(path, index=False)
257
- return df, df.head(50), path
258
-
259
- gen_btn.click(on_generate, [n, seed], [state_df, df_out, csv])
260
-
261
- # Tab 2: Train & Evaluate
262
- with gr.Tab("2) Train & Evaluate"):
263
- model_type = gr.Radio(["Logistic Regression", "Random Forest"], value="Logistic Regression", label="Model")
264
- calibration = gr.Radio(["Geen", "Platt (sigmoid)", "Isotonic"], value="Geen", label="Calibratie")
265
- test_size = gr.Slider(0.1, 0.5, value=0.2, step=0.05, label="Test set fractie")
266
- threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Beslisdrempel (positief bij p ≥ drempel)")
267
- seed2 = gr.Slider(0, 9999, value=42, step=1, label="Random seed")
268
-
269
- train_btn = gr.Button("Train model")
270
- metrics = gr.JSON(label="Metrics (accuracy, ROC AUC, AP)")
271
- roc_img = gr.Image(label="ROC-curve")
272
- pr_img = gr.Image(label="PR-curve")
273
- cm_img = gr.Image(label="Confusion matrix")
274
- imp_img = gr.Image(label="Permutation importance (top 20)")
275
- shap_img = gr.Image(label="SHAP summary (optioneel)")
276
-
277
- def on_train(model_type, test_size, seed, threshold, calibration, df):
278
- if df is None or len(df)==0:
279
- df = generate_synthetic_dataset(1000, 42)
280
- model, metrics_out, cm, roc_png, cm_png, pr_png, imp_png, shap_png = train_model(
281
- df, model_type=model_type, test_size=float(test_size), seed=int(seed), threshold=float(threshold),
282
- calibration=None if calibration=="Geen" else calibration
283
- )
284
- return model, metrics_out, roc_png, pr_png, cm_png, imp_png, shap_png
285
-
286
- train_btn.click(on_train, [model_type, test_size, seed2, threshold, calibration, state_df],
287
- [model_state, metrics, roc_img, pr_img, cm_img, imp_img, shap_img])
288
-
289
- cv_btn = gr.Button("Cross-validation (k=5) – ROC AUC & Accuracy")
290
- cv_json = gr.JSON(label="CV-resultaten")
291
-
292
- def on_cv(model_type, calibration, df):
293
- if df is None or len(df)==0:
294
- df = generate_synthetic_dataset(1000, 42)
295
- return cross_validate(df, model_type=model_type, calibration=None if calibration=="Geen" else calibration)
296
-
297
- cv_btn.click(on_cv, [model_type, calibration, state_df], [cv_json])
298
-
299
- with gr.Row():
300
- save_btn = gr.Button("Sla model op (.joblib)")
301
- model_file = gr.File(label="Gedownloade model-file", interactive=False)
302
- load_file = gr.File(label="Laad model (.joblib)")
303
- load_btn = gr.Button("Laad model in app")
304
-
305
- def on_save(model):
306
- if model is None:
307
- return None
308
- path = "trained_pipeline.joblib"
309
- joblib.dump(model, path)
310
- return path
311
-
312
- save_btn.click(on_save, [model_state], [model_file])
313
-
314
- def on_load(file_obj):
315
- if file_obj is None:
316
- return None
317
- model = joblib.load(file_obj.name)
318
- return model
319
-
320
- load_btn.click(on_load, [load_file], [model_state])
321
-
322
- # Tab 3: Voorspellen
323
- with gr.Tab("3) Voorspellen (speels)"):
324
- gr.Markdown("Kies kenmerken om een kans op *actuele depressie* te laten berekenen (didactisch, niet klinisch).")
325
- with gr.Row():
326
- age = gr.Slider(18, 80, value=35, step=1, label="Leeftijd")
327
- sex = gr.Radio(["man", "vrouw"], value="vrouw", label="Geslacht")
328
- bmi = gr.Slider(16.0, 45.0, value=25.0, step=0.1, label="BMI")
329
- with gr.Row():
330
- sleep_hours = gr.Slider(3.0, 12.0, value=7.0, step=0.1, label="Slaap (uren/dag)")
331
- activity_minutes = gr.Slider(0, 180, value=30, step=5, label="Lichaamsbeweging (min/dag)")
332
- employment = gr.Radio(["werkend", "student", "werkloos", "ziekverlof"], value="werkend", label="Werkstatus")
333
- with gr.Row():
334
- phq9 = gr.Slider(0, 27, value=10, step=1, label="PHQ-9")
335
- gad7 = gr.Slider(0, 21, value=7, step=1, label="GAD-7")
336
- social_support = gr.Slider(1, 5, value=3, step=1, label="Sociale steun (1-5)")
337
- with gr.Row():
338
- prior_depr = gr.Checkbox(False, label="Eerder depressieve episode")
339
- family_history = gr.Checkbox(False, label="Familiaire voorgeschiedenis")
340
- chronic_ill = gr.Checkbox(False, label="Chronische somatische aandoening")
341
- substance_use = gr.Checkbox(False, label="Middelengebruik (actueel)")
342
- stressful_events = gr.Slider(0, 6, value=1, step=1, label="Belastende levensgebeurtenissen (0-6)")
343
-
344
- pred_btn = gr.Button("Bereken kans")
345
- pred_json = gr.JSON(label="Voorspelling")
346
-
347
- def predict_fn(age, sex, bmi, sleep_hours, activity_minutes, employment, phq9, gad7, social_support, prior_depr, family_history, chronic_ill, substance_use, stressful_events, model):
348
- if model is None:
349
- df = generate_synthetic_dataset(1000, 42)
350
- model, *_ = train_model(df)
351
- input_df = pd.DataFrame([{
352
- "age": age,
353
- "sex": sex,
354
- "bmi": bmi,
355
- "sleep_hours": sleep_hours,
356
- "activity_minutes": activity_minutes,
357
- "phq9": phq9,
358
- "gad7": gad7,
359
- "prior_depression": int(prior_depr),
360
- "family_history": int(family_history),
361
- "chronic_illness": int(chronic_ill),
362
- "substance_use": int(substance_use),
363
- "stressful_events": stressful_events,
364
- "social_support": social_support,
365
- "employment_status": employment
366
- }])
367
- try:
368
- prob = float(model.predict_proba(input_df)[0,1])
369
- except Exception:
370
- prob = float(model.predict(input_df)[0])
371
- return {"probability_current_depression": round(prob, 3)}
372
-
373
- pred_inputs = [age, sex, bmi, sleep_hours, activity_minutes, employment, phq9, gad7, social_support,
374
- prior_depr, family_history, chronic_ill, substance_use, stressful_events, model_state]
375
- pred_btn.click(predict_fn, pred_inputs, [pred_json])
376
-
377
- gr.Markdown(
378
- "---\n"
379
- "### Ethische noot\n"
380
- "- Data zijn **geheel synthetisch** en bevatten geen persoonsgegevens.\n"
381
- "- Model is **niet** gevalideerd voor klinisch gebruik.\n"
382
- "- Gebruik dit uitsluitend voor onderwijs/demonstratie."
383
- )
384
-
385
- return demo
386
-
387
- # Heel belangrijk voor Hugging Face Spaces: maak een **globale** `demo` variabele.
388
- demo = build_app()
389
-
390
- if __name__ == "__main__":
391
- demo.launch()
 
1
+ # Hugging Face Space — Live Supervised Training Visualizer (Student WOW Edition)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1 @@
1
- gradio>=4.0.0
2
- pandas
3
- numpy
4
- scikit-learn
5
- matplotlib
6
- shap
7
- scipy
8
- joblib
 
1
+