Upload 3 files
Browse files- app.py +172 -226
- ggz_depressie_synth_1000_modeling.csv +0 -0
app.py
CHANGED
|
@@ -16,80 +16,67 @@ from sklearn.model_selection import train_test_split
|
|
| 16 |
DESCRIPTION = """
|
| 17 |
# Interactieve Scatter (2D/3D) — Gradio + Plotly + Live Train
|
| 18 |
|
| 19 |
-
- **
|
|
|
|
|
|
|
| 20 |
- **Kies** x, y (en z) kolommen voor de scatter
|
| 21 |
- **Kleur** op cluster/categorie of continue variabele
|
| 22 |
- **Hover** toont gekozen kenmerken
|
| 23 |
-
- **Train (live)**: KMeans
|
| 24 |
-
- **Auto-train bij start**: model traint automatisch wanneer de Space start
|
| 25 |
"""
|
| 26 |
|
| 27 |
MODEL_PATH = Path("model.joblib")
|
|
|
|
| 28 |
|
| 29 |
# -----------------------------
|
| 30 |
-
#
|
| 31 |
# -----------------------------
|
| 32 |
-
def
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
df["cluster"] = pd.Categorical(["A" if l == 0 else ("B" if l == 1 else "C") for l in labels])
|
| 44 |
-
df["age"] = rng.integers(20, 90, size=
|
| 45 |
-
df["sex"] = pd.Categorical(rng.choice(["F",
|
| 46 |
-
df["diagnosis"] = pd.Categorical(rng.choice(["Type I",
|
| 47 |
-
df["patient_id"] = [f"P{1000+i}" for i in range(
|
| 48 |
return df
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# -----------------------------
|
| 53 |
-
#
|
| 54 |
# -----------------------------
|
| 55 |
-
def parse_file(file_obj):
|
| 56 |
-
if file_obj is None:
|
| 57 |
-
return DEMO_DF.copy(), "Demo dataset geladen (geen upload)."
|
| 58 |
-
name = getattr(file_obj, "name", str(file_obj))
|
| 59 |
-
path = name
|
| 60 |
-
|
| 61 |
-
if name.lower().endswith(".csv"):
|
| 62 |
-
df = pd.read_csv(path)
|
| 63 |
-
elif name.lower().endswith(".tsv"):
|
| 64 |
-
df = pd.read_csv(path, sep="\t")
|
| 65 |
-
elif name.lower().endswith(".parquet"):
|
| 66 |
-
df = pd.read_parquet(path)
|
| 67 |
-
else:
|
| 68 |
-
df = pd.read_csv(path)
|
| 69 |
-
|
| 70 |
-
return df, f"Bestand geladen: {Path(name).name} — {df.shape[0]} rijen, {df.shape[1]} kolommen."
|
| 71 |
-
|
| 72 |
-
def detect_columns(df):
|
| 73 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 74 |
-
all_cols = df.columns.tolist()
|
| 75 |
-
|
| 76 |
-
x_default = next((c for c in ["x", "X", "dim1", "pc1", "tsne1", "umap1"] if c in df.columns),
|
| 77 |
-
(numeric_cols[0] if numeric_cols else None))
|
| 78 |
-
y_default = next((c for c in ["y", "Y", "dim2", "pc2", "tsne2", "umap2"] if c in df.columns and c != x_default),
|
| 79 |
-
(numeric_cols[1] if len(numeric_cols) > 1 else None))
|
| 80 |
-
z_default = next((c for c in ["z", "Z", "dim3", "pc3", "tsne3", "umap3"] if c in df.columns and c not in {x_default, y_default}),
|
| 81 |
-
(numeric_cols[2] if len(numeric_cols) > 2 else None))
|
| 82 |
-
|
| 83 |
-
cat_candidates = [c for c in df.columns if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
|
| 84 |
-
and c not in {x_default, y_default, z_default}]
|
| 85 |
-
color_default = next((c for c in ["cluster", "label", "group", "diagnosis", "category"] if c in df.columns),
|
| 86 |
-
(cat_candidates[0] if cat_candidates else (numeric_cols[0] if numeric_cols else None)))
|
| 87 |
-
return numeric_cols, all_cols, x_default, y_default, z_default, color_default
|
| 88 |
-
|
| 89 |
def build_hovertemplate(hover_cols):
|
| 90 |
if not hover_cols:
|
| 91 |
return "%{x}, %{y}<extra></extra>"
|
| 92 |
-
|
| 93 |
lines = []
|
| 94 |
for i, col in enumerate(hover_cols):
|
| 95 |
lines.append(f"<b>{col}</b>: %{{customdata[{i}]}}")
|
|
@@ -108,30 +95,22 @@ def make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, point_s
|
|
| 108 |
fig = go.Figure()
|
| 109 |
if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
|
| 110 |
for cat_val, dsub in df.groupby(color_col):
|
| 111 |
-
fig.add_trace(
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
)
|
| 118 |
-
)
|
| 119 |
else:
|
| 120 |
-
fig.add_trace(
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
)
|
| 126 |
-
)
|
| 127 |
fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
|
| 128 |
-
fig.update_layout(
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
legend=dict(itemsizing='trace', title=color_col if color_col else None),
|
| 132 |
-
xaxis_title=x_col, yaxis_title=y_col,
|
| 133 |
-
dragmode='pan',
|
| 134 |
-
)
|
| 135 |
return fig
|
| 136 |
|
| 137 |
# 3D
|
|
@@ -140,66 +119,63 @@ def make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, point_s
|
|
| 140 |
if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
|
| 141 |
fig = go.Figure()
|
| 142 |
for cat_val, dsub in df.groupby(color_col):
|
| 143 |
-
fig.add_trace(
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
)
|
| 150 |
-
)
|
| 151 |
else:
|
| 152 |
-
fig = go.Figure(
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
|
| 158 |
-
)
|
| 159 |
-
]
|
| 160 |
-
)
|
| 161 |
fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
|
| 162 |
-
fig.update_layout(
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
legend=dict(itemsizing='trace', title=color_col if color_col else None),
|
| 166 |
-
scene=dict(xaxis_title=x_col, yaxis_title=y_col, zaxis_title=z_col),
|
| 167 |
-
)
|
| 168 |
return fig
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def _plot_confusion_matrix(y_true, y_pred, title="Confusion matrix"):
|
| 171 |
labels = sorted(pd.Series(y_true).unique())
|
| 172 |
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 173 |
-
fig = go.Figure(
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
hovertemplate="Pred=%{x}<br>True=%{y}<br>Count=%{z}<extra></extra>"
|
| 177 |
-
)
|
| 178 |
-
)
|
| 179 |
-
fig.update_layout(title=title, xaxis_title="Predicted", yaxis_title="True",
|
| 180 |
-
template="plotly_white", margin=dict(l=10,r=10,t=40,b=10))
|
| 181 |
return fig
|
| 182 |
|
| 183 |
def _fmt_metrics(metrics: dict) -> str:
|
| 184 |
lines = []
|
| 185 |
for k,v in metrics.items():
|
| 186 |
-
if isinstance(v,
|
| 187 |
-
lines.append(f"{k}: {v:.4f}")
|
| 188 |
-
else:
|
| 189 |
-
lines.append(f"{k}: {v}")
|
| 190 |
return "\n".join(lines)
|
| 191 |
|
| 192 |
-
# -----------------------------
|
| 193 |
-
# Training
|
| 194 |
-
# -----------------------------
|
| 195 |
def train_live(df, task, feature_cols, label_col, k_clusters, seed):
|
| 196 |
-
if df is None or len(df)
|
| 197 |
raise gr.Error("Geen data om te trainen.")
|
| 198 |
if not feature_cols:
|
| 199 |
raise gr.Error("Selecteer minimaal één featurekolom.")
|
| 200 |
-
|
| 201 |
X = df[feature_cols].select_dtypes(include=[np.number])
|
| 202 |
-
if X.shape[1]
|
| 203 |
raise gr.Error("De gekozen features bevatten geen numerieke kolommen.")
|
| 204 |
|
| 205 |
log_lines = []
|
|
@@ -207,68 +183,41 @@ def train_live(df, task, feature_cols, label_col, k_clusters, seed):
|
|
| 207 |
eval_fig = None
|
| 208 |
|
| 209 |
if task == "Clustering (KMeans)":
|
| 210 |
-
pipe = Pipeline([
|
| 211 |
-
("scaler", StandardScaler()),
|
| 212 |
-
("kmeans", KMeans(n_clusters=k_clusters, random_state=seed, n_init="auto")),
|
| 213 |
-
])
|
| 214 |
pipe.fit(X)
|
| 215 |
labels = pipe["kmeans"].labels_
|
| 216 |
-
|
| 217 |
-
df["cluster_model"] = pd.Categorical(alpha)
|
| 218 |
color_col_suggestion = "cluster_model"
|
| 219 |
dump(pipe, MODEL_PATH)
|
| 220 |
-
log_lines
|
| 221 |
-
log_lines.append(f"Model opgeslagen: {MODEL_PATH.resolve()}")
|
| 222 |
|
| 223 |
elif task == "Classificatie (Logistic Regression)":
|
| 224 |
if not label_col:
|
| 225 |
raise gr.Error("Kies een labelkolom voor classificatie.")
|
| 226 |
y = df[label_col].astype(str)
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
Xtr, Xva, ytr, yva = train_test_split(
|
| 230 |
-
X, y, test_size=0.2, random_state=seed, stratify=y if y.nunique()>1 else None
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
pipe = Pipeline([
|
| 234 |
-
("scaler", StandardScaler()),
|
| 235 |
-
("logreg", LogisticRegression(max_iter=1000, random_state=seed, class_weight="balanced"))
|
| 236 |
-
])
|
| 237 |
pipe.fit(Xtr, ytr)
|
| 238 |
-
|
| 239 |
-
# Validatie
|
| 240 |
yhat = pipe.predict(Xva)
|
| 241 |
-
metrics = {
|
| 242 |
-
|
| 243 |
-
"f1_weighted": f1_score(yva, yhat, average="weighted"),
|
| 244 |
-
}
|
| 245 |
-
if y.nunique() == 2:
|
| 246 |
try:
|
| 247 |
-
proba = pipe.predict_proba(Xva)[:,
|
| 248 |
uniq = list(pd.Series(y).unique())
|
| 249 |
mapping = {uniq[0]:0, uniq[1]:1}
|
| 250 |
metrics["roc_auc"] = roc_auc_score(yva.map(mapping), proba)
|
| 251 |
except Exception:
|
| 252 |
pass
|
| 253 |
-
|
| 254 |
eval_fig = _plot_confusion_matrix(yva, yhat, title="Confusion matrix (validatie)")
|
| 255 |
-
|
| 256 |
-
# Train op alle data en voorspel voor visualisatie
|
| 257 |
pipe.fit(X, y)
|
| 258 |
preds = pipe.predict(X)
|
| 259 |
df["pred_model"] = pd.Categorical(preds)
|
| 260 |
-
if y.nunique()
|
| 261 |
-
try:
|
| 262 |
-
|
| 263 |
-
except Exception:
|
| 264 |
-
pass
|
| 265 |
-
|
| 266 |
color_col_suggestion = "pred_model"
|
| 267 |
dump(pipe, MODEL_PATH)
|
| 268 |
-
log_lines
|
| 269 |
-
log_lines.append(f"Model opgeslagen: {MODEL_PATH.resolve()}")
|
| 270 |
-
log_lines.append("Metrics (validatie):\n" + _fmt_metrics(metrics))
|
| 271 |
-
|
| 272 |
else:
|
| 273 |
raise gr.Error("Onbekende taak.")
|
| 274 |
|
|
@@ -286,29 +235,49 @@ def try_load_model():
|
|
| 286 |
# -----------------------------
|
| 287 |
# Gradio callbacks
|
| 288 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
def init_from_file(file_obj):
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
numeric_cols, all_cols, x_d, y_d, z_d, color_d = detect_columns(df)
|
| 292 |
-
hover_default = [c for c in ["patient_id",
|
| 293 |
-
feat_default =
|
|
|
|
| 294 |
return (
|
| 295 |
gr.update(choices=all_cols, value=x_d),
|
| 296 |
gr.update(choices=all_cols, value=y_d),
|
| 297 |
gr.update(choices=all_cols, value=z_d),
|
| 298 |
-
gr.update(choices=all_cols, value=color_d),
|
| 299 |
gr.update(choices=all_cols, value=hover_default),
|
| 300 |
status,
|
| 301 |
df,
|
| 302 |
-
gr.update(choices=
|
| 303 |
-
gr.update(choices=all_cols, value=
|
| 304 |
)
|
| 305 |
|
| 306 |
def update_plot(df, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
|
| 307 |
-
if df is None or (isinstance(df,
|
| 308 |
-
df =
|
| 309 |
mode_3d = (mode_dim == "3D")
|
| 310 |
-
|
| 311 |
-
return fig
|
| 312 |
|
| 313 |
def on_train_click(df, task, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
|
| 314 |
df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), task, feature_cols, label_col, k_clusters, seed)
|
|
@@ -317,8 +286,11 @@ def on_train_click(df, task, feature_cols, label_col, k_clusters, seed, x_col, y
|
|
| 317 |
return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
|
| 318 |
|
| 319 |
def startup_auto_train(df, task_default, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
|
|
|
|
| 320 |
try:
|
| 321 |
-
|
|
|
|
|
|
|
| 322 |
new_color = color_suggestion if color_suggestion else color_col
|
| 323 |
fig = update_plot(df2, x_col, y_col, z_col, new_color, hover_cols, mode_dim, size, opacity)
|
| 324 |
return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
|
|
@@ -334,87 +306,61 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px !important}") as demo:
|
|
| 334 |
with gr.Row():
|
| 335 |
with gr.Column(scale=1):
|
| 336 |
data_file = gr.File(label="Upload CSV/TSV/Parquet", file_count="single", type="filepath")
|
| 337 |
-
status_box = gr.Markdown("
|
| 338 |
|
| 339 |
with gr.Accordion("Assen & kleur", open=True):
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
)
|
| 350 |
|
| 351 |
with gr.Accordion("Weergave", open=True):
|
| 352 |
-
mode_dim = gr.Radio(["2D",
|
| 353 |
size_slider = gr.Slider(3, 18, value=8, step=1, label="Puntgrootte")
|
| 354 |
opacity_slider = gr.Slider(0.1, 1.0, value=0.8, step=0.05, label="Transparantie (opacity)")
|
| 355 |
|
| 356 |
with gr.Accordion("Training (live)", open=True):
|
| 357 |
-
task_radio = gr.Radio(
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
label="Taak"
|
| 361 |
-
)
|
| 362 |
-
feat_ms = gr.Dropdown(choices=DEMO_DF.select_dtypes(include=[np.number]).columns.tolist(),
|
| 363 |
-
value=["x", "y", "z", "age"],
|
| 364 |
-
multiselect=True,
|
| 365 |
-
label="Feature kolommen (numeriek)")
|
| 366 |
-
label_dd = gr.Dropdown(choices=DEMO_DF.columns.tolist(), value=None, label="Label kolom (alleen voor classificatie)")
|
| 367 |
k_slider = gr.Slider(2, 12, value=3, step=1, label="K (clusters) — KMeans")
|
| 368 |
seed_slider = gr.Slider(0, 10_000, value=7, step=1, label="Random seed")
|
| 369 |
train_btn = gr.Button("🚀 Train (live)")
|
| 370 |
train_log = gr.Textbox(label="Train log", lines=6, interactive=False)
|
| 371 |
|
| 372 |
-
hidden_df = gr.State(
|
| 373 |
|
| 374 |
with gr.Column(scale=2):
|
| 375 |
-
# Zelfde visualisatie (bolletjes) en layout behouden
|
| 376 |
plot = gr.Plot(label="Scatterplot")
|
| 377 |
with gr.Accordion("Evaluatie (validatie)", open=False):
|
| 378 |
cm_plot = gr.Plot(label="Confusion Matrix (validatie)")
|
| 379 |
|
| 380 |
# ===== Events =====
|
| 381 |
-
data_file.change(
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
outputs=[x_dd, y_dd, z_dd, color_dd, hover_ms, status_box, hidden_df, feat_ms, label_dd],
|
| 385 |
-
show_progress=False,
|
| 386 |
-
)
|
| 387 |
|
| 388 |
for comp in [x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider]:
|
| 389 |
-
comp.change(
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
outputs=plot,
|
| 393 |
-
show_progress=False,
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
train_btn.click(
|
| 397 |
-
fn=on_train_click,
|
| 398 |
-
inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider,
|
| 399 |
-
x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
|
| 400 |
-
outputs=[hidden_df, train_log, color_dd, plot, cm_plot],
|
| 401 |
-
show_progress=True,
|
| 402 |
-
)
|
| 403 |
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
outputs=plot,
|
| 408 |
-
show_progress=False,
|
| 409 |
-
)
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
|
| 419 |
if __name__ == "__main__":
|
| 420 |
demo.launch()
|
|
|
|
| 16 |
DESCRIPTION = """
|
| 17 |
# Interactieve Scatter (2D/3D) — Gradio + Plotly + Live Train
|
| 18 |
|
| 19 |
+
- **Bundled dataset**: ggz_depressie_synth_1000_modeling.csv (laadt automatisch)
|
| 20 |
+
- **Auto-train (supervised)** bij start met label: `target_respons50`
|
| 21 |
+
- **Upload** een eigen CSV/TSV/Parquet indien gewenst
|
| 22 |
- **Kies** x, y (en z) kolommen voor de scatter
|
| 23 |
- **Kleur** op cluster/categorie of continue variabele
|
| 24 |
- **Hover** toont gekozen kenmerken
|
| 25 |
+
- **Train (live)**: KMeans of Logistic Regression
|
|
|
|
| 26 |
"""
|
| 27 |
|
| 28 |
MODEL_PATH = Path("model.joblib")
|
| 29 |
+
DATA_PATH = Path("data/ggz_depressie_synth_1000_modeling.csv")
|
| 30 |
|
| 31 |
# -----------------------------
|
| 32 |
+
# Data loading
|
| 33 |
# -----------------------------
|
| 34 |
+
def load_default_df():
|
| 35 |
+
if DATA_PATH.exists():
|
| 36 |
+
try:
|
| 37 |
+
return pd.read_csv(DATA_PATH)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print("Kon bundled dataset niet laden:", e)
|
| 40 |
+
# Fallback demo
|
| 41 |
+
rng = np.random.default_rng(7)
|
| 42 |
+
centers = np.array([[0,0,0],[5,5,2],[-4,3,-3]])
|
| 43 |
+
labels = rng.integers(0, len(centers), size=400)
|
| 44 |
+
points = centers[labels] + rng.normal(0, 1.1, size=(400,3))
|
| 45 |
+
df = pd.DataFrame(points, columns=["x","y","z"])
|
| 46 |
df["cluster"] = pd.Categorical(["A" if l == 0 else ("B" if l == 1 else "C") for l in labels])
|
| 47 |
+
df["age"] = rng.integers(20, 90, size=400)
|
| 48 |
+
df["sex"] = pd.Categorical(rng.choice(["F","M"], size=400))
|
| 49 |
+
df["diagnosis"] = pd.Categorical(rng.choice(["Type I","Type II","Control"], size=400, p=[0.35,0.35,0.30]))
|
| 50 |
+
df["patient_id"] = [f"P{1000+i}" for i in range(400)]
|
| 51 |
return df
|
| 52 |
|
| 53 |
+
BASE_DF = load_default_df()
|
| 54 |
+
|
| 55 |
+
# Heuristics for label/features
|
| 56 |
+
def pick_default_label(df):
|
| 57 |
+
# voorkeur: target_respons50 -> anders bekende varianten -> anders eerste target_
|
| 58 |
+
for name in ["target_respons50", "target_remissie", "target_uitval", "target_opname6mnd", "target_rtw3mnd"]:
|
| 59 |
+
if name in df.columns:
|
| 60 |
+
return name
|
| 61 |
+
for c in df.columns:
|
| 62 |
+
if str(c).lower() in ["label","target","y","class","diagnosis","outcome"]:
|
| 63 |
+
return c
|
| 64 |
+
for c in df.columns:
|
| 65 |
+
if str(c).startswith("target_"):
|
| 66 |
+
return c
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def default_features(df):
|
| 70 |
+
num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 71 |
+
exclude = set(["patient_id"] + [c for c in df.columns if str(c).startswith("target_")])
|
| 72 |
+
return [c for c in num_cols if c not in exclude]
|
| 73 |
|
| 74 |
# -----------------------------
|
| 75 |
+
# Plot utils (identiek aan eerdere bolletjesvis)
|
| 76 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def build_hovertemplate(hover_cols):
|
| 78 |
if not hover_cols:
|
| 79 |
return "%{x}, %{y}<extra></extra>"
|
|
|
|
| 80 |
lines = []
|
| 81 |
for i, col in enumerate(hover_cols):
|
| 82 |
lines.append(f"<b>{col}</b>: %{{customdata[{i}]}}")
|
|
|
|
| 95 |
fig = go.Figure()
|
| 96 |
if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
|
| 97 |
for cat_val, dsub in df.groupby(color_col):
|
| 98 |
+
fig.add_trace(go.Scattergl(
|
| 99 |
+
x=dsub[x_col], y=dsub[y_col], mode='markers', name=str(cat_val),
|
| 100 |
+
marker=dict(size=point_size), opacity=opacity,
|
| 101 |
+
customdata=(dsub[hover_cols].to_numpy() if hover_cols else None),
|
| 102 |
+
hovertemplate=hovertemplate,
|
| 103 |
+
))
|
|
|
|
|
|
|
| 104 |
else:
|
| 105 |
+
fig.add_trace(go.Scattergl(
|
| 106 |
+
x=df[x_col], y=df[y_col], mode='markers', name=color_col if color_col else "data",
|
| 107 |
+
marker=dict(size=point_size, color=(color_series if color_series is not None else None), coloraxis='coloraxis'),
|
| 108 |
+
opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
|
| 109 |
+
))
|
|
|
|
|
|
|
| 110 |
fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
|
| 111 |
+
fig.update_layout(template="plotly_white", margin=dict(l=10,r=10,t=30,b=10),
|
| 112 |
+
legend=dict(itemsizing='trace', title=color_col if color_col else None),
|
| 113 |
+
xaxis_title=x_col, yaxis_title=y_col, dragmode='pan')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return fig
|
| 115 |
|
| 116 |
# 3D
|
|
|
|
| 119 |
if color_series is not None and (color_series.dtype == 'object' or str(color_series.dtype).startswith('category')):
|
| 120 |
fig = go.Figure()
|
| 121 |
for cat_val, dsub in df.groupby(color_col):
|
| 122 |
+
fig.add_trace(go.Scatter3d(
|
| 123 |
+
x=dsub[x_col], y=dsub[y_col], z=dsub[z_col], mode='markers', name=str(cat_val),
|
| 124 |
+
marker=dict(size=point_size), opacity=opacity,
|
| 125 |
+
customdata=(dsub[hover_cols].to_numpy() if hover_cols else None),
|
| 126 |
+
hovertemplate=hovertemplate,
|
| 127 |
+
))
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 130 |
+
x=df[x_col], y=df[y_col], z=df[z_col], mode='markers', name=color_col if color_col else "data",
|
| 131 |
+
marker=dict(size=point_size, color=(color_series if color_series is not None else None), coloraxis='coloraxis'),
|
| 132 |
+
opacity=opacity, customdata=customdata, hovertemplate=hovertemplate,
|
| 133 |
+
)])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
fig.update_layout(coloraxis=dict(colorbar=dict(title=color_col)))
|
| 135 |
+
fig.update_layout(template="plotly_white", margin=dict(l=10,r=10,t=30,b=10),
|
| 136 |
+
legend=dict(itemsizing='trace', title=color_col if color_col else None),
|
| 137 |
+
scene=dict(xaxis_title=x_col, yaxis_title=y_col, zaxis_title=z_col))
|
|
|
|
|
|
|
|
|
|
| 138 |
return fig
|
| 139 |
|
| 140 |
+
# -----------------------------
|
| 141 |
+
# App state & defaults
|
| 142 |
+
# -----------------------------
|
| 143 |
+
def detect_columns(df):
|
| 144 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 145 |
+
all_cols = df.columns.tolist()
|
| 146 |
+
# try common embedding columns
|
| 147 |
+
x_default = next((c for c in ["x","X","dim1","pc1","tsne1","umap1"] if c in df.columns), (numeric_cols[0] if numeric_cols else None))
|
| 148 |
+
y_default = next((c for c in ["y","Y","dim2","pc2","tsne2","umap2"] if c in df.columns and c != x_default), (numeric_cols[1] if len(numeric_cols)>1 else None))
|
| 149 |
+
z_default = next((c for c in ["z","Z","dim3","pc3","tsne3","umap3"] if c in df.columns and c not in {x_default,y_default}), (numeric_cols[2] if len(numeric_cols)>2 else None))
|
| 150 |
+
# color default
|
| 151 |
+
cat_candidates = [c for c in df.columns if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category')) and c not in {x_default, y_default, z_default}]
|
| 152 |
+
color_default = next((c for c in ["cluster","label","group","diagnosis","category","pred_model"] if c in df.columns), (cat_candidates[0] if cat_candidates else (numeric_cols[0] if numeric_cols else None)))
|
| 153 |
+
return numeric_cols, all_cols, x_default, y_default, z_default, color_default
|
| 154 |
+
|
| 155 |
+
# -----------------------------
|
| 156 |
+
# Training
|
| 157 |
+
# -----------------------------
|
| 158 |
def _plot_confusion_matrix(y_true, y_pred, title="Confusion matrix"):
|
| 159 |
labels = sorted(pd.Series(y_true).unique())
|
| 160 |
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 161 |
+
fig = go.Figure(data=go.Heatmap(z=cm, x=labels, y=labels, text=cm, texttemplate="%{text}",
|
| 162 |
+
hovertemplate="Pred=%{x}<br>True=%{y}<br>Count=%{z}<extra></extra>"))
|
| 163 |
+
fig.update_layout(title=title, xaxis_title="Predicted", yaxis_title="True", template="plotly_white", margin=dict(l=10,r=10,t=40,b=10))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
return fig
|
| 165 |
|
| 166 |
def _fmt_metrics(metrics: dict) -> str:
|
| 167 |
lines = []
|
| 168 |
for k,v in metrics.items():
|
| 169 |
+
lines.append(f"{k}: {v:.4f}" if isinstance(v,float) else f"{k}: {v}")
|
|
|
|
|
|
|
|
|
|
| 170 |
return "\n".join(lines)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
| 172 |
def train_live(df, task, feature_cols, label_col, k_clusters, seed):
|
| 173 |
+
if df is None or len(df)==0:
|
| 174 |
raise gr.Error("Geen data om te trainen.")
|
| 175 |
if not feature_cols:
|
| 176 |
raise gr.Error("Selecteer minimaal één featurekolom.")
|
|
|
|
| 177 |
X = df[feature_cols].select_dtypes(include=[np.number])
|
| 178 |
+
if X.shape[1]==0:
|
| 179 |
raise gr.Error("De gekozen features bevatten geen numerieke kolommen.")
|
| 180 |
|
| 181 |
log_lines = []
|
|
|
|
| 183 |
eval_fig = None
|
| 184 |
|
| 185 |
if task == "Clustering (KMeans)":
|
| 186 |
+
pipe = Pipeline([("scaler", StandardScaler()), ("kmeans", KMeans(n_clusters=k_clusters, random_state=seed, n_init="auto"))])
|
|
|
|
|
|
|
|
|
|
| 187 |
pipe.fit(X)
|
| 188 |
labels = pipe["kmeans"].labels_
|
| 189 |
+
df["cluster_model"] = pd.Categorical([chr(ord('A') + (i % 26)) for i in labels])
|
|
|
|
| 190 |
color_col_suggestion = "cluster_model"
|
| 191 |
dump(pipe, MODEL_PATH)
|
| 192 |
+
log_lines += [f"✅ KMeans getraind met k={k_clusters} op {X.shape[0]} rijen.", f"Model opgeslagen: {MODEL_PATH.resolve()}"]
|
|
|
|
| 193 |
|
| 194 |
elif task == "Classificatie (Logistic Regression)":
|
| 195 |
if not label_col:
|
| 196 |
raise gr.Error("Kies een labelkolom voor classificatie.")
|
| 197 |
y = df[label_col].astype(str)
|
| 198 |
+
Xtr, Xva, ytr, yva = train_test_split(X, y, test_size=0.2, random_state=seed, stratify=y if y.nunique()>1 else None)
|
| 199 |
+
pipe = Pipeline([("scaler", StandardScaler()), ("logreg", LogisticRegression(max_iter=1000, random_state=seed, class_weight="balanced"))])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
pipe.fit(Xtr, ytr)
|
|
|
|
|
|
|
| 201 |
yhat = pipe.predict(Xva)
|
| 202 |
+
metrics = {"accuracy": accuracy_score(yva, yhat), "f1_weighted": f1_score(yva, yhat, average="weighted")}
|
| 203 |
+
if y.nunique()==2:
|
|
|
|
|
|
|
|
|
|
| 204 |
try:
|
| 205 |
+
proba = pipe.predict_proba(Xva)[:,1]
|
| 206 |
uniq = list(pd.Series(y).unique())
|
| 207 |
mapping = {uniq[0]:0, uniq[1]:1}
|
| 208 |
metrics["roc_auc"] = roc_auc_score(yva.map(mapping), proba)
|
| 209 |
except Exception:
|
| 210 |
pass
|
|
|
|
| 211 |
eval_fig = _plot_confusion_matrix(yva, yhat, title="Confusion matrix (validatie)")
|
|
|
|
|
|
|
| 212 |
pipe.fit(X, y)
|
| 213 |
preds = pipe.predict(X)
|
| 214 |
df["pred_model"] = pd.Categorical(preds)
|
| 215 |
+
if y.nunique()==2:
|
| 216 |
+
try: df["pred_proba"] = pipe.predict_proba(X)[:,1]
|
| 217 |
+
except Exception: pass
|
|
|
|
|
|
|
|
|
|
| 218 |
color_col_suggestion = "pred_model"
|
| 219 |
dump(pipe, MODEL_PATH)
|
| 220 |
+
log_lines += ["✅ LogisticRegression getraind (split 80/20).", f"Model opgeslagen: {MODEL_PATH.resolve()}", "Metrics (validatie):\n"+_fmt_metrics(metrics)]
|
|
|
|
|
|
|
|
|
|
| 221 |
else:
|
| 222 |
raise gr.Error("Onbekende taak.")
|
| 223 |
|
|
|
|
| 235 |
# -----------------------------
|
| 236 |
# Gradio callbacks
|
| 237 |
# -----------------------------
|
| 238 |
+
def parse_file(file_obj):
|
| 239 |
+
if file_obj is None:
|
| 240 |
+
return BASE_DF.copy(), ("Bundled dataset geladen." if DATA_PATH.exists() else "Demo dataset geladen.")
|
| 241 |
+
name = getattr(file_obj, "name", str(file_obj))
|
| 242 |
+
path = name
|
| 243 |
+
if name.lower().endswith(".csv"):
|
| 244 |
+
df = pd.read_csv(path)
|
| 245 |
+
elif name.lower().endswith(".tsv"):
|
| 246 |
+
df = pd.read_csv(path, sep="\t")
|
| 247 |
+
elif name.lower().endswith(".parquet"):
|
| 248 |
+
df = pd.read_parquet(path)
|
| 249 |
+
else:
|
| 250 |
+
df = pd.read_csv(path)
|
| 251 |
+
return df, f"Bestand geladen: {Path(name).name} — {df.shape[0]} rijen, {df.shape[1]} kolommen."
|
| 252 |
+
|
| 253 |
def init_from_file(file_obj):
|
| 254 |
+
if file_obj is None:
|
| 255 |
+
df = BASE_DF.copy()
|
| 256 |
+
status = "Bundled dataset geladen." if DATA_PATH.exists() else "Demo dataset geladen."
|
| 257 |
+
else:
|
| 258 |
+
df, status = parse_file(file_obj)
|
| 259 |
+
|
| 260 |
numeric_cols, all_cols, x_d, y_d, z_d, color_d = detect_columns(df)
|
| 261 |
+
hover_default = [c for c in ["patient_id","age","sex","diagnosis","cluster"] if c in df.columns]
|
| 262 |
+
feat_default = default_features(df)
|
| 263 |
+
label_default = pick_default_label(df)
|
| 264 |
return (
|
| 265 |
gr.update(choices=all_cols, value=x_d),
|
| 266 |
gr.update(choices=all_cols, value=y_d),
|
| 267 |
gr.update(choices=all_cols, value=z_d),
|
| 268 |
+
gr.update(choices=all_cols, value="pred_model" if "pred_model" in df.columns else color_d),
|
| 269 |
gr.update(choices=all_cols, value=hover_default),
|
| 270 |
status,
|
| 271 |
df,
|
| 272 |
+
gr.update(choices=df.select_dtypes(include=[np.number]).columns.tolist(), value=feat_default),
|
| 273 |
+
gr.update(choices=all_cols, value=label_default),
|
| 274 |
)
|
| 275 |
|
| 276 |
def update_plot(df, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
|
| 277 |
+
if df is None or (isinstance(df,(list,tuple)) and len(df)==0):
|
| 278 |
+
df = BASE_DF.copy()
|
| 279 |
mode_3d = (mode_dim == "3D")
|
| 280 |
+
return make_figure(df, x_col, y_col, z_col, color_col, hover_cols, mode_3d, size, opacity)
|
|
|
|
| 281 |
|
| 282 |
def on_train_click(df, task, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
|
| 283 |
df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), task, feature_cols, label_col, k_clusters, seed)
|
|
|
|
| 286 |
return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
|
| 287 |
|
| 288 |
def startup_auto_train(df, task_default, feature_cols, label_col, k_clusters, seed, x_col, y_col, z_col, color_col, hover_cols, mode_dim, size, opacity):
|
| 289 |
+
# Forceer supervised classificatie bij start
|
| 290 |
try:
|
| 291 |
+
chosen_label = label_col or pick_default_label(df)
|
| 292 |
+
chosen_feats = feature_cols or default_features(df)
|
| 293 |
+
df2, log_text, color_suggestion, eval_fig = train_live(df.copy(), "Classificatie (Logistic Regression)", chosen_feats, chosen_label, k_clusters, seed)
|
| 294 |
new_color = color_suggestion if color_suggestion else color_col
|
| 295 |
fig = update_plot(df2, x_col, y_col, z_col, new_color, hover_cols, mode_dim, size, opacity)
|
| 296 |
return df2, log_text, gr.update(value=new_color, choices=df2.columns.tolist()), fig, eval_fig
|
|
|
|
| 306 |
with gr.Row():
|
| 307 |
with gr.Column(scale=1):
|
| 308 |
data_file = gr.File(label="Upload CSV/TSV/Parquet", file_count="single", type="filepath")
|
| 309 |
+
status_box = gr.Markdown("Bundled dataset wordt standaard geladen en getraind bij start.")
|
| 310 |
|
| 311 |
with gr.Accordion("Assen & kleur", open=True):
|
| 312 |
+
base_numeric = BASE_DF.select_dtypes(include=[np.number]).columns.tolist()
|
| 313 |
+
x_default = base_numeric[0] if len(base_numeric)>0 else None
|
| 314 |
+
y_default = base_numeric[1] if len(base_numeric)>1 else None
|
| 315 |
+
z_default = base_numeric[2] if len(base_numeric)>2 else None
|
| 316 |
+
x_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=x_default, label="X kolom")
|
| 317 |
+
y_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=y_default, label="Y kolom")
|
| 318 |
+
z_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=z_default, label="Z kolom (voor 3D)")
|
| 319 |
+
color_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value="pred_model" if "pred_model" in BASE_DF.columns else None, label="Kleur op kolom")
|
| 320 |
+
hover_ms = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=[c for c in ["patient_id","age","sex","diagnosis","cluster"] if c in BASE_DF.columns], multiselect=True, label="Hover info kolommen")
|
|
|
|
| 321 |
|
| 322 |
with gr.Accordion("Weergave", open=True):
|
| 323 |
+
mode_dim = gr.Radio(["2D","3D"], value="2D", label="Dimensie")
|
| 324 |
size_slider = gr.Slider(3, 18, value=8, step=1, label="Puntgrootte")
|
| 325 |
opacity_slider = gr.Slider(0.1, 1.0, value=0.8, step=0.05, label="Transparantie (opacity)")
|
| 326 |
|
| 327 |
with gr.Accordion("Training (live)", open=True):
|
| 328 |
+
task_radio = gr.Radio(["Clustering (KMeans)","Classificatie (Logistic Regression)"], value="Classificatie (Logistic Regression)", label="Taak")
|
| 329 |
+
feat_ms = gr.Dropdown(choices=BASE_DF.select_dtypes(include=[np.number]).columns.tolist(), value=default_features(BASE_DF), multiselect=True, label="Feature kolommen (numeriek)")
|
| 330 |
+
label_dd = gr.Dropdown(choices=BASE_DF.columns.tolist(), value=pick_default_label(BASE_DF), label="Label kolom (alleen voor classificatie)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
k_slider = gr.Slider(2, 12, value=3, step=1, label="K (clusters) — KMeans")
|
| 332 |
seed_slider = gr.Slider(0, 10_000, value=7, step=1, label="Random seed")
|
| 333 |
train_btn = gr.Button("🚀 Train (live)")
|
| 334 |
train_log = gr.Textbox(label="Train log", lines=6, interactive=False)
|
| 335 |
|
| 336 |
+
hidden_df = gr.State(BASE_DF.copy())
|
| 337 |
|
| 338 |
with gr.Column(scale=2):
|
|
|
|
| 339 |
plot = gr.Plot(label="Scatterplot")
|
| 340 |
with gr.Accordion("Evaluatie (validatie)", open=False):
|
| 341 |
cm_plot = gr.Plot(label="Confusion Matrix (validatie)")
|
| 342 |
|
| 343 |
# ===== Events =====
|
| 344 |
+
data_file.change(fn=init_from_file, inputs=[data_file],
|
| 345 |
+
outputs=[x_dd, y_dd, z_dd, color_dd, hover_ms, status_box, hidden_df, feat_ms, label_dd],
|
| 346 |
+
show_progress=False)
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
for comp in [x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider]:
|
| 349 |
+
comp.change(fn=update_plot,
|
| 350 |
+
inputs=[hidden_df, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
|
| 351 |
+
outputs=plot, show_progress=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
+
train_btn.click(fn=on_train_click,
|
| 354 |
+
inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
|
| 355 |
+
outputs=[hidden_df, train_log, color_dd, plot, cm_plot], show_progress=True)
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
+
# Initial plot
|
| 358 |
+
demo.load(fn=update_plot, inputs=[hidden_df, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider], outputs=plot, show_progress=False)
|
| 359 |
+
|
| 360 |
+
# Auto-train bij start (forceer supervised classificatie met default label)
|
| 361 |
+
demo.load(fn=startup_auto_train,
|
| 362 |
+
inputs=[hidden_df, task_radio, feat_ms, label_dd, k_slider, seed_slider, x_dd, y_dd, z_dd, color_dd, hover_ms, mode_dim, size_slider, opacity_slider],
|
| 363 |
+
outputs=[hidden_df, train_log, color_dd, plot, cm_plot], show_progress=True)
|
| 364 |
|
| 365 |
if __name__ == "__main__":
|
| 366 |
demo.launch()
|
ggz_depressie_synth_1000_modeling.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|