halil21 commited on
Commit
f2cc132
·
verified ·
1 Parent(s): aefa7b6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +517 -0
app.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ iPVC Treatment Non-response Prediction — Clinical Calculator
3
+ =============================================================
4
+ Gradio web app supporting 4 models:
5
+ Logistic Regression, XGBoost, TabTransformer, KAN
6
+
7
+ Model weights and scaler.pkl are expected in the model_weights/ subdirectory.
8
+ """
9
+
10
+ import os
11
+ import numpy as np
12
+ import joblib
13
+ import torch
14
+ import torch.nn as nn
15
+ import gradio as gr
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Paths
19
+ # ---------------------------------------------------------------------------
20
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ WEIGHTS_DIR = os.path.join(APP_DIR, "model_weights")
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Feature definitions (must match notebook order exactly)
25
+ # ---------------------------------------------------------------------------
26
+ numeric_features = [
27
+ "PVCyüzdesi",
28
+ "PVCQRS",
29
+ "LVEF",
30
+ "Yaş",
31
+ "PVCPrematurındex",
32
+ "QRSratio",
33
+ "OrtalamaHR",
34
+ "SemptomSüresi",
35
+ "QTCsinus",
36
+ "PVCCouplingIntervaldispersiyon",
37
+ "CIvariability",
38
+ "PVCPeakQRSduration",
39
+ "PVCCouplingInterval",
40
+ "PVCCompansatuarInterval",
41
+ ]
42
+
43
+ categorical_features = [
44
+ "MultifokalPVC",
45
+ "Non_susteinedVT",
46
+ "Cins",
47
+ "HT",
48
+ "DM",
49
+ "Fullcompansasion",
50
+ ]
51
+
52
+ all_features = numeric_features + categorical_features # total = 20
53
+
54
+ # Slider label -> internal feature name (same order as numeric_features)
55
+ SLIDER_LABELS = [
56
+ "PVC Burden (%)",
57
+ "PVC QRS Duration (ms)",
58
+ "LVEF (%)",
59
+ "Age (years)",
60
+ "PVC Prematurity Index",
61
+ "QRS Ratio",
62
+ "Mean Heart Rate (bpm)",
63
+ "Symptom Duration (months)",
64
+ "QTc Sinus (ms)",
65
+ "PVC CI Dispersion (ms)",
66
+ "CI Variability",
67
+ "PVC Peak QRS Duration (ms)",
68
+ "PVC Coupling Interval (ms)",
69
+ "PVC Compensatory Interval (ms)",
70
+ ]
71
+
72
+ RADIO_LABELS = [
73
+ "Multifocal PVC",
74
+ "Non-sustained VT",
75
+ "Gender",
76
+ "Hypertension",
77
+ "Diabetes Mellitus",
78
+ "Full Compensation",
79
+ ]
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # PyTorch model architectures (identical to notebook)
83
+ # ---------------------------------------------------------------------------
84
+
85
+ # ---- TabTransformer ----
86
+ class TabTransformer(nn.Module):
87
+ def __init__(self, input_dim=20, num_classes=2, d_model=64, nhead=4,
88
+ num_layers=3, dropout=0.1):
89
+ super().__init__()
90
+ self.embedding = nn.Linear(input_dim, d_model)
91
+ encoder_layer = nn.TransformerEncoderLayer(
92
+ d_model=d_model,
93
+ nhead=nhead,
94
+ dim_feedforward=d_model * 4,
95
+ dropout=dropout,
96
+ activation="gelu",
97
+ batch_first=True,
98
+ )
99
+ self.transformer_encoder = nn.TransformerEncoder(
100
+ encoder_layer, num_layers=num_layers
101
+ )
102
+ self.fc = nn.Sequential(
103
+ nn.Linear(d_model, d_model // 2),
104
+ nn.ReLU(),
105
+ nn.Dropout(dropout),
106
+ nn.Linear(d_model // 2, num_classes),
107
+ )
108
+
109
+ def forward(self, x):
110
+ x = self.embedding(x)
111
+ x = x.unsqueeze(1)
112
+ x = self.transformer_encoder(x)
113
+ x = x.squeeze(1)
114
+ return self.fc(x)
115
+
116
+
117
+ # ---- KAN (Kolmogorov-Arnold Network) ----
118
+ class KolmogorovArnoldLayer(nn.Module):
119
+ def __init__(self, input_dim, inner_dim, output_dim):
120
+ super().__init__()
121
+ self.inner_functions = nn.ModuleList([
122
+ nn.Sequential(
123
+ nn.Linear(1, inner_dim), nn.ReLU(), nn.Linear(inner_dim, 1)
124
+ )
125
+ for _ in range(input_dim)
126
+ ])
127
+ self.outer_function = nn.Sequential(
128
+ nn.Linear(input_dim, inner_dim),
129
+ nn.ReLU(),
130
+ nn.Linear(inner_dim, output_dim),
131
+ )
132
+
133
+ def forward(self, x):
134
+ inner_outputs = [f(x[:, i:i + 1]) for i, f in enumerate(self.inner_functions)]
135
+ return self.outer_function(torch.cat(inner_outputs, dim=1))
136
+
137
+
138
+ class KolmogorovArnoldNetwork(nn.Module):
139
+ def __init__(self, input_dim=20, hidden_dims=None, inner_dim=37, dropout=0.467):
140
+ super().__init__()
141
+ if hidden_dims is None:
142
+ hidden_dims = [94, 55]
143
+ layers = []
144
+ prev_dim = input_dim
145
+ for hd in hidden_dims:
146
+ layers.append(KolmogorovArnoldLayer(prev_dim, inner_dim, hd))
147
+ prev_dim = hd
148
+ self.kan_layers = nn.ModuleList(layers)
149
+ self.dropout = nn.Dropout(dropout)
150
+ self.output_layer = nn.Linear(hidden_dims[-1], 2)
151
+
152
+ def forward(self, x):
153
+ for layer in self.kan_layers:
154
+ x = self.dropout(layer(x))
155
+ return self.output_layer(x)
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # Load artefacts
160
+ # ---------------------------------------------------------------------------
161
+
162
+ def _load_scaler():
163
+ path = os.path.join(WEIGHTS_DIR, "scaler.pkl")
164
+ if not os.path.exists(path):
165
+ raise FileNotFoundError(
166
+ f"scaler.pkl not found in {WEIGHTS_DIR}. "
167
+ "Copy scaler.pkl from the training outputs into model_weights/."
168
+ )
169
+ return joblib.load(path)
170
+
171
+
172
+ def _load_sklearn_model(filename):
173
+ path = os.path.join(WEIGHTS_DIR, filename)
174
+ if not os.path.exists(path):
175
+ raise FileNotFoundError(f"{filename} not found in {WEIGHTS_DIR}.")
176
+ return joblib.load(path)
177
+
178
+
179
+ def _load_tabtransformer():
180
+ path = os.path.join(WEIGHTS_DIR, "tabtransformer_model.pth")
181
+ if not os.path.exists(path):
182
+ raise FileNotFoundError(f"tabtransformer_model.pth not found in {WEIGHTS_DIR}.")
183
+ model = TabTransformer(
184
+ input_dim=20, num_classes=2, d_model=64, nhead=4,
185
+ num_layers=3, dropout=0.1
186
+ )
187
+ state = torch.load(path, map_location="cpu", weights_only=True)
188
+ model.load_state_dict(state)
189
+ model.eval()
190
+ return model
191
+
192
+
193
+ def _load_kan():
194
+ path = os.path.join(WEIGHTS_DIR, "kan_model.pth")
195
+ if not os.path.exists(path):
196
+ raise FileNotFoundError(f"kan_model.pth not found in {WEIGHTS_DIR}.")
197
+ checkpoint = torch.load(path, map_location="cpu", weights_only=True)
198
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
199
+ model = KolmogorovArnoldNetwork(
200
+ input_dim=20, hidden_dims=[94, 55], inner_dim=37, dropout=0.467
201
+ )
202
+ model.load_state_dict(state_dict)
203
+ model.eval()
204
+ return model
205
+
206
+
207
+ # Lazy-loaded cache so the models are only read once
208
+ _cache = {}
209
+
210
+
211
+ def _get(key, loader, *args):
212
+ if key not in _cache:
213
+ _cache[key] = loader(*args)
214
+ return _cache[key]
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # Categorical encoding helper
219
+ # ---------------------------------------------------------------------------
220
+
221
+ def _encode_categorical(value: str) -> int:
222
+ """Encode radio-button value to integer.
223
+
224
+ Mapping (matches LabelEncoder fit on training data):
225
+ 'No' -> 0, 'Yes' -> 1
226
+ 'Female' -> 0, 'Male' -> 1
227
+ """
228
+ mapping = {"No": 0, "Yes": 1, "Female": 0, "Male": 1}
229
+ return mapping[value]
230
+
231
+
232
+ # ---------------------------------------------------------------------------
233
+ # Prediction function
234
+ # ---------------------------------------------------------------------------
235
+
236
+ def predict(
237
+ model_choice,
238
+ pvc_burden, pvc_qrs, lvef, age, pvc_prematur_index,
239
+ qrs_ratio, mean_hr, symptom_duration, qtc_sinus,
240
+ pvc_ci_dispersion, ci_variability, pvc_peak_qrs,
241
+ pvc_coupling_interval, pvc_compensatory_interval,
242
+ multifocal_pvc, nonsustained_vt, gender,
243
+ hypertension, diabetes, full_compensation,
244
+ ):
245
+ try:
246
+ scaler = _get("scaler", _load_scaler)
247
+
248
+ # -- Build numeric array (14 features) in the correct order --
249
+ numeric_values = np.array([[
250
+ pvc_burden,
251
+ pvc_qrs,
252
+ lvef,
253
+ age,
254
+ pvc_prematur_index,
255
+ qrs_ratio,
256
+ mean_hr,
257
+ symptom_duration,
258
+ qtc_sinus,
259
+ pvc_ci_dispersion,
260
+ ci_variability,
261
+ pvc_peak_qrs,
262
+ pvc_coupling_interval,
263
+ pvc_compensatory_interval,
264
+ ]], dtype=np.float64)
265
+
266
+ # Scale numeric features using the training scaler
267
+ numeric_scaled = scaler.transform(numeric_values)
268
+
269
+ # -- Build categorical array (6 features) --
270
+ cat_values = np.array([[
271
+ _encode_categorical(multifocal_pvc),
272
+ _encode_categorical(nonsustained_vt),
273
+ _encode_categorical(gender),
274
+ _encode_categorical(hypertension),
275
+ _encode_categorical(diabetes),
276
+ _encode_categorical(full_compensation),
277
+ ]], dtype=np.float64)
278
+
279
+ # Concatenate: numeric (scaled) + categorical -> (1, 20)
280
+ x = np.hstack([numeric_scaled, cat_values])
281
+
282
+ # -- Predict probability --
283
+ if model_choice == "Logistic Regression":
284
+ model = _get("lr", _load_sklearn_model, "logistic_regression_model.pkl")
285
+ prob = float(model.predict_proba(x)[0, 1])
286
+
287
+ elif model_choice == "XGBoost":
288
+ model = _get("xgb", _load_sklearn_model, "xgboost_model.pkl")
289
+ prob = float(model.predict_proba(x)[0, 1])
290
+
291
+ elif model_choice == "TabTransformer":
292
+ model = _get("tt", _load_tabtransformer)
293
+ with torch.no_grad():
294
+ tensor_x = torch.FloatTensor(x)
295
+ logits = model(tensor_x)
296
+ prob = float(torch.softmax(logits, dim=1)[0, 1].item())
297
+
298
+ elif model_choice == "KAN":
299
+ model = _get("kan", _load_kan)
300
+ with torch.no_grad():
301
+ tensor_x = torch.FloatTensor(x)
302
+ logits = model(tensor_x)
303
+ prob = float(torch.softmax(logits, dim=1)[0, 1].item())
304
+
305
+ else:
306
+ return "Error: Unknown model selected.", "", ""
307
+
308
+ # -- Risk stratification --
309
+ pct = prob * 100.0
310
+ if pct < 20.0:
311
+ risk = "LOW RISK"
312
+ elif pct <= 40.0:
313
+ risk = "MODERATE RISK"
314
+ else:
315
+ risk = "HIGH RISK"
316
+
317
+ # -- Interpretation --
318
+ interpretation = _build_interpretation(model_choice, pct, risk)
319
+
320
+ probability_text = f"{pct:.1f}%"
321
+ risk_text = f"{risk} (< 20% Low | 20-40% Moderate | > 40% High)"
322
+
323
+ return probability_text, risk_text, interpretation
324
+
325
+ except FileNotFoundError as e:
326
+ return str(e), "", ""
327
+ except Exception as e:
328
+ return f"Prediction error: {e}", "", ""
329
+
330
+
331
+ def _build_interpretation(model_name: str, pct: float, risk: str) -> str:
332
+ """Return a short clinical interpretation paragraph."""
333
+ lines = [
334
+ f"Using the {model_name} model, the predicted probability of "
335
+ f"treatment non-response (iPVC persistence) is {pct:.1f}%.",
336
+ ]
337
+ if risk == "LOW RISK":
338
+ lines.append(
339
+ "This patient falls in the LOW risk category (< 20%). "
340
+ "The model suggests a favorable response to anti-arrhythmic "
341
+ "or ablation therapy is likely. Standard follow-up is recommended."
342
+ )
343
+ elif risk == "MODERATE RISK":
344
+ lines.append(
345
+ "This patient falls in the MODERATE risk category (20-40%). "
346
+ "There is an intermediate likelihood of treatment non-response. "
347
+ "Close monitoring and potential therapy optimization should be considered."
348
+ )
349
+ else:
350
+ lines.append(
351
+ "This patient falls in the HIGH risk category (> 40%). "
352
+ "The model indicates a substantial probability of treatment "
353
+ "non-response. Intensified management strategies, combination "
354
+ "therapy, or early referral for catheter ablation may be warranted."
355
+ )
356
+ lines.append(
357
+ "Note: This calculator is intended for research and clinical "
358
+ "decision support only. It should not replace clinical judgment."
359
+ )
360
+ return " ".join(lines)
361
+
362
+
363
+ # ---------------------------------------------------------------------------
364
+ # Gradio interface
365
+ # ---------------------------------------------------------------------------
366
+
367
+ def build_app():
368
+ with gr.Blocks(
369
+ title="iPVC Non-response Predictor",
370
+ theme=gr.themes.Soft(),
371
+ ) as demo:
372
+ gr.Markdown(
373
+ "# iPVC Treatment Non-response Prediction Calculator\n"
374
+ "Enter patient parameters below and select a prediction model. "
375
+ "The tool estimates the probability that the patient will **not respond** "
376
+ "to iPVC treatment (anti-arrhythmic / ablation therapy)."
377
+ )
378
+
379
+ with gr.Row():
380
+ model_dropdown = gr.Dropdown(
381
+ choices=[
382
+ "Logistic Regression",
383
+ "XGBoost",
384
+ "TabTransformer",
385
+ "KAN",
386
+ ],
387
+ value="Logistic Regression",
388
+ label="Prediction Model",
389
+ )
390
+
391
+ gr.Markdown("## Numeric Parameters")
392
+
393
+ with gr.Row():
394
+ pvc_burden = gr.Slider(
395
+ minimum=0, maximum=100, step=0.1, value=15.0,
396
+ label="PVC Burden (%)",
397
+ )
398
+ pvc_qrs = gr.Slider(
399
+ minimum=80, maximum=300, step=1, value=140,
400
+ label="PVC QRS Duration (ms)",
401
+ )
402
+ lvef = gr.Slider(
403
+ minimum=10, maximum=80, step=1, value=55,
404
+ label="LVEF (%)",
405
+ )
406
+ with gr.Row():
407
+ age = gr.Slider(
408
+ minimum=18, maximum=100, step=1, value=50,
409
+ label="Age (years)",
410
+ )
411
+ pvc_prematur_index = gr.Slider(
412
+ minimum=0.0, maximum=2.0, step=0.01, value=0.75,
413
+ label="PVC Prematurity Index",
414
+ )
415
+ qrs_ratio = gr.Slider(
416
+ minimum=0.5, maximum=3.0, step=0.01, value=1.2,
417
+ label="QRS Ratio",
418
+ )
419
+ with gr.Row():
420
+ mean_hr = gr.Slider(
421
+ minimum=40, maximum=200, step=1, value=75,
422
+ label="Mean Heart Rate (bpm)",
423
+ )
424
+ symptom_duration = gr.Slider(
425
+ minimum=0, maximum=360, step=1, value=12,
426
+ label="Symptom Duration (months)",
427
+ )
428
+ qtc_sinus = gr.Slider(
429
+ minimum=300, maximum=600, step=1, value=420,
430
+ label="QTc Sinus (ms)",
431
+ )
432
+ with gr.Row():
433
+ pvc_ci_dispersion = gr.Slider(
434
+ minimum=0, maximum=300, step=1, value=50,
435
+ label="PVC CI Dispersion (ms)",
436
+ )
437
+ ci_variability = gr.Slider(
438
+ minimum=0.0, maximum=1.0, step=0.01, value=0.10,
439
+ label="CI Variability",
440
+ )
441
+ pvc_peak_qrs = gr.Slider(
442
+ minimum=80, maximum=300, step=1, value=140,
443
+ label="PVC Peak QRS Duration (ms)",
444
+ )
445
+ with gr.Row():
446
+ pvc_coupling_interval = gr.Slider(
447
+ minimum=200, maximum=800, step=1, value=450,
448
+ label="PVC Coupling Interval (ms)",
449
+ )
450
+ pvc_compensatory_interval = gr.Slider(
451
+ minimum=400, maximum=1500, step=1, value=900,
452
+ label="PVC Compensatory Interval (ms)",
453
+ )
454
+
455
+ gr.Markdown("## Categorical Parameters")
456
+
457
+ with gr.Row():
458
+ multifocal_pvc = gr.Radio(
459
+ choices=["No", "Yes"], value="No", label="Multifocal PVC"
460
+ )
461
+ nonsustained_vt = gr.Radio(
462
+ choices=["No", "Yes"], value="No", label="Non-sustained VT"
463
+ )
464
+ gender = gr.Radio(
465
+ choices=["Female", "Male"], value="Male", label="Gender"
466
+ )
467
+ with gr.Row():
468
+ hypertension = gr.Radio(
469
+ choices=["No", "Yes"], value="No", label="Hypertension"
470
+ )
471
+ diabetes = gr.Radio(
472
+ choices=["No", "Yes"], value="No", label="Diabetes Mellitus"
473
+ )
474
+ full_compensation = gr.Radio(
475
+ choices=["No", "Yes"], value="No", label="Full Compensation"
476
+ )
477
+
478
+ gr.Markdown("## Prediction Results")
479
+
480
+ with gr.Row():
481
+ out_prob = gr.Textbox(label="Predicted Probability", interactive=False)
482
+ out_risk = gr.Textbox(label="Risk Category", interactive=False)
483
+ out_interp = gr.Textbox(
484
+ label="Clinical Interpretation", interactive=False, lines=5
485
+ )
486
+
487
+ predict_btn = gr.Button("Predict", variant="primary")
488
+
489
+ predict_btn.click(
490
+ fn=predict,
491
+ inputs=[
492
+ model_dropdown,
493
+ pvc_burden, pvc_qrs, lvef, age, pvc_prematur_index,
494
+ qrs_ratio, mean_hr, symptom_duration, qtc_sinus,
495
+ pvc_ci_dispersion, ci_variability, pvc_peak_qrs,
496
+ pvc_coupling_interval, pvc_compensatory_interval,
497
+ multifocal_pvc, nonsustained_vt, gender,
498
+ hypertension, diabetes, full_compensation,
499
+ ],
500
+ outputs=[out_prob, out_risk, out_interp],
501
+ )
502
+
503
+ gr.Markdown(
504
+ "---\n"
505
+ "*This tool is for research and clinical decision support purposes only. "
506
+ "Predictions should be interpreted in the context of the full clinical picture.*"
507
+ )
508
+
509
+ return demo
510
+
511
+
512
+ # ---------------------------------------------------------------------------
513
+ # Entry point
514
+ # ---------------------------------------------------------------------------
515
+ if __name__ == "__main__":
516
+ app = build_app()
517
+ app.launch(share=False)