Marcel0123 commited on
Commit
6a40b52
·
verified ·
1 Parent(s): 39e478b

Upload 3 files

Browse files
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import roc_auc_score, average_precision_score, classification_report, confusion_matrix
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.linear_model import LogisticRegression
9
+ from sklearn.pipeline import Pipeline
10
+ from sklearn.decomposition import TruncatedSVD
11
+ from sklearn.manifold import TSNE
12
+ import plotly.express as px
13
+
14
+ DEFAULT_CSV = "synthetische_ggz_agressie_dataset_1000.csv"
15
+
16
+ DESCRIPTION = \"\"\"
17
+ # GGZ Agressie (synthetisch) — Auto-train + 2D visualisatie
18
+
19
+ Deze Space **traint automatisch** bij het opstarten op een **synthetische Nederlandstalige GGZ-dataset** en toont
20
+ een **2D "bolletjes" plot** (interactief) waarin iedere punt een patiëntvoorval representeert.
21
+
22
+ - **Kleur**: op basis van het *werkelijke label* (0/1) of de *voorspelde kans*.
23
+ - **Hover**: toont een korte snippet uit de *rapportage* en relevante features.
24
+ - **Model**: TF‑IDF ➜ Logistic Regression (probabilistisch), standaard train/test split (stratified).
25
+
26
+ > ⚠️ **Belangrijk**: dit is **synthetische data** en uitsluitend voor **educatieve doeleinden**.
27
+ > Niet gebruiken voor klinische beslissingen.
28
+ \"\"\"
29
+
30
+ FOOTER = \"\"\"
31
+ **Tips**
32
+ - Upload een eigen CSV met minimaal kolommen `rapportage` en `agressie_volgende30d` om opnieuw te trainen.
33
+ - Pas de *threshold* aan om de confusion matrix en metrics live te zien.
34
+ - De 2D-plot gebruikt **TruncatedSVD (50D)** gevolgd door **t-SNE (2D)** op TF‑IDF features (sneller & expressief).
35
+ \"\"\"
36
+
37
+ def load_dataset(file_obj=None):
38
+ if file_obj is None:
39
+ df = pd.read_csv(DEFAULT_CSV)
40
+ else:
41
+ df = pd.read_csv(file_obj.name if hasattr(file_obj, "name") else file_obj)
42
+ # Basiseisen
43
+ req = {"rapportage", "agressie_volgende30d"}
44
+ missing = req - set(df.columns)
45
+ if missing:
46
+ raise ValueError(f"CSV mist kolommen: {missing}")
47
+ df = df.dropna(subset=["rapportage", "agressie_volgende30d"]).copy()
48
+ df["agressie_volgende30d"] = (df["agressie_volgende30d"].astype(int) > 0).astype(int)
49
+ return df
50
+
51
+ def build_and_train(df, test_size=0.2, random_state=42, max_features=4000, ngram_max=2):
52
+ X = df["rapportage"].astype(str).values
53
+ y = df["agressie_volgende30d"].values
54
+ X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
55
+ X, y, np.arange(len(X)), test_size=test_size, random_state=random_state, stratify=y
56
+ )
57
+ vect = TfidfVectorizer(max_features=max_features, ngram_range=(1, ngram_max))
58
+ clf = LogisticRegression(max_iter=3000)
59
+ pipe = Pipeline([("tfidf", vect), ("clf", clf)])
60
+ pipe.fit(X_train, y_train)
61
+
62
+ # Probabilities
63
+ y_score = pipe.predict_proba(X_test)[:, 1]
64
+ auroc = float(roc_auc_score(y_test, y_score))
65
+ auprc = float(average_precision_score(y_test, y_score))
66
+
67
+ # for visualization: compute 2D embedding on ALL data to show full cloud
68
+ tfidf_all = pipe.named_steps["tfidf"].fit_transform(X) # fit on all text for viz only
69
+ svd = TruncatedSVD(n_components=50, random_state=random_state)
70
+ X50 = svd.fit_transform(tfidf_all)
71
+ tsne = TSNE(n_components=2, random_state=random_state, perplexity=30, learning_rate="auto", init="pca")
72
+ X2 = tsne.fit_transform(X50)
73
+ # scale to 0-1 for nicer plotting sizes
74
+ x = (X2[:,0] - X2[:,0].min()) / (X2[:,0].ptp() + 1e-9)
75
+ y2 = (X2[:,1] - X2[:,1].min()) / (X2[:,1].ptp() + 1e-9)
76
+
77
+ # Pred proba on all
78
+ proba_all = pipe.predict_proba(X)[:, 1]
79
+
80
+ # Build DataFrame for plotting
81
+ plot_df = pd.DataFrame({
82
+ "x": x, "y": y2,
83
+ "label": df["agressie_volgende30d"].values,
84
+ "kans": proba_all,
85
+ "rapportage": df["rapportage"].str.slice(0, 180) + "..."
86
+ })
87
+ # annotate some optional features if present
88
+ for col in ["PHQ9_baseline","GAD7_baseline","stress_niveau_1_5","slaap_uren","sociale_steun_0_10","zorgsetting"]:
89
+ if col in df.columns:
90
+ plot_df[col] = df[col]
91
+
92
+ # Test set indices mask for highlighting (optional)
93
+ test_mask = np.zeros(len(plot_df), dtype=bool)
94
+ test_mask[idx_test] = True
95
+ plot_df["split"] = np.where(test_mask, "test", "train")
96
+
97
+ return pipe, (X_test, y_test, y_score), plot_df, auroc, auprc
98
+
99
+ def make_scatter(plot_df, color_mode="label"):
100
+ if color_mode == "label":
101
+ color = plot_df["label"].map({0:"geen agressie", 1:"agressie"})
102
+ fig = px.scatter(
103
+ plot_df, x="x", y="y", color=color, hover_data=["rapportage","kans","split"],
104
+ title="2D projectie van teksten (t‑SNE) — kleur = werkelijk label",
105
+ opacity=0.8
106
+ )
107
+ else:
108
+ fig = px.scatter(
109
+ plot_df, x="x", y="y", color="kans", hover_data=["rapportage","kans","split"],
110
+ color_continuous_scale="Turbo",
111
+ title="2D projectie van teksten (t‑SNE) — kleur = voorspelde kans",
112
+ opacity=0.85
113
+ )
114
+ fig.update_traces(marker=dict(size=8, line=dict(width=0)))
115
+ fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white")
116
+ return fig
117
+
118
+ def metrics_table(y_true, y_score, thr):
119
+ y_pred = (y_score >= thr).astype(int)
120
+ rep = classification_report(y_true, y_pred, output_dict=True)
121
+ rep_df = pd.DataFrame(rep).T.round(3)
122
+ cm = confusion_matrix(y_true, y_pred)
123
+ cm_df = pd.DataFrame(cm, index=["True 0","True 1"], columns=["Pred 0","Pred 1"])
124
+ return rep_df, cm_df
125
+
126
+ # Global state for auto-training on load
127
+ GLOBAL = {"pipe": None, "plot_df": None, "eval": None, "auroc": None, "auprc": None}
128
+
129
+ def do_train(file_obj=None, test_size=0.2, seed=42, max_features=4000, ngram_max=2):
130
+ df = load_dataset(file_obj)
131
+ pipe, eval_pack, plot_df, auroc, auprc = build_and_train(df, test_size, seed, max_features, ngram_max)
132
+ GLOBAL["pipe"] = pipe
133
+ GLOBAL["plot_df"] = plot_df
134
+ GLOBAL["eval"] = eval_pack
135
+ GLOBAL["auroc"] = auroc
136
+ GLOBAL["auprc"] = auprc
137
+ fig_label = make_scatter(plot_df, color_mode="label")
138
+ fig_prob = make_scatter(plot_df, color_mode="prob")
139
+ rep_df, cm_df = metrics_table(eval_pack[1], eval_pack[2], thr=0.5)
140
+ return (
141
+ float(auroc), float(auprc),
142
+ fig_label, fig_prob,
143
+ rep_df, cm_df
144
+ )
145
+
146
+ def predict_one(text):
147
+ if GLOBAL["pipe"] is None:
148
+ return "Nog geen model getraind.", None
149
+ if not text or text.strip() == "":
150
+ return "Voer een rapportage in.", None
151
+ proba = float(GLOBAL["pipe"].predict_proba([text])[:,1][0])
152
+ label = int(proba >= 0.5)
153
+ md = f"**Kans op agressie (30d)**: **{proba:.3f}** — voorspelde klasse: **{label}** (drempel 0.50)"
154
+ return md, proba
155
+
156
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as demo:
157
+ gr.Markdown(DESCRIPTION)
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=2):
161
+ auroc_box = gr.Number(label="AUROC", precision=3)
162
+ with gr.Column(scale=2):
163
+ auprc_box = gr.Number(label="AUPRC", precision=3)
164
+
165
+ with gr.Tabs():
166
+ with gr.Tab("Visualisatie"):
167
+ color_mode = gr.Radio(choices=["label","prob"], value="label", label="Kleurmodus (label of kans)")
168
+ fig_out = gr.Plot(label="2D bolletjes-plot")
169
+ def _switch_color(mode):
170
+ if GLOBAL["plot_df"] is None:
171
+ return None
172
+ return make_scatter(GLOBAL["plot_df"], color_mode=mode)
173
+ color_mode.change(_switch_color, inputs=color_mode, outputs=fig_out)
174
+
175
+ # Also show both plots on load
176
+ fig_label_out = gr.Plot(visible=False)
177
+ fig_prob_out = gr.Plot(visible=False)
178
+
179
+ with gr.Tab("Evaluatie"):
180
+ thr = gr.Slider(0.05, 0.95, value=0.5, step=0.05, label="Drempel voor classificatie")
181
+ rep_df = gr.Dataframe(label="Classification report")
182
+ cm_df = gr.Dataframe(label="Confusion matrix")
183
+
184
+ def _update_eval(t):
185
+ if GLOBAL["eval"] is None:
186
+ return None, None
187
+ y_true, y_score = GLOBAL["eval"][1], GLOBAL["eval"][2]
188
+ rep, cm = metrics_table(y_true, y_score, t)
189
+ return rep, cm
190
+ thr.release(_update_eval, inputs=thr, outputs=[rep_df, cm_df])
191
+
192
+ with gr.Tab("Predict (vrije tekst)"):
193
+ txt = gr.Textbox(lines=6, label="Rapportage (NL)")
194
+ btn = gr.Button("Voorspel")
195
+ md_out = gr.Markdown()
196
+ proba_out = gr.Number(label="Kans", precision=3)
197
+ btn.click(predict_one, inputs=txt, outputs=[md_out, proba_out])
198
+
199
+ with gr.Tab("(Optioneel) Hertrain"):
200
+ csv_in = gr.File(label="Upload eigen CSV")
201
+ test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test set grootte")
202
+ seed = gr.Slider(1, 999, value=42, step=1, label="Random seed")
203
+ max_features = gr.Slider(1000, 12000, value=4000, step=1000, label="TF‑IDF max_features")
204
+ ngram_max = gr.Radio(choices=[1,2], value=2, label="n‑gram max")
205
+
206
+ train_btn = gr.Button("Train opnieuw")
207
+ def _train(csv_in, test_size, seed, max_features, ngram_max):
208
+ return do_train(csv_in, test_size, int(seed), int(max_features), int(ngram_max))
209
+ train_btn.click(_train, inputs=[csv_in, test_size, seed, max_features, ngram_max],
210
+ outputs=[auroc_box, auprc_box, fig_label_out, fig_prob_out, rep_df, cm_df]).then(
211
+ lambda: _switch_color(color_mode.value), None, fig_out
212
+ )
213
+
214
+ # Auto-train on load using default CSV
215
+ def _auto_train():
216
+ return do_train(None, 0.2, 42, 4000, 2)
217
+
218
+ demo.load(_auto_train, inputs=None, outputs=[auroc_box, auprc_box, fig_label_out, fig_prob_out, rep_df, cm_df]).then(
219
+ lambda: _switch_color("label"), None, fig_out
220
+ )
221
+
222
+ gr.Markdown(FOOTER)
223
+
224
+ if __name__ == "__main__":
225
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ gradio>=4.16.0
3
+ pandas>=2.0.0
4
+ numpy>=1.24.0
5
+ scikit-learn>=1.3.0
6
+ plotly>=5.20.0
synthetische_ggz_agressie_dataset_1000.csv ADDED
The diff for this file is too large to render. See raw diff