File size: 46,085 Bytes
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
be95d70
 
a2cdcd5
38df674
a2cdcd5
 
 
b73f440
 
be95d70
 
38df674
be95d70
 
328c923
 
 
be95d70
38df674
 
40f6ba2
38df674
 
6925d79
40f6ba2
6925d79
40f6ba2
a2cdcd5
 
 
40f6ba2
6925d79
 
 
 
 
 
 
 
a2cdcd5
38df674
 
40f6ba2
b73f440
 
40f6ba2
8e4e1ea
38df674
40f6ba2
8e4e1ea
66f101c
40f6ba2
 
 
 
 
 
 
 
 
 
 
9939972
38df674
40f6ba2
9939972
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
 
69fa73d
 
40f6ba2
69fa73d
 
40f6ba2
 
38df674
 
be95d70
40f6ba2
 
 
 
38df674
 
40f6ba2
a2cdcd5
 
 
be95d70
 
 
40f6ba2
be95d70
 
 
40f6ba2
 
be95d70
 
6925d79
be95d70
66238e3
be95d70
40f6ba2
 
be95d70
 
38df674
be95d70
 
 
 
38df674
be95d70
38df674
 
 
 
a2cdcd5
c93ba77
40f6ba2
 
 
 
 
 
 
 
 
 
6925d79
 
 
 
 
 
 
40f6ba2
6925d79
 
6233851
6925d79
 
 
 
40f6ba2
9a53826
40f6ba2
6925d79
40f6ba2
 
 
6925d79
 
 
 
c93ba77
 
6925d79
 
 
40f6ba2
 
 
 
 
 
 
 
 
6925d79
c93ba77
6925d79
a2cdcd5
 
 
 
 
 
 
40f6ba2
a2cdcd5
40f6ba2
c6fa454
40f6ba2
 
 
 
a2cdcd5
 
40f6ba2
 
 
 
 
 
 
 
 
 
2b0df54
40f6ba2
 
 
 
2b0df54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f6ba2
2b0df54
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b0df54
40f6ba2
2b0df54
40f6ba2
9ba0865
 
 
 
40f6ba2
 
 
 
 
 
 
 
 
 
9ba0865
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba0865
 
40f6ba2
328c923
 
 
 
 
40f6ba2
328c923
 
 
 
 
 
40f6ba2
328c923
 
 
 
 
 
 
 
40f6ba2
328c923
 
 
 
 
 
 
40f6ba2
 
 
 
 
 
328c923
 
 
 
 
 
40f6ba2
328c923
 
40f6ba2
 
 
c6fa454
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
c6fa454
 
 
40f6ba2
 
 
 
 
 
 
c6fa454
40f6ba2
 
 
 
 
 
 
 
 
 
c6fa454
 
 
40f6ba2
 
c6fa454
40f6ba2
 
 
 
 
 
 
 
 
c6fa454
 
 
40f6ba2
 
 
 
 
 
 
 
 
 
c6fa454
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
c6fa454
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6fa454
a2cdcd5
6925d79
 
 
 
328c923
6925d79
 
 
 
 
40f6ba2
38df674
 
be95d70
 
38df674
be95d70
a561ed9
40f6ba2
 
 
c6fa454
 
40f6ba2
6925d79
6233851
 
40f6ba2
 
6925d79
 
40f6ba2
c93ba77
6233851
 
40f6ba2
 
c6fa454
 
40f6ba2
6925d79
40f6ba2
38df674
 
 
 
40f6ba2
38df674
40f6ba2
66238e3
40f6ba2
 
 
d881680
a2cdcd5
66238e3
38df674
40f6ba2
 
 
d881680
 
fca118f
 
d881680
38df674
 
a2cdcd5
d881680
38df674
 
40f6ba2
38df674
be95d70
38df674
 
 
 
 
 
 
40f6ba2
38df674
d881680
40f6ba2
 
 
 
 
d881680
38df674
 
d881680
 
 
40f6ba2
 
 
 
d881680
40f6ba2
 
 
 
 
d881680
 
 
40f6ba2
 
 
 
 
d881680
40f6ba2
 
 
 
 
 
d881680
 
40f6ba2
 
 
 
 
 
38df674
d881680
40f6ba2
 
 
 
 
 
 
 
 
38df674
 
40f6ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cdcd5
6925d79
 
 
c147389
40f6ba2
6925d79
 
 
 
 
38df674
40f6ba2
6925d79
be95d70
a2cdcd5
40f6ba2
a2cdcd5
 
 
40f6ba2
a2cdcd5
 
 
 
 
 
 
 
 
6925d79
40f6ba2
328c923
40f6ba2
78d674f
328c923
40f6ba2
78d674f
 
 
 
328c923
40f6ba2
d881680
 
c147389
40f6ba2
 
 
 
 
 
 
c6fa454
40f6ba2
 
c6fa454
40f6ba2
 
c147389
40f6ba2
0228efc
78d674f
40f6ba2
78d674f
9ba0865
c6fa454
40f6ba2
0228efc
38df674
40f6ba2
38df674
 
 
 
40f6ba2
 
6925d79
69fa73d
40f6ba2
69fa73d
6925d79
38df674
 
a2cdcd5
38df674
40f6ba2
8e4e1ea
 
40f6ba2
0228efc
 
40f6ba2
27fee84
0228efc
40f6ba2
 
 
 
 
 
 
 
 
 
 
0228efc
40f6ba2
 
 
 
5a94fc8
 
0228efc
 
 
40f6ba2
c934531
40f6ba2
 
 
 
38df674
40f6ba2
6233851
 
 
40f6ba2
a2cdcd5
6233851
 
 
 
 
 
 
 
0228efc
6233851
9939972
38df674
78d674f
 
38df674
40f6ba2
9939972
 
 
40f6ba2
d881680
519c226
d881680
 
9a53826
69fa73d
9939972
c147389
40f6ba2
 
 
 
 
9a53826
 
6459344
9939972
38df674
2b0df54
9ba0865
2b0df54
38df674
40f6ba2
6459344
 
69fa73d
7bfb1de
 
6459344
 
328c923
 
 
 
40f6ba2
c6fa454
40f6ba2
c6fa454
 
 
 
 
 
 
40f6ba2
6459344
 
 
40f6ba2
 
 
 
27fee84
6459344
 
9939972
40f6ba2
2b0df54
40f6ba2
 
 
 
 
2b0df54
 
 
 
40f6ba2
2b0df54
 
 
 
 
 
27fee84
6233851
40f6ba2
5fa54dc
 
 
 
40f6ba2
 
 
 
 
 
5fa54dc
 
40f6ba2
6233851
 
40f6ba2
 
6233851
 
9939972
6925d79
 
40f6ba2
 
6925d79
 
 
40f6ba2
9939972
 
9ba0865
9939972
2b0df54
328c923
9ba0865
 
40f6ba2
9ba0865
6233851
40f6ba2
c147389
 
 
 
40f6ba2
 
 
c147389
 
40f6ba2
9939972
40f6ba2
6233851
 
 
 
 
 
78d674f
9ba0865
 
c6fa454
40f6ba2
6233851
 
40f6ba2
6925d79
 
 
9939972
 
6925d79
78d674f
9ba0865
 
c6fa454
40f6ba2
9939972
38df674
40f6ba2
d881680
 
 
 
40f6ba2
 
 
 
d881680
 
40f6ba2
38df674
be95d70
6925d79
be95d70
40f6ba2
 
 
 
 
 
38df674
328c923
78d674f
9ba0865
 
c6fa454
40f6ba2
38df674
40f6ba2
a2cdcd5
40f6ba2
a2cdcd5
40f6ba2
 
a2cdcd5
 
40f6ba2
 
a2cdcd5
 
 
40f6ba2
a2cdcd5
 
40f6ba2
a2cdcd5
 
 
 
 
40f6ba2
a2cdcd5
38df674
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
# app.py — GGZ Agressie (synthetisch) — One-page UI
# - Auto-train bij openen met TF-IDF
# - Handmatig trainen zonder CSV upload: kies TF-IDF / ClinicalBERT / DutchBERT
# - (Optioneel) Hertrain met eigen CSV (nu altijd zichtbaar)
# - MLflow experiment tracking + LIME explainability tab
# - Confusion matrix met betekenislabels + Markdown-uitleg bij classification report
# - Extra: Confusion-matrix heatmap-plot onder de tabel
# - Evaluatieplots links (met datavoorbeeld erboven); Predict rechts
# - Visualisatie: 2D/3D-projecties (label & kans) + afbeelding direct onder kans-plot
# - Classification report met eenheden (% en aantallen)
# - Datavoorbeeld: eerste 10 rijen of hele dataset (scrollbaar via CSS)
# - Extra tabs: Kalibratie, Cumulative Gains, Lift, KS-curve, Dataset-profiel

import os
import typing as _t
import numpy as np
import pandas as pd
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path

from huggingface_hub import hf_hub_download
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    classification_report, confusion_matrix,
    precision_score, recall_score, f1_score,
    roc_curve, precision_recall_curve
)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.calibration import calibration_curve

# --- NEW: experiment tracking + explainability ---
import mlflow, mlflow.sklearn
from lime.lime_text import LimeTextExplainer

# --- Optional DL deps (voor BERT) ---
try:
    import torch
    from transformers import AutoTokenizer, AutoModel
except Exception:
    torch = None
    AutoTokenizer = None
    AutoModel = None

# ============ Config & Intro ============
DEFAULT_CSV = "synthetische_ggz_agressie_dataset_1000.csv"

# Afbeelding die direct onder de 2D/3D-kans-plot verschijnt (bestand naast app.py)
INFO_IMAGE = str(Path(__file__).resolve().parent / "imglk;l;kl.png")

# Volledige-breedte koptekst
SLOGAN = "Studieobject Marcel Ooms: Veiligere zorg begint hier: het 30-dagenrisico op agressie onderbouwd en uitlegbaar."

# Gebruikersvriendelijke intro: alleen kop vet
INTRO = """
**Van verslag naar risico: kans op agressie in de komende 30 dagen**
Wat doet deze pagina voor jou?
Deze demo helpt om uit vrije-tekstrapportages snel een inschatting van het risico op agressief gedrag in de komende 30 dagen te krijgen. Plak een stukje verslag in het tekstvak en je krijgt een kans (probabiliteit) terug, plus een voorgesteld label op basis van een drempel die je zelf kunt verschuiven. Zo kun je risico vroegtijdig signaleren en bepalen welke acties passen: extra observatie, bijsturing in het behandelplan of overleg in het team.
Hoe werkt het in grote lijnen (zonder technisch gedoe):
- Bij het openen staat er al een startmodel klaar.
- Je kunt hertrainen met drie aanpakken: TF-IDF, ClinicalBERT of DutchBERT.
- De grafieken laten zien hoe nauwkeurig het model is en hoe de drempel precision en recall beïnvloedt.
- Met LIME zie je welke woorden in de tekst het meest hebben bijgedragen aan de inschatting; dat maakt de uitkomst uitlegbaar.
Belangrijk om te weten:
- Dit is een demonstratie op synthetische data. De uitkomst is een waarschijnlijkheid, geen zekerheid.
- Het systeem voorspelt niet of iemand agressief wordt, maar schat de kans binnen 30 dagen in op basis van tekstsignalen.
- Gebruik de uitkomst altijd naast klinische expertise en bestaande veiligheidsprotocollen.
"""

# Herschreven rechter tekstblok: alleen kopjes vet
WHAT_YOU_SEE = """
**Wat zie je op deze pagina?**
**Status & prestaties**  
Hier zie je hoe goed het model onderscheid maakt. AUROC en AUPRC tonen in één oogopslag hoe betrouwbaar de inschatting is; hoger is beter.
**Handmatig trainen (zonder upload)**  
Kies een featurizer (TF-IDF, ClinicalBERT of DutchBERT) en klik op Train algoritme. Je kunt opties aanpassen en direct vergelijken wat in jouw setting het beste werkt.
**Visualisatie**  
De interactieve 2D/3D-plot laat elke tekst als een punt zien. Kleur en positie helpen om patronen te herkennen; met de muis zie je extra uitleg per punt. Er zijn twee weergaven: kleur naar werkelijk label en kleur naar voorspelde kans.
**Evaluatie**  
Met de drempel-schuif bepaal je wanneer “hoog risico” wordt toegekend. Je ziet wat dat betekent voor precision, recall en F1. Zo kun je kiezen tussen minder valse alarmen of meer signalen oppikken.
**Predict**  
Plak een rapportage in het tekstvak en krijg meteen een kans en een voorgesteld label. Het is een hulpmiddel voor vroegtijdige signalering, geen definitieve uitspraak.
**Hertrain met eigen CSV**  
Upload een CSV met de juiste kolommen en train het model opnieuw. De nieuwe prestaties en grafieken worden direct bijgewerkt.
"""

# Verhaal over ML dat direct onder de afbeelding komt: alleen kop vet
ML_STORY = """
**Van ruwe data naar beslisinformatie**
De afbeelding schetst de weg van ruwe data naar beslisinformatie. We starten met tekst: observaties, verslagen en notities. Met historische labels leert een algoritme patronen herkennen. In de verwerking wordt tekst omgezet naar kenmerken (bijvoorbeeld TF-IDF of BERT-embeddings) en leert het model welke combinaties iets zeggen over het risico op agressie binnen 30 dagen.
Het resultaat is een waarschijnlijkheid, geen absolute waarheid. Die kans helpt teams om eerder te signaleren en bewust te kiezen: wil je minder valse alarmen (hogere precision) of juist meer signaal oppikken (hogere recall)? De mens blijft aan het roer: de uitkomst is uitlegbaar met LIME, meetbaar met AUROC/AUPRC en bedoeld om het klinisch oordeel te ondersteunen.
"""

FOOTER = """
**Technische noot**
Modellen: TF-IDF → Logistic Regression; ClinicalBERT/DutchBERT → Logistic Regression  
Visualisatie: SVD(50) → t-SNE(2D/3D) op de gekozen tekstfeatures  
CSV-loader: lokaal (map van dit bestand) of via Hugging Face Hub
"""

# MLflow experiment
mlflow.set_experiment("ggz-agressie")

# ============ Data loading ============
def _resolve_csv_path(uploaded=None):
    if uploaded is not None:
        return uploaded.name if hasattr(uploaded, "name") else uploaded
    candidates = [
        os.path.join(os.getcwd(), DEFAULT_CSV),
        os.path.join(os.path.dirname(__file__), DEFAULT_CSV),
        DEFAULT_CSV,
    ]
    for p in candidates:
        if os.path.exists(p):
            return p
    repo_id = os.environ.get("SPACE_ID")
    if repo_id:
        return hf_hub_download(repo_id=repo_id, filename=DEFAULT_CSV)
    raise FileNotFoundError(
        f"Kon {DEFAULT_CSV} niet vinden. Zet het bestand in de repo-root "
        "of upload een CSV met kolommen `rapportage` en `agressie_volgende30d`."
    )

def load_dataset(file_obj=None):
    path = _resolve_csv_path(file_obj)
    df = pd.read_csv(path)
    required = {"rapportage", "agressie_volgende30d"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"CSV mist verplichte kolommen: {missing}")
    df = df.dropna(subset=["rapportage", "agressie_volgende30d"]).copy()
    df["agressie_volgende30d"] = (df["agressie_volgende30d"].astype(int) > 0).astype(int)
    return df

# ============ HF Text Embedder ============
class HFTextEmbedder(BaseEstimator, TransformerMixin):
    """
    Sklearn-compatibele transformer die sentence-embeddings maakt met een HF encoder.
    - Mean-pooling over token embeddings (mask-aware)
    - Batching en device auto-select
    """
    def __init__(self,
                 model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
                 max_length: int = 128,
                 batch_size: int = 16,
                 device: _t.Optional[str] = None):
        self.model_name = model_name
        self.max_length = max_length
        self.batch_size = batch_size
        self.device = device
        self._tokenizer = None
        self._model = None
        self._dev = None

    def _ensure_backend(self):
        if torch is None or AutoTokenizer is None or AutoModel is None:
            raise RuntimeError("BERT-embeddings vereisen 'torch' en 'transformers'.")
        self._dev = self.device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if self._model is None:
            self._model = AutoModel.from_pretrained(self.model_name).to(self._dev)
        self._model.eval()

    def fit(self, X, y=None):
        self._ensure_backend()
        return self

    @torch.no_grad()
    def transform(self, X):
        self._ensure_backend()
        texts = pd.Series(X).astype(str).tolist()
        if not texts:
            return np.zeros((0, 768), dtype=np.float32)
        embs = []
        for i in range(0, len(texts), self.batch_size):
            batch = texts[i:i+self.batch_size]
            toks = self._tokenizer(
                batch, padding=True, truncation=True,
                max_length=self.max_length, return_tensors="pt"
            ).to(self._dev)
            outs = self._model(**toks).last_hidden_state  # (B, T, H)
            mask = toks.attention_mask.unsqueeze(-1)      # (B, T, 1)
            summed = (outs * mask).sum(dim=1)             # (B, H)
            counts = mask.sum(dim=1).clamp(min=1)         # (B, 1)
            pooled = summed / counts                      # (B, H)
            embs.append(pooled.cpu().numpy())
        return np.vstack(embs)

# ============ Explainability helpers ============
def _clf_and_vectorizer_from_pipe(pipe):
    vec = pipe.named_steps.get("txt")
    clf = pipe.named_steps.get("clf")
    return vec, clf

def tfidf_global_top_words(pipe, k=20):
    """Top-k 'pro-agressie' en 'anti-agressie' woorden (alleen bij TF-IDF)."""
    vec, clf = _clf_and_vectorizer_from_pipe(pipe)
    if not hasattr(vec, "get_feature_names_out"):
        return [], []
    feature_names = np.array(vec.get_feature_names_out())
    coefs = clf.coef_[0]
    top_pos_idx = np.argsort(coefs)[-k:][::-1]
    top_neg_idx = np.argsort(coefs)[:k]
    return list(feature_names[top_pos_idx]), list(feature_names[top_neg_idx])

_lime_explainer = LimeTextExplainer(class_names=["geen agressie", "agressie"])
def lime_explain_text(pipe, text, num_features=8):
    def predict_proba_text(texts):
        p1 = pipe.predict_proba(texts)[:, 1]
        p0 = 1 - p1
        return np.vstack([p0, p1]).T
    exp = _lime_explainer.explain_instance(text, predict_proba_text, num_features=num_features)
    return exp.as_html()

# ============ Metrics helpers ============
def _format_confusion_df(cm: np.ndarray) -> pd.DataFrame:
    """
    Maakt een confusion-matrix dataframe met uitleg per cel (TN/FP/FN/TP).
    Klassen: 0 = 'geen agressie', 1 = 'agressie'.
    """
    if cm.shape != (2, 2):
        return pd.DataFrame(cm, index=["True 0", "True 1"], columns=["Pred 0", "Pred 1"])
    tn, fp, fn, tp = cm.ravel()
    data = [
        [f"{tn} — TN (True Negatives: echte negatieven)",
         f"{fp} — FP (False Positives: fout-positieven)"],
        [f"{fn} — FN (False Negatives: fout-negatieven)",
         f"{tp} — TP (True Positives: echte positieven)"]
    ]
    idx = ["True 0 (geen agressie)", "True 1 (agressie)"]
    cols = ["Pred 0 (geen agressie)", "Pred 1 (agressie)"]
    return pd.DataFrame(data, index=idx, columns=cols)

def _build_report_markdown(rep: dict, thr: float) -> str:
    acc = rep.get("accuracy", 0)
    macro = rep.get("macro avg", {})
    weighted = rep.get("weighted avg", {})
    s0 = int(rep.get("0", {}).get("support", 0))
    s1 = int(rep.get("1", {}).get("support", 0))
    md = f"""
### ℹ️ Uitleg bij het classification report (drempel = {thr:.2f})
Klasselabels  
0 = geen agressie, 1 = agressie.  
De drempel bepaalt wanneer de kans wordt omgezet naar label 1 (≥ drempel) of 0 (< drempel).
Velden in het rapport  
Precision: van alle voorspelde positieven (label 1), welk deel was echt positief?  
Recall (sensitiviteit): van alle werkelijk positieven (label 1), welk deel hebben we gevonden?  
F1-score: harmonisch gemiddelde van precision en recall.  
Support: aantal voorbeelden per klasse.  
Accuracy: (TP + TN) / totaal — gevoelig voor class imbalance.  
Macro avg: ongewogen gemiddelde over klassen.  
Weighted avg: gewogen gemiddelde (weging = support).
Huidige set (support/accuracy)  
Support klasse 0: {s0}, klasse 1: {s1}.  
Accuracy (totaal): {acc:.3f}.  
Macro avg F1: {macro.get('f1-score', 0):.3f}, Weighted avg F1: {weighted.get('f1-score', 0):.3f}.
Drempel-tips  
Drempel omhoog → vaak hogere precision maar lagere recall.  
Drempel omlaag → vaak hogere recall maar lagere precision.
"""
    return md

# Visuele confusion-matrix (heatmap)
def make_confusion_heatmap(y_true, y_score, thr=0.5):
    y_pred = (y_score >= thr).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    z = cm.astype(int)
    xlabels = ["Pred 0", "Pred 1"]
    ylabels = ["True 0", "True 1"]

    fig = go.Figure(
        data=go.Heatmap(
            z=z, x=xlabels, y=ylabels,
            colorscale="Blues", showscale=True
        )
    )
    # Annotaties (TN, FP, FN, TP)
    tn, fp, fn, tp = z.ravel()
    annotations = [
        (0, 0, f"TN: {tn}"),
        (0, 1, f"FP: {fp}"),
        (1, 0, f"FN: {fn}"),
        (1, 1, f"TP: {tp}"),
    ]
    for r, c, text in annotations:
        fig.add_annotation(x=xlabels[c], y=ylabels[r], text=text, showarrow=False)

    fig.update_layout(
        title=f"Confusion matrix (drempel = {thr:.2f})",
        xaxis_title="Voorspelling",
        yaxis_title="Werkelijkheid",
        template="simple_white",
        margin=dict(l=10, r=10, t=40, b=10)
    )
    return fig

# -------- Eval-plots --------
def make_roc_fig(y_true, y_score, auroc=None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    title = f"ROC-curve (AUROC={auroc:.3f})" if auroc is not None else "ROC-curve"
    fig = px.area(x=fpr, y=tpr, title=title, labels={"x":"False Positive Rate", "y":"True Positive Rate"})
    fig.add_shape(type="line", x0=0, x1=1, y0=0, y1=1, line=dict(dash="dash"))
    fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white")
    return fig

def make_pr_fig(y_true, y_score, auprc=None):
    prec, rec, _ = precision_recall_curve(y_true, y_score)
    title = f"Precision–Recall (AUPRC={auprc:.3f})" if auprc is not None else "Precision–Recall"
    fig = px.area(x=rec, y=prec, title=title, labels={"x":"Recall", "y":"Precision"})
    fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white")
    return fig

def make_prob_hist(y_true, y_score):
    df = pd.DataFrame({"kans": y_score, "label": np.where(y_true==1, "Werkelijk: agressie (1)", "Werkelijk: geen agressie (0)")})
    fig = px.histogram(df, x="kans", color="label", barmode="overlay", nbins=40,
                       title="Verdeling voorspelde kansen per werkelijke klasse",
                       labels={"kans":"Voorspelde kans"})
    fig.update_traces(opacity=0.6)
    fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white")
    return fig

def make_threshold_metrics_fig(y_true, y_score, thr_line=0.5):
    thresholds = np.linspace(0.0, 1.0, 101)
    rows = []
    for t in thresholds:
        y_pred = (y_score >= t).astype(int)
        rows.append({
            "threshold": t,
            "precision": precision_score(y_true, y_pred, zero_division=0),
            "recall": recall_score(y_true, y_pred, zero_division=0),
            "f1": f1_score(y_true, y_pred, zero_division=0),
        })
    df = pd.DataFrame(rows)
    df_m = df.melt(id_vars="threshold", value_vars=["precision","recall","f1"], var_name="metric", value_name="score")
    fig = px.line(df_m, x="threshold", y="score", color="metric",
                  title="Metrics vs. drempel (precision/recall/F1)",
                  labels={"threshold":"Drempel", "score":"Score"})
    fig.add_vline(x=float(thr_line), line_dash="dash", annotation_text=f"drempel={thr_line:.2f}", annotation_position="top")
    fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), template="simple_white", yaxis=dict(range=[0,1]))
    return fig

# -------- Extra evaluaties: Kalibratie / Gains / Lift / KS --------
def make_calibration_fig(y_true, y_score, n_bins=10):
    frac_pos, mean_pred = calibration_curve(y_true, y_score, n_bins=n_bins, strategy="quantile")
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=[0,1], y=[0,1], mode="lines", name="Perfect gekalibreerd", line=dict(dash="dash")))
    fig.add_trace(go.Scatter(x=mean_pred, y=frac_pos, mode="lines+markers", name="Model"))
    fig.update_layout(
        title="Kalibratie (Reliability Diagram)",
        xaxis_title="Gemiddelde voorspelde kans",
        yaxis_title="Werkelijk aandeel positieven",
        template="simple_white",
        margin=dict(l=10, r=10, t=40, b=10)
    )
    return fig

def _gains_data(y_true, y_score):
    df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p", ascending=False).reset_index(drop=True)
    df["cum_pos"] = df["y"].cumsum()
    total_pos = df["y"].sum()
    total = len(df)
    pct_samples = (np.arange(1, total+1) / total)
    cum_gain = (df["cum_pos"] / (total_pos if total_pos > 0 else 1))
    return pct_samples, cum_gain

def make_gains_fig(y_true, y_score):
    x, gains = _gains_data(y_true, y_score)
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=x, mode="lines", name="Baseline (random)", line=dict(dash="dash")))
    fig.add_trace(go.Scatter(x=x, y=gains, mode="lines", name="Cumulative Gains"))
    fig.update_layout(
        title="Cumulative Gains",
        xaxis_title="Percentage van populatie (gesorteerd op kans)",
        yaxis_title="Percentage van positieven gedekt",
        template="simple_white",
        margin=dict(l=10, r=10, t=40, b=10),
        yaxis=dict(range=[0,1]), xaxis=dict(range=[0,1])
    )
    return fig

def make_lift_fig(y_true, y_score):
    x, gains = _gains_data(y_true, y_score)
    lift = gains / np.clip(x, 1e-9, None)
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=np.ones_like(x), mode="lines", name="Baseline (lift=1)", line=dict(dash="dash")))
    fig.add_trace(go.Scatter(x=x, y=lift, mode="lines", name="Lift"))
    fig.update_layout(
        title="Lift-curve",
        xaxis_title="Percentage van populatie (gesorteerd op kans)",
        yaxis_title="Lift",
        template="simple_white",
        margin=dict(l=10, r=10, t=40, b=10)
    )
    return fig

def make_ks_fig(y_true, y_score):
    df = pd.DataFrame({"y": y_true, "p": y_score}).sort_values("p", ascending=False).reset_index(drop=True)
    total_pos = df["y"].sum()
    total_neg = len(df) - total_pos
    df["tp_cum"] = df["y"].cumsum() / (total_pos if total_pos > 0 else 1)
    df["fp_cum"] = ((1 - df["y"]).cumsum()) / (total_neg if total_neg > 0 else 1)
    ks_series = (df["tp_cum"] - df["fp_cum"]).abs()
    ks_max_idx = int(ks_series.values.argmax()) if len(ks_series) else 0
    ks_value = float(ks_series.iloc[ks_max_idx]) if len(ks_series) else 0.0
    x = (np.arange(1, len(df)+1) / len(df)) if len(df) else np.array([0])

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=df["tp_cum"], mode="lines", name="TPR cumulatief"))
    fig.add_trace(go.Scatter(x=x, y=df["fp_cum"], mode="lines", name="FPR cumulatief"))
    if len(x):
        fig.add_vline(x=float(x[ks_max_idx]), line_dash="dash",
                      annotation_text=f"KS={ks_value:.3f}", annotation_position="top")
    fig.update_layout(
        title="KS-curve",
        xaxis_title="Percentage van populatie (gesorteerd op kans)",
        yaxis_title="Cumulatieve ratio",
        template="simple_white",
        margin=dict(l=10, r=10, t=40, b=10),
        yaxis=dict(range=[0,1]), xaxis=dict(range=[0,1])
    )
    return fig

def make_dataset_profile(df):
    text = df["rapportage"].astype(str)
    lengths = text.str.len()
    pos = df["agressie_volgende30d"].astype(int)
    prof = pd.DataFrame({
        "kenmerk": [
            "Aantal rijen",
            "Aantal positieven (1)",
            "Aantal negatieven (0)",
            "Positiefratio",
            "Tekstlengte — gemiddeld",
            "Tekstlengte — mediaan",
            "Tekstlengte — p10",
            "Tekstlengte — p90",
        ],
        "waarde": [
            int(len(df)),
            int(pos.sum()),
            int((1 - pos).sum()),
            f"{(pos.mean()*100):.1f}%",
            f"{lengths.mean():.1f}",
            int(lengths.median()),
            int(np.percentile(lengths, 10)),
            int(np.percentile(lengths, 90)),
        ]
    })
    return prof

# ============ Model & Viz ============
def build_and_train(
    df,
    test_size=0.2,
    random_state=42,
    featurizer="TF-IDF",
    max_features=4000,
    ngram_max=2,
    bert_maxlen=128,
    bert_batch=16
):
    X = df["rapportage"].astype(str).values
    y = df["agressie_volgende30d"].values
    X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
        X, y, np.arange(len(X)),
        test_size=test_size, random_state=random_state, stratify=y
    )

    if featurizer == "TF-IDF":
        txt = TfidfVectorizer(max_features=max_features, ngram_range=(1, ngram_max))
        clf = LogisticRegression(max_iter=3000)
        pipe = Pipeline([("txt", txt), ("clf", clf)])
        pipe.fit(X_train, y_train)
        y_score = pipe.predict_proba(X_test)[:, 1]
        txt_all = pipe.named_steps["txt"].transform(X)  # sparse
    elif featurizer == "ClinicalBERT":
        emb = HFTextEmbedder(model_name="emilyalsentzer/Bio_ClinicalBERT",
                             max_length=bert_maxlen, batch_size=bert_batch)
        clf = LogisticRegression(max_iter=3000)
        pipe = Pipeline([("txt", emb), ("clf", clf)])
        pipe.fit(X_train, y_train)
        y_score = pipe.predict_proba(X_test)[:, 1]
        txt_all = pipe.named_steps["txt"].transform(X)  # dense
    elif featurizer == "DutchBERT":
        emb = HFTextEmbedder(model_name="wietsedv/bert-base-dutch-cased",
                             max_length=bert_maxlen, batch_size=bert_batch)
        clf = LogisticRegression(max_iter=3000)
        pipe = Pipeline([("txt", emb), ("clf", clf)])
        pipe.fit(X_train, y_train)
        y_score = pipe.predict_proba(X_test)[:, 1]
        txt_all = pipe.named_steps["txt"].transform(X)  # dense
    else:
        raise ValueError("Onbekende featurizer. Kies 'TF-IDF', 'ClinicalBERT' of 'DutchBERT'.")

    auroc = float(roc_auc_score(y_test, y_score))
    auprc = float(average_precision_score(y_test, y_score))

    # 2D/3D embedding: SVD (50) -> t-SNE (2D en 3D)
    svd = TruncatedSVD(n_components=50, random_state=random_state)
    X50 = svd.fit_transform(txt_all)

    # t-SNE 2D
    tsne2 = TSNE(n_components=2, random_state=random_state, perplexity=30,
                 learning_rate="auto", init="pca")
    X2 = tsne2.fit_transform(X50)
    x2 = (X2[:, 0] - np.min(X2[:, 0])) / (np.ptp(X2[:, 0]) + 1e-9)
    y2 = (X2[:, 1] - np.min(X2[:, 1])) / (np.ptp(X2[:, 1]) + 1e-9)

    # t-SNE 3D
    tsne3 = TSNE(n_components=3, random_state=random_state, perplexity=30,
                 learning_rate="auto", init="pca")
    X3 = tsne3.fit_transform(X50)
    x3 = (X3[:, 0] - np.min(X3[:, 0])) / (np.ptp(X3[:, 0]) + 1e-9)
    y3 = (X3[:, 1] - np.min(X3[:, 1])) / (np.ptp(X3[:, 1]) + 1e-9)
    z3 = (X3[:, 2] - np.min(X3[:, 2])) / (np.ptp(X3[:, 2]) + 1e-9)

    proba_all = pipe.predict_proba(X)[:, 1]
    plot_df = pd.DataFrame({
        "x": x2, "y": y2,
        "x3": x3, "y3": y3, "z3": z3,
        "label": df["agressie_volgende30d"].values,
        "kans": proba_all,
        "rapportage": df["rapportage"].str.slice(0, 180) + "..."
    })
    for col in ["PHQ9_baseline","GAD7_baseline","stress_niveau_1_5","slaap_uren","sociale_steun_0_10","zorgsetting"]:
        if col in df.columns:
            plot_df[col] = df[col]

    test_mask = np.zeros(len(plot_df), dtype=bool)
    test_mask[idx_test] = True
    plot_df["split"] = np.where(test_mask, "test", "train")

    return pipe, (X_test, y_test, y_score), plot_df, auroc, auprc

def make_scatter(plot_df, color_mode="label", dim="2D"):
    """
    Algemene scattermaker:
    - color_mode: 'label' of 'kans'
    - dim: '2D' of '3D'
    """
    hover_cols = ["rapportage", "kans", "split"]
    if color_mode == "label":
        color = plot_df["label"].map({0: "geen agressie", 1: "agressie"})
        title_2d = "2D projectie (t-SNE) — kleur = werkelijk label"
        title_3d = "3D projectie (t-SNE) — kleur = werkelijk label"
        if dim == "2D":
            fig = px.scatter(
                plot_df, x="x", y="y", color=color,
                hover_data=hover_cols, title=title_2d, opacity=0.85
            )
        else:
            fig = px.scatter_3d(
                plot_df, x="x3", y="y3", z="z3", color=color,
                hover_data=hover_cols, title=title_3d, opacity=0.9
            )
    else:  # 'kans'
        title_2d = "2D projectie (t-SNE) — kleur = voorspelde kans"
        title_3d = "3D projectie (t-SNE) — kleur = voorspelde kans"
        if dim == "2D":
            fig = px.scatter(
                plot_df, x="x", y="y", color="kans",
                hover_data=hover_cols, title=title_2d,
                color_continuous_scale="Turbo", opacity=0.9
            )
        else:
            fig = px.scatter_3d(
                plot_df, x="x3", y="y3", z="z3", color="kans",
                hover_data=hover_cols, title=title_3d,
                color_continuous_scale="Turbo", opacity=0.9
            )
    # Styling + ASTITELS
    if dim == "2D":
        fig.update_traces(marker=dict(size=8, line=dict(width=0)))
        fig.update_layout(
            margin=dict(l=10, r=10, t=40, b=10),
            template="simple_white",
            xaxis_title="x (t-SNE)",
            yaxis_title="y (t-SNE)"
        )
    else:
        fig.update_traces(marker=dict(size=4))
        fig.update_layout(
            margin=dict(l=10, r=10, t=40, b=10),
            template="simple_white",
            scene=dict(
                xaxis_title="x (t-SNE)",
                yaxis_title="y (t-SNE)",
                zaxis_title="z (t-SNE)"
            )
        )
    return fig

# --- (Niet meer gebruikt) Beslissingslandschap-overlay ---
def make_prob_with_decision_landscape(plot_df, grid_n=150):
    """
    Achtergrond: LR(x,y)->label geeft per gridcel P(klasse=1).
    Voorgrond: punten gekleurd naar model-kans (plot_df['kans']).
    Wordt behouden voor referentie, maar niet meer gebruikt in de UI.
    """
    X2 = plot_df[["x", "y"]].values
    y = plot_df["label"].values.astype(int)

    clf = LogisticRegression(max_iter=2000)
    clf.fit(X2, y)

    gx = np.linspace(0.0, 1.0, grid_n)
    gy = np.linspace(0.0, 1.0, grid_n)
    XX, YY = np.meshgrid(gx, gy)
    grid = np.c_[XX.ravel(), YY.ravel()]
    proba = clf.predict_proba(grid)[:, 1].reshape(XX.shape)

    heat = go.Heatmap(
        x=gx, y=gy, z=proba,
        zmin=0, zmax=1,
        colorscale="Turbo",
        showscale=True,
        colorbar=dict(title="kans (landschap)")
    )
    fig = go.Figure(data=[heat])
    fig.update_layout(
        title="2D projectie (t-SNE) — kleur = voorspelde kans (met beslissingslandschap)",
        template="simple_white",
        margin=dict(l=10, r=10, t=40, b=10),
        xaxis_title="x (t-SNE)", yaxis_title="y (t-SNE)"
    )
    fig.add_trace(go.Scatter(
        x=plot_df["x"], y=plot_df["y"],
        mode="markers",
        marker=dict(
            size=8,
            opacity=0.85,
            color=plot_df["kans"],
            colorscale="Turbo",
            showscale=False,
            line=dict(width=0)
        ),
        text=(
            "kans=" + plot_df["kans"].round(3).astype(str) +
            " | split=" + plot_df["split"].astype(str)
        ),
        hovertemplate="x=%{x:.3f}, y=%{y:.3f}<br>%{text}<extra></extra>",
        name="punten"
    ))
    fig.update_xaxes(range=[0, 1])
    fig.update_yaxes(range=[0, 1])
    return fig

def metrics_table(y_true, y_score, thr):
    """
    Maakt het classification report met eenheden (%, aantallen) voor compacte weergave.
    - precision/recall/f1: percentages met 1 decimaal (bijv. 87.5%)
    - support: integer
    - accuracy: extra kolom 'accuracy_%' met percentage
    """
    y_pred = (y_score >= thr).astype(int)
    rep = classification_report(y_true, y_pred, output_dict=True, zero_division=0)

    rep_df = pd.DataFrame(rep).T
    rep_df_disp = rep_df.copy()

    for col in ["precision", "recall", "f1-score"]:
        if col in rep_df_disp:
            rep_df_disp[col] = (rep_df_disp[col] * 100).round(1).map(
                lambda v: f"{v:.1f}%" if pd.notnull(v) else ""
            )

    if "support" in rep_df_disp:
        rep_df_disp["support"] = rep_df_disp["support"].map(
            lambda v: f"{int(v)}" if pd.notnull(v) else ""
        )

    if "accuracy" in rep:
        acc_pct = f"{rep['accuracy'] * 100:.1f}%"
        rep_df_disp["accuracy_%"] = ""
        if "accuracy" in rep_df_disp.index:
            rep_df_disp.loc["accuracy", "accuracy_%"] = acc_pct

    rep_df_disp = rep_df_disp.fillna("")

    cm = confusion_matrix(y_true, y_pred)
    cm_df = _format_confusion_df(cm)
    rep_md = _build_report_markdown(rep, thr)

    return rep_df_disp, cm_df, rep_md

# ============ State & Train ============
GLOBAL = {
    "pipe": None, "plot_df": None, "eval": None,
    "auroc": None, "auprc": None,
    "featurizer": "TF-IDF",
    "df": None,  # bewaar dataset voor datavoorbeeld
}

def do_train(file_obj=None, test_size=0.2, seed=42,
             featurizer="TF-IDF", max_features=4000, ngram_max=2,
             bert_maxlen=128, bert_batch=16):
    df = load_dataset(file_obj)
    pipe, eval_pack, plot_df, auroc, auprc = build_and_train(
        df, test_size, seed, featurizer, max_features, ngram_max, bert_maxlen, bert_batch
    )

    # MLflow logging
    with mlflow.start_run(run_name=f"{featurizer}"):
        mlflow.log_param("featurizer", featurizer)
        mlflow.log_param("test_size", test_size)
        if featurizer == "TF-IDF":
            mlflow.log_param("tfidf_max_features", max_features)
            mlflow.log_param("tfidf_ngram_max", ngram_max)
        else:
            mlflow.log_param("bert_maxlen", bert_maxlen)
            mlflow.log_param("bert_batch", bert_batch)
        mlflow.log_metric("auroc", auroc)
        mlflow.log_metric("auprc", auprc)
        mlflow.sklearn.log_model(pipe, artifact_path="model")

    GLOBAL.update(pipe=pipe, plot_df=plot_df, eval=eval_pack,
                  auroc=auroc, auprc=auprc, featurizer=featurizer, df=df)

    # Tabel + uitleg
    rep_df, cm_df, rep_md = metrics_table(eval_pack[1], eval_pack[2], thr=0.5)

    # Plots basis
    roc_fig = make_roc_fig(eval_pack[1], eval_pack[2], auroc)
    pr_fig = make_pr_fig(eval_pack[1], eval_pack[2], auprc)
    hist_fig = make_prob_hist(eval_pack[1], eval_pack[2])
    thr_fig = make_threshold_metrics_fig(eval_pack[1], eval_pack[2], thr_line=0.5)

    # Standaard visualisaties: 2D
    fig_label = make_scatter(plot_df, color_mode="label", dim="2D")
    fig_prob  = make_scatter(plot_df, color_mode="kans",  dim="2D")

    # Extra evaluaties
    y_true, y_score = eval_pack[1], eval_pack[2]
    cal_fig  = make_calibration_fig(y_true, y_score, n_bins=10)
    gains_fig = make_gains_fig(y_true, y_score)
    lift_fig  = make_lift_fig(y_true, y_score)
    ks_fig    = make_ks_fig(y_true, y_score)
    profile_df = make_dataset_profile(df)

    # Confusion heatmap op basis van default drempel
    cm_plot = make_confusion_heatmap(y_true, y_score, thr=0.5)

    # Datavoorbeeld (standaard: eerste 10 rijen)
    preview_df = df.head(10)

    status_msg = f"✅ Model getraind met {featurizer}. AUROC: {auroc:.3f} | AUPRC: {auprc:.3f}"
    return (
        status_msg, auroc, auprc,
        preview_df,            # datavoorbeeld output
        fig_label, fig_prob,
        rep_df, cm_df, cm_plot, rep_md,
        roc_fig, pr_fig, hist_fig, thr_fig,
        cal_fig, gains_fig, lift_fig, ks_fig, profile_df
    )

def predict_one(text):
    if GLOBAL["pipe"] is None:
        return "Nog geen model getraind.", None
    if not text or text.strip() == "":
        return "Voer een rapportage in.", None
    proba = float(GLOBAL["pipe"].predict_proba([text])[:, 1][0])
    label = int(proba >= 0.5)
    md = (
        f"Kans op agressie (30d): {proba:.3f} — "
        f"voorspelde klasse: {label} (drempel 0.50)\n"
        f"Featurizer: {GLOBAL.get('featurizer','?')}"
    )
    return md, proba

# ============ UI ============
with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as demo:
    # Volledige-breedte kopregel (h1)
    gr.Markdown(f"# {SLOGAN}")

    # --- opvallende styling voor de knoppen + scrollbare data-preview ---
    gr.HTML("""
    <style>
      /* Zelfde gradient-stijl voor alle 3 knoppen */
      #train-btn, #retrain-btn, #predict-btn {
        background: linear-gradient(90deg, #ef4444 0%, #f97316 100%);
        color: white !important;
        font-weight: 700;
        border: none !important;
      }
      #train-btn:hover, #retrain-btn:hover, #predict-btn:hover {
        filter: brightness(0.95);
      }
      /* Scrollbare DataFrame container */
      #data-preview {
        max-height: 320px;
        overflow: auto;
      }
      #data-preview table {
        width: 100%;
      }
      /* Afbeelding direct onder projectie zonder top-ruimte */
      #viz-img { margin-top: 0 !important; padding-top: 0 !important; }
      #viz-img img { display: block; margin-top: 0 !important; }
    </style>
    """)

    # Introductie & overzicht naast elkaar
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(INTRO)
        with gr.Column(scale=1):
            gr.Markdown(WHAT_YOU_SEE)

    # ---- Handmatig trainen (zonder CSV upload) ----
    gr.Markdown("## 🛠️ Handmatig trainen (zonder CSV upload)")
    with gr.Row():
        featur_quick = gr.Radio(
            choices=["TF-IDF", "ClinicalBERT", "DutchBERT"],
            value="TF-IDF",
            label="Kies featurizer"
        )
    with gr.Row(visible=True) as tfidf_quick_row:
        max_features_q = gr.Slider(1000, 12000, value=4000, step=1000, label="TF-IDF max_features")
        ngram_max_q    = gr.Radio(choices=[1, 2], value=2, label="n-gram max")
    with gr.Row(visible=False) as bert_quick_row:
        bert_maxlen_q = gr.Slider(64, 256, value=128, step=8, label="BERT max_length")
        bert_batch_q  = gr.Slider(4, 64, value=16, step=4, label="BERT batch_size")
    train_quick_btn = gr.Button("Train algoritme", variant="primary", elem_id="train-btn")

    status = gr.Markdown()
    with gr.Row():
        auroc_box = gr.Number(label="AUROC", precision=3)
        auprc_box = gr.Number(label="AUPRC", precision=3)

    # Visualisatie + evaluatie-tabellen
    with gr.Row():
        with gr.Column(scale=3):
            gr.Markdown("### 🔍 Visualisatie")
            # Gezamenlijke toggle voor dimensie
            proj_dim = gr.Radio(choices=["2D", "3D"], value="2D", label="Projectiedimensie (geldt voor beide projecties)")
            with gr.Column():
                fig_out_label = gr.Plot(label="Projectie — kleur = werkelijk label")
                fig_out_prob  = gr.Plot(label="Projectie — kleur = voorspelde kans")
                viz_img = gr.Image(value=INFO_IMAGE, show_label=False, interactive=False, elem_id="viz-img")
                gr.Markdown(ML_STORY)
        with gr.Column(scale=2):
            gr.Markdown("### 📄 Datavoorbeeld")
            data_preview_mode = gr.Radio(
                choices=["Eerste 10 rijen", "Gehele dataset (scrollbaar)"],
                value="Eerste 10 rijen",
                label="Weergave"
            )
            data_preview = gr.Dataframe(label="Dataset", interactive=False, elem_id="data-preview")

            gr.Markdown("### ⚙️ Evaluatie (tabellen & drempel)")
            thr = gr.Slider(0.05, 0.95, value=0.5, step=0.05, label="Drempel (threshold)")
            rep_df = gr.Dataframe(label="Classification report")
            cm_df = gr.Dataframe(label="Confusion matrix (met uitleg)")
            cm_plot = gr.Plot(label="Confusion matrix (heatmap)")
            rep_md = gr.Markdown(label="Uitleg classification report")

    # === Twee kolommen — links plots (met tabs), rechts predict ===
    with gr.Row():
        with gr.Column(scale=3):
            with gr.Tabs():
                with gr.TabItem("Metrics vs. drempel"):
                    thr_plot = gr.Plot(label="Precision/Recall/F1 over drempel")
                with gr.TabItem("Kansverdeling"):
                    hist_plot = gr.Plot(label="Verdeling voorspelde kansen")
                with gr.TabItem("ROC"):
                    roc_plot = gr.Plot(label="ROC-curve")
                with gr.TabItem("Precision–Recall"):
                    pr_plot = gr.Plot(label="PR-curve")
                # ---- Nieuw: extra tabs ----
                with gr.TabItem("Kalibratie"):
                    cal_plot = gr.Plot(label="Kalibratie (Reliability Diagram)")
                with gr.TabItem("Cumulative Gains"):
                    gains_plot = gr.Plot(label="Cumulative Gains")
                with gr.TabItem("Lift"):
                    lift_plot = gr.Plot(label="Lift-curve")
                with gr.TabItem("KS-curve"):
                    ks_plot = gr.Plot(label="KS-curve")
                with gr.TabItem("Dataset-profiel"):
                    profile_df_out = gr.Dataframe(label="Dataset-profiel", interactive=False)
        with gr.Column(scale=2):
            gr.Markdown("### 🗣️ Predict (vrije tekst)")
            with gr.Row():
                txt = gr.Textbox(
                    lines=12, label="Rapportage (NL)",
                    placeholder="Bijv.: Patiënt oogt geagiteerd, slaapt slecht, weigert medicatie..."
                )
            btn = gr.Button("Voorspel", elem_id="predict-btn")
            md_out = gr.Markdown()
            proba_out = gr.Number(label="Kans", precision=3)

    # ===== Hertrain met eigen CSV — ALTIJD ZICHTBAAR =====
    gr.Markdown("## 🔁 Hertrain met eigen CSV")
    gr.Markdown(
        "Upload een CSV met kolommen `rapportage` (tekst) en `agressie_volgende30d` (0/1). "
        "Kies je parameters en klik Train opnieuw (met upload)."
    )
    csv_in = gr.File(label="Upload CSV (kolommen: rapportage, agressie_volgende30d)")
    with gr.Row():
        test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test set grootte")
        seed = gr.Slider(1, 999, value=42, step=1, label="Random seed")
    with gr.Row():
        featur = gr.Radio(choices=["TF-IDF", "ClinicalBERT", "DutchBERT"], value="TF-IDF", label="Tekst-featurizer")
    with gr.Row(visible=True) as tfidf_row:
        max_features = gr.Slider(1000, 12000, value=4000, step=1000, label="TF-IDF max_features")
        ngram_max = gr.Radio(choices=[1, 2], value=2, label="n-gram max")
    with gr.Row(visible=False) as bert_row:
        bert_maxlen = gr.Slider(64, 256, value=128, step=8, label="BERT max_length")
        bert_batch = gr.Slider(4, 64, value=16, step=4, label="BERT batch_size")
    retrain_btn = gr.Button("Train opnieuw (met upload)", elem_id="retrain-btn")

    # << VERPLAATST: uitleg over de evaluatieplots — lager in dezelfde kolom >>
    with gr.Row():
        with gr.Column(scale=2, min_width=0):
            gr.Markdown(
                "### ℹ️ Over de evaluatieplots\n\n"
                "De onderstaande grafieken laten zien hoe het model presteert bij verschillende drempels en uitkomsten:\n\n"
                "- Metrics vs. drempel — toont hoe precision, recall en F1-score veranderen als je de drempel aanpast.\n"
                "- Kansverdeling — laat zien hoe voorspelde kansen verdeeld zijn over de echte klassen (0/1).\n"
                "- ROC-curve — vergelijkt True Positive Rate met False Positive Rate (AUROC = scheidingskracht).\n"
                "- Precision–Recall-curve — nuttig bij ongebalanceerde data; focust op de positieve klasse.\n\n"
                "Gebruik ze samen om te bepalen waar je drempel moet liggen en hoe betrouwbaar het model is."
            )

    # Toggle zichtbaarheid param-rijen
    def _toggle_quick(choice):
        return (
            gr.update(visible=(choice == "TF-IDF")),
            gr.update(visible=(choice in ("ClinicalBERT", "DutchBERT")))
        )
    featur_quick.change(_toggle_quick, inputs=featur_quick, outputs=[tfidf_quick_row, bert_quick_row])

    def _toggle_rows(choice):
        return (
            gr.update(visible=(choice == "TF-IDF")),
            gr.update(visible=(choice in ("ClinicalBERT", "DutchBERT")))
        )
    featur.change(_toggle_rows, inputs=featur, outputs=[tfidf_row, bert_row])

    # ===== Interactie-functies =====
    def _update_eval(t):
        if GLOBAL["eval"] is None:
            return None, None, None, None, None
        y_true, y_score = GLOBAL["eval"][1], GLOBAL["eval"][2]
        rep, cm, rep_md_text = metrics_table(y_true, y_score, t)
        thr_fig_new = make_threshold_metrics_fig(y_true, y_score, thr_line=float(t))
        cm_plot_new = make_confusion_heatmap(y_true, y_score, thr=float(t))
        return rep, cm, cm_plot_new, rep_md_text, thr_fig_new

    thr.release(_update_eval, inputs=thr, outputs=[rep_df, cm_df, cm_plot, rep_md, thr_plot])

    # Datavoorbeeld wisselen
    def _refresh_preview(mode):
        df = GLOBAL.get("df")
        if df is None or not isinstance(df, pd.DataFrame):
            return None
        if mode.startswith("Eerste"):
            return df.head(10)
        return df
    data_preview_mode.change(_refresh_preview, inputs=data_preview_mode, outputs=data_preview)

    btn.click(predict_one, inputs=txt, outputs=[md_out, proba_out])

    # Handmatig trainen (zonder CSV upload)
    def _train_quick(featur, max_features_q, ngram_max_q, bert_maxlen_q, bert_batch_q):
        return do_train(None, 0.2, 42, featur, int(max_features_q), int(ngram_max_q),
                        int(bert_maxlen_q), int(bert_batch_q))
    train_quick_btn.click(
        _train_quick,
        inputs=[featur_quick, max_features_q, ngram_max_q, bert_maxlen_q, bert_batch_q],
        outputs=[status, auroc_box, auprc_box, data_preview,
                 fig_out_label, fig_out_prob,
                 rep_df, cm_df, cm_plot, rep_md,
                 roc_plot, pr_plot, hist_plot, thr_plot,
                 cal_plot, gains_plot, lift_plot, ks_plot, profile_df_out]
    )

    # Upload-hertrain
    def _retrain(csv_in, test_size, seed, featur, max_features, ngram_max, bert_maxlen, bert_batch):
        return do_train(csv_in, test_size, int(seed), featur, int(max_features), int(ngram_max),
                        int(bert_maxlen), int(bert_batch))
    retrain_btn.click(
        _retrain,
        inputs=[csv_in, test_size, seed, featur, max_features, ngram_max, bert_maxlen, bert_batch],
        outputs=[status, auroc_box, auprc_box, data_preview,
                 fig_out_label, fig_out_prob,
                 rep_df, cm_df, cm_plot, rep_md,
                 roc_plot, pr_plot, hist_plot, thr_plot,
                 cal_plot, gains_plot, lift_plot, ks_plot, profile_df_out]
    )

    # ---- Dimensie-toggle werkt op beide projecties ----
    def _update_projection(dim):
        pdf = GLOBAL.get("plot_df")
        if pdf is None:
            return None, None
        fig_lbl = make_scatter(pdf, color_mode="label", dim=dim)
        fig_prb = make_scatter(pdf, color_mode="kans",  dim=dim)
        return fig_lbl, fig_prb

    proj_dim.change(_update_projection, inputs=proj_dim, outputs=[fig_out_label, fig_out_prob])

    # ---- Auto-train bij openen met TF-IDF ----
    def _auto_train():
        try:
            return do_train(None, 0.2, 42, "TF-IDF", 4000, 2, 128, 16)
        except Exception as e:
            return (f"❌ Fout bij laden/trainen: `{e}`",
                    None, None, None,
                    None, None,
                    None, None, None, None,
                    None, None, None, None,
                    None, None, None, None, None)

    demo.load(_auto_train, inputs=None,
              outputs=[status, auroc_box, auprc_box, data_preview,
                       fig_out_label, fig_out_prob,
                       rep_df, cm_df, cm_plot, rep_md,
                       roc_plot, pr_plot, hist_plot, thr_plot,
                       cal_plot, gains_plot, lift_plot, ks_plot, profile_df_out])

    # --- Explainability tab/accordion ---
    with gr.Accordion("🪄 Uitleg (Explainability)", open=False):
        gr.Markdown("Leg uit waarom het model een voorspelling maakt (LIME).")
        with gr.Row():
            txt_explain = gr.Textbox(lines=4, label="Tekst om uit te leggen",
                                     placeholder="Plak hier een rapportage voor uitleg")
            btn_explain = gr.Button("Genereer uitleg")
        lime_html = gr.HTML(label="LIME uitleg (per voorbeeld)")

        # Optioneel: globale top-woorden (alleen TF-IDF)
        top_pos_df = gr.Dataframe(headers=["Top pro-agressie woorden"], row_count=5)
        top_neg_df = gr.Dataframe(headers=["Top anti-agressie woorden"], row_count=5)

        def _do_explain(text):
            if GLOBAL["pipe"] is None:
                return "Train eerst een model.", None, None
            html = lime_explain_text(GLOBAL["pipe"], text, num_features=8)
            pos, neg = tfidf_global_top_words(GLOBAL["pipe"], k=15)
            pos = [[w] for w in pos] if pos else None
            neg = [[w] for w in neg] if neg else None
            return html, pos, neg

        btn_explain.click(_do_explain, inputs=txt_explain, outputs=[lime_html, top_pos_df, top_neg_df])

    gr.Markdown(FOOTER)

if __name__ == "__main__":
    demo.launch()