Marcel0123 commited on
Commit
40f6ba2
·
verified ·
1 Parent(s): a561ed9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +527 -406
app.py CHANGED
@@ -1,10 +1,15 @@
1
- # app.py — GGZ Agressie (synthetisch) — stabiele, uitlegbare versie
2
- # Fixes:
3
- # - Geen isotone overkalibratie; standaard 'sigmoid' (of pure LR)
4
- # - KeywordBoost-kenmerken (geweld-lexicon) voor robuuste signalen
5
- # - FeatureUnion werkt voor TF-IDF én BERT/XLM-R (dense→sparse adapter)
6
- # - Top-woorden uitlegbaarheid blijft bruikbaar met FeatureUnion-slice
7
- # - UI/outputs exact 22; auto-train faalpad matched
 
 
 
 
 
8
 
9
  import os
10
  import typing as _t
@@ -25,21 +30,17 @@ from sklearn.metrics import (
25
  )
26
  from sklearn.feature_extraction.text import TfidfVectorizer
27
  from sklearn.linear_model import LogisticRegression
28
- from sklearn.pipeline import Pipeline, FeatureUnion
29
  from sklearn.decomposition import TruncatedSVD
30
  from sklearn.manifold import TSNE
31
  from sklearn.base import BaseEstimator, TransformerMixin
32
- from sklearn.calibration import calibration_curve, CalibratedClassifierCV
33
- from sklearn.metrics import brier_score_loss
34
 
35
- from scipy import sparse
36
- import re
37
-
38
- # --- MLflow + LIME ---
39
  import mlflow, mlflow.sklearn
40
  from lime.lime_text import LimeTextExplainer
41
 
42
- # --- Optionele DL-deps voor BERT ---
43
  try:
44
  import torch
45
  from transformers import AutoTokenizer, AutoModel
@@ -51,73 +52,80 @@ except Exception:
51
  # ============ Config & Intro ============
52
  DEFAULT_CSV = "synthetische_ggz_agressie_dataset_1000.csv"
53
 
 
54
  INFO_IMAGE = str(Path(__file__).resolve().parent / "imglk;l;kl.png")
55
- if not os.path.exists(INFO_IMAGE):
56
- INFO_IMAGE = None
57
 
 
58
  SLOGAN = "Studieobject Marcel Ooms: Veiligere zorg begint hier: het 30-dagenrisico op agressie onderbouwd en uitlegbaar."
59
 
 
60
  INTRO = """
61
  **Van verslag naar risico: kans op agressie in de komende 30 dagen**
62
- Plak een stukje verslag in het tekstvak en je krijgt een kans terug plus een labeladvies.
63
- Je kunt hertrainen met TF-IDF (woord/char) of BERT/XLM-R, en je ziet scheiding (AUROC/AUPRC),
64
- kalibratie (Brier/reliability) en uitleg (LIME).
 
 
 
 
 
 
 
 
65
  """
66
 
 
67
  WHAT_YOU_SEE = """
68
- **Wat zie je?**
69
- **Status & prestaties** — AUROC/AUPRC; extra: kalibratie (Brier), gains, lift, KS.
70
- **Trainen** kies featurizer en vergelijk.
71
- **Visualisatie** 2D/3D-projecties (label/kans).
72
- **Evaluatie** drempel schuiven; confusion matrix met uitleg.
73
- **Predict** — rapportage + (optionele) context.
74
- **Hertrain** CSV met `rapportage`, optioneel `context`, en `agressie_volgende30d` (0/1).
 
 
 
 
 
 
75
  """
76
 
 
77
  ML_STORY = """
78
  **Van ruwe data naar beslisinformatie**
79
- Tekst kenmerken (TF-IDF of BERT) kans op 30-dagenrisico. Kalibratie (Brier/reliability)
80
- maakt kans→actie betrouwbaar; Gains/Lift/KS helpen bij triage (top x%). LIME toont bijdragen.
81
  """
82
 
83
  FOOTER = """
84
- **Technisch**
85
- Modellen: TF-IDF/char TF-IDF/XLM-R/DutchBERT/ClinicalBERT → Logistic Regression (+ opt. CalibratedClassifierCV(method='sigmoid'), class_weight='balanced')
86
- Visualisatie: SVD(50) → t-SNE(2D/3D)
87
- CSV: `rapportage` (str), optioneel `context` (str), `agressie_volgende30d` (0/1)
88
  """
89
 
 
90
  mlflow.set_experiment("ggz-agressie")
91
 
92
- # ============ Context helpers ============
93
- CTX_START = "[CTX]"
94
- TARGET_START = "[TARGET]"
95
- SEP = "[SEP]"
96
-
97
- def concat_with_context(context: str, current: str) -> str:
98
- context = (context or "").strip()
99
- current = (current or "").strip()
100
- if context:
101
- return f"{CTX_START} {context} {SEP} {TARGET_START} {current}"
102
- return f"{TARGET_START} {current}"
103
-
104
  # ============ Data loading ============
105
  def _resolve_csv_path(uploaded=None):
106
  if uploaded is not None:
107
  return uploaded.name if hasattr(uploaded, "name") else uploaded
108
- for p in [
109
  os.path.join(os.getcwd(), DEFAULT_CSV),
110
  os.path.join(os.path.dirname(__file__), DEFAULT_CSV),
111
  DEFAULT_CSV,
112
- ]:
 
113
  if os.path.exists(p):
114
  return p
115
  repo_id = os.environ.get("SPACE_ID")
116
  if repo_id:
117
  return hf_hub_download(repo_id=repo_id, filename=DEFAULT_CSV)
118
  raise FileNotFoundError(
119
- f"Kon {DEFAULT_CSV} niet vinden. Zet het in de repo-root of upload een CSV met "
120
- "`rapportage` en `agressie_volgende30d` (en optioneel `context`)."
121
  )
122
 
123
  def load_dataset(file_obj=None):
@@ -127,54 +135,22 @@ def load_dataset(file_obj=None):
127
  missing = required - set(df.columns)
128
  if missing:
129
  raise ValueError(f"CSV mist verplichte kolommen: {missing}")
130
- if "context" not in df.columns:
131
- df["context"] = ""
132
  df = df.dropna(subset=["rapportage", "agressie_volgende30d"]).copy()
133
  df["agressie_volgende30d"] = (df["agressie_volgende30d"].astype(int) > 0).astype(int)
134
- df["rapportage_ctx"] = df.apply(lambda r: concat_with_context(r.get("context",""), r["rapportage"]), axis=1)
135
  return df
136
 
137
- # ============ Extra features ============
138
- class KeywordBoost(BaseEstimator, TransformerMixin):
139
- """Kleine lexicon-feature: vangt 'geweld'-signalen. Geeft 2 kolommen: count, binary."""
140
- def __init__(self, lexicon=None):
141
- self.lexicon = lexicon or [
142
- r"\bgewelddadig(e|heid)?\b",
143
- r"\bgeweld\b",
144
- r"\bextreem\s+gewelddadig\b",
145
- r"\b(ontzettend|heel|zeer)\s+boos\b",
146
- r"\bwoedend\b",
147
- r"\bbedreig\w*\b", r"\bbedreigend\b",
148
- r"\b(sla(an|at|gen|g)|slaan)\b", r"\bschop(pen|t|te)?\b",
149
- r"\baanval(len|lig)\b",
150
- r"\bagressie(f|viteit)?\b", r"\bagressief\b", # vaak spelfout met 1 g/s
151
- ]
152
- self._pat = re.compile("|".join(self.lexicon), flags=re.IGNORECASE)
153
-
154
- def fit(self, X, y=None): return self
155
- def transform(self, X):
156
- texts = pd.Series(X).astype(str)
157
- counts = texts.str.count(self._pat).fillna(0).to_numpy().reshape(-1,1)
158
- binary = (counts > 0).astype(int)
159
- # output als sparse voor compatibiliteit met TF-IDF
160
- return sparse.csr_matrix(np.hstack([counts, binary]).astype("float32"))
161
-
162
- class DenseAdapter(BaseEstimator, TransformerMixin):
163
- """Wrapt een dense transformer (bv. HFTextEmbedder) en zet uitkomst om naar CSR-sparse."""
164
- def __init__(self, base):
165
- self.base = base
166
- def fit(self, X, y=None):
167
- self.base.fit(X, y)
168
- return self
169
- def transform(self, X):
170
- arr = self.base.transform(X)
171
- if sparse.issparse(arr):
172
- return arr
173
- return sparse.csr_matrix(arr)
174
-
175
  # ============ HF Text Embedder ============
176
  class HFTextEmbedder(BaseEstimator, TransformerMixin):
177
- def __init__(self, model_name="emilyalsentzer/Bio_ClinicalBERT", max_length=128, batch_size=16, device=None):
 
 
 
 
 
 
 
 
 
178
  self.model_name = model_name
179
  self.max_length = max_length
180
  self.batch_size = batch_size
@@ -182,20 +158,21 @@ class HFTextEmbedder(BaseEstimator, TransformerMixin):
182
  self._tokenizer = None
183
  self._model = None
184
  self._dev = None
 
185
  def _ensure_backend(self):
186
  if torch is None or AutoTokenizer is None or AutoModel is None:
187
  raise RuntimeError("BERT-embeddings vereisen 'torch' en 'transformers'.")
188
  self._dev = self.device or ("cuda" if torch.cuda.is_available() else "cpu")
189
  if self._tokenizer is None:
190
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
191
- self._tokenizer.add_special_tokens({"additional_special_tokens":[CTX_START, TARGET_START, SEP]})
192
  if self._model is None:
193
- self._model = AutoModel.from_pretrained(self.model_name)
194
- self._model.resize_token_embeddings(len(self._tokenizer))
195
- self._model.to(self._dev)
196
  self._model.eval()
 
197
  def fit(self, X, y=None):
198
- self._ensure_backend(); return self
 
 
199
  @torch.no_grad()
200
  def transform(self, X):
201
  self._ensure_backend()
@@ -205,70 +182,50 @@ class HFTextEmbedder(BaseEstimator, TransformerMixin):
205
  embs = []
206
  for i in range(0, len(texts), self.batch_size):
207
  batch = texts[i:i+self.batch_size]
208
- toks = self._tokenizer(batch, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt").to(self._dev)
209
- outs = self._model(**toks).last_hidden_state
210
- mask = toks.attention_mask.unsqueeze(-1)
211
- pooled = (outs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
 
 
 
 
 
212
  embs.append(pooled.cpu().numpy())
213
  return np.vstack(embs)
214
 
215
  # ============ Explainability helpers ============
216
- _lime_explainer = LimeTextExplainer(class_names=["geen agressie", "agressie"])
217
-
218
- def lime_explain_text(pipe, text, num_features=8):
219
- def predict_proba_text(texts):
220
- p1 = pipe.predict_proba(texts)[:, 1]
221
- p0 = 1 - p1
222
- return np.vstack([p0, p1]).T
223
- exp = _lime_explainer.explain_instance(text, predict_proba_text, num_features=num_features)
224
- return exp.as_html()
225
-
226
  def _clf_and_vectorizer_from_pipe(pipe):
227
  vec = pipe.named_steps.get("txt")
228
  clf = pipe.named_steps.get("clf")
229
  return vec, clf
230
 
231
- def _get_lr_from_calibrator(clf):
232
- # CalibratedClassifierCV(method='sigmoid') expose't estimator
233
- return getattr(clf, "estimator", getattr(clf, "base_estimator", clf))
234
-
235
  def tfidf_global_top_words(pipe, k=20):
236
- """
237
- Haal top pro/anti woorden uit TF-IDF deel, ook als er een FeatureUnion is met KeywordBoost.
238
- """
239
  vec, clf = _clf_and_vectorizer_from_pipe(pipe)
240
- # Zoek TF-IDF transformer en aantal features
241
- tfidf = None
242
- n_tfidf = None
243
- if isinstance(vec, FeatureUnion):
244
- # zoek subtransformer 'tfidf'
245
- for name, tr in vec.transformer_list:
246
- if name == "tfidf":
247
- tfidf = tr
248
- break
249
- if tfidf is None or not hasattr(tfidf, "get_feature_names_out"):
250
- return [], []
251
- feature_names = np.array(tfidf.get_feature_names_out())
252
- n_tfidf = len(feature_names)
253
- else:
254
- if not hasattr(vec, "get_feature_names_out"):
255
- return [], []
256
- tfidf = vec
257
- feature_names = np.array(tfidf.get_feature_names_out())
258
- n_tfidf = len(feature_names)
259
-
260
- lr = _get_lr_from_calibrator(clf)
261
- if not hasattr(lr, "coef_"):
262
  return [], []
263
- coefs = lr.coef_[0]
264
- # Bij FeatureUnion komen TF-IDF kolommen eerst (we voegen KeywordBoost erna toe)
265
- tfidf_coefs = coefs[:n_tfidf]
266
- top_pos_idx = np.argsort(tfidf_coefs)[-k:][::-1]
267
- top_neg_idx = np.argsort(tfidf_coefs)[:k]
268
  return list(feature_names[top_pos_idx]), list(feature_names[top_neg_idx])
269
 
270
- # ============ Metrics helpers & plots ============
 
 
 
 
 
 
 
 
 
271
  def _format_confusion_df(cm: np.ndarray) -> pd.DataFrame:
 
 
 
 
272
  if cm.shape != (2, 2):
273
  return pd.DataFrame(cm, index=["True 0", "True 1"], columns=["Pred 0", "Pred 1"])
274
  tn, fp, fn, tp = cm.ravel()
@@ -288,47 +245,77 @@ def _build_report_markdown(rep: dict, thr: float) -> str:
288
  weighted = rep.get("weighted avg", {})
289
  s0 = int(rep.get("0", {}).get("support", 0))
290
  s1 = int(rep.get("1", {}).get("support", 0))
291
- return f"""
292
  ### ℹ️ Uitleg bij het classification report (drempel = {thr:.2f})
293
- 0 = geen agressie, 1 = agressie. Drempel bepaalt label 1 (≥) of 0 (<).
294
- Support(0/1): {s0}/{s1} Accuracy: {acc:.3f} — Macro F1: {macro.get('f1-score', 0):.3f} Weighted F1: {weighted.get('f1-score', 0):.3f}.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  """
 
296
 
297
- def best_threshold_f1(y_true, y_score):
298
- thr = np.linspace(0, 1, 1001)
299
- best_t, best_f1 = 0.5, -1
300
- for t in thr:
301
- y_pred = (y_score >= t).astype(int)
302
- f1 = f1_score(y_true, y_pred, zero_division=0)
303
- if f1 > best_f1:
304
- best_f1, best_t = f1, float(t)
305
- return best_t, best_f1
306
-
307
  def make_confusion_heatmap(y_true, y_score, thr=0.5):
308
  y_pred = (y_score >= thr).astype(int)
309
  cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
310
  z = cm.astype(int)
311
- fig = go.Figure(data=go.Heatmap(z=z, x=["Pred 0", "Pred 1"], y=["True 0", "True 1"], colorscale="Blues", showscale=True))
 
 
 
 
 
 
 
 
 
312
  tn, fp, fn, tp = z.ravel()
313
- for (r,c,text) in [(0,0,f"TN: {tn}"), (0,1,f"FP: {fp}"), (1,0,f"FN: {fn}"), (1,1,f"TP: {tp}")]:
314
- fig.add_annotation(x=["Pred 0","Pred 1"][c], y=["True 0","True 1"][r], text=text, showarrow=False)
315
- fig.update_layout(title=f"Confusion matrix (drempel = {thr:.2f})", xaxis_title="Voorspelling", yaxis_title="Werkelijkheid",
316
- template="simple_white", margin=dict(l=10,r=10,t=40,b=10))
 
 
 
 
 
 
 
 
 
 
 
 
317
  return fig
318
 
 
319
  def make_roc_fig(y_true, y_score, auroc=None):
320
  fpr, tpr, _ = roc_curve(y_true, y_score)
321
  title = f"ROC-curve (AUROC={auroc:.3f})" if auroc is not None else "ROC-curve"
322
  fig = px.area(x=fpr, y=tpr, title=title, labels={"x":"False Positive Rate", "y":"True Positive Rate"})
323
  fig.add_shape(type="line", x0=0, x1=1, y0=0, y1=1, line=dict(dash="dash"))
324
- fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white")
325
  return fig
326
 
327
  def make_pr_fig(y_true, y_score, auprc=None):
328
  prec, rec, _ = precision_recall_curve(y_true, y_score)
329
  title = f"Precision–Recall (AUPRC={auprc:.3f})" if auprc is not None else "Precision–Recall"
330
  fig = px.area(x=rec, y=prec, title=title, labels={"x":"Recall", "y":"Precision"})
331
- fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white")
332
  return fig
333
 
334
  def make_prob_hist(y_true, y_score):
@@ -337,7 +324,7 @@ def make_prob_hist(y_true, y_score):
337
  title="Verdeling voorspelde kansen per werkelijke klasse",
338
  labels={"kans":"Voorspelde kans"})
339
  fig.update_traces(opacity=0.6)
340
- fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white")
341
  return fig
342
 
343
  def make_threshold_metrics_fig(y_true, y_score, thr_line=0.5):
@@ -345,105 +332,129 @@ def make_threshold_metrics_fig(y_true, y_score, thr_line=0.5):
345
  rows = []
346
  for t in thresholds:
347
  y_pred = (y_score >= t).astype(int)
348
- rows.append({"threshold": t,
349
- "precision": precision_score(y_true, y_pred, zero_division=0),
350
- "recall": recall_score(y_true, y_pred, zero_division=0),
351
- "f1": f1_score(y_true, y_pred, zero_division=0)})
 
 
352
  df = pd.DataFrame(rows)
353
  df_m = df.melt(id_vars="threshold", value_vars=["precision","recall","f1"], var_name="metric", value_name="score")
354
  fig = px.line(df_m, x="threshold", y="score", color="metric",
355
  title="Metrics vs. drempel (precision/recall/F1)",
356
  labels={"threshold":"Drempel", "score":"Score"})
357
  fig.add_vline(x=float(thr_line), line_dash="dash", annotation_text=f"drempel={thr_line:.2f}", annotation_position="top")
358
- fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white", yaxis=dict(range=[0,1]))
359
  return fig
360
 
361
- # -------- Nieuwe diagnostische plots --------
362
- def make_calibration_fig(y_true, y_score):
363
- prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=10, strategy="quantile")
364
- brier = brier_score_loss(y_true, y_score)
365
  fig = go.Figure()
366
- fig.add_trace(go.Scatter(x=prob_pred, y=prob_true, mode="lines+markers", name="gekalibreerde kansen"))
367
- fig.add_trace(go.Scatter(x=[0,1], y=[0,1], mode="lines", name="perfect", line=dict(dash="dash")))
368
- fig.update_layout(title=f"Kalibratie (reliability) — Brier={brier:.3f}",
369
- xaxis_title="Voorspelde kans", yaxis_title="Werkelijke kans",
370
- template="simple_white", margin=dict(l=10,r=10,t=40,b=10))
371
- return fig, float(brier)
372
-
373
- def make_cumulative_gains_fig(y_true, y_score):
 
 
 
 
374
  df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p", ascending=False).reset_index(drop=True)
375
  df["cum_pos"] = df["y"].cumsum()
376
  total_pos = df["y"].sum()
377
- n = len(df)
378
- x = (np.arange(1, n+1) / n)
379
- y = (df["cum_pos"] / max(total_pos, 1e-9))
 
 
 
 
380
  fig = go.Figure()
381
- fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name="model"))
382
- fig.add_trace(go.Scatter(x=x, y=x, mode="lines", name="random", line=dict(dash="dash")))
383
- fig.update_layout(title="Cumulative Gains",
384
- xaxis_title="Fractie van populatie (gesorteerd op kans)",
385
- yaxis_title="Fractie gevangen positieven",
386
- template="simple_white", margin=dict(l=10,r=10,t=40,b=10))
 
 
 
 
387
  return fig
388
 
389
  def make_lift_fig(y_true, y_score):
390
- df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p", ascending=False).reset_index(drop=True)
391
- total_pos = df["y"].sum()
392
- base_rate = total_pos / max(len(df), 1e-9)
393
- x, lift = [], []
394
- for k in range(1, len(df)+1):
395
- frac = k / len(df)
396
- captured = df.loc[:k-1, "y"].sum() / max(k, 1e-9)
397
- lift.append(captured / max(base_rate, 1e-9))
398
- x.append(frac)
399
  fig = go.Figure()
400
- fig.add_trace(go.Scatter(x=x, y=lift, mode="lines", name="lift"))
401
- fig.add_hline(y=1.0, line_dash="dash")
402
- fig.update_layout(title="Lift-curve",
403
- xaxis_title="Fractie van populatie (gesorteerd op kans)",
404
- yaxis_title="Lift (t.o.v. random=1)",
405
- template="simple_white", margin=dict(l=10,r=10,t=40,b=10))
 
 
 
406
  return fig
407
 
408
  def make_ks_fig(y_true, y_score):
409
- df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p")
410
- pos = df[df["y"]==1]["p"].values
411
- neg = df[df["y"]==0]["p"].values
412
- grid = np.linspace(0, 1, 201)
413
- cdf_pos = np.searchsorted(np.sort(pos), grid, side="right") / max(len(pos), 1e-9)
414
- cdf_neg = np.searchsorted(np.sort(neg), grid, side="right") / max(len(neg), 1e-9)
415
- diff = np.abs(cdf_pos - cdf_neg)
416
- ks_idx = int(np.argmax(diff))
417
- ks_x, ks_y = float(grid[ks_idx]), float(diff[ks_idx])
 
418
  fig = go.Figure()
419
- fig.add_trace(go.Scatter(x=grid, y=cdf_pos, mode="lines", name="CDF positief"))
420
- fig.add_trace(go.Scatter(x=grid, y=cdf_neg, mode="lines", name="CDF negatief"))
421
- fig.add_vline(x=ks_x, line_dash="dot", annotation_text=f"KS={ks_y:.3f} @ {ks_x:.2f}")
422
- fig.update_layout(title="KS-curve (Kolmogorov–Smirnov)",
423
- xaxis_title="Score", yaxis_title="Cumulatieve fractie",
424
- template="simple_white", margin=dict(l=10,r=10,t=40,b=10))
425
- return fig, ks_y, ks_x
 
 
 
 
 
 
 
426
 
427
- def metrics_table(y_true, y_score, thr):
428
- y_pred = (y_score >= thr).astype(int)
429
- rep = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
430
- rep_df = pd.DataFrame(rep).T
431
- rep_df_disp = rep_df.copy()
432
- for col in ["precision", "recall", "f1-score"]:
433
- if col in rep_df_disp:
434
- rep_df_disp[col] = (rep_df_disp[col] * 100).round(1).map(lambda v: f"{v:.1f}%" if pd.notnull(v) else "")
435
- if "support" in rep_df_disp:
436
- rep_df_disp["support"] = rep_df_disp["support"].map(lambda v: f"{int(v)}" if pd.notnull(v) else "")
437
- if "accuracy" in rep:
438
- acc_pct = f"{rep['accuracy'] * 100:.1f}%"
439
- rep_df_disp["accuracy_%"] = ""
440
- if "accuracy" in rep_df_disp.index:
441
- rep_df_disp.loc["accuracy", "accuracy_%"] = acc_pct
442
- rep_df_disp = rep_df_disp.fillna("")
443
- cm = confusion_matrix(y_true, y_pred)
444
- cm_df = _format_confusion_df(cm)
445
- rep_md = _build_report_markdown(rep, thr)
446
- return rep_df_disp, cm_df, rep_md
 
 
 
 
 
 
 
447
 
448
  # ============ Model & Viz ============
449
  def build_and_train(
@@ -456,104 +467,56 @@ def build_and_train(
456
  bert_maxlen=128,
457
  bert_batch=16
458
  ):
459
- X = df["rapportage_ctx"].astype(str).values
460
  y = df["agressie_volgende30d"].values
461
  X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
462
  X, y, np.arange(len(X)),
463
  test_size=test_size, random_state=random_state, stratify=y
464
  )
465
 
466
- kb = KeywordBoost()
467
-
468
- def make_lr(sigmoid=True):
469
- base = LogisticRegression(max_iter=3000, class_weight="balanced")
470
- if sigmoid:
471
- # stabielere kalibratie dan isotonic op kleine sets
472
- return CalibratedClassifierCV(estimator=base, method="sigmoid", cv=3)
473
- return base
474
-
475
  if featurizer == "TF-IDF":
476
- tfidf = TfidfVectorizer(max_features=max_features, ngram_range=(1, ngram_max))
477
- feats = FeatureUnion([("tfidf", tfidf), ("kb", kb)])
478
- clf = make_lr(sigmoid=True)
479
- pipe = Pipeline([("txt", feats), ("clf", clf)])
480
- pipe.fit(X_train, y_train)
481
- y_score = pipe.predict_proba(X_test)[:, 1]
482
- # alleen TF-IDF deel voor SVD/TSNE (eerste blok kolommen)
483
- n_tfidf = len(tfidf.get_feature_names_out())
484
- txt_all = feats.transform(X)
485
- X_tfidf_only = txt_all[:, :n_tfidf]
486
-
487
- elif featurizer == "TF-IDF (char 3–5)":
488
- tfidf = TfidfVectorizer(analyzer="char_wb", ngram_range=(3,5), min_df=2, max_features=max_features)
489
- feats = FeatureUnion([("tfidf", tfidf), ("kb", kb)])
490
- clf = make_lr(sigmoid=True)
491
- pipe = Pipeline([("txt", feats), ("clf", clf)])
492
  pipe.fit(X_train, y_train)
493
  y_score = pipe.predict_proba(X_test)[:, 1]
494
- n_tfidf = len(tfidf.get_feature_names_out())
495
- txt_all = feats.transform(X)
496
- X_tfidf_only = txt_all[:, :n_tfidf]
497
-
498
  elif featurizer == "ClinicalBERT":
499
  emb = HFTextEmbedder(model_name="emilyalsentzer/Bio_ClinicalBERT",
500
  max_length=bert_maxlen, batch_size=bert_batch)
501
- emb_sparse = DenseAdapter(emb)
502
- feats = FeatureUnion([("bert", emb_sparse), ("kb", kb)])
503
- clf = make_lr(sigmoid=True)
504
- pipe = Pipeline([("txt", feats), ("clf", clf)])
505
  pipe.fit(X_train, y_train)
506
  y_score = pipe.predict_proba(X_test)[:, 1]
507
- # voor visualisatie gebruiken we de BERT-embeddings (eerste blok kolommen)
508
- emb_sample = emb.transform(["x"])
509
- n_emb = emb_sample.shape[1]
510
- X_all = feats.transform(X)
511
- X_tfidf_only = X_all[:, :n_emb] # hier: "embeddings-only" voor SVD
512
-
513
  elif featurizer == "DutchBERT":
514
  emb = HFTextEmbedder(model_name="wietsedv/bert-base-dutch-cased",
515
  max_length=bert_maxlen, batch_size=bert_batch)
516
- emb_sparse = DenseAdapter(emb)
517
- feats = FeatureUnion([("bert", emb_sparse), ("kb", kb)])
518
- clf = make_lr(sigmoid=True)
519
- pipe = Pipeline([("txt", feats), ("clf", clf)])
520
  pipe.fit(X_train, y_train)
521
  y_score = pipe.predict_proba(X_test)[:, 1]
522
- emb_sample = emb.transform(["x"])
523
- n_emb = emb_sample.shape[1]
524
- X_all = feats.transform(X)
525
- X_tfidf_only = X_all[:, :n_emb]
526
-
527
- elif featurizer == "XLM-RoBERTa":
528
- emb = HFTextEmbedder(model_name="xlm-roberta-base",
529
- max_length=bert_maxlen, batch_size=bert_batch)
530
- emb_sparse = DenseAdapter(emb)
531
- feats = FeatureUnion([("bert", emb_sparse), ("kb", kb)])
532
- clf = make_lr(sigmoid=True)
533
- pipe = Pipeline([("txt", feats), ("clf", clf)])
534
- pipe.fit(X_train, y_train)
535
- y_score = pipe.predict_proba(X_test)[:, 1]
536
- emb_sample = emb.transform(["x"])
537
- n_emb = emb_sample.shape[1]
538
- X_all = feats.transform(X)
539
- X_tfidf_only = X_all[:, :n_emb]
540
-
541
  else:
542
- raise ValueError("Onbekende featurizer. Kies 'TF-IDF', 'TF-IDF (char 3–5)', 'ClinicalBERT', 'DutchBERT' of 'XLM-RoBERTa'.")
543
 
544
  auroc = float(roc_auc_score(y_test, y_score))
545
  auprc = float(average_precision_score(y_test, y_score))
546
 
547
- # 2D/3D embedding: SVD (50) -> t-SNE (2D/3D) op gekozen basisfeatures
548
  svd = TruncatedSVD(n_components=50, random_state=random_state)
549
- X50 = svd.fit_transform(X_tfidf_only)
550
 
551
- tsne2 = TSNE(n_components=2, random_state=random_state, perplexity=30, learning_rate="auto", init="pca")
 
 
552
  X2 = tsne2.fit_transform(X50)
553
  x2 = (X2[:, 0] - np.min(X2[:, 0])) / (np.ptp(X2[:, 0]) + 1e-9)
554
  y2 = (X2[:, 1] - np.min(X2[:, 1])) / (np.ptp(X2[:, 1]) + 1e-9)
555
 
556
- tsne3 = TSNE(n_components=3, random_state=random_state, perplexity=30, learning_rate="auto", init="pca")
 
 
557
  X3 = tsne3.fit_transform(X50)
558
  x3 = (X3[:, 0] - np.min(X3[:, 0])) / (np.ptp(X3[:, 0]) + 1e-9)
559
  y3 = (X3[:, 1] - np.min(X3[:, 1])) / (np.ptp(X3[:, 1]) + 1e-9)
@@ -565,7 +528,7 @@ def build_and_train(
565
  "x3": x3, "y3": y3, "z3": z3,
566
  "label": df["agressie_volgende30d"].values,
567
  "kans": proba_all,
568
- "rapportage": df["rapportage"].astype(str).str.slice(0, 180) + "..."
569
  })
570
  for col in ["PHQ9_baseline","GAD7_baseline","stress_niveau_1_5","slaap_uren","sociale_steun_0_10","zorgsetting"]:
571
  if col in df.columns:
@@ -575,55 +538,180 @@ def build_and_train(
575
  test_mask[idx_test] = True
576
  plot_df["split"] = np.where(test_mask, "test", "train")
577
 
578
- return pipe, (X_test, y_test, y_score), plot_df, auroc, auprc, df
579
 
580
  def make_scatter(plot_df, color_mode="label", dim="2D"):
 
 
 
 
 
581
  hover_cols = ["rapportage", "kans", "split"]
582
  if color_mode == "label":
583
  color = plot_df["label"].map({0: "geen agressie", 1: "agressie"})
584
  title_2d = "2D projectie (t-SNE) — kleur = werkelijk label"
585
  title_3d = "3D projectie (t-SNE) — kleur = werkelijk label"
586
  if dim == "2D":
587
- fig = px.scatter(plot_df, x="x", y="y", color=color, hover_data=hover_cols, title=title_2d, opacity=0.85)
 
 
 
588
  else:
589
- fig = px.scatter_3d(plot_df, x="x3", y="y3", z="z3", color=color, hover_data=hover_cols, title=title_3d, opacity=0.9)
590
- else:
 
 
 
591
  title_2d = "2D projectie (t-SNE) — kleur = voorspelde kans"
592
  title_3d = "3D projectie (t-SNE) — kleur = voorspelde kans"
593
  if dim == "2D":
594
- fig = px.scatter(plot_df, x="x", y="y", color="kans", hover_data=hover_cols, title=title_2d, color_continuous_scale="Turbo", opacity=0.9)
 
 
 
 
595
  else:
596
- fig = px.scatter_3d(plot_df, x="x3", y="y3", z="z3", color="kans", hover_data=hover_cols, title=title_3d, color_continuous_scale="Turbo", opacity=0.9)
 
 
 
 
 
597
  if dim == "2D":
598
  fig.update_traces(marker=dict(size=8, line=dict(width=0)))
599
- fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white", xaxis_title="x (t-SNE)", yaxis_title="y (t-SNE)")
 
 
 
 
 
600
  else:
601
  fig.update_traces(marker=dict(size=4))
602
- fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white",
603
- scene=dict(xaxis_title="x (t-SNE)", yaxis_title="y (t-SNE)", zaxis_title="z (t-SNE)"))
 
 
 
 
 
 
 
604
  return fig
605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
  # ============ State & Train ============
607
  GLOBAL = {
608
  "pipe": None, "plot_df": None, "eval": None,
609
  "auroc": None, "auprc": None,
610
  "featurizer": "TF-IDF",
611
- "df": None,
612
- "thr_suggested": 0.5
613
  }
614
 
615
  def do_train(file_obj=None, test_size=0.2, seed=42,
616
  featurizer="TF-IDF", max_features=4000, ngram_max=2,
617
  bert_maxlen=128, bert_batch=16):
618
  df = load_dataset(file_obj)
619
- pipe, eval_pack, plot_df, auroc, auprc, full_df = build_and_train(
620
  df, test_size, seed, featurizer, max_features, ngram_max, bert_maxlen, bert_batch
621
  )
622
 
 
623
  with mlflow.start_run(run_name=f"{featurizer}"):
624
  mlflow.log_param("featurizer", featurizer)
625
  mlflow.log_param("test_size", test_size)
626
- if featurizer.startswith("TF-IDF"):
627
  mlflow.log_param("tfidf_max_features", max_features)
628
  mlflow.log_param("tfidf_ngram_max", ngram_max)
629
  else:
@@ -634,86 +722,103 @@ def do_train(file_obj=None, test_size=0.2, seed=42,
634
  mlflow.sklearn.log_model(pipe, artifact_path="model")
635
 
636
  GLOBAL.update(pipe=pipe, plot_df=plot_df, eval=eval_pack,
637
- auroc=auroc, auprc=auprc, featurizer=featurizer, df=full_df)
638
 
 
639
  rep_df, cm_df, rep_md = metrics_table(eval_pack[1], eval_pack[2], thr=0.5)
640
 
 
641
  roc_fig = make_roc_fig(eval_pack[1], eval_pack[2], auroc)
642
  pr_fig = make_pr_fig(eval_pack[1], eval_pack[2], auprc)
643
  hist_fig = make_prob_hist(eval_pack[1], eval_pack[2])
644
  thr_fig = make_threshold_metrics_fig(eval_pack[1], eval_pack[2], thr_line=0.5)
645
 
 
646
  fig_label = make_scatter(plot_df, color_mode="label", dim="2D")
647
  fig_prob = make_scatter(plot_df, color_mode="kans", dim="2D")
648
 
649
- cm_plot = make_confusion_heatmap(eval_pack[1], eval_pack[2], thr=0.5)
650
-
651
- preview_df = full_df.head(10)[["rapportage","context","agressie_volgende30d"]]
652
-
653
- t_star, _ = best_threshold_f1(eval_pack[1], eval_pack[2])
654
- GLOBAL["thr_suggested"] = t_star
 
655
 
656
- cal_fig, brier = make_calibration_fig(eval_pack[1], eval_pack[2])
657
- gains_fig = make_cumulative_gains_fig(eval_pack[1], eval_pack[2])
658
- lift_fig = make_lift_fig(eval_pack[1], eval_pack[2])
659
- ks_fig, ks_val, ks_at = make_ks_fig(eval_pack[1], eval_pack[2])
660
- cls_bar = make_class_balance_bar(full_df)
661
- len_hist = make_text_length_hist(full_df)
662
 
663
- thr_md = f"**Aanbevolen drempel (F1):** `{t_star:.2f}` · **Brier:** `{brier:.3f}` · **KS:** `{ks_val:.3f}` (score ≈ `{ks_at:.2f}`)"
664
- status_msg = f"✅ Model getraind met {featurizer}. AUROC: {auroc:.3f} | AUPRC: {auprc:.3f} | Suggested thr (F1): {t_star:.2f}"
665
 
 
666
  return (
667
  status_msg, auroc, auprc,
668
- preview_df,
669
  fig_label, fig_prob,
670
  rep_df, cm_df, cm_plot, rep_md,
671
  roc_fig, pr_fig, hist_fig, thr_fig,
672
- cal_fig, brier, gains_fig, lift_fig, ks_fig, thr_md, cls_bar, len_hist
673
  )
674
 
675
- def predict_one(text, ctx=None):
676
  if GLOBAL["pipe"] is None:
677
  return "Nog geen model getraind.", None
678
  if not text or text.strip() == "":
679
  return "Voer een rapportage in.", None
680
- merged = concat_with_context(ctx or "", text)
681
- proba = float(GLOBAL["pipe"].predict_proba([merged])[:, 1][0])
682
- thr_use = float(GLOBAL.get("thr_suggested", 0.5))
683
- label = int(proba >= thr_use)
684
  md = (
685
  f"Kans op agressie (30d): {proba:.3f} — "
686
- f"voorspelde klasse: {label} (drempel {thr_use:.2f})\n"
687
  f"Featurizer: {GLOBAL.get('featurizer','?')}"
688
  )
689
  return md, proba
690
 
691
  # ============ UI ============
692
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as demo:
 
693
  gr.Markdown(f"# {SLOGAN}")
694
 
 
695
  gr.HTML("""
696
  <style>
 
697
  #train-btn, #retrain-btn, #predict-btn {
698
  background: linear-gradient(90deg, #ef4444 0%, #f97316 100%);
699
- color: white !important; font-weight: 700; border: none !important;
 
 
 
 
 
 
 
 
 
 
700
  }
701
- #train-btn:hover, #retrain-btn:hover, #predict-btn:hover { filter: brightness(0.95); }
702
- #data-preview { max-height: 320px; overflow: auto; }
703
- #data-preview table { width: 100%; }
 
704
  #viz-img { margin-top: 0 !important; padding-top: 0 !important; }
705
  #viz-img img { display: block; margin-top: 0 !important; }
706
  </style>
707
  """)
708
 
 
709
  with gr.Row():
710
- with gr.Column(scale=1): gr.Markdown(INTRO)
711
- with gr.Column(scale=1): gr.Markdown(WHAT_YOU_SEE)
 
 
712
 
 
713
  gr.Markdown("## 🛠️ Handmatig trainen (zonder CSV upload)")
714
  with gr.Row():
715
  featur_quick = gr.Radio(
716
- choices=["TF-IDF", "TF-IDF (char 3–5)", "ClinicalBERT", "DutchBERT", "XLM-RoBERTa"],
717
  value="TF-IDF",
718
  label="Kies featurizer"
719
  )
@@ -729,11 +834,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
729
  with gr.Row():
730
  auroc_box = gr.Number(label="AUROC", precision=3)
731
  auprc_box = gr.Number(label="AUPRC", precision=3)
732
- thr_badge = gr.Markdown()
733
 
 
734
  with gr.Row():
735
  with gr.Column(scale=3):
736
  gr.Markdown("### 🔍 Visualisatie")
 
737
  proj_dim = gr.Radio(choices=["2D", "3D"], value="2D", label="Projectiedimensie (geldt voor beide projecties)")
738
  with gr.Column():
739
  fig_out_label = gr.Plot(label="Projectie — kleur = werkelijk label")
@@ -742,7 +848,11 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
742
  gr.Markdown(ML_STORY)
743
  with gr.Column(scale=2):
744
  gr.Markdown("### 📄 Datavoorbeeld")
745
- data_preview_mode = gr.Radio(choices=["Eerste 10 rijen", "Gehele dataset (scrollbaar)"], value="Eerste 10 rijen", label="Weergave")
 
 
 
 
746
  data_preview = gr.Dataframe(label="Dataset", interactive=False, elem_id="data-preview")
747
 
748
  gr.Markdown("### ⚙️ Evaluatie (tabellen & drempel)")
@@ -752,6 +862,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
752
  cm_plot = gr.Plot(label="Confusion matrix (heatmap)")
753
  rep_md = gr.Markdown(label="Uitleg classification report")
754
 
 
755
  with gr.Row():
756
  with gr.Column(scale=3):
757
  with gr.Tabs():
@@ -763,11 +874,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
763
  roc_plot = gr.Plot(label="ROC-curve")
764
  with gr.TabItem("Precision–Recall"):
765
  pr_plot = gr.Plot(label="PR-curve")
766
- with gr.Tabs():
767
  with gr.TabItem("Kalibratie"):
768
- calib_plot = gr.Plot(label="Reliability diagram")
769
- brier_box = gr.Number(label="Brier-score", precision=3)
770
- gr.Markdown("Lagere Brier is beter; lijn dicht bij de diagonaal = goede kalibratie.")
771
  with gr.TabItem("Cumulative Gains"):
772
  gains_plot = gr.Plot(label="Cumulative Gains")
773
  with gr.TabItem("Lift"):
@@ -775,29 +884,30 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
775
  with gr.TabItem("KS-curve"):
776
  ks_plot = gr.Plot(label="KS-curve")
777
  with gr.TabItem("Dataset-profiel"):
778
- cls_bar_plot = gr.Plot(label="Klassenbalans")
779
- len_hist_plot = gr.Plot(label="Tekstlengteverdeling")
780
  with gr.Column(scale=2):
781
  gr.Markdown("### 🗣️ Predict (vrije tekst)")
782
  with gr.Row():
783
- txt = gr.Textbox(lines=12, label="Rapportage (NL)",
784
- placeholder="Bijv.: Patiënt is extreem geagiteerd, weigert medicatie...")
785
- with gr.Row():
786
- ctx_txt = gr.Textbox(lines=6, label="Vorige context (optioneel)",
787
- placeholder="Bijv.: eerdere observaties of citaten uit het dossier...")
788
  btn = gr.Button("Voorspel", elem_id="predict-btn")
789
  md_out = gr.Markdown()
790
  proba_out = gr.Number(label="Kans", precision=3)
791
 
 
792
  gr.Markdown("## 🔁 Hertrain met eigen CSV")
793
- gr.Markdown("Upload een CSV met `rapportage` (tekst), optioneel `context`, en `agressie_volgende30d` (0/1).")
794
- csv_in = gr.File(label="Upload CSV (kolommen: rapportage, [context], agressie_volgende30d)")
 
 
 
795
  with gr.Row():
796
  test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test set grootte")
797
  seed = gr.Slider(1, 999, value=42, step=1, label="Random seed")
798
  with gr.Row():
799
- featur = gr.Radio(choices=["TF-IDF", "TF-IDF (char 3–5)", "ClinicalBERT", "DutchBERT", "XLM-RoBERTa"],
800
- value="TF-IDF", label="Tekst-featurizer")
801
  with gr.Row(visible=True) as tfidf_row:
802
  max_features = gr.Slider(1000, 12000, value=4000, step=1000, label="TF-IDF max_features")
803
  ngram_max = gr.Radio(choices=[1, 2], value=2, label="n-gram max")
@@ -806,33 +916,35 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
806
  bert_batch = gr.Slider(4, 64, value=16, step=4, label="BERT batch_size")
807
  retrain_btn = gr.Button("Train opnieuw (met upload)", elem_id="retrain-btn")
808
 
 
809
  with gr.Row():
810
  with gr.Column(scale=2, min_width=0):
811
  gr.Markdown(
812
  "### ℹ️ Over de evaluatieplots\n\n"
813
- "- Metrics vs. drempel precision, recall, F1 over de drempel.\n"
814
- "- Kansverdelingvoorspelde kansen per werkelijke klasse.\n"
815
- "- ROC & PR scheidingskracht, nuttig bij onbalans.\n"
816
- "- Kalibratiebetrouwbaarheid van kansen (Brier lager is beter).\n"
817
- "- Gains/Lift/KSinzicht voor triage (top x% casussen)."
 
818
  )
819
 
820
- # Toggles
821
  def _toggle_quick(choice):
822
  return (
823
- gr.update(visible=(choice.startswith("TF-IDF"))),
824
- gr.update(visible=(choice in ("ClinicalBERT", "DutchBERT", "XLM-RoBERTa")))
825
  )
826
  featur_quick.change(_toggle_quick, inputs=featur_quick, outputs=[tfidf_quick_row, bert_quick_row])
827
 
828
  def _toggle_rows(choice):
829
  return (
830
- gr.update(visible=(choice.startswith("TF-IDF"))),
831
- gr.update(visible=(choice in ("ClinicalBERT", "DutchBERT", "XLM-RoBERTa")))
832
  )
833
  featur.change(_toggle_rows, inputs=featur, outputs=[tfidf_row, bert_row])
834
 
835
- # Interactie
836
  def _update_eval(t):
837
  if GLOBAL["eval"] is None:
838
  return None, None, None, None, None
@@ -841,19 +953,22 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
841
  thr_fig_new = make_threshold_metrics_fig(y_true, y_score, thr_line=float(t))
842
  cm_plot_new = make_confusion_heatmap(y_true, y_score, thr=float(t))
843
  return rep, cm, cm_plot_new, rep_md_text, thr_fig_new
 
844
  thr.release(_update_eval, inputs=thr, outputs=[rep_df, cm_df, cm_plot, rep_md, thr_plot])
845
 
 
846
  def _refresh_preview(mode):
847
  df = GLOBAL.get("df")
848
  if df is None or not isinstance(df, pd.DataFrame):
849
  return None
850
- cols = [c for c in ["rapportage","context","agressie_volgende30d"] if c in df.columns]
851
- show = df[cols]
852
- return show.head(10) if mode.startswith("Eerste") else show
853
  data_preview_mode.change(_refresh_preview, inputs=data_preview_mode, outputs=data_preview)
854
 
855
- btn.click(predict_one, inputs=[txt, ctx_txt], outputs=[md_out, proba_out])
856
 
 
857
  def _train_quick(featur, max_features_q, ngram_max_q, bert_maxlen_q, bert_batch_q):
858
  return do_train(None, 0.2, 42, featur, int(max_features_q), int(ngram_max_q),
859
  int(bert_maxlen_q), int(bert_batch_q))
@@ -864,9 +979,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
864
  fig_out_label, fig_out_prob,
865
  rep_df, cm_df, cm_plot, rep_md,
866
  roc_plot, pr_plot, hist_plot, thr_plot,
867
- calib_plot, brier_box, gains_plot, lift_plot, ks_plot, thr_badge, cls_bar_plot, len_hist_plot]
868
  )
869
 
 
870
  def _retrain(csv_in, test_size, seed, featur, max_features, ngram_max, bert_maxlen, bert_batch):
871
  return do_train(csv_in, test_size, int(seed), featur, int(max_features), int(ngram_max),
872
  int(bert_maxlen), int(bert_batch))
@@ -877,57 +993,62 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as
877
  fig_out_label, fig_out_prob,
878
  rep_df, cm_df, cm_plot, rep_md,
879
  roc_plot, pr_plot, hist_plot, thr_plot,
880
- calib_plot, brier_box, gains_plot, lift_plot, ks_plot, thr_badge, cls_bar_plot, len_hist_plot]
881
  )
882
 
 
883
  def _update_projection(dim):
884
  pdf = GLOBAL.get("plot_df")
885
  if pdf is None:
886
  return None, None
887
- return make_scatter(pdf, color_mode="label", dim=dim), make_scatter(pdf, color_mode="kans", dim=dim)
 
 
 
888
  proj_dim.change(_update_projection, inputs=proj_dim, outputs=[fig_out_label, fig_out_prob])
889
 
890
- # Auto-train: exact 22 outputs bij error
891
  def _auto_train():
892
  try:
893
  return do_train(None, 0.2, 42, "TF-IDF", 4000, 2, 128, 16)
894
  except Exception as e:
895
- return (
896
- f"❌ Fout bij laden/trainen: `{e}`",
897
- None, None, None, None, None,
898
- None, None, None, None,
899
- None, None, None, None,
900
- None, None, None, None, None, None, None, None
901
- )
902
 
903
  demo.load(_auto_train, inputs=None,
904
  outputs=[status, auroc_box, auprc_box, data_preview,
905
  fig_out_label, fig_out_prob,
906
  rep_df, cm_df, cm_plot, rep_md,
907
  roc_plot, pr_plot, hist_plot, thr_plot,
908
- calib_plot, brier_box, gains_plot, lift_plot, ks_plot, thr_badge, cls_bar_plot, len_hist_plot])
909
 
910
- # Explainability
911
  with gr.Accordion("🪄 Uitleg (Explainability)", open=False):
912
- gr.Markdown("LIME legt afzonderlijke voorspellingen uit. 'Top woorden' werken voor TF-IDF-varianten.")
913
  with gr.Row():
914
- txt_explain = gr.Textbox(lines=4, label="Tekst om uit te leggen", placeholder="Plak hier een rapportage voor uitleg")
 
915
  btn_explain = gr.Button("Genereer uitleg")
916
  lime_html = gr.HTML(label="LIME uitleg (per voorbeeld)")
 
 
917
  top_pos_df = gr.Dataframe(headers=["Top pro-agressie woorden"], row_count=5)
918
  top_neg_df = gr.Dataframe(headers=["Top anti-agressie woorden"], row_count=5)
919
 
920
- def _do_explain(text, ctx=""):
921
  if GLOBAL["pipe"] is None:
922
  return "Train eerst een model.", None, None
923
- merged = concat_with_context(ctx or "", text or "")
924
- html = lime_explain_text(GLOBAL["pipe"], merged, num_features=8)
925
  pos, neg = tfidf_global_top_words(GLOBAL["pipe"], k=15)
926
  pos = [[w] for w in pos] if pos else None
927
  neg = [[w] for w in neg] if neg else None
928
  return html, pos, neg
929
 
930
- btn_explain.click(_do_explain, inputs=[txt_explain, ctx_txt], outputs=[lime_html, top_pos_df, top_neg_df])
931
 
932
  gr.Markdown(FOOTER)
933
 
 
1
+ # app.py — GGZ Agressie (synthetisch) — One-page UI
2
+ # - Auto-train bij openen met TF-IDF
3
+ # - Handmatig trainen zonder CSV upload: kies TF-IDF / ClinicalBERT / DutchBERT
4
+ # - (Optioneel) Hertrain met eigen CSV (nu altijd zichtbaar)
5
+ # - MLflow experiment tracking + LIME explainability tab
6
+ # - Confusion matrix met betekenislabels + Markdown-uitleg bij classification report
7
+ # - Extra: Confusion-matrix heatmap-plot onder de tabel
8
+ # - Evaluatieplots links (met datavoorbeeld erboven); Predict rechts
9
+ # - Visualisatie: 2D/3D-projecties (label & kans) + afbeelding direct onder kans-plot
10
+ # - Classification report met eenheden (% en aantallen)
11
+ # - Datavoorbeeld: eerste 10 rijen of hele dataset (scrollbaar via CSS)
12
+ # - Extra tabs: Kalibratie, Cumulative Gains, Lift, KS-curve, Dataset-profiel
13
 
14
  import os
15
  import typing as _t
 
30
  )
31
  from sklearn.feature_extraction.text import TfidfVectorizer
32
  from sklearn.linear_model import LogisticRegression
33
+ from sklearn.pipeline import Pipeline
34
  from sklearn.decomposition import TruncatedSVD
35
  from sklearn.manifold import TSNE
36
  from sklearn.base import BaseEstimator, TransformerMixin
37
+ from sklearn.calibration import calibration_curve
 
38
 
39
+ # --- NEW: experiment tracking + explainability ---
 
 
 
40
  import mlflow, mlflow.sklearn
41
  from lime.lime_text import LimeTextExplainer
42
 
43
+ # --- Optional DL deps (voor BERT) ---
44
  try:
45
  import torch
46
  from transformers import AutoTokenizer, AutoModel
 
52
  # ============ Config & Intro ============
53
  DEFAULT_CSV = "synthetische_ggz_agressie_dataset_1000.csv"
54
 
55
+ # Afbeelding die direct onder de 2D/3D-kans-plot verschijnt (bestand naast app.py)
56
  INFO_IMAGE = str(Path(__file__).resolve().parent / "imglk;l;kl.png")
 
 
57
 
58
+ # Volledige-breedte koptekst
59
  SLOGAN = "Studieobject Marcel Ooms: Veiligere zorg begint hier: het 30-dagenrisico op agressie onderbouwd en uitlegbaar."
60
 
61
+ # Gebruikersvriendelijke intro: alleen kop vet
62
  INTRO = """
63
  **Van verslag naar risico: kans op agressie in de komende 30 dagen**
64
+ Wat doet deze pagina voor jou?
65
+ Deze demo helpt om uit vrije-tekstrapportages snel een inschatting van het risico op agressief gedrag in de komende 30 dagen te krijgen. Plak een stukje verslag in het tekstvak en je krijgt een kans (probabiliteit) terug, plus een voorgesteld label op basis van een drempel die je zelf kunt verschuiven. Zo kun je risico vroegtijdig signaleren en bepalen welke acties passen: extra observatie, bijsturing in het behandelplan of overleg in het team.
66
+ Hoe werkt het in grote lijnen (zonder technisch gedoe):
67
+ - Bij het openen staat er al een startmodel klaar.
68
+ - Je kunt hertrainen met drie aanpakken: TF-IDF, ClinicalBERT of DutchBERT.
69
+ - De grafieken laten zien hoe nauwkeurig het model is en hoe de drempel precision en recall beïnvloedt.
70
+ - Met LIME zie je welke woorden in de tekst het meest hebben bijgedragen aan de inschatting; dat maakt de uitkomst uitlegbaar.
71
+ Belangrijk om te weten:
72
+ - Dit is een demonstratie op synthetische data. De uitkomst is een waarschijnlijkheid, geen zekerheid.
73
+ - Het systeem voorspelt niet of iemand agressief wordt, maar schat de kans binnen 30 dagen in op basis van tekstsignalen.
74
+ - Gebruik de uitkomst altijd naast klinische expertise en bestaande veiligheidsprotocollen.
75
  """
76
 
77
+ # Herschreven rechter tekstblok: alleen kopjes vet
78
  WHAT_YOU_SEE = """
79
+ **Wat zie je op deze pagina?**
80
+ **Status & prestaties**
81
+ Hier zie je hoe goed het model onderscheid maakt. AUROC en AUPRC tonen in één oogopslag hoe betrouwbaar de inschatting is; hoger is beter.
82
+ **Handmatig trainen (zonder upload)**
83
+ Kies een featurizer (TF-IDF, ClinicalBERT of DutchBERT) en klik op Train algoritme. Je kunt opties aanpassen en direct vergelijken wat in jouw setting het beste werkt.
84
+ **Visualisatie**
85
+ De interactieve 2D/3D-plot laat elke tekst als een punt zien. Kleur en positie helpen om patronen te herkennen; met de muis zie je extra uitleg per punt. Er zijn twee weergaven: kleur naar werkelijk label en kleur naar voorspelde kans.
86
+ **Evaluatie**
87
+ Met de drempel-schuif bepaal je wanneer “hoog risico” wordt toegekend. Je ziet wat dat betekent voor precision, recall en F1. Zo kun je kiezen tussen minder valse alarmen of meer signalen oppikken.
88
+ **Predict**
89
+ Plak een rapportage in het tekstvak en krijg meteen een kans en een voorgesteld label. Het is een hulpmiddel voor vroegtijdige signalering, geen definitieve uitspraak.
90
+ **Hertrain met eigen CSV**
91
+ Upload een CSV met de juiste kolommen en train het model opnieuw. De nieuwe prestaties en grafieken worden direct bijgewerkt.
92
  """
93
 
94
+ # Verhaal over ML dat direct onder de afbeelding komt: alleen kop vet
95
  ML_STORY = """
96
  **Van ruwe data naar beslisinformatie**
97
+ De afbeelding schetst de weg van ruwe data naar beslisinformatie. We starten met tekst: observaties, verslagen en notities. Met historische labels leert een algoritme patronen herkennen. In de verwerking wordt tekst omgezet naar kenmerken (bijvoorbeeld TF-IDF of BERT-embeddings) en leert het model welke combinaties iets zeggen over het risico op agressie binnen 30 dagen.
98
+ Het resultaat is een waarschijnlijkheid, geen absolute waarheid. Die kans helpt teams om eerder te signaleren en bewust te kiezen: wil je minder valse alarmen (hogere precision) of juist meer signaal oppikken (hogere recall)? De mens blijft aan het roer: de uitkomst is uitlegbaar met LIME, meetbaar met AUROC/AUPRC en bedoeld om het klinisch oordeel te ondersteunen.
99
  """
100
 
101
  FOOTER = """
102
+ **Technische noot**
103
+ Modellen: TF-IDF → Logistic Regression; ClinicalBERT/DutchBERT Logistic Regression
104
+ Visualisatie: SVD(50) → t-SNE(2D/3D) op de gekozen tekstfeatures
105
+ CSV-loader: lokaal (map van dit bestand) of via Hugging Face Hub
106
  """
107
 
108
+ # MLflow experiment
109
  mlflow.set_experiment("ggz-agressie")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # ============ Data loading ============
112
  def _resolve_csv_path(uploaded=None):
113
  if uploaded is not None:
114
  return uploaded.name if hasattr(uploaded, "name") else uploaded
115
+ candidates = [
116
  os.path.join(os.getcwd(), DEFAULT_CSV),
117
  os.path.join(os.path.dirname(__file__), DEFAULT_CSV),
118
  DEFAULT_CSV,
119
+ ]
120
+ for p in candidates:
121
  if os.path.exists(p):
122
  return p
123
  repo_id = os.environ.get("SPACE_ID")
124
  if repo_id:
125
  return hf_hub_download(repo_id=repo_id, filename=DEFAULT_CSV)
126
  raise FileNotFoundError(
127
+ f"Kon {DEFAULT_CSV} niet vinden. Zet het bestand in de repo-root "
128
+ "of upload een CSV met kolommen `rapportage` en `agressie_volgende30d`."
129
  )
130
 
131
  def load_dataset(file_obj=None):
 
135
  missing = required - set(df.columns)
136
  if missing:
137
  raise ValueError(f"CSV mist verplichte kolommen: {missing}")
 
 
138
  df = df.dropna(subset=["rapportage", "agressie_volgende30d"]).copy()
139
  df["agressie_volgende30d"] = (df["agressie_volgende30d"].astype(int) > 0).astype(int)
 
140
  return df
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # ============ HF Text Embedder ============
143
  class HFTextEmbedder(BaseEstimator, TransformerMixin):
144
+ """
145
+ Sklearn-compatibele transformer die sentence-embeddings maakt met een HF encoder.
146
+ - Mean-pooling over token embeddings (mask-aware)
147
+ - Batching en device auto-select
148
+ """
149
+ def __init__(self,
150
+ model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
151
+ max_length: int = 128,
152
+ batch_size: int = 16,
153
+ device: _t.Optional[str] = None):
154
  self.model_name = model_name
155
  self.max_length = max_length
156
  self.batch_size = batch_size
 
158
  self._tokenizer = None
159
  self._model = None
160
  self._dev = None
161
+
162
  def _ensure_backend(self):
163
  if torch is None or AutoTokenizer is None or AutoModel is None:
164
  raise RuntimeError("BERT-embeddings vereisen 'torch' en 'transformers'.")
165
  self._dev = self.device or ("cuda" if torch.cuda.is_available() else "cpu")
166
  if self._tokenizer is None:
167
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
168
  if self._model is None:
169
+ self._model = AutoModel.from_pretrained(self.model_name).to(self._dev)
 
 
170
  self._model.eval()
171
+
172
  def fit(self, X, y=None):
173
+ self._ensure_backend()
174
+ return self
175
+
176
  @torch.no_grad()
177
  def transform(self, X):
178
  self._ensure_backend()
 
182
  embs = []
183
  for i in range(0, len(texts), self.batch_size):
184
  batch = texts[i:i+self.batch_size]
185
+ toks = self._tokenizer(
186
+ batch, padding=True, truncation=True,
187
+ max_length=self.max_length, return_tensors="pt"
188
+ ).to(self._dev)
189
+ outs = self._model(**toks).last_hidden_state # (B, T, H)
190
+ mask = toks.attention_mask.unsqueeze(-1) # (B, T, 1)
191
+ summed = (outs * mask).sum(dim=1) # (B, H)
192
+ counts = mask.sum(dim=1).clamp(min=1) # (B, 1)
193
+ pooled = summed / counts # (B, H)
194
  embs.append(pooled.cpu().numpy())
195
  return np.vstack(embs)
196
 
197
  # ============ Explainability helpers ============
 
 
 
 
 
 
 
 
 
 
198
  def _clf_and_vectorizer_from_pipe(pipe):
199
  vec = pipe.named_steps.get("txt")
200
  clf = pipe.named_steps.get("clf")
201
  return vec, clf
202
 
 
 
 
 
203
  def tfidf_global_top_words(pipe, k=20):
204
+ """Top-k 'pro-agressie' en 'anti-agressie' woorden (alleen bij TF-IDF)."""
 
 
205
  vec, clf = _clf_and_vectorizer_from_pipe(pipe)
206
+ if not hasattr(vec, "get_feature_names_out"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return [], []
208
+ feature_names = np.array(vec.get_feature_names_out())
209
+ coefs = clf.coef_[0]
210
+ top_pos_idx = np.argsort(coefs)[-k:][::-1]
211
+ top_neg_idx = np.argsort(coefs)[:k]
 
212
  return list(feature_names[top_pos_idx]), list(feature_names[top_neg_idx])
213
 
214
+ _lime_explainer = LimeTextExplainer(class_names=["geen agressie", "agressie"])
215
+ def lime_explain_text(pipe, text, num_features=8):
216
+ def predict_proba_text(texts):
217
+ p1 = pipe.predict_proba(texts)[:, 1]
218
+ p0 = 1 - p1
219
+ return np.vstack([p0, p1]).T
220
+ exp = _lime_explainer.explain_instance(text, predict_proba_text, num_features=num_features)
221
+ return exp.as_html()
222
+
223
+ # ============ Metrics helpers ============
224
  def _format_confusion_df(cm: np.ndarray) -> pd.DataFrame:
225
+ """
226
+ Maakt een confusion-matrix dataframe met uitleg per cel (TN/FP/FN/TP).
227
+ Klassen: 0 = 'geen agressie', 1 = 'agressie'.
228
+ """
229
  if cm.shape != (2, 2):
230
  return pd.DataFrame(cm, index=["True 0", "True 1"], columns=["Pred 0", "Pred 1"])
231
  tn, fp, fn, tp = cm.ravel()
 
245
  weighted = rep.get("weighted avg", {})
246
  s0 = int(rep.get("0", {}).get("support", 0))
247
  s1 = int(rep.get("1", {}).get("support", 0))
248
+ md = f"""
249
  ### ℹ️ Uitleg bij het classification report (drempel = {thr:.2f})
250
+ Klasselabels
251
+ 0 = geen agressie, 1 = agressie.
252
+ De drempel bepaalt wanneer de kans wordt omgezet naar label 1 (≥ drempel) of 0 (< drempel).
253
+ Velden in het rapport
254
+ Precision: van alle voorspelde positieven (label 1), welk deel was echt positief?
255
+ Recall (sensitiviteit): van alle werkelijk positieven (label 1), welk deel hebben we gevonden?
256
+ F1-score: harmonisch gemiddelde van precision en recall.
257
+ Support: aantal voorbeelden per klasse.
258
+ Accuracy: (TP + TN) / totaal — gevoelig voor class imbalance.
259
+ Macro avg: ongewogen gemiddelde over klassen.
260
+ Weighted avg: gewogen gemiddelde (weging = support).
261
+ Huidige set (support/accuracy)
262
+ Support klasse 0: {s0}, klasse 1: {s1}.
263
+ Accuracy (totaal): {acc:.3f}.
264
+ Macro avg F1: {macro.get('f1-score', 0):.3f}, Weighted avg F1: {weighted.get('f1-score', 0):.3f}.
265
+ Drempel-tips
266
+ Drempel omhoog → vaak hogere precision maar lagere recall.
267
+ Drempel omlaag → vaak hogere recall maar lagere precision.
268
  """
269
+ return md
270
 
271
+ # Visuele confusion-matrix (heatmap)
 
 
 
 
 
 
 
 
 
272
  def make_confusion_heatmap(y_true, y_score, thr=0.5):
273
  y_pred = (y_score >= thr).astype(int)
274
  cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
275
  z = cm.astype(int)
276
+ xlabels = ["Pred 0", "Pred 1"]
277
+ ylabels = ["True 0", "True 1"]
278
+
279
+ fig = go.Figure(
280
+ data=go.Heatmap(
281
+ z=z, x=xlabels, y=ylabels,
282
+ colorscale="Blues", showscale=True
283
+ )
284
+ )
285
+ # Annotaties (TN, FP, FN, TP)
286
  tn, fp, fn, tp = z.ravel()
287
+ annotations = [
288
+ (0, 0, f"TN: {tn}"),
289
+ (0, 1, f"FP: {fp}"),
290
+ (1, 0, f"FN: {fn}"),
291
+ (1, 1, f"TP: {tp}"),
292
+ ]
293
+ for r, c, text in annotations:
294
+ fig.add_annotation(x=xlabels[c], y=ylabels[r], text=text, showarrow=False)
295
+
296
+ fig.update_layout(
297
+ title=f"Confusion matrix (drempel = {thr:.2f})",
298
+ xaxis_title="Voorspelling",
299
+ yaxis_title="Werkelijkheid",
300
+ template="simple_white",
301
+ margin=dict(l=10, r=10, t=40, b=10)
302
+ )
303
  return fig
304
 
305
+ # -------- Eval-plots --------
306
  def make_roc_fig(y_true, y_score, auroc=None):
307
  fpr, tpr, _ = roc_curve(y_true, y_score)
308
  title = f"ROC-curve (AUROC={auroc:.3f})" if auroc is not None else "ROC-curve"
309
  fig = px.area(x=fpr, y=tpr, title=title, labels={"x":"False Positive Rate", "y":"True Positive Rate"})
310
  fig.add_shape(type="line", x0=0, x1=1, y0=0, y1=1, line=dict(dash="dash"))
311
+ fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white")
312
  return fig
313
 
314
  def make_pr_fig(y_true, y_score, auprc=None):
315
  prec, rec, _ = precision_recall_curve(y_true, y_score)
316
  title = f"Precision–Recall (AUPRC={auprc:.3f})" if auprc is not None else "Precision–Recall"
317
  fig = px.area(x=rec, y=prec, title=title, labels={"x":"Recall", "y":"Precision"})
318
+ fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white")
319
  return fig
320
 
321
  def make_prob_hist(y_true, y_score):
 
324
  title="Verdeling voorspelde kansen per werkelijke klasse",
325
  labels={"kans":"Voorspelde kans"})
326
  fig.update_traces(opacity=0.6)
327
+ fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white")
328
  return fig
329
 
330
  def make_threshold_metrics_fig(y_true, y_score, thr_line=0.5):
 
332
  rows = []
333
  for t in thresholds:
334
  y_pred = (y_score >= t).astype(int)
335
+ rows.append({
336
+ "threshold": t,
337
+ "precision": precision_score(y_true, y_pred, zero_division=0),
338
+ "recall": recall_score(y_true, y_pred, zero_division=0),
339
+ "f1": f1_score(y_true, y_pred, zero_division=0),
340
+ })
341
  df = pd.DataFrame(rows)
342
  df_m = df.melt(id_vars="threshold", value_vars=["precision","recall","f1"], var_name="metric", value_name="score")
343
  fig = px.line(df_m, x="threshold", y="score", color="metric",
344
  title="Metrics vs. drempel (precision/recall/F1)",
345
  labels={"threshold":"Drempel", "score":"Score"})
346
  fig.add_vline(x=float(thr_line), line_dash="dash", annotation_text=f"drempel={thr_line:.2f}", annotation_position="top")
347
+ fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white", yaxis=dict(range=[0,1]))
348
  return fig
349
 
350
+ # -------- Extra evaluaties: Kalibratie / Gains / Lift / KS --------
351
+ def make_calibration_fig(y_true, y_score, n_bins=10):
352
+ frac_pos, mean_pred = calibration_curve(y_true, y_score, n_bins=n_bins, strategy="quantile")
 
353
  fig = go.Figure()
354
+ fig.add_trace(go.Scatter(x=[0,1], y=[0,1], mode="lines", name="Perfect gekalibreerd", line=dict(dash="dash")))
355
+ fig.add_trace(go.Scatter(x=mean_pred, y=frac_pos, mode="lines+markers", name="Model"))
356
+ fig.update_layout(
357
+ title="Kalibratie (Reliability Diagram)",
358
+ xaxis_title="Gemiddelde voorspelde kans",
359
+ yaxis_title="Werkelijk aandeel positieven",
360
+ template="simple_white",
361
+ margin=dict(l=10, r=10, t=40, b=10)
362
+ )
363
+ return fig
364
+
365
+ def _gains_data(y_true, y_score):
366
  df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p", ascending=False).reset_index(drop=True)
367
  df["cum_pos"] = df["y"].cumsum()
368
  total_pos = df["y"].sum()
369
+ total = len(df)
370
+ pct_samples = (np.arange(1, total+1) / total)
371
+ cum_gain = (df["cum_pos"] / (total_pos if total_pos > 0 else 1))
372
+ return pct_samples, cum_gain
373
+
374
+ def make_gains_fig(y_true, y_score):
375
+ x, gains = _gains_data(y_true, y_score)
376
  fig = go.Figure()
377
+ fig.add_trace(go.Scatter(x=x, y=x, mode="lines", name="Baseline (random)", line=dict(dash="dash")))
378
+ fig.add_trace(go.Scatter(x=x, y=gains, mode="lines", name="Cumulative Gains"))
379
+ fig.update_layout(
380
+ title="Cumulative Gains",
381
+ xaxis_title="Percentage van populatie (gesorteerd op kans)",
382
+ yaxis_title="Percentage van positieven gedekt",
383
+ template="simple_white",
384
+ margin=dict(l=10, r=10, t=40, b=10),
385
+ yaxis=dict(range=[0,1]), xaxis=dict(range=[0,1])
386
+ )
387
  return fig
388
 
389
  def make_lift_fig(y_true, y_score):
390
+ x, gains = _gains_data(y_true, y_score)
391
+ lift = gains / np.clip(x, 1e-9, None)
 
 
 
 
 
 
 
392
  fig = go.Figure()
393
+ fig.add_trace(go.Scatter(x=x, y=np.ones_like(x), mode="lines", name="Baseline (lift=1)", line=dict(dash="dash")))
394
+ fig.add_trace(go.Scatter(x=x, y=lift, mode="lines", name="Lift"))
395
+ fig.update_layout(
396
+ title="Lift-curve",
397
+ xaxis_title="Percentage van populatie (gesorteerd op kans)",
398
+ yaxis_title="Lift",
399
+ template="simple_white",
400
+ margin=dict(l=10, r=10, t=40, b=10)
401
+ )
402
  return fig
403
 
404
  def make_ks_fig(y_true, y_score):
405
+ df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p", ascending=False).reset_index(drop=True)
406
+ total_pos = df["y"].sum()
407
+ total_neg = len(df) - total_pos
408
+ df["tp_cum"] = df["y"].cumsum() / (total_pos if total_pos > 0 else 1)
409
+ df["fp_cum"] = ((1 - df["y"]).cumsum()) / (total_neg if total_neg > 0 else 1)
410
+ ks_series = (df["tp_cum"] - df["fp_cum"]).abs()
411
+ ks_max_idx = int(ks_series.values.argmax()) if len(ks_series) else 0
412
+ ks_value = float(ks_series.iloc[ks_max_idx]) if len(ks_series) else 0.0
413
+ x = (np.arange(1, len(df)+1) / len(df)) if len(df) else np.array([0])
414
+
415
  fig = go.Figure()
416
+ fig.add_trace(go.Scatter(x=x, y=df["tp_cum"], mode="lines", name="TPR cumulatief"))
417
+ fig.add_trace(go.Scatter(x=x, y=df["fp_cum"], mode="lines", name="FPR cumulatief"))
418
+ if len(x):
419
+ fig.add_vline(x=float(x[ks_max_idx]), line_dash="dash",
420
+ annotation_text=f"KS={ks_value:.3f}", annotation_position="top")
421
+ fig.update_layout(
422
+ title="KS-curve",
423
+ xaxis_title="Percentage van populatie (gesorteerd op kans)",
424
+ yaxis_title="Cumulatieve ratio",
425
+ template="simple_white",
426
+ margin=dict(l=10, r=10, t=40, b=10),
427
+ yaxis=dict(range=[0,1]), xaxis=dict(range=[0,1])
428
+ )
429
+ return fig
430
 
431
+ def make_dataset_profile(df):
432
+ text = df["rapportage"].astype(str)
433
+ lengths = text.str.len()
434
+ pos = df["agressie_volgende30d"].astype(int)
435
+ prof = pd.DataFrame({
436
+ "kenmerk": [
437
+ "Aantal rijen",
438
+ "Aantal positieven (1)",
439
+ "Aantal negatieven (0)",
440
+ "Positiefratio",
441
+ "Tekstlengte gemiddeld",
442
+ "Tekstlengte mediaan",
443
+ "Tekstlengte p10",
444
+ "Tekstlengte p90",
445
+ ],
446
+ "waarde": [
447
+ int(len(df)),
448
+ int(pos.sum()),
449
+ int((1 - pos).sum()),
450
+ f"{(pos.mean()*100):.1f}%",
451
+ f"{lengths.mean():.1f}",
452
+ int(lengths.median()),
453
+ int(np.percentile(lengths, 10)),
454
+ int(np.percentile(lengths, 90)),
455
+ ]
456
+ })
457
+ return prof
458
 
459
  # ============ Model & Viz ============
460
  def build_and_train(
 
467
  bert_maxlen=128,
468
  bert_batch=16
469
  ):
470
+ X = df["rapportage"].astype(str).values
471
  y = df["agressie_volgende30d"].values
472
  X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
473
  X, y, np.arange(len(X)),
474
  test_size=test_size, random_state=random_state, stratify=y
475
  )
476
 
 
 
 
 
 
 
 
 
 
477
  if featurizer == "TF-IDF":
478
+ txt = TfidfVectorizer(max_features=max_features, ngram_range=(1, ngram_max))
479
+ clf = LogisticRegression(max_iter=3000)
480
+ pipe = Pipeline([("txt", txt), ("clf", clf)])
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  pipe.fit(X_train, y_train)
482
  y_score = pipe.predict_proba(X_test)[:, 1]
483
+ txt_all = pipe.named_steps["txt"].transform(X) # sparse
 
 
 
484
  elif featurizer == "ClinicalBERT":
485
  emb = HFTextEmbedder(model_name="emilyalsentzer/Bio_ClinicalBERT",
486
  max_length=bert_maxlen, batch_size=bert_batch)
487
+ clf = LogisticRegression(max_iter=3000)
488
+ pipe = Pipeline([("txt", emb), ("clf", clf)])
 
 
489
  pipe.fit(X_train, y_train)
490
  y_score = pipe.predict_proba(X_test)[:, 1]
491
+ txt_all = pipe.named_steps["txt"].transform(X) # dense
 
 
 
 
 
492
  elif featurizer == "DutchBERT":
493
  emb = HFTextEmbedder(model_name="wietsedv/bert-base-dutch-cased",
494
  max_length=bert_maxlen, batch_size=bert_batch)
495
+ clf = LogisticRegression(max_iter=3000)
496
+ pipe = Pipeline([("txt", emb), ("clf", clf)])
 
 
497
  pipe.fit(X_train, y_train)
498
  y_score = pipe.predict_proba(X_test)[:, 1]
499
+ txt_all = pipe.named_steps["txt"].transform(X) # dense
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  else:
501
+ raise ValueError("Onbekende featurizer. Kies 'TF-IDF', 'ClinicalBERT' of 'DutchBERT'.")
502
 
503
  auroc = float(roc_auc_score(y_test, y_score))
504
  auprc = float(average_precision_score(y_test, y_score))
505
 
506
+ # 2D/3D embedding: SVD (50) -> t-SNE (2D en 3D)
507
  svd = TruncatedSVD(n_components=50, random_state=random_state)
508
+ X50 = svd.fit_transform(txt_all)
509
 
510
+ # t-SNE 2D
511
+ tsne2 = TSNE(n_components=2, random_state=random_state, perplexity=30,
512
+ learning_rate="auto", init="pca")
513
  X2 = tsne2.fit_transform(X50)
514
  x2 = (X2[:, 0] - np.min(X2[:, 0])) / (np.ptp(X2[:, 0]) + 1e-9)
515
  y2 = (X2[:, 1] - np.min(X2[:, 1])) / (np.ptp(X2[:, 1]) + 1e-9)
516
 
517
+ # t-SNE 3D
518
+ tsne3 = TSNE(n_components=3, random_state=random_state, perplexity=30,
519
+ learning_rate="auto", init="pca")
520
  X3 = tsne3.fit_transform(X50)
521
  x3 = (X3[:, 0] - np.min(X3[:, 0])) / (np.ptp(X3[:, 0]) + 1e-9)
522
  y3 = (X3[:, 1] - np.min(X3[:, 1])) / (np.ptp(X3[:, 1]) + 1e-9)
 
528
  "x3": x3, "y3": y3, "z3": z3,
529
  "label": df["agressie_volgende30d"].values,
530
  "kans": proba_all,
531
+ "rapportage": df["rapportage"].str.slice(0, 180) + "..."
532
  })
533
  for col in ["PHQ9_baseline","GAD7_baseline","stress_niveau_1_5","slaap_uren","sociale_steun_0_10","zorgsetting"]:
534
  if col in df.columns:
 
538
  test_mask[idx_test] = True
539
  plot_df["split"] = np.where(test_mask, "test", "train")
540
 
541
+ return pipe, (X_test, y_test, y_score), plot_df, auroc, auprc
542
 
543
  def make_scatter(plot_df, color_mode="label", dim="2D"):
544
+ """
545
+ Algemene scattermaker:
546
+ - color_mode: 'label' of 'kans'
547
+ - dim: '2D' of '3D'
548
+ """
549
  hover_cols = ["rapportage", "kans", "split"]
550
  if color_mode == "label":
551
  color = plot_df["label"].map({0: "geen agressie", 1: "agressie"})
552
  title_2d = "2D projectie (t-SNE) — kleur = werkelijk label"
553
  title_3d = "3D projectie (t-SNE) — kleur = werkelijk label"
554
  if dim == "2D":
555
+ fig = px.scatter(
556
+ plot_df, x="x", y="y", color=color,
557
+ hover_data=hover_cols, title=title_2d, opacity=0.85
558
+ )
559
  else:
560
+ fig = px.scatter_3d(
561
+ plot_df, x="x3", y="y3", z="z3", color=color,
562
+ hover_data=hover_cols, title=title_3d, opacity=0.9
563
+ )
564
+ else: # 'kans'
565
  title_2d = "2D projectie (t-SNE) — kleur = voorspelde kans"
566
  title_3d = "3D projectie (t-SNE) — kleur = voorspelde kans"
567
  if dim == "2D":
568
+ fig = px.scatter(
569
+ plot_df, x="x", y="y", color="kans",
570
+ hover_data=hover_cols, title=title_2d,
571
+ color_continuous_scale="Turbo", opacity=0.9
572
+ )
573
  else:
574
+ fig = px.scatter_3d(
575
+ plot_df, x="x3", y="y3", z="z3", color="kans",
576
+ hover_data=hover_cols, title=title_3d,
577
+ color_continuous_scale="Turbo", opacity=0.9
578
+ )
579
+ # Styling + ASTITELS
580
  if dim == "2D":
581
  fig.update_traces(marker=dict(size=8, line=dict(width=0)))
582
+ fig.update_layout(
583
+ margin=dict(l=10, r=10, t=40, b=10),
584
+ template="simple_white",
585
+ xaxis_title="x (t-SNE)",
586
+ yaxis_title="y (t-SNE)"
587
+ )
588
  else:
589
  fig.update_traces(marker=dict(size=4))
590
+ fig.update_layout(
591
+ margin=dict(l=10, r=10, t=40, b=10),
592
+ template="simple_white",
593
+ scene=dict(
594
+ xaxis_title="x (t-SNE)",
595
+ yaxis_title="y (t-SNE)",
596
+ zaxis_title="z (t-SNE)"
597
+ )
598
+ )
599
  return fig
600
 
601
+ # --- (Niet meer gebruikt) Beslissingslandschap-overlay ---
602
+ def make_prob_with_decision_landscape(plot_df, grid_n=150):
603
+ """
604
+ Achtergrond: LR(x,y)->label geeft per gridcel P(klasse=1).
605
+ Voorgrond: punten gekleurd naar model-kans (plot_df['kans']).
606
+ Wordt behouden voor referentie, maar niet meer gebruikt in de UI.
607
+ """
608
+ X2 = plot_df[["x", "y"]].values
609
+ y = plot_df["label"].values.astype(int)
610
+
611
+ clf = LogisticRegression(max_iter=2000)
612
+ clf.fit(X2, y)
613
+
614
+ gx = np.linspace(0.0, 1.0, grid_n)
615
+ gy = np.linspace(0.0, 1.0, grid_n)
616
+ XX, YY = np.meshgrid(gx, gy)
617
+ grid = np.c_[XX.ravel(), YY.ravel()]
618
+ proba = clf.predict_proba(grid)[:, 1].reshape(XX.shape)
619
+
620
+ heat = go.Heatmap(
621
+ x=gx, y=gy, z=proba,
622
+ zmin=0, zmax=1,
623
+ colorscale="Turbo",
624
+ showscale=True,
625
+ colorbar=dict(title="kans (landschap)")
626
+ )
627
+ fig = go.Figure(data=[heat])
628
+ fig.update_layout(
629
+ title="2D projectie (t-SNE) — kleur = voorspelde kans (met beslissingslandschap)",
630
+ template="simple_white",
631
+ margin=dict(l=10, r=10, t=40, b=10),
632
+ xaxis_title="x (t-SNE)", yaxis_title="y (t-SNE)"
633
+ )
634
+ fig.add_trace(go.Scatter(
635
+ x=plot_df["x"], y=plot_df["y"],
636
+ mode="markers",
637
+ marker=dict(
638
+ size=8,
639
+ opacity=0.85,
640
+ color=plot_df["kans"],
641
+ colorscale="Turbo",
642
+ showscale=False,
643
+ line=dict(width=0)
644
+ ),
645
+ text=(
646
+ "kans=" + plot_df["kans"].round(3).astype(str) +
647
+ " | split=" + plot_df["split"].astype(str)
648
+ ),
649
+ hovertemplate="x=%{x:.3f}, y=%{y:.3f}<br>%{text}<extra></extra>",
650
+ name="punten"
651
+ ))
652
+ fig.update_xaxes(range=[0, 1])
653
+ fig.update_yaxes(range=[0, 1])
654
+ return fig
655
+
656
+ def metrics_table(y_true, y_score, thr):
657
+ """
658
+ Maakt het classification report met eenheden (%, aantallen) voor compacte weergave.
659
+ - precision/recall/f1: percentages met 1 decimaal (bijv. 87.5%)
660
+ - support: integer
661
+ - accuracy: extra kolom 'accuracy_%' met percentage
662
+ """
663
+ y_pred = (y_score >= thr).astype(int)
664
+ rep = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
665
+
666
+ rep_df = pd.DataFrame(rep).T
667
+ rep_df_disp = rep_df.copy()
668
+
669
+ for col in ["precision", "recall", "f1-score"]:
670
+ if col in rep_df_disp:
671
+ rep_df_disp[col] = (rep_df_disp[col] * 100).round(1).map(
672
+ lambda v: f"{v:.1f}%" if pd.notnull(v) else ""
673
+ )
674
+
675
+ if "support" in rep_df_disp:
676
+ rep_df_disp["support"] = rep_df_disp["support"].map(
677
+ lambda v: f"{int(v)}" if pd.notnull(v) else ""
678
+ )
679
+
680
+ if "accuracy" in rep:
681
+ acc_pct = f"{rep['accuracy'] * 100:.1f}%"
682
+ rep_df_disp["accuracy_%"] = ""
683
+ if "accuracy" in rep_df_disp.index:
684
+ rep_df_disp.loc["accuracy", "accuracy_%"] = acc_pct
685
+
686
+ rep_df_disp = rep_df_disp.fillna("")
687
+
688
+ cm = confusion_matrix(y_true, y_pred)
689
+ cm_df = _format_confusion_df(cm)
690
+ rep_md = _build_report_markdown(rep, thr)
691
+
692
+ return rep_df_disp, cm_df, rep_md
693
+
694
  # ============ State & Train ============
695
  GLOBAL = {
696
  "pipe": None, "plot_df": None, "eval": None,
697
  "auroc": None, "auprc": None,
698
  "featurizer": "TF-IDF",
699
+ "df": None, # bewaar dataset voor datavoorbeeld
 
700
  }
701
 
702
  def do_train(file_obj=None, test_size=0.2, seed=42,
703
  featurizer="TF-IDF", max_features=4000, ngram_max=2,
704
  bert_maxlen=128, bert_batch=16):
705
  df = load_dataset(file_obj)
706
+ pipe, eval_pack, plot_df, auroc, auprc = build_and_train(
707
  df, test_size, seed, featurizer, max_features, ngram_max, bert_maxlen, bert_batch
708
  )
709
 
710
+ # MLflow logging
711
  with mlflow.start_run(run_name=f"{featurizer}"):
712
  mlflow.log_param("featurizer", featurizer)
713
  mlflow.log_param("test_size", test_size)
714
+ if featurizer == "TF-IDF":
715
  mlflow.log_param("tfidf_max_features", max_features)
716
  mlflow.log_param("tfidf_ngram_max", ngram_max)
717
  else:
 
722
  mlflow.sklearn.log_model(pipe, artifact_path="model")
723
 
724
  GLOBAL.update(pipe=pipe, plot_df=plot_df, eval=eval_pack,
725
+ auroc=auroc, auprc=auprc, featurizer=featurizer, df=df)
726
 
727
+ # Tabel + uitleg
728
  rep_df, cm_df, rep_md = metrics_table(eval_pack[1], eval_pack[2], thr=0.5)
729
 
730
+ # Plots basis
731
  roc_fig = make_roc_fig(eval_pack[1], eval_pack[2], auroc)
732
  pr_fig = make_pr_fig(eval_pack[1], eval_pack[2], auprc)
733
  hist_fig = make_prob_hist(eval_pack[1], eval_pack[2])
734
  thr_fig = make_threshold_metrics_fig(eval_pack[1], eval_pack[2], thr_line=0.5)
735
 
736
+ # Standaard visualisaties: 2D
737
  fig_label = make_scatter(plot_df, color_mode="label", dim="2D")
738
  fig_prob = make_scatter(plot_df, color_mode="kans", dim="2D")
739
 
740
+ # Extra evaluaties
741
+ y_true, y_score = eval_pack[1], eval_pack[2]
742
+ cal_fig = make_calibration_fig(y_true, y_score, n_bins=10)
743
+ gains_fig = make_gains_fig(y_true, y_score)
744
+ lift_fig = make_lift_fig(y_true, y_score)
745
+ ks_fig = make_ks_fig(y_true, y_score)
746
+ profile_df = make_dataset_profile(df)
747
 
748
+ # Confusion heatmap op basis van default drempel
749
+ cm_plot = make_confusion_heatmap(y_true, y_score, thr=0.5)
 
 
 
 
750
 
751
+ # Datavoorbeeld (standaard: eerste 10 rijen)
752
+ preview_df = df.head(10)
753
 
754
+ status_msg = f"✅ Model getraind met {featurizer}. AUROC: {auroc:.3f} | AUPRC: {auprc:.3f}"
755
  return (
756
  status_msg, auroc, auprc,
757
+ preview_df, # datavoorbeeld output
758
  fig_label, fig_prob,
759
  rep_df, cm_df, cm_plot, rep_md,
760
  roc_fig, pr_fig, hist_fig, thr_fig,
761
+ cal_fig, gains_fig, lift_fig, ks_fig, profile_df
762
  )
763
 
764
+ def predict_one(text):
765
  if GLOBAL["pipe"] is None:
766
  return "Nog geen model getraind.", None
767
  if not text or text.strip() == "":
768
  return "Voer een rapportage in.", None
769
+ proba = float(GLOBAL["pipe"].predict_proba([text])[:, 1][0])
770
+ label = int(proba >= 0.5)
 
 
771
  md = (
772
  f"Kans op agressie (30d): {proba:.3f} — "
773
+ f"voorspelde klasse: {label} (drempel 0.50)\n"
774
  f"Featurizer: {GLOBAL.get('featurizer','?')}"
775
  )
776
  return md, proba
777
 
778
  # ============ UI ============
779
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as demo:
780
+ # Volledige-breedte kopregel (h1)
781
  gr.Markdown(f"# {SLOGAN}")
782
 
783
+ # --- opvallende styling voor de knoppen + scrollbare data-preview ---
784
  gr.HTML("""
785
  <style>
786
+ /* Zelfde gradient-stijl voor alle 3 knoppen */
787
  #train-btn, #retrain-btn, #predict-btn {
788
  background: linear-gradient(90deg, #ef4444 0%, #f97316 100%);
789
+ color: white !important;
790
+ font-weight: 700;
791
+ border: none !important;
792
+ }
793
+ #train-btn:hover, #retrain-btn:hover, #predict-btn:hover {
794
+ filter: brightness(0.95);
795
+ }
796
+ /* Scrollbare DataFrame container */
797
+ #data-preview {
798
+ max-height: 320px;
799
+ overflow: auto;
800
  }
801
+ #data-preview table {
802
+ width: 100%;
803
+ }
804
+ /* Afbeelding direct onder projectie zonder top-ruimte */
805
  #viz-img { margin-top: 0 !important; padding-top: 0 !important; }
806
  #viz-img img { display: block; margin-top: 0 !important; }
807
  </style>
808
  """)
809
 
810
+ # Introductie & overzicht naast elkaar
811
  with gr.Row():
812
+ with gr.Column(scale=1):
813
+ gr.Markdown(INTRO)
814
+ with gr.Column(scale=1):
815
+ gr.Markdown(WHAT_YOU_SEE)
816
 
817
+ # ---- Handmatig trainen (zonder CSV upload) ----
818
  gr.Markdown("## 🛠️ Handmatig trainen (zonder CSV upload)")
819
  with gr.Row():
820
  featur_quick = gr.Radio(
821
+ choices=["TF-IDF", "ClinicalBERT", "DutchBERT"],
822
  value="TF-IDF",
823
  label="Kies featurizer"
824
  )
 
834
  with gr.Row():
835
  auroc_box = gr.Number(label="AUROC", precision=3)
836
  auprc_box = gr.Number(label="AUPRC", precision=3)
 
837
 
838
+ # Visualisatie + evaluatie-tabellen
839
  with gr.Row():
840
  with gr.Column(scale=3):
841
  gr.Markdown("### 🔍 Visualisatie")
842
+ # Gezamenlijke toggle voor dimensie
843
  proj_dim = gr.Radio(choices=["2D", "3D"], value="2D", label="Projectiedimensie (geldt voor beide projecties)")
844
  with gr.Column():
845
  fig_out_label = gr.Plot(label="Projectie — kleur = werkelijk label")
 
848
  gr.Markdown(ML_STORY)
849
  with gr.Column(scale=2):
850
  gr.Markdown("### 📄 Datavoorbeeld")
851
+ data_preview_mode = gr.Radio(
852
+ choices=["Eerste 10 rijen", "Gehele dataset (scrollbaar)"],
853
+ value="Eerste 10 rijen",
854
+ label="Weergave"
855
+ )
856
  data_preview = gr.Dataframe(label="Dataset", interactive=False, elem_id="data-preview")
857
 
858
  gr.Markdown("### ⚙️ Evaluatie (tabellen & drempel)")
 
862
  cm_plot = gr.Plot(label="Confusion matrix (heatmap)")
863
  rep_md = gr.Markdown(label="Uitleg classification report")
864
 
865
+ # === Twee kolommen — links plots (met tabs), rechts predict ===
866
  with gr.Row():
867
  with gr.Column(scale=3):
868
  with gr.Tabs():
 
874
  roc_plot = gr.Plot(label="ROC-curve")
875
  with gr.TabItem("Precision–Recall"):
876
  pr_plot = gr.Plot(label="PR-curve")
877
+ # ---- Nieuw: extra tabs ----
878
  with gr.TabItem("Kalibratie"):
879
+ cal_plot = gr.Plot(label="Kalibratie (Reliability Diagram)")
 
 
880
  with gr.TabItem("Cumulative Gains"):
881
  gains_plot = gr.Plot(label="Cumulative Gains")
882
  with gr.TabItem("Lift"):
 
884
  with gr.TabItem("KS-curve"):
885
  ks_plot = gr.Plot(label="KS-curve")
886
  with gr.TabItem("Dataset-profiel"):
887
+ profile_df_out = gr.Dataframe(label="Dataset-profiel", interactive=False)
 
888
  with gr.Column(scale=2):
889
  gr.Markdown("### 🗣️ Predict (vrije tekst)")
890
  with gr.Row():
891
+ txt = gr.Textbox(
892
+ lines=12, label="Rapportage (NL)",
893
+ placeholder="Bijv.: Patiënt oogt geagiteerd, slaapt slecht, weigert medicatie..."
894
+ )
 
895
  btn = gr.Button("Voorspel", elem_id="predict-btn")
896
  md_out = gr.Markdown()
897
  proba_out = gr.Number(label="Kans", precision=3)
898
 
899
+ # ===== Hertrain met eigen CSV — ALTIJD ZICHTBAAR =====
900
  gr.Markdown("## 🔁 Hertrain met eigen CSV")
901
+ gr.Markdown(
902
+ "Upload een CSV met kolommen `rapportage` (tekst) en `agressie_volgende30d` (0/1). "
903
+ "Kies je parameters en klik Train opnieuw (met upload)."
904
+ )
905
+ csv_in = gr.File(label="Upload CSV (kolommen: rapportage, agressie_volgende30d)")
906
  with gr.Row():
907
  test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test set grootte")
908
  seed = gr.Slider(1, 999, value=42, step=1, label="Random seed")
909
  with gr.Row():
910
+ featur = gr.Radio(choices=["TF-IDF", "ClinicalBERT", "DutchBERT"], value="TF-IDF", label="Tekst-featurizer")
 
911
  with gr.Row(visible=True) as tfidf_row:
912
  max_features = gr.Slider(1000, 12000, value=4000, step=1000, label="TF-IDF max_features")
913
  ngram_max = gr.Radio(choices=[1, 2], value=2, label="n-gram max")
 
916
  bert_batch = gr.Slider(4, 64, value=16, step=4, label="BERT batch_size")
917
  retrain_btn = gr.Button("Train opnieuw (met upload)", elem_id="retrain-btn")
918
 
919
+ # << VERPLAATST: uitleg over de evaluatieplots — lager in dezelfde kolom >>
920
  with gr.Row():
921
  with gr.Column(scale=2, min_width=0):
922
  gr.Markdown(
923
  "### ℹ️ Over de evaluatieplots\n\n"
924
+ "De onderstaande grafieken laten zien hoe het model presteert bij verschillende drempels en uitkomsten:\n\n"
925
+ "- Metrics vs. drempel toont hoe precision, recall en F1-score veranderen als je de drempel aanpast.\n"
926
+ "- Kansverdeling laat zien hoe voorspelde kansen verdeeld zijn over de echte klassen (0/1).\n"
927
+ "- ROC-curvevergelijkt True Positive Rate met False Positive Rate (AUROC = scheidingskracht).\n"
928
+ "- Precision–Recall-curvenuttig bij ongebalanceerde data; focust op de positieve klasse.\n\n"
929
+ "Gebruik ze samen om te bepalen waar je drempel moet liggen en hoe betrouwbaar het model is."
930
  )
931
 
932
+ # Toggle zichtbaarheid param-rijen
933
  def _toggle_quick(choice):
934
  return (
935
+ gr.update(visible=(choice == "TF-IDF")),
936
+ gr.update(visible=(choice in ("ClinicalBERT", "DutchBERT")))
937
  )
938
  featur_quick.change(_toggle_quick, inputs=featur_quick, outputs=[tfidf_quick_row, bert_quick_row])
939
 
940
  def _toggle_rows(choice):
941
  return (
942
+ gr.update(visible=(choice == "TF-IDF")),
943
+ gr.update(visible=(choice in ("ClinicalBERT", "DutchBERT")))
944
  )
945
  featur.change(_toggle_rows, inputs=featur, outputs=[tfidf_row, bert_row])
946
 
947
+ # ===== Interactie-functies =====
948
  def _update_eval(t):
949
  if GLOBAL["eval"] is None:
950
  return None, None, None, None, None
 
953
  thr_fig_new = make_threshold_metrics_fig(y_true, y_score, thr_line=float(t))
954
  cm_plot_new = make_confusion_heatmap(y_true, y_score, thr=float(t))
955
  return rep, cm, cm_plot_new, rep_md_text, thr_fig_new
956
+
957
  thr.release(_update_eval, inputs=thr, outputs=[rep_df, cm_df, cm_plot, rep_md, thr_plot])
958
 
959
+ # Datavoorbeeld wisselen
960
  def _refresh_preview(mode):
961
  df = GLOBAL.get("df")
962
  if df is None or not isinstance(df, pd.DataFrame):
963
  return None
964
+ if mode.startswith("Eerste"):
965
+ return df.head(10)
966
+ return df
967
  data_preview_mode.change(_refresh_preview, inputs=data_preview_mode, outputs=data_preview)
968
 
969
+ btn.click(predict_one, inputs=txt, outputs=[md_out, proba_out])
970
 
971
+ # Handmatig trainen (zonder CSV upload)
972
  def _train_quick(featur, max_features_q, ngram_max_q, bert_maxlen_q, bert_batch_q):
973
  return do_train(None, 0.2, 42, featur, int(max_features_q), int(ngram_max_q),
974
  int(bert_maxlen_q), int(bert_batch_q))
 
979
  fig_out_label, fig_out_prob,
980
  rep_df, cm_df, cm_plot, rep_md,
981
  roc_plot, pr_plot, hist_plot, thr_plot,
982
+ cal_plot, gains_plot, lift_plot, ks_plot, profile_df_out]
983
  )
984
 
985
+ # Upload-hertrain
986
  def _retrain(csv_in, test_size, seed, featur, max_features, ngram_max, bert_maxlen, bert_batch):
987
  return do_train(csv_in, test_size, int(seed), featur, int(max_features), int(ngram_max),
988
  int(bert_maxlen), int(bert_batch))
 
993
  fig_out_label, fig_out_prob,
994
  rep_df, cm_df, cm_plot, rep_md,
995
  roc_plot, pr_plot, hist_plot, thr_plot,
996
+ cal_plot, gains_plot, lift_plot, ks_plot, profile_df_out]
997
  )
998
 
999
+ # ---- Dimensie-toggle werkt op beide projecties ----
1000
  def _update_projection(dim):
1001
  pdf = GLOBAL.get("plot_df")
1002
  if pdf is None:
1003
  return None, None
1004
+ fig_lbl = make_scatter(pdf, color_mode="label", dim=dim)
1005
+ fig_prb = make_scatter(pdf, color_mode="kans", dim=dim)
1006
+ return fig_lbl, fig_prb
1007
+
1008
  proj_dim.change(_update_projection, inputs=proj_dim, outputs=[fig_out_label, fig_out_prob])
1009
 
1010
+ # ---- Auto-train bij openen met TF-IDF ----
1011
  def _auto_train():
1012
  try:
1013
  return do_train(None, 0.2, 42, "TF-IDF", 4000, 2, 128, 16)
1014
  except Exception as e:
1015
+ return (f"❌ Fout bij laden/trainen: `{e}`",
1016
+ None, None, None,
1017
+ None, None,
1018
+ None, None, None, None,
1019
+ None, None, None, None,
1020
+ None, None, None, None, None)
 
1021
 
1022
  demo.load(_auto_train, inputs=None,
1023
  outputs=[status, auroc_box, auprc_box, data_preview,
1024
  fig_out_label, fig_out_prob,
1025
  rep_df, cm_df, cm_plot, rep_md,
1026
  roc_plot, pr_plot, hist_plot, thr_plot,
1027
+ cal_plot, gains_plot, lift_plot, ks_plot, profile_df_out])
1028
 
1029
+ # --- Explainability tab/accordion ---
1030
  with gr.Accordion("🪄 Uitleg (Explainability)", open=False):
1031
+ gr.Markdown("Leg uit waarom het model een voorspelling maakt (LIME).")
1032
  with gr.Row():
1033
+ txt_explain = gr.Textbox(lines=4, label="Tekst om uit te leggen",
1034
+ placeholder="Plak hier een rapportage voor uitleg")
1035
  btn_explain = gr.Button("Genereer uitleg")
1036
  lime_html = gr.HTML(label="LIME uitleg (per voorbeeld)")
1037
+
1038
+ # Optioneel: globale top-woorden (alleen TF-IDF)
1039
  top_pos_df = gr.Dataframe(headers=["Top pro-agressie woorden"], row_count=5)
1040
  top_neg_df = gr.Dataframe(headers=["Top anti-agressie woorden"], row_count=5)
1041
 
1042
+ def _do_explain(text):
1043
  if GLOBAL["pipe"] is None:
1044
  return "Train eerst een model.", None, None
1045
+ html = lime_explain_text(GLOBAL["pipe"], text, num_features=8)
 
1046
  pos, neg = tfidf_global_top_words(GLOBAL["pipe"], k=15)
1047
  pos = [[w] for w in pos] if pos else None
1048
  neg = [[w] for w in neg] if neg else None
1049
  return html, pos, neg
1050
 
1051
+ btn_explain.click(_do_explain, inputs=txt_explain, outputs=[lime_html, top_pos_df, top_neg_df])
1052
 
1053
  gr.Markdown(FOOTER)
1054