Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +225 -0
- requirements.txt +6 -0
- synthetische_ggz_agressie_dataset_1000.csv +0 -0
app.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.model_selection import train_test_split
|
| 6 |
+
from sklearn.metrics import roc_auc_score, average_precision_score, classification_report, confusion_matrix
|
| 7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 8 |
+
from sklearn.linear_model import LogisticRegression
|
| 9 |
+
from sklearn.pipeline import Pipeline
|
| 10 |
+
from sklearn.decomposition import TruncatedSVD
|
| 11 |
+
from sklearn.manifold import TSNE
|
| 12 |
+
import plotly.express as px
|
| 13 |
+
|
| 14 |
+
DEFAULT_CSV = "synthetische_ggz_agressie_dataset_1000.csv"
|
| 15 |
+
|
| 16 |
+
DESCRIPTION = \"\"\"
|
| 17 |
+
# GGZ Agressie (synthetisch) — Auto-train + 2D visualisatie
|
| 18 |
+
|
| 19 |
+
Deze Space **traint automatisch** bij het opstarten op een **synthetische Nederlandstalige GGZ-dataset** en toont
|
| 20 |
+
een **2D "bolletjes" plot** (interactief) waarin iedere punt een patiëntvoorval representeert.
|
| 21 |
+
|
| 22 |
+
- **Kleur**: op basis van het *werkelijke label* (0/1) of de *voorspelde kans*.
|
| 23 |
+
- **Hover**: toont een korte snippet uit de *rapportage* en relevante features.
|
| 24 |
+
- **Model**: TF‑IDF ➜ Logistic Regression (probabilistisch), standaard train/test split (stratified).
|
| 25 |
+
|
| 26 |
+
> ⚠️ **Belangrijk**: dit is **synthetische data** en uitsluitend voor **educatieve doeleinden**.
|
| 27 |
+
> Niet gebruiken voor klinische beslissingen.
|
| 28 |
+
\"\"\"
|
| 29 |
+
|
| 30 |
+
FOOTER = \"\"\"
|
| 31 |
+
**Tips**
|
| 32 |
+
- Upload een eigen CSV met minimaal kolommen `rapportage` en `agressie_volgende30d` om opnieuw te trainen.
|
| 33 |
+
- Pas de *threshold* aan om de confusion matrix en metrics live te zien.
|
| 34 |
+
- De 2D-plot gebruikt **TruncatedSVD (50D)** gevolgd door **t-SNE (2D)** op TF‑IDF features (sneller & expressief).
|
| 35 |
+
\"\"\"
|
| 36 |
+
|
| 37 |
+
def load_dataset(file_obj=None):
|
| 38 |
+
if file_obj is None:
|
| 39 |
+
df = pd.read_csv(DEFAULT_CSV)
|
| 40 |
+
else:
|
| 41 |
+
df = pd.read_csv(file_obj.name if hasattr(file_obj, "name") else file_obj)
|
| 42 |
+
# Basiseisen
|
| 43 |
+
req = {"rapportage", "agressie_volgende30d"}
|
| 44 |
+
missing = req - set(df.columns)
|
| 45 |
+
if missing:
|
| 46 |
+
raise ValueError(f"CSV mist kolommen: {missing}")
|
| 47 |
+
df = df.dropna(subset=["rapportage", "agressie_volgende30d"]).copy()
|
| 48 |
+
df["agressie_volgende30d"] = (df["agressie_volgende30d"].astype(int) > 0).astype(int)
|
| 49 |
+
return df
|
| 50 |
+
|
| 51 |
+
def build_and_train(df, test_size=0.2, random_state=42, max_features=4000, ngram_max=2):
|
| 52 |
+
X = df["rapportage"].astype(str).values
|
| 53 |
+
y = df["agressie_volgende30d"].values
|
| 54 |
+
X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
|
| 55 |
+
X, y, np.arange(len(X)), test_size=test_size, random_state=random_state, stratify=y
|
| 56 |
+
)
|
| 57 |
+
vect = TfidfVectorizer(max_features=max_features, ngram_range=(1, ngram_max))
|
| 58 |
+
clf = LogisticRegression(max_iter=3000)
|
| 59 |
+
pipe = Pipeline([("tfidf", vect), ("clf", clf)])
|
| 60 |
+
pipe.fit(X_train, y_train)
|
| 61 |
+
|
| 62 |
+
# Probabilities
|
| 63 |
+
y_score = pipe.predict_proba(X_test)[:, 1]
|
| 64 |
+
auroc = float(roc_auc_score(y_test, y_score))
|
| 65 |
+
auprc = float(average_precision_score(y_test, y_score))
|
| 66 |
+
|
| 67 |
+
# for visualization: compute 2D embedding on ALL data to show full cloud
|
| 68 |
+
tfidf_all = pipe.named_steps["tfidf"].fit_transform(X) # fit on all text for viz only
|
| 69 |
+
svd = TruncatedSVD(n_components=50, random_state=random_state)
|
| 70 |
+
X50 = svd.fit_transform(tfidf_all)
|
| 71 |
+
tsne = TSNE(n_components=2, random_state=random_state, perplexity=30, learning_rate="auto", init="pca")
|
| 72 |
+
X2 = tsne.fit_transform(X50)
|
| 73 |
+
# scale to 0-1 for nicer plotting sizes
|
| 74 |
+
x = (X2[:,0] - X2[:,0].min()) / (X2[:,0].ptp() + 1e-9)
|
| 75 |
+
y2 = (X2[:,1] - X2[:,1].min()) / (X2[:,1].ptp() + 1e-9)
|
| 76 |
+
|
| 77 |
+
# Pred proba on all
|
| 78 |
+
proba_all = pipe.predict_proba(X)[:, 1]
|
| 79 |
+
|
| 80 |
+
# Build DataFrame for plotting
|
| 81 |
+
plot_df = pd.DataFrame({
|
| 82 |
+
"x": x, "y": y2,
|
| 83 |
+
"label": df["agressie_volgende30d"].values,
|
| 84 |
+
"kans": proba_all,
|
| 85 |
+
"rapportage": df["rapportage"].str.slice(0, 180) + "..."
|
| 86 |
+
})
|
| 87 |
+
# annotate some optional features if present
|
| 88 |
+
for col in ["PHQ9_baseline","GAD7_baseline","stress_niveau_1_5","slaap_uren","sociale_steun_0_10","zorgsetting"]:
|
| 89 |
+
if col in df.columns:
|
| 90 |
+
plot_df[col] = df[col]
|
| 91 |
+
|
| 92 |
+
# Test set indices mask for highlighting (optional)
|
| 93 |
+
test_mask = np.zeros(len(plot_df), dtype=bool)
|
| 94 |
+
test_mask[idx_test] = True
|
| 95 |
+
plot_df["split"] = np.where(test_mask, "test", "train")
|
| 96 |
+
|
| 97 |
+
return pipe, (X_test, y_test, y_score), plot_df, auroc, auprc
|
| 98 |
+
|
| 99 |
+
def make_scatter(plot_df, color_mode="label"):
|
| 100 |
+
if color_mode == "label":
|
| 101 |
+
color = plot_df["label"].map({0:"geen agressie", 1:"agressie"})
|
| 102 |
+
fig = px.scatter(
|
| 103 |
+
plot_df, x="x", y="y", color=color, hover_data=["rapportage","kans","split"],
|
| 104 |
+
title="2D projectie van teksten (t‑SNE) — kleur = werkelijk label",
|
| 105 |
+
opacity=0.8
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
fig = px.scatter(
|
| 109 |
+
plot_df, x="x", y="y", color="kans", hover_data=["rapportage","kans","split"],
|
| 110 |
+
color_continuous_scale="Turbo",
|
| 111 |
+
title="2D projectie van teksten (t‑SNE) — kleur = voorspelde kans",
|
| 112 |
+
opacity=0.85
|
| 113 |
+
)
|
| 114 |
+
fig.update_traces(marker=dict(size=8, line=dict(width=0)))
|
| 115 |
+
fig.update_layout(margin=dict(l=10,r=10,t=40,b=10), template="simple_white")
|
| 116 |
+
return fig
|
| 117 |
+
|
| 118 |
+
def metrics_table(y_true, y_score, thr):
|
| 119 |
+
y_pred = (y_score >= thr).astype(int)
|
| 120 |
+
rep = classification_report(y_true, y_pred, output_dict=True)
|
| 121 |
+
rep_df = pd.DataFrame(rep).T.round(3)
|
| 122 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 123 |
+
cm_df = pd.DataFrame(cm, index=["True 0","True 1"], columns=["Pred 0","Pred 1"])
|
| 124 |
+
return rep_df, cm_df
|
| 125 |
+
|
| 126 |
+
# Global state for auto-training on load
|
| 127 |
+
GLOBAL = {"pipe": None, "plot_df": None, "eval": None, "auroc": None, "auprc": None}
|
| 128 |
+
|
| 129 |
+
def do_train(file_obj=None, test_size=0.2, seed=42, max_features=4000, ngram_max=2):
|
| 130 |
+
df = load_dataset(file_obj)
|
| 131 |
+
pipe, eval_pack, plot_df, auroc, auprc = build_and_train(df, test_size, seed, max_features, ngram_max)
|
| 132 |
+
GLOBAL["pipe"] = pipe
|
| 133 |
+
GLOBAL["plot_df"] = plot_df
|
| 134 |
+
GLOBAL["eval"] = eval_pack
|
| 135 |
+
GLOBAL["auroc"] = auroc
|
| 136 |
+
GLOBAL["auprc"] = auprc
|
| 137 |
+
fig_label = make_scatter(plot_df, color_mode="label")
|
| 138 |
+
fig_prob = make_scatter(plot_df, color_mode="prob")
|
| 139 |
+
rep_df, cm_df = metrics_table(eval_pack[1], eval_pack[2], thr=0.5)
|
| 140 |
+
return (
|
| 141 |
+
float(auroc), float(auprc),
|
| 142 |
+
fig_label, fig_prob,
|
| 143 |
+
rep_df, cm_df
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def predict_one(text):
|
| 147 |
+
if GLOBAL["pipe"] is None:
|
| 148 |
+
return "Nog geen model getraind.", None
|
| 149 |
+
if not text or text.strip() == "":
|
| 150 |
+
return "Voer een rapportage in.", None
|
| 151 |
+
proba = float(GLOBAL["pipe"].predict_proba([text])[:,1][0])
|
| 152 |
+
label = int(proba >= 0.5)
|
| 153 |
+
md = f"**Kans op agressie (30d)**: **{proba:.3f}** — voorspelde klasse: **{label}** (drempel 0.50)"
|
| 154 |
+
return md, proba
|
| 155 |
+
|
| 156 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="red", neutral_hue="slate")) as demo:
|
| 157 |
+
gr.Markdown(DESCRIPTION)
|
| 158 |
+
|
| 159 |
+
with gr.Row():
|
| 160 |
+
with gr.Column(scale=2):
|
| 161 |
+
auroc_box = gr.Number(label="AUROC", precision=3)
|
| 162 |
+
with gr.Column(scale=2):
|
| 163 |
+
auprc_box = gr.Number(label="AUPRC", precision=3)
|
| 164 |
+
|
| 165 |
+
with gr.Tabs():
|
| 166 |
+
with gr.Tab("Visualisatie"):
|
| 167 |
+
color_mode = gr.Radio(choices=["label","prob"], value="label", label="Kleurmodus (label of kans)")
|
| 168 |
+
fig_out = gr.Plot(label="2D bolletjes-plot")
|
| 169 |
+
def _switch_color(mode):
|
| 170 |
+
if GLOBAL["plot_df"] is None:
|
| 171 |
+
return None
|
| 172 |
+
return make_scatter(GLOBAL["plot_df"], color_mode=mode)
|
| 173 |
+
color_mode.change(_switch_color, inputs=color_mode, outputs=fig_out)
|
| 174 |
+
|
| 175 |
+
# Also show both plots on load
|
| 176 |
+
fig_label_out = gr.Plot(visible=False)
|
| 177 |
+
fig_prob_out = gr.Plot(visible=False)
|
| 178 |
+
|
| 179 |
+
with gr.Tab("Evaluatie"):
|
| 180 |
+
thr = gr.Slider(0.05, 0.95, value=0.5, step=0.05, label="Drempel voor classificatie")
|
| 181 |
+
rep_df = gr.Dataframe(label="Classification report")
|
| 182 |
+
cm_df = gr.Dataframe(label="Confusion matrix")
|
| 183 |
+
|
| 184 |
+
def _update_eval(t):
|
| 185 |
+
if GLOBAL["eval"] is None:
|
| 186 |
+
return None, None
|
| 187 |
+
y_true, y_score = GLOBAL["eval"][1], GLOBAL["eval"][2]
|
| 188 |
+
rep, cm = metrics_table(y_true, y_score, t)
|
| 189 |
+
return rep, cm
|
| 190 |
+
thr.release(_update_eval, inputs=thr, outputs=[rep_df, cm_df])
|
| 191 |
+
|
| 192 |
+
with gr.Tab("Predict (vrije tekst)"):
|
| 193 |
+
txt = gr.Textbox(lines=6, label="Rapportage (NL)")
|
| 194 |
+
btn = gr.Button("Voorspel")
|
| 195 |
+
md_out = gr.Markdown()
|
| 196 |
+
proba_out = gr.Number(label="Kans", precision=3)
|
| 197 |
+
btn.click(predict_one, inputs=txt, outputs=[md_out, proba_out])
|
| 198 |
+
|
| 199 |
+
with gr.Tab("(Optioneel) Hertrain"):
|
| 200 |
+
csv_in = gr.File(label="Upload eigen CSV")
|
| 201 |
+
test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test set grootte")
|
| 202 |
+
seed = gr.Slider(1, 999, value=42, step=1, label="Random seed")
|
| 203 |
+
max_features = gr.Slider(1000, 12000, value=4000, step=1000, label="TF‑IDF max_features")
|
| 204 |
+
ngram_max = gr.Radio(choices=[1,2], value=2, label="n‑gram max")
|
| 205 |
+
|
| 206 |
+
train_btn = gr.Button("Train opnieuw")
|
| 207 |
+
def _train(csv_in, test_size, seed, max_features, ngram_max):
|
| 208 |
+
return do_train(csv_in, test_size, int(seed), int(max_features), int(ngram_max))
|
| 209 |
+
train_btn.click(_train, inputs=[csv_in, test_size, seed, max_features, ngram_max],
|
| 210 |
+
outputs=[auroc_box, auprc_box, fig_label_out, fig_prob_out, rep_df, cm_df]).then(
|
| 211 |
+
lambda: _switch_color(color_mode.value), None, fig_out
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Auto-train on load using default CSV
|
| 215 |
+
def _auto_train():
|
| 216 |
+
return do_train(None, 0.2, 42, 4000, 2)
|
| 217 |
+
|
| 218 |
+
demo.load(_auto_train, inputs=None, outputs=[auroc_box, auprc_box, fig_label_out, fig_prob_out, rep_df, cm_df]).then(
|
| 219 |
+
lambda: _switch_color("label"), None, fig_out
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
gr.Markdown(FOOTER)
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
gradio>=4.16.0
|
| 3 |
+
pandas>=2.0.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
scikit-learn>=1.3.0
|
| 6 |
+
plotly>=5.20.0
|
synthetische_ggz_agressie_dataset_1000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|